Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/image_ops.cc
     17 #define EIGEN_USE_THREADS
     18 
     19 #include <math.h>
     20 #include <algorithm>
     21 #include <array>
     22 
     23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/register_types.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/kernels/image_resizer_state.h"
     30 #include "tensorflow/core/lib/core/status.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 
     33 namespace tensorflow {
     34 namespace {
     35 
     36 static const int64 kTableSize = (1 << 10);
     37 
     38 const float* InitCoeffsTable(const double a) {
     39   // Allocate and initialize coefficients table using Bicubic
     40   // convolution algorithm.
     41   // https://en.wikipedia.org/wiki/Bicubic_interpolation
     42   float* coeffs_table = new float[(kTableSize + 1) * 2];
     43   for (int i = 0; i <= kTableSize; ++i) {
     44     float x = i * 1.0 / kTableSize;
     45     coeffs_table[i * 2] = ((a + 2) * x - (a + 3)) * x * x + 1;
     46     x += 1.0;
     47     coeffs_table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
     48   }
     49 
     50   return coeffs_table;
     51 }
     52 
     53 const float* GetCoeffsTable(const bool use_keys_cubic) {
     54   // Static so that we initialize it on first use
     55   if (use_keys_cubic) {
     56     // http://ieeexplore.ieee.org/document/1163711/
     57     // R. G. Keys. Cubic convolution interpolation for digital image
     58     // processing. IEEE Transactions on Acoustics, Speech, and Signal
     59     // Processing, 29(6):11531160, 1981.
     60     static const float* coeffs_table = InitCoeffsTable(-0.5f);
     61     return coeffs_table;
     62   } else {
     63     static const float* coeffs_table = InitCoeffsTable(-0.75f);
     64     return coeffs_table;
     65   }
     66 }
     67 
     68 inline int64 Bound(int64 val, int64 limit) {
     69   return std::min(limit - 1ll, std::max(int64{0}, val));
     70 }
     71 
     72 struct WeightsAndIndices {
     73   float weight_0;
     74   float weight_1;
     75   float weight_2;
     76   float weight_3;
     77   int64 index_0;
     78   int64 index_1;
     79   int64 index_2;
     80   int64 index_3;
     81 
     82   int advance;  // advance value.
     83 };
     84 
     85 template <typename Scaler, bool use_keys_cubic>
     86 inline void GetWeightsAndIndices(const float scale, const int64 out_loc,
     87                                  const int64 limit, WeightsAndIndices* out) {
     88   const Scaler scaler;
     89   const float in_loc_f = scaler(out_loc, scale);
     90   const int64 in_loc = std::floor(in_loc_f);
     91   const float delta = in_loc_f - in_loc;
     92   const int64 offset = lrintf(delta * kTableSize);
     93   const float* coeffs_table = GetCoeffsTable(use_keys_cubic);
     94   if (use_keys_cubic) {
     95     // The legacy code placed more weight on the edge pixels, since bounding
     96     // the set of inputs to sample could cause an edge pixel to be repeated.
     97     // Here we change the behavior at borders to match that used by the
     98     // scale_and_translate_op, where sampling locations outside the image have
     99     // their weight set to 0, and the weights are renormalized so that their sum
    100     // is 1.0.
    101     out->index_0 = Bound(in_loc - 1, limit);
    102     out->weight_0 =
    103         (out->index_0 == in_loc - 1 ? coeffs_table[offset * 2 + 1] : 0.0f);
    104     out->index_1 = Bound(in_loc, limit);
    105     out->weight_1 = (out->index_1 == in_loc ? coeffs_table[offset * 2] : 0.0f);
    106     out->index_2 = Bound(in_loc + 1, limit);
    107     out->weight_2 =
    108         (out->index_2 == in_loc + 1 ? coeffs_table[(kTableSize - offset) * 2]
    109                                     : 0.0f);
    110     out->index_3 = Bound(in_loc + 2, limit);
    111     out->weight_3 = (out->index_3 == in_loc + 2
    112                          ? coeffs_table[(kTableSize - offset) * 2 + 1]
    113                          : 0.0f);
    114 
    115     const float weight_sum =
    116         out->weight_0 + out->weight_1 + out->weight_2 + out->weight_3;
    117     if (std::abs(weight_sum) >= 1000.0f * std::numeric_limits<float>::min()) {
    118       const float one_over_weight_sum = 1.0f / weight_sum;
    119       out->weight_0 *= one_over_weight_sum;
    120       out->weight_1 *= one_over_weight_sum;
    121       out->weight_2 *= one_over_weight_sum;
    122       out->weight_3 *= one_over_weight_sum;
    123     }
    124   } else {
    125     out->weight_0 = coeffs_table[offset * 2 + 1];
    126     out->weight_1 = coeffs_table[offset * 2];
    127     out->weight_2 = coeffs_table[(kTableSize - offset) * 2];
    128     out->weight_3 = coeffs_table[(kTableSize - offset) * 2 + 1];
    129     out->index_0 = Bound(in_loc - 1, limit);
    130     out->index_1 = Bound(in_loc, limit);
    131     out->index_2 = Bound(in_loc + 1, limit);
    132     out->index_3 = Bound(in_loc + 2, limit);
    133   }
    134 }
    135 
    136 template <typename T>
    137 inline float Interpolate1D(const float weight_0, const float weight_1,
    138                            const float weight_2, const float weight_3,
    139                            const T value_0, const T value_1, const T value_2,
    140                            const T value_3) {
    141   return static_cast<float>(value_0) * weight_0 +
    142          static_cast<float>(value_1) * weight_1 +
    143          static_cast<float>(value_2) * weight_2 +
    144          static_cast<float>(value_3) * weight_3;
    145 }
    146 
    147 // Compute the 1D interpolation for a given X index using the y_weights
    148 static float Compute(float values_[4], const float xw_0, const float xw_1,
    149                      const float xw_2, const float xw_3) {
    150   return Interpolate1D(xw_0, xw_1, xw_2, xw_3, values_[0], values_[1],
    151                        values_[2], values_[3]);
    152 }
    153 
    154 // In order to compute a single output value, we look at a 4x4 patch in the
    155 // source image. As we iterate increasing X across the image, the new 4x4 patch
    156 // often overlaps with the previous 4x4 patch we just looked at.
    157 //
    158 // This class helps compute the number of values to copy from the previous
    159 // point's values.
    160 class CachedInterpolationCalculator {
    161  public:
    162   CachedInterpolationCalculator() : indexes_{-1, -1, -1, -1} {}
    163 
    164   // Advances iteration. Returns the number of values that should be copied from
    165   // the current point to the next point. The copying should always be done by
    166   // copying the last <retval> values from the old point to the first <retval>
    167   // values of the new point.
    168   inline int Advance(const int64 x_0, const int64 x_1, const int64 x_2,
    169                      const int64 x_3) {
    170     // We use 2 hands and walk through, copying from one to another where
    171     // we already have values.
    172     // Invariant, new_indicies_hand <= cached_values_hand
    173     const std::array<int64, 4> new_x_indices{{x_0, x_1, x_2, x_3}};
    174     int cached_values_hand = 0;
    175     int new_indicies_hand = 0;
    176     while (cached_values_hand < 4) {
    177       if (indexes_[cached_values_hand] == new_x_indices[new_indicies_hand]) {
    178         if (new_indicies_hand < cached_values_hand) {
    179           indexes_[new_indicies_hand] = indexes_[cached_values_hand];
    180         }
    181         cached_values_hand++;
    182         new_indicies_hand++;
    183       } else {
    184         cached_values_hand++;
    185       }
    186     }
    187     switch (new_indicies_hand) {
    188       case 0:
    189         indexes_[0] = x_0;
    190         TF_FALLTHROUGH_INTENDED;
    191       case 1:
    192         indexes_[1] = x_1;
    193         TF_FALLTHROUGH_INTENDED;
    194       case 2:
    195         indexes_[2] = x_2;
    196         TF_FALLTHROUGH_INTENDED;
    197       case 3:
    198         indexes_[3] = x_3;
    199         break;
    200     }
    201     return new_indicies_hand;
    202   }
    203 
    204  private:
    205   int64 indexes_[4];
    206 };
    207 
    208 static void ComputeXWeightsAndIndices(const ImageResizerState& resizer_state,
    209                                       const bool half_pixel_centers,
    210                                       std::vector<WeightsAndIndices>* x_wais) {
    211   CachedInterpolationCalculator calc;
    212   if (half_pixel_centers) {
    213     for (int64 x = 0; x < resizer_state.out_width; ++x) {
    214       GetWeightsAndIndices<HalfPixelScaler, true>(
    215           resizer_state.width_scale, x, resizer_state.in_width, &(*x_wais)[x]);
    216       auto& x_wai = (*x_wais)[x];
    217       x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2,
    218                                    x_wai.index_3);
    219     }
    220   } else {
    221     for (int64 x = 0; x < resizer_state.out_width; ++x) {
    222       GetWeightsAndIndices<LegacyScaler, false>(
    223           resizer_state.width_scale, x, resizer_state.in_width, &(*x_wais)[x]);
    224       auto& x_wai = (*x_wais)[x];
    225       x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2,
    226                                    x_wai.index_3);
    227     }
    228   }
    229   // Scale the values so they can be used as offsets into buffers.
    230   for (int x = 0; x < resizer_state.out_width; ++x) {
    231     (*x_wais)[x].index_0 *= resizer_state.channels;
    232     (*x_wais)[x].index_1 *= resizer_state.channels;
    233     (*x_wais)[x].index_2 *= resizer_state.channels;
    234     (*x_wais)[x].index_3 *= resizer_state.channels;
    235   }
    236 }
    237 
    238 static void ComputeGradientXWeightsAndIndices(
    239     const ImageResizerGradientState& resizer_state,
    240     const bool half_pixel_centers, std::vector<WeightsAndIndices>* x_wais) {
    241   CachedInterpolationCalculator calc;
    242   if (half_pixel_centers) {
    243     for (int64 x = 0; x < resizer_state.resized_width; ++x) {
    244       GetWeightsAndIndices<HalfPixelScaler, true>(resizer_state.width_scale, x,
    245                                                   resizer_state.original_width,
    246                                                   &(*x_wais)[x]);
    247       auto& x_wai = (*x_wais)[x];
    248       x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2,
    249                                    x_wai.index_3);
    250     }
    251 
    252   } else {
    253     for (int64 x = 0; x < resizer_state.resized_width; ++x) {
    254       GetWeightsAndIndices<LegacyScaler, false>(resizer_state.width_scale, x,
    255                                                 resizer_state.original_width,
    256                                                 &(*x_wais)[x]);
    257       auto& x_wai = (*x_wais)[x];
    258       x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2,
    259                                    x_wai.index_3);
    260     }
    261   }
    262   // Do not scale, as we will be using these directly as tensor indices on the
    263   // gradient pass.
    264 }
    265 
    266 template <typename T>
    267 static EIGEN_ALWAYS_INLINE float ComputeYInterpolation(
    268     int which, int channel_num, const WeightsAndIndices& y_wai,
    269     const T* y_ptr_0, const T* y_ptr_1, const T* y_ptr_2, const T* y_ptr_3,
    270     const WeightsAndIndices& x_wai) {
    271   int x_index;
    272   switch (which) {
    273     case 0:
    274       x_index = x_wai.index_0;
    275       break;
    276     case 1:
    277       x_index = x_wai.index_1;
    278       break;
    279     case 2:
    280       x_index = x_wai.index_2;
    281       break;
    282     default:
    283       x_index = x_wai.index_3;
    284       break;
    285   }
    286   const int64 pt_index = x_index + channel_num;
    287   return Interpolate1D<T>(y_wai.weight_0, y_wai.weight_1, y_wai.weight_2,
    288                           y_wai.weight_3, y_ptr_0[pt_index], y_ptr_1[pt_index],
    289                           y_ptr_2[pt_index], y_ptr_3[pt_index]);
    290 }
    291 
    292 template <typename T>
    293 inline void interpolate_with_caching(
    294     const typename TTypes<T, 4>::ConstTensor& input_data,
    295     const ImageResizerState& resizer_state, const bool half_pixel_centers,
    296     typename TTypes<float, 4>::Tensor output_data) {
    297   std::vector<WeightsAndIndices> x_wais(resizer_state.out_width);
    298   ComputeXWeightsAndIndices(resizer_state, half_pixel_centers, &x_wais);
    299 
    300   const auto num_channels = resizer_state.channels;
    301   const int64 in_row_width = resizer_state.in_width * num_channels;
    302   const int64 in_batch_width = resizer_state.in_height * in_row_width;
    303 
    304   const T* input_b_ptr = input_data.data();
    305   float* output_y_ptr = output_data.data();
    306   std::vector<float> cached_value(num_channels == 3 ? 0 : 4 * num_channels, 0);
    307 
    308   for (int64 b = 0; b < resizer_state.batch_size;
    309        ++b, input_b_ptr += in_batch_width) {
    310     for (int64 y = 0; y < resizer_state.out_height;
    311          ++y, output_y_ptr += resizer_state.out_width * num_channels) {
    312       WeightsAndIndices y_wai;
    313       if (half_pixel_centers) {
    314         GetWeightsAndIndices<HalfPixelScaler, true>(
    315             resizer_state.height_scale, y, resizer_state.in_height, &y_wai);
    316       } else {
    317         GetWeightsAndIndices<LegacyScaler, false>(
    318             resizer_state.height_scale, y, resizer_state.in_height, &y_wai);
    319       }
    320       // Make pointers represent offsets of data in input_b_ptr.
    321       const T* y_ptr_0 = input_b_ptr + y_wai.index_0 * in_row_width;
    322       const T* y_ptr_1 = input_b_ptr + y_wai.index_1 * in_row_width;
    323       const T* y_ptr_2 = input_b_ptr + y_wai.index_2 * in_row_width;
    324       const T* y_ptr_3 = input_b_ptr + y_wai.index_3 * in_row_width;
    325 
    326       if (num_channels == 3) {
    327         // Manually unroll case of 3 channels.
    328         float cached_value_0[4] = {0};
    329         float cached_value_1[4] = {0};
    330         float cached_value_2[4] = {0};
    331         for (int64 x = 0; x < resizer_state.out_width; ++x) {
    332           const WeightsAndIndices& x_wai = x_wais[x];
    333           // Shift values in cached_value_* to fill first 'advance' values.
    334           switch (x_wai.advance) {
    335             case 3:
    336               cached_value_0[0] = cached_value_0[1];
    337               cached_value_0[1] = cached_value_0[2];
    338               cached_value_0[2] = cached_value_0[3];
    339               cached_value_1[0] = cached_value_1[1];
    340               cached_value_1[1] = cached_value_1[2];
    341               cached_value_1[2] = cached_value_1[3];
    342               cached_value_2[0] = cached_value_2[1];
    343               cached_value_2[1] = cached_value_2[2];
    344               cached_value_2[2] = cached_value_2[3];
    345               break;
    346             case 2:
    347               cached_value_0[0] = cached_value_0[2];
    348               cached_value_0[1] = cached_value_0[3];
    349               cached_value_1[0] = cached_value_1[2];
    350               cached_value_1[1] = cached_value_1[3];
    351               cached_value_2[0] = cached_value_2[2];
    352               cached_value_2[1] = cached_value_2[3];
    353               break;
    354             case 1: {
    355               cached_value_0[0] = cached_value_0[3];
    356               cached_value_1[0] = cached_value_1[3];
    357               cached_value_2[0] = cached_value_2[3];
    358               break;
    359             }
    360           }
    361 
    362           // Set the remaining '4-advance' values by computing.
    363           switch (x_wai.advance) {
    364             case 0:
    365               cached_value_0[0] = ComputeYInterpolation(
    366                   0, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    367               cached_value_1[0] = ComputeYInterpolation(
    368                   0, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    369               cached_value_2[0] = ComputeYInterpolation(
    370                   0, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    371               TF_FALLTHROUGH_INTENDED;
    372             case 1:
    373               cached_value_0[1] = ComputeYInterpolation(
    374                   1, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    375               cached_value_1[1] = ComputeYInterpolation(
    376                   1, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    377               cached_value_2[1] = ComputeYInterpolation(
    378                   1, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    379               TF_FALLTHROUGH_INTENDED;
    380             case 2:
    381               cached_value_0[2] = ComputeYInterpolation(
    382                   2, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    383               cached_value_1[2] = ComputeYInterpolation(
    384                   2, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    385               cached_value_2[2] = ComputeYInterpolation(
    386                   2, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    387               TF_FALLTHROUGH_INTENDED;
    388             case 3:
    389               cached_value_0[3] = ComputeYInterpolation(
    390                   3, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    391               cached_value_1[3] = ComputeYInterpolation(
    392                   3, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    393               cached_value_2[3] = ComputeYInterpolation(
    394                   3, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    395               break;
    396           }
    397           output_y_ptr[x * num_channels + 0] =
    398               Compute(cached_value_0, x_wai.weight_0, x_wai.weight_1,
    399                       x_wai.weight_2, x_wai.weight_3);
    400           output_y_ptr[x * num_channels + 1] =
    401               Compute(cached_value_1, x_wai.weight_0, x_wai.weight_1,
    402                       x_wai.weight_2, x_wai.weight_3);
    403           output_y_ptr[x * num_channels + 2] =
    404               Compute(cached_value_2, x_wai.weight_0, x_wai.weight_1,
    405                       x_wai.weight_2, x_wai.weight_3);
    406         }
    407       } else {
    408         for (int64 x = 0; x < resizer_state.out_width; ++x) {
    409           const WeightsAndIndices& x_wai = x_wais[x];
    410           // Shift values in cached_value to fill first 'advance' values.
    411           switch (x_wai.advance) {
    412             case 3:
    413               for (int64 c = 0; c < num_channels; ++c) {
    414                 cached_value[4 * c + 0] = cached_value[4 * c + 1];
    415                 cached_value[4 * c + 1] = cached_value[4 * c + 2];
    416                 cached_value[4 * c + 2] = cached_value[4 * c + 3];
    417               }
    418               break;
    419             case 2:
    420               for (int64 c = 0; c < num_channels; ++c) {
    421                 cached_value[4 * c + 0] = cached_value[4 * c + 2];
    422                 cached_value[4 * c + 1] = cached_value[4 * c + 3];
    423               }
    424               break;
    425             case 1: {
    426               for (int64 c = 0; c < num_channels; ++c) {
    427                 cached_value[4 * c + 0] = cached_value[4 * c + 3];
    428               }
    429               break;
    430             }
    431           }
    432 
    433           // Set the remaining '4-advance' values by computing.
    434           switch (x_wai.advance) {
    435             case 0:
    436               for (int64 c = 0; c < num_channels; ++c) {
    437                 cached_value[4 * c + 0] = ComputeYInterpolation(
    438                     0, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    439               }
    440               TF_FALLTHROUGH_INTENDED;
    441             case 1:
    442               for (int64 c = 0; c < num_channels; ++c) {
    443                 cached_value[4 * c + 1] = ComputeYInterpolation(
    444                     1, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    445               }
    446               TF_FALLTHROUGH_INTENDED;
    447             case 2:
    448               for (int64 c = 0; c < num_channels; ++c) {
    449                 cached_value[4 * c + 2] = ComputeYInterpolation(
    450                     2, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    451               }
    452               TF_FALLTHROUGH_INTENDED;
    453             case 3:
    454               for (int64 c = 0; c < num_channels; ++c) {
    455                 cached_value[4 * c + 3] = ComputeYInterpolation(
    456                     3, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    457               }
    458               break;
    459           }
    460           for (int64 c = 0; c < num_channels; ++c) {
    461             output_y_ptr[x * num_channels + c] =
    462                 Compute(&cached_value[4 * c], x_wai.weight_0, x_wai.weight_1,
    463                         x_wai.weight_2, x_wai.weight_3);
    464           }
    465         }
    466       }
    467     }
    468   }
    469 }
    470 
    471 template <typename T>
    472 inline void ResizeBicubicGrad(typename TTypes<float, 4>::ConstTensor input_grad,
    473                               const ImageResizerGradientState& resizer_state,
    474                               const bool half_pixel_centers,
    475                               typename TTypes<T, 4>::Tensor output_grad) {
    476   // This function computes gradients for the ResizeBicubic op by iterating over
    477   // the input_grad Tensor and using WeightsAndIndices to appropriately update
    478   // the output gradient.
    479   const float height_scale = resizer_state.height_scale;
    480   const int64 original_height = resizer_state.original_height;
    481   const int channels = resizer_state.channels;
    482   const int64 resized_width = resizer_state.resized_width;
    483   const int64 resized_height = resizer_state.resized_height;
    484 
    485   output_grad.setZero();
    486 
    487   std::vector<WeightsAndIndices> x_wais(resizer_state.resized_width);
    488   ComputeGradientXWeightsAndIndices(resizer_state, half_pixel_centers, &x_wais);
    489   for (int64 b = 0; b < resizer_state.batch_size; ++b) {
    490     for (int64 y = 0; y < resized_height; ++y) {
    491       WeightsAndIndices y_wai;
    492       if (half_pixel_centers) {
    493         GetWeightsAndIndices<HalfPixelScaler, true>(height_scale, y,
    494                                                     original_height, &y_wai);
    495       } else {
    496         GetWeightsAndIndices<LegacyScaler, false>(height_scale, y,
    497                                                   original_height, &y_wai);
    498       }
    499       for (int64 x = 0; x < resized_width; ++x) {
    500         const WeightsAndIndices& x_wai = x_wais[x];
    501         for (int64 c = 0; c < channels; ++c) {
    502           T curr_input_grad = input_grad(b, y, x, c);
    503           // row 0 of 0, 1, 2, 3
    504           output_grad(b, y_wai.index_0, x_wai.index_0, c) +=
    505               T(curr_input_grad * y_wai.weight_0 * x_wai.weight_0);
    506           output_grad(b, y_wai.index_0, x_wai.index_1, c) +=
    507               T(curr_input_grad * y_wai.weight_0 * x_wai.weight_1);
    508           output_grad(b, y_wai.index_0, x_wai.index_2, c) +=
    509               T(curr_input_grad * y_wai.weight_0 * x_wai.weight_2);
    510           output_grad(b, y_wai.index_0, x_wai.index_3, c) +=
    511               T(curr_input_grad * y_wai.weight_0 * x_wai.weight_3);
    512           // row 1 of 0, 1, 2, 3
    513           output_grad(b, y_wai.index_1, x_wai.index_0, c) +=
    514               T(curr_input_grad * y_wai.weight_1 * x_wai.weight_0);
    515           output_grad(b, y_wai.index_1, x_wai.index_1, c) +=
    516               T(curr_input_grad * y_wai.weight_1 * x_wai.weight_1);
    517           output_grad(b, y_wai.index_1, x_wai.index_2, c) +=
    518               T(curr_input_grad * y_wai.weight_1 * x_wai.weight_2);
    519           output_grad(b, y_wai.index_1, x_wai.index_3, c) +=
    520               T(curr_input_grad * y_wai.weight_1 * x_wai.weight_3);
    521           // row 2 of 0, 1, 2, 3
    522           output_grad(b, y_wai.index_2, x_wai.index_0, c) +=
    523               T(curr_input_grad * y_wai.weight_2 * x_wai.weight_0);
    524           output_grad(b, y_wai.index_2, x_wai.index_1, c) +=
    525               T(curr_input_grad * y_wai.weight_2 * x_wai.weight_1);
    526           output_grad(b, y_wai.index_2, x_wai.index_2, c) +=
    527               T(curr_input_grad * y_wai.weight_2 * x_wai.weight_2);
    528           output_grad(b, y_wai.index_2, x_wai.index_3, c) +=
    529               T(curr_input_grad * y_wai.weight_2 * x_wai.weight_3);
    530           // row 3 of 0, 1, 2, 3
    531           output_grad(b, y_wai.index_3, x_wai.index_0, c) +=
    532               T(curr_input_grad * y_wai.weight_3 * x_wai.weight_0);
    533           output_grad(b, y_wai.index_3, x_wai.index_1, c) +=
    534               T(curr_input_grad * y_wai.weight_3 * x_wai.weight_1);
    535           output_grad(b, y_wai.index_3, x_wai.index_2, c) +=
    536               T(curr_input_grad * y_wai.weight_3 * x_wai.weight_2);
    537           output_grad(b, y_wai.index_3, x_wai.index_3, c) +=
    538               T(curr_input_grad * y_wai.weight_3 * x_wai.weight_3);
    539         }
    540       }
    541     }
    542   }
    543 }
    544 
    545 }  // namespace
    546 
    547 typedef Eigen::ThreadPoolDevice CPUDevice;
    548 
    549 template <typename Device, typename T>
    550 class ResizeBicubicOp : public OpKernel {
    551  public:
    552   explicit ResizeBicubicOp(OpKernelConstruction* context) : OpKernel(context) {
    553     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
    554     OP_REQUIRES_OK(
    555         context, context->GetAttr("half_pixel_centers", &half_pixel_centers_));
    556   }
    557 
    558   void Compute(OpKernelContext* context) override {
    559     const Tensor& input = context->input(0);
    560     ImageResizerState st(align_corners_, half_pixel_centers_);
    561     st.ValidateAndCreateOutput(context, input);
    562 
    563     if (!context->status().ok()) return;
    564 
    565     typename TTypes<T, 4>::ConstTensor input_data(input.tensor<T, 4>());
    566     TTypes<float, 4>::Tensor output_data = st.output->tensor<float, 4>();
    567 
    568     interpolate_with_caching<T>(input_data, st, half_pixel_centers_,
    569                                 output_data);
    570   }
    571 
    572  private:
    573   bool align_corners_;
    574   bool half_pixel_centers_;
    575 };
    576 
    577 template <typename Device, typename T>
    578 class ResizeBicubicOpGrad : public OpKernel {
    579  public:
    580   explicit ResizeBicubicOpGrad(OpKernelConstruction* context)
    581       : OpKernel(context) {
    582     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
    583     OP_REQUIRES_OK(
    584         context, context->GetAttr("half_pixel_centers", &half_pixel_centers_));
    585   }
    586 
    587   void Compute(OpKernelContext* context) override {
    588     // Validate input.
    589     // First argument is gradient with respect to resized image.
    590     const Tensor& input = context->input(0);
    591     const Tensor& original_image = context->input(1);
    592 
    593     ImageResizerGradientState st(align_corners_, half_pixel_centers_);
    594     st.ValidateAndCreateOutput(context, input, original_image);
    595 
    596     if (!context->status().ok()) return;
    597 
    598     TTypes<float, 4>::ConstTensor input_grad = input.tensor<float, 4>();
    599     typename TTypes<T, 4>::Tensor output_grad(st.output->tensor<T, 4>());
    600 
    601     ResizeBicubicGrad<T>(input_grad, st, half_pixel_centers_, output_grad);
    602   }
    603 
    604  private:
    605   bool align_corners_;
    606   bool half_pixel_centers_;
    607 };
    608 
    609 #define REGISTER_KERNEL(T)                            \
    610   REGISTER_KERNEL_BUILDER(Name("ResizeBicubic")       \
    611                               .Device(DEVICE_CPU)     \
    612                               .TypeConstraint<T>("T") \
    613                               .HostMemory("size"),    \
    614                           ResizeBicubicOp<CPUDevice, T>);
    615 
    616 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
    617 
    618 #undef REGISTER_KERNEL
    619 
    620 #define REGISTER_GRAD_KERNEL(T)                                            \
    621   REGISTER_KERNEL_BUILDER(                                                 \
    622       Name("ResizeBicubicGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    623       ResizeBicubicOpGrad<CPUDevice, T>);
    624 
    625 TF_CALL_float(REGISTER_GRAD_KERNEL);
    626 TF_CALL_double(REGISTER_GRAD_KERNEL);
    627 
    628 #undef REGISTER_GRAD_KERNEL
    629 
    630 }  // namespace tensorflow
    631