Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/image_ops.cc
     17 #define EIGEN_USE_THREADS
     18 
     19 #include <math.h>
     20 #include <algorithm>
     21 #include <array>
     22 
     23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/register_types.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/kernels/image_resizer_state.h"
     30 #include "tensorflow/core/lib/core/status.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 
     33 namespace tensorflow {
     34 namespace {
     35 
     36 static const int64 kTableSize = (1 << 10);
     37 
     38 const float* InitCoeffsTable() {
     39   // Allocate and initialize coefficients table using Bicubic
     40   // convolution algorithm.
     41   // https://en.wikipedia.org/wiki/Bicubic_interpolation
     42   float* coeffs_table = new float[(kTableSize + 1) * 2];
     43   static const double A = -0.75;
     44   for (int i = 0; i <= kTableSize; ++i) {
     45     float x = i * 1.0 / kTableSize;
     46     coeffs_table[i * 2] = ((A + 2) * x - (A + 3)) * x * x + 1;
     47     x += 1.0;
     48     coeffs_table[i * 2 + 1] = ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
     49   }
     50   return coeffs_table;
     51 }
     52 
     53 const float* GetCoeffsTable() {
     54   // Static so that we initialize it on first use
     55   static const float* coeffs_table = InitCoeffsTable();
     56   return coeffs_table;
     57 }
     58 
     59 inline int64 Bound(int64 val, int64 limit) {
     60   return std::min(limit - 1ll, std::max(0ll, val));
     61 }
     62 
     63 struct WeightsAndIndices {
     64   float weight_0;
     65   float weight_1;
     66   float weight_2;
     67   float weight_3;
     68   int64 index_0;
     69   int64 index_1;
     70   int64 index_2;
     71   int64 index_3;
     72 
     73   int advance;  // advance value.
     74 };
     75 
     76 inline void GetWeightsAndIndices(const float scale, const int64 out_loc,
     77                                  const int64 limit, WeightsAndIndices* out) {
     78   const int64 in_loc = scale * out_loc;
     79   const float delta = scale * out_loc - in_loc;
     80   const int64 offset = lrintf(delta * kTableSize);
     81   const float* coeffs_table = GetCoeffsTable();
     82   out->weight_0 = coeffs_table[offset * 2 + 1];
     83   out->weight_1 = coeffs_table[offset * 2];
     84   out->weight_2 = coeffs_table[(kTableSize - offset) * 2];
     85   out->weight_3 = coeffs_table[(kTableSize - offset) * 2 + 1];
     86   out->index_0 = Bound(in_loc - 1, limit);
     87   out->index_1 = Bound(in_loc, limit);
     88   out->index_2 = Bound(in_loc + 1, limit);
     89   out->index_3 = Bound(in_loc + 2, limit);
     90 }
     91 
     92 template <typename T>
     93 inline float Interpolate1D(const float weight_0, const float weight_1,
     94                            const float weight_2, const float weight_3,
     95                            const T value_0, const T value_1, const T value_2,
     96                            const T value_3) {
     97   return static_cast<float>(value_0) * weight_0 +
     98          static_cast<float>(value_1) * weight_1 +
     99          static_cast<float>(value_2) * weight_2 +
    100          static_cast<float>(value_3) * weight_3;
    101 }
    102 
    103 // Compute the 1D interpolation for a given X index using the y_weights
    104 static float Compute(float values_[4], const float xw_0, const float xw_1,
    105                      const float xw_2, const float xw_3) {
    106   return Interpolate1D(xw_0, xw_1, xw_2, xw_3, values_[0], values_[1],
    107                        values_[2], values_[3]);
    108 }
    109 
    110 // In order to compute a single output value, we look at a 4x4 patch in the
    111 // source image. As we iterate increasing X across the image, the new 4x4 patch
    112 // often overlaps with the previous 4x4 patch we just looked at.
    113 //
    114 // This class helps compute the number of values to copy from the previous
    115 // point's values.
    116 class CachedInterpolationCalculator {
    117  public:
    118   CachedInterpolationCalculator() : indexes_{-1, -1, -1, -1} {}
    119 
    120   // Advances iteration. Returns the number of values that should be copied from
    121   // the current point to the next point. The copying should always be done by
    122   // copying the last <retval> values from the old point to the first <retval>
    123   // values of the new point.
    124   inline int Advance(const int64 x_0, const int64 x_1, const int64 x_2,
    125                      const int64 x_3) {
    126     // We use 2 hands and walk through, copying from one to another where
    127     // we already have values.
    128     // Invariant, new_indicies_hand <= cached_values_hand
    129     const std::array<int64, 4> new_x_indices{{x_0, x_1, x_2, x_3}};
    130     int cached_values_hand = 0;
    131     int new_indicies_hand = 0;
    132     while (cached_values_hand < 4) {
    133       if (indexes_[cached_values_hand] == new_x_indices[new_indicies_hand]) {
    134         if (new_indicies_hand < cached_values_hand) {
    135           indexes_[new_indicies_hand] = indexes_[cached_values_hand];
    136         }
    137         cached_values_hand++;
    138         new_indicies_hand++;
    139       } else {
    140         cached_values_hand++;
    141       }
    142     }
    143     switch (new_indicies_hand) {
    144       case 0:
    145         indexes_[0] = x_0;
    146         TF_FALLTHROUGH_INTENDED;
    147       case 1:
    148         indexes_[1] = x_1;
    149         TF_FALLTHROUGH_INTENDED;
    150       case 2:
    151         indexes_[2] = x_2;
    152         TF_FALLTHROUGH_INTENDED;
    153       case 3:
    154         indexes_[3] = x_3;
    155         break;
    156     }
    157     return new_indicies_hand;
    158   }
    159 
    160  private:
    161   int64 indexes_[4];
    162 };
    163 
    164 static void ComputeXWeightsAndIndices(const ImageResizerState& resizer_state,
    165                                       std::vector<WeightsAndIndices>* x_wais) {
    166   CachedInterpolationCalculator calc;
    167   for (int64 x = 0; x < resizer_state.out_width; ++x) {
    168     GetWeightsAndIndices(resizer_state.width_scale, x, resizer_state.in_width,
    169                          &(*x_wais)[x]);
    170     auto& x_wai = (*x_wais)[x];
    171     x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2,
    172                                  x_wai.index_3);
    173   }
    174   // Scale the values so they can be used as offsets into buffers.
    175   for (int x = 0; x < resizer_state.out_width; ++x) {
    176     (*x_wais)[x].index_0 *= resizer_state.channels;
    177     (*x_wais)[x].index_1 *= resizer_state.channels;
    178     (*x_wais)[x].index_2 *= resizer_state.channels;
    179     (*x_wais)[x].index_3 *= resizer_state.channels;
    180   }
    181 }
    182 
    183 static void ComputeGradientXWeightsAndIndices(
    184     const ImageResizerGradientState& resizer_state,
    185     std::vector<WeightsAndIndices>* x_wais) {
    186   CachedInterpolationCalculator calc;
    187   for (int64 x = 0; x < resizer_state.resized_width; ++x) {
    188     GetWeightsAndIndices(resizer_state.width_scale, x,
    189                          resizer_state.original_width, &(*x_wais)[x]);
    190     auto& x_wai = (*x_wais)[x];
    191     x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2,
    192                                  x_wai.index_3);
    193   }
    194   // Do not scale, as we will be using these directly as tensor indices on the
    195   // gradient pass.
    196 }
    197 
    198 template <typename T>
    199 static EIGEN_ALWAYS_INLINE float ComputeYInterpolation(
    200     int which, int channel_num, const WeightsAndIndices& y_wai,
    201     const T* y_ptr_0, const T* y_ptr_1, const T* y_ptr_2, const T* y_ptr_3,
    202     const WeightsAndIndices& x_wai) {
    203   int x_index;
    204   switch (which) {
    205     case 0:
    206       x_index = x_wai.index_0;
    207       break;
    208     case 1:
    209       x_index = x_wai.index_1;
    210       break;
    211     case 2:
    212       x_index = x_wai.index_2;
    213       break;
    214     default:
    215       x_index = x_wai.index_3;
    216       break;
    217   }
    218   const int64 pt_index = x_index + channel_num;
    219   return Interpolate1D<T>(y_wai.weight_0, y_wai.weight_1, y_wai.weight_2,
    220                           y_wai.weight_3, y_ptr_0[pt_index], y_ptr_1[pt_index],
    221                           y_ptr_2[pt_index], y_ptr_3[pt_index]);
    222 }
    223 
    224 template <typename T>
    225 inline void interpolate_with_caching(
    226     const typename TTypes<T, 4>::ConstTensor& input_data,
    227     const ImageResizerState& resizer_state,
    228     typename TTypes<float, 4>::Tensor output_data) {
    229   std::vector<WeightsAndIndices> x_wais(resizer_state.out_width);
    230   ComputeXWeightsAndIndices(resizer_state, &x_wais);
    231 
    232   const auto num_channels = resizer_state.channels;
    233   const int64 in_row_width = resizer_state.in_width * num_channels;
    234   const int64 in_batch_width = resizer_state.in_height * in_row_width;
    235 
    236   const T* input_b_ptr = input_data.data();
    237   float* output_y_ptr = output_data.data();
    238   std::vector<float> cached_value(num_channels == 3 ? 0 : 4 * num_channels, 0);
    239 
    240   for (int64 b = 0; b < resizer_state.batch_size;
    241        ++b, input_b_ptr += in_batch_width) {
    242     for (int64 y = 0; y < resizer_state.out_height;
    243          ++y, output_y_ptr += resizer_state.out_width * num_channels) {
    244       WeightsAndIndices y_wai;
    245       GetWeightsAndIndices(resizer_state.height_scale, y,
    246                            resizer_state.in_height, &y_wai);
    247       // Make pointers represent offsets of data in input_b_ptr.
    248       const T* y_ptr_0 = input_b_ptr + y_wai.index_0 * in_row_width;
    249       const T* y_ptr_1 = input_b_ptr + y_wai.index_1 * in_row_width;
    250       const T* y_ptr_2 = input_b_ptr + y_wai.index_2 * in_row_width;
    251       const T* y_ptr_3 = input_b_ptr + y_wai.index_3 * in_row_width;
    252 
    253       if (num_channels == 3) {
    254         // Manually unroll case of 3 channels.
    255         float cached_value_0[4] = {0};
    256         float cached_value_1[4] = {0};
    257         float cached_value_2[4] = {0};
    258         for (int64 x = 0; x < resizer_state.out_width; ++x) {
    259           const WeightsAndIndices& x_wai = x_wais[x];
    260           // Shift values in cached_value_* to fill first 'advance' values.
    261           switch (x_wai.advance) {
    262             case 3:
    263               cached_value_0[0] = cached_value_0[1];
    264               cached_value_0[1] = cached_value_0[2];
    265               cached_value_0[2] = cached_value_0[3];
    266               cached_value_1[0] = cached_value_1[1];
    267               cached_value_1[1] = cached_value_1[2];
    268               cached_value_1[2] = cached_value_1[3];
    269               cached_value_2[0] = cached_value_2[1];
    270               cached_value_2[1] = cached_value_2[2];
    271               cached_value_2[2] = cached_value_2[3];
    272               break;
    273             case 2:
    274               cached_value_0[0] = cached_value_0[2];
    275               cached_value_0[1] = cached_value_0[3];
    276               cached_value_1[0] = cached_value_1[2];
    277               cached_value_1[1] = cached_value_1[3];
    278               cached_value_2[0] = cached_value_2[2];
    279               cached_value_2[1] = cached_value_2[3];
    280               break;
    281             case 1: {
    282               cached_value_0[0] = cached_value_0[3];
    283               cached_value_1[0] = cached_value_1[3];
    284               cached_value_2[0] = cached_value_2[3];
    285               break;
    286             }
    287           }
    288 
    289           // Set the remaining '4-advance' values by computing.
    290           switch (x_wai.advance) {
    291             case 0:
    292               cached_value_0[0] = ComputeYInterpolation(
    293                   0, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    294               cached_value_1[0] = ComputeYInterpolation(
    295                   0, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    296               cached_value_2[0] = ComputeYInterpolation(
    297                   0, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    298               TF_FALLTHROUGH_INTENDED;
    299             case 1:
    300               cached_value_0[1] = ComputeYInterpolation(
    301                   1, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    302               cached_value_1[1] = ComputeYInterpolation(
    303                   1, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    304               cached_value_2[1] = ComputeYInterpolation(
    305                   1, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    306               TF_FALLTHROUGH_INTENDED;
    307             case 2:
    308               cached_value_0[2] = ComputeYInterpolation(
    309                   2, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    310               cached_value_1[2] = ComputeYInterpolation(
    311                   2, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    312               cached_value_2[2] = ComputeYInterpolation(
    313                   2, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    314               TF_FALLTHROUGH_INTENDED;
    315             case 3:
    316               cached_value_0[3] = ComputeYInterpolation(
    317                   3, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    318               cached_value_1[3] = ComputeYInterpolation(
    319                   3, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    320               cached_value_2[3] = ComputeYInterpolation(
    321                   3, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    322               break;
    323           }
    324           output_y_ptr[x * num_channels + 0] =
    325               Compute(cached_value_0, x_wai.weight_0, x_wai.weight_1,
    326                       x_wai.weight_2, x_wai.weight_3);
    327           output_y_ptr[x * num_channels + 1] =
    328               Compute(cached_value_1, x_wai.weight_0, x_wai.weight_1,
    329                       x_wai.weight_2, x_wai.weight_3);
    330           output_y_ptr[x * num_channels + 2] =
    331               Compute(cached_value_2, x_wai.weight_0, x_wai.weight_1,
    332                       x_wai.weight_2, x_wai.weight_3);
    333         }
    334       } else {
    335         for (int64 x = 0; x < resizer_state.out_width; ++x) {
    336           const WeightsAndIndices& x_wai = x_wais[x];
    337           // Shift values in cached_value to fill first 'advance' values.
    338           switch (x_wai.advance) {
    339             case 3:
    340               for (int64 c = 0; c < num_channels; ++c) {
    341                 cached_value[4 * c + 0] = cached_value[4 * c + 1];
    342                 cached_value[4 * c + 1] = cached_value[4 * c + 2];
    343                 cached_value[4 * c + 2] = cached_value[4 * c + 3];
    344               }
    345               break;
    346             case 2:
    347               for (int64 c = 0; c < num_channels; ++c) {
    348                 cached_value[4 * c + 0] = cached_value[4 * c + 2];
    349                 cached_value[4 * c + 1] = cached_value[4 * c + 3];
    350               }
    351               break;
    352             case 1: {
    353               for (int64 c = 0; c < num_channels; ++c) {
    354                 cached_value[4 * c + 0] = cached_value[4 * c + 3];
    355               }
    356               break;
    357             }
    358           }
    359 
    360           // Set the remaining '4-advance' values by computing.
    361           switch (x_wai.advance) {
    362             case 0:
    363               for (int64 c = 0; c < num_channels; ++c) {
    364                 cached_value[4 * c + 0] = ComputeYInterpolation(
    365                     0, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    366               }
    367               TF_FALLTHROUGH_INTENDED;
    368             case 1:
    369               for (int64 c = 0; c < num_channels; ++c) {
    370                 cached_value[4 * c + 1] = ComputeYInterpolation(
    371                     1, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    372               }
    373               TF_FALLTHROUGH_INTENDED;
    374             case 2:
    375               for (int64 c = 0; c < num_channels; ++c) {
    376                 cached_value[4 * c + 2] = ComputeYInterpolation(
    377                     2, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    378               }
    379               TF_FALLTHROUGH_INTENDED;
    380             case 3:
    381               for (int64 c = 0; c < num_channels; ++c) {
    382                 cached_value[4 * c + 3] = ComputeYInterpolation(
    383                     3, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
    384               }
    385               break;
    386           }
    387           for (int64 c = 0; c < num_channels; ++c) {
    388             output_y_ptr[x * num_channels + c] =
    389                 Compute(&cached_value[4 * c], x_wai.weight_0, x_wai.weight_1,
    390                         x_wai.weight_2, x_wai.weight_3);
    391           }
    392         }
    393       }
    394     }
    395   }
    396 }
    397 
    398 template <typename T>
    399 inline void ResizeBicubicGrad(typename TTypes<float, 4>::ConstTensor input_grad,
    400                               const ImageResizerGradientState& resizer_state,
    401                               typename TTypes<T, 4>::Tensor output_grad) {
    402   // This function computes gradients for the ResizeBicubic op by iterating over
    403   // the input_grad Tensor and using WeightsAndIndices to appropriately update
    404   // the output gradient.
    405   const float height_scale = resizer_state.height_scale;
    406   const int64 original_height = resizer_state.original_height;
    407   const int channels = resizer_state.channels;
    408   const int64 resized_width = resizer_state.resized_width;
    409   const int64 resized_height = resizer_state.resized_height;
    410 
    411   output_grad.setZero();
    412 
    413   std::vector<WeightsAndIndices> x_wais(resizer_state.resized_width);
    414   ComputeGradientXWeightsAndIndices(resizer_state, &x_wais);
    415   for (int64 b = 0; b < resizer_state.batch_size; ++b) {
    416     for (int64 y = 0; y < resized_height; ++y) {
    417       WeightsAndIndices y_wai;
    418       GetWeightsAndIndices(height_scale, y, original_height, &y_wai);
    419       for (int64 x = 0; x < resized_width; ++x) {
    420         const WeightsAndIndices& x_wai = x_wais[x];
    421         for (int64 c = 0; c < channels; ++c) {
    422           T curr_input_grad = input_grad(b, y, x, c);
    423           // row 0 of 0, 1, 2, 3
    424           output_grad(b, y_wai.index_0, x_wai.index_0, c) +=
    425               T(curr_input_grad * y_wai.weight_0 * x_wai.weight_0);
    426           output_grad(b, y_wai.index_0, x_wai.index_1, c) +=
    427               T(curr_input_grad * y_wai.weight_0 * x_wai.weight_1);
    428           output_grad(b, y_wai.index_0, x_wai.index_2, c) +=
    429               T(curr_input_grad * y_wai.weight_0 * x_wai.weight_2);
    430           output_grad(b, y_wai.index_0, x_wai.index_3, c) +=
    431               T(curr_input_grad * y_wai.weight_0 * x_wai.weight_3);
    432           // row 1 of 0, 1, 2, 3
    433           output_grad(b, y_wai.index_1, x_wai.index_0, c) +=
    434               T(curr_input_grad * y_wai.weight_1 * x_wai.weight_0);
    435           output_grad(b, y_wai.index_1, x_wai.index_1, c) +=
    436               T(curr_input_grad * y_wai.weight_1 * x_wai.weight_1);
    437           output_grad(b, y_wai.index_1, x_wai.index_2, c) +=
    438               T(curr_input_grad * y_wai.weight_1 * x_wai.weight_2);
    439           output_grad(b, y_wai.index_1, x_wai.index_3, c) +=
    440               T(curr_input_grad * y_wai.weight_1 * x_wai.weight_3);
    441           // row 2 of 0, 1, 2, 3
    442           output_grad(b, y_wai.index_2, x_wai.index_0, c) +=
    443               T(curr_input_grad * y_wai.weight_2 * x_wai.weight_0);
    444           output_grad(b, y_wai.index_2, x_wai.index_1, c) +=
    445               T(curr_input_grad * y_wai.weight_2 * x_wai.weight_1);
    446           output_grad(b, y_wai.index_2, x_wai.index_2, c) +=
    447               T(curr_input_grad * y_wai.weight_2 * x_wai.weight_2);
    448           output_grad(b, y_wai.index_2, x_wai.index_3, c) +=
    449               T(curr_input_grad * y_wai.weight_2 * x_wai.weight_3);
    450           // row 3 of 0, 1, 2, 3
    451           output_grad(b, y_wai.index_3, x_wai.index_0, c) +=
    452               T(curr_input_grad * y_wai.weight_3 * x_wai.weight_0);
    453           output_grad(b, y_wai.index_3, x_wai.index_1, c) +=
    454               T(curr_input_grad * y_wai.weight_3 * x_wai.weight_1);
    455           output_grad(b, y_wai.index_3, x_wai.index_2, c) +=
    456               T(curr_input_grad * y_wai.weight_3 * x_wai.weight_2);
    457           output_grad(b, y_wai.index_3, x_wai.index_3, c) +=
    458               T(curr_input_grad * y_wai.weight_3 * x_wai.weight_3);
    459         }
    460       }
    461     }
    462   }
    463 }
    464 
    465 }  // namespace
    466 
    467 typedef Eigen::ThreadPoolDevice CPUDevice;
    468 
    469 template <typename Device, typename T>
    470 class ResizeBicubicOp : public OpKernel {
    471  public:
    472   explicit ResizeBicubicOp(OpKernelConstruction* context) : OpKernel(context) {
    473     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
    474   }
    475 
    476   void Compute(OpKernelContext* context) override {
    477     const Tensor& input = context->input(0);
    478     ImageResizerState st(align_corners_);
    479     st.ValidateAndCreateOutput(context, input);
    480 
    481     if (!context->status().ok()) return;
    482 
    483     typename TTypes<T, 4>::ConstTensor input_data(input.tensor<T, 4>());
    484     TTypes<float, 4>::Tensor output_data = st.output->tensor<float, 4>();
    485 
    486     interpolate_with_caching<T>(input_data, st, output_data);
    487   }
    488 
    489  private:
    490   bool align_corners_;
    491 };
    492 
    493 template <typename Device, typename T>
    494 class ResizeBicubicOpGrad : public OpKernel {
    495  public:
    496   explicit ResizeBicubicOpGrad(OpKernelConstruction* context)
    497       : OpKernel(context) {
    498     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
    499   }
    500 
    501   void Compute(OpKernelContext* context) override {
    502     // Validate input.
    503     // First argument is gradient with respect to resized image.
    504     const Tensor& input = context->input(0);
    505     const Tensor& original_image = context->input(1);
    506 
    507     ImageResizerGradientState st(align_corners_);
    508     st.ValidateAndCreateOutput(context, input, original_image);
    509 
    510     if (!context->status().ok()) return;
    511 
    512     TTypes<float, 4>::ConstTensor input_grad = input.tensor<float, 4>();
    513     typename TTypes<T, 4>::Tensor output_grad(st.output->tensor<T, 4>());
    514 
    515     ResizeBicubicGrad<T>(input_grad, st, output_grad);
    516   }
    517 
    518  private:
    519   bool align_corners_;
    520 };
    521 
    522 #define REGISTER_KERNEL(T)                            \
    523   REGISTER_KERNEL_BUILDER(Name("ResizeBicubic")       \
    524                               .Device(DEVICE_CPU)     \
    525                               .TypeConstraint<T>("T") \
    526                               .HostMemory("size"),    \
    527                           ResizeBicubicOp<CPUDevice, T>);
    528 
    529 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
    530 
    531 #undef REGISTER_KERNEL
    532 
    533 #define REGISTER_GRAD_KERNEL(T)                                            \
    534   REGISTER_KERNEL_BUILDER(                                                 \
    535       Name("ResizeBicubicGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    536       ResizeBicubicOpGrad<CPUDevice, T>);
    537 
    538 TF_CALL_float(REGISTER_GRAD_KERNEL);
    539 TF_CALL_double(REGISTER_GRAD_KERNEL);
    540 
    541 #undef REGISTER_GRAD_KERNEL
    542 
    543 }  // namespace tensorflow
    544