Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/image_ops.cc
     17 #define EIGEN_USE_THREADS
     18 
     19 #include "tensorflow/core/kernels/resize_bilinear_op.h"
     20 
     21 #include <memory>
     22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/framework/tensor.h"
     26 #include "tensorflow/core/framework/tensor_shape.h"
     27 #include "tensorflow/core/framework/types.h"
     28 #include "tensorflow/core/kernels/image_resizer_state.h"
     29 #include "tensorflow/core/lib/core/status.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 
     32 namespace tensorflow {
     33 
     34 typedef Eigen::ThreadPoolDevice CPUDevice;
     35 typedef Eigen::GpuDevice GPUDevice;
     36 
     37 template <typename Device, typename T>
     38 class ResizeBilinearOp : public OpKernel {
     39  public:
     40   explicit ResizeBilinearOp(OpKernelConstruction* context) : OpKernel(context) {
     41     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
     42   }
     43 
     44   void Compute(OpKernelContext* context) override {
     45     const Tensor& input = context->input(0);
     46     ImageResizerState st(align_corners_);
     47     st.ValidateAndCreateOutput(context, input);
     48 
     49     if (!context->status().ok()) return;
     50 
     51     // Return if the output is empty.
     52     if (st.output->NumElements() == 0) return;
     53 
     54     typename TTypes<T, 4>::ConstTensor image_data(input.tensor<T, 4>());
     55     TTypes<float, 4>::Tensor output_data = st.output->tensor<float, 4>();
     56 
     57     functor::ResizeBilinear<Device, T>()(context->eigen_device<Device>(),
     58                                          image_data, st.height_scale,
     59                                          st.width_scale, output_data);
     60   }
     61 
     62  private:
     63   bool align_corners_;
     64 };
     65 
     66 namespace {
     67 // Compute the interpolation indices only once.
     68 struct CachedInterpolation {
     69   int64 lower;  // Lower source index used in the interpolation
     70   int64 upper;  // Upper source index used in the interpolation
     71   // 1-D linear iterpolation scale (see:
     72   // https://en.wikipedia.org/wiki/Bilinear_interpolation)
     73   float lerp;
     74 };
     75 
     76 inline void compute_interpolation_weights(const int64 out_size,
     77                                           const int64 in_size,
     78                                           const float scale,
     79                                           CachedInterpolation* interpolation) {
     80   interpolation[out_size].lower = 0;
     81   interpolation[out_size].upper = 0;
     82   for (int64 i = out_size - 1; i >= 0; --i) {
     83     const float in = i * scale;
     84     interpolation[i].lower = static_cast<int64>(in);
     85     interpolation[i].upper = std::min(interpolation[i].lower + 1, in_size - 1);
     86     interpolation[i].lerp = in - interpolation[i].lower;
     87   }
     88 }
     89 
     90 /**
     91  * Computes the bilinear interpolation from the appropriate 4 float points
     92  * and the linear interpolation weights.
     93  */
     94 inline float compute_lerp(const float top_left, const float top_right,
     95                           const float bottom_left, const float bottom_right,
     96                           const float x_lerp, const float y_lerp) {
     97   const float top = top_left + (top_right - top_left) * x_lerp;
     98   const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
     99   return top + (bottom - top) * y_lerp;
    100 }
    101 
    102 template <typename T>
    103 void resize_image(
    104     typename TTypes<T, 4>::ConstTensor images, const int batch_size,
    105     const int64 in_height, const int64 in_width, const int64 out_height,
    106     const int64 out_width, const int channels,
    107     const std::vector<CachedInterpolation>& xs,
    108     const std::vector<CachedInterpolation>& ys,
    109     typename TTypes<float, 4>::Tensor output) TF_ATTRIBUTE_NOINLINE;
    110 template <typename T>
    111 void resize_image(typename TTypes<T, 4>::ConstTensor images,
    112                   const int batch_size, const int64 in_height,
    113                   const int64 in_width, const int64 out_height,
    114                   const int64 out_width, const int channels,
    115                   const std::vector<CachedInterpolation>& xs_vec,
    116                   const std::vector<CachedInterpolation>& ys,
    117                   typename TTypes<float, 4>::Tensor output) {
    118   const int64 in_row_size = in_width * channels;
    119   const int64 in_batch_num_values = in_height * in_row_size;
    120   const int64 out_row_size = out_width * channels;
    121 
    122   const T* input_b_ptr = images.data();
    123   const CachedInterpolation* xs = xs_vec.data();
    124 
    125   if (channels == 3) {
    126     float* output_y_ptr = output.data();
    127     for (int b = 0; b < batch_size; ++b) {
    128       for (int64 y = 0; y < out_height; ++y) {
    129         const T* ys_input_lower_ptr = input_b_ptr + ys[y].lower * in_row_size;
    130         const T* ys_input_upper_ptr = input_b_ptr + ys[y].upper * in_row_size;
    131         const float ys_lerp = ys[y].lerp;
    132         for (int64 x = 0; x < out_width; ++x) {
    133           const int64 xs_lower = xs[x].lower;
    134           const int64 xs_upper = xs[x].upper;
    135           const float xs_lerp = xs[x].lerp;
    136 
    137           // Read channel 0.
    138           const float top_left0(ys_input_lower_ptr[xs_lower + 0]);
    139           const float top_right0(ys_input_lower_ptr[xs_upper + 0]);
    140           const float bottom_left0(ys_input_upper_ptr[xs_lower + 0]);
    141           const float bottom_right0(ys_input_upper_ptr[xs_upper + 0]);
    142 
    143           // Read channel 1.
    144           const float top_left1(ys_input_lower_ptr[xs_lower + 1]);
    145           const float top_right1(ys_input_lower_ptr[xs_upper + 1]);
    146           const float bottom_left1(ys_input_upper_ptr[xs_lower + 1]);
    147           const float bottom_right1(ys_input_upper_ptr[xs_upper + 1]);
    148 
    149           // Read channel 2.
    150           const float top_left2(ys_input_lower_ptr[xs_lower + 2]);
    151           const float top_right2(ys_input_lower_ptr[xs_upper + 2]);
    152           const float bottom_left2(ys_input_upper_ptr[xs_lower + 2]);
    153           const float bottom_right2(ys_input_upper_ptr[xs_upper + 2]);
    154 
    155           // Compute output.
    156           output_y_ptr[x * channels + 0] =
    157               compute_lerp(top_left0, top_right0, bottom_left0, bottom_right0,
    158                            xs_lerp, ys_lerp);
    159           output_y_ptr[x * channels + 1] =
    160               compute_lerp(top_left1, top_right1, bottom_left1, bottom_right1,
    161                            xs_lerp, ys_lerp);
    162           output_y_ptr[x * channels + 2] =
    163               compute_lerp(top_left2, top_right2, bottom_left2, bottom_right2,
    164                            xs_lerp, ys_lerp);
    165         }
    166         output_y_ptr += out_row_size;
    167       }
    168       input_b_ptr += in_batch_num_values;
    169     }
    170   } else {
    171     float* output_y_ptr = output.data();
    172     for (int b = 0; b < batch_size; ++b) {
    173       for (int64 y = 0; y < out_height; ++y) {
    174         const T* ys_input_lower_ptr = input_b_ptr + ys[y].lower * in_row_size;
    175         const T* ys_input_upper_ptr = input_b_ptr + ys[y].upper * in_row_size;
    176         const float ys_lerp = ys[y].lerp;
    177         for (int64 x = 0; x < out_width; ++x) {
    178           auto xs_lower = xs[x].lower;
    179           auto xs_upper = xs[x].upper;
    180           auto xs_lerp = xs[x].lerp;
    181           for (int c = 0; c < channels; ++c) {
    182             const float top_left(ys_input_lower_ptr[xs_lower + c]);
    183             const float top_right(ys_input_lower_ptr[xs_upper + c]);
    184             const float bottom_left(ys_input_upper_ptr[xs_lower + c]);
    185             const float bottom_right(ys_input_upper_ptr[xs_upper + c]);
    186             output_y_ptr[x * channels + c] =
    187                 compute_lerp(top_left, top_right, bottom_left, bottom_right,
    188                              xs_lerp, ys_lerp);
    189           }
    190         }
    191         output_y_ptr += out_row_size;
    192       }
    193       input_b_ptr += in_batch_num_values;
    194     }
    195   }
    196 }
    197 
    198 }  // namespace
    199 
    200 // Partial specialization of ResizeBilinear functor for a CPUDevice.
    201 namespace functor {
    202 template <typename T>
    203 struct ResizeBilinear<CPUDevice, T> {
    204   void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor images,
    205                   const float height_scale, const float width_scale,
    206                   typename TTypes<float, 4>::Tensor output) {
    207     const int batch_size = images.dimension(0);
    208     const int64 in_height = images.dimension(1);
    209     const int64 in_width = images.dimension(2);
    210     const int channels = images.dimension(3);
    211 
    212     const int64 out_height = output.dimension(1);
    213     const int64 out_width = output.dimension(2);
    214 
    215     // Handle no-op resizes efficiently.
    216     if (out_height == in_height && out_width == in_width) {
    217       output = images.template cast<float>();
    218       return;
    219     }
    220 
    221     std::vector<CachedInterpolation> ys(out_height + 1);
    222     std::vector<CachedInterpolation> xs(out_width + 1);
    223 
    224     // Compute the cached interpolation weights on the x and y dimensions.
    225     compute_interpolation_weights(out_height, in_height, height_scale,
    226                                   ys.data());
    227     compute_interpolation_weights(out_width, in_width, width_scale, xs.data());
    228 
    229     // Scale x interpolation weights to avoid a multiplication during iteration.
    230     for (int i = 0; i < xs.size(); ++i) {
    231       xs[i].lower *= channels;
    232       xs[i].upper *= channels;
    233     }
    234 
    235     resize_image<T>(images, batch_size, in_height, in_width, out_height,
    236                     out_width, channels, xs, ys, output);
    237   }
    238 };
    239 }  // namespace functor
    240 
    241 template <typename Device, typename T>
    242 class ResizeBilinearOpGrad : public OpKernel {
    243  public:
    244   explicit ResizeBilinearOpGrad(OpKernelConstruction* context)
    245       : OpKernel(context) {
    246     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
    247   }
    248 
    249   void Compute(OpKernelContext* context) override {
    250     // Validate input.
    251     // First argument is gradient with respect to resized image.
    252     const Tensor& input = context->input(0);
    253     const Tensor& original_image = context->input(1);
    254 
    255     ImageResizerGradientState st(align_corners_);
    256     st.ValidateAndCreateOutput(context, input, original_image);
    257 
    258     if (!context->status().ok()) return;
    259 
    260     TTypes<float, 4>::ConstTensor input_grad = input.tensor<float, 4>();
    261     typename TTypes<T, 4>::Tensor output_grad(st.output->tensor<T, 4>());
    262 
    263     functor::ResizeBilinearGrad<Device, T>()(context->eigen_device<Device>(),
    264                                              input_grad, st.height_scale,
    265                                              st.width_scale, output_grad);
    266   }
    267 
    268  private:
    269   bool align_corners_;
    270 };
    271 
    272 // Partial specialization of ResizeBilinearGrad functor for a CPUDevice.
    273 namespace functor {
    274 template <typename T>
    275 struct ResizeBilinearGrad<CPUDevice, T> {
    276   void operator()(const CPUDevice& d,
    277                   typename TTypes<float, 4>::ConstTensor input_grad,
    278                   const float height_scale, const float width_scale,
    279                   typename TTypes<T, 4>::Tensor output_grad) {
    280     const int batch = output_grad.dimension(0);
    281     const int64 original_height = output_grad.dimension(1);
    282     const int64 original_width = output_grad.dimension(2);
    283     const int channels = output_grad.dimension(3);
    284 
    285     const int64 resized_height = input_grad.dimension(1);
    286     const int64 resized_width = input_grad.dimension(2);
    287 
    288     output_grad.setZero();
    289 
    290     // Each resized pixel was computed as a weighted average of four input
    291     // pixels. Here we find the pixels that contributed to each output pixel
    292     // and add the corresponding coefficient to the gradient.
    293     // resized(b, y, x, c) = top_left * (1 - y) * (1 - x)
    294     //                       +  top_right * (1 - y) * x
    295     //                       +  bottom_left * y * (1 - x)
    296     //                       +  bottom_right * y * x
    297     for (int64 b = 0; b < batch; ++b) {
    298       for (int64 y = 0; y < resized_height; ++y) {
    299         const float in_y = y * height_scale;
    300         const int64 top_y_index = static_cast<int64>(floorf(in_y));
    301         const int64 bottom_y_index =
    302             std::min(static_cast<int64>(ceilf(in_y)), original_height - 1);
    303         const float y_lerp = in_y - top_y_index;
    304         const float inverse_y_lerp = (1.0f - y_lerp);
    305         for (int64 x = 0; x < resized_width; ++x) {
    306           const float in_x = x * width_scale;
    307           const int64 left_x_index = static_cast<int64>(floorf(in_x));
    308           const int64 right_x_index =
    309               std::min(static_cast<int64>(ceilf(in_x)), original_width - 1);
    310           const float x_lerp = in_x - left_x_index;
    311           const float inverse_x_lerp = (1.0f - x_lerp);
    312           for (int64 c = 0; c < channels; ++c) {
    313             output_grad(b, top_y_index, left_x_index, c) +=
    314                 T(input_grad(b, y, x, c) * inverse_y_lerp * inverse_x_lerp);
    315             output_grad(b, top_y_index, right_x_index, c) +=
    316                 T(input_grad(b, y, x, c) * inverse_y_lerp * x_lerp);
    317             output_grad(b, bottom_y_index, left_x_index, c) +=
    318                 T(input_grad(b, y, x, c) * y_lerp * inverse_x_lerp);
    319             output_grad(b, bottom_y_index, right_x_index, c) +=
    320                 T(input_grad(b, y, x, c) * y_lerp * x_lerp);
    321           }
    322         }
    323       }
    324     }
    325   }
    326 };
    327 }  // namespace functor
    328 
    329 #define REGISTER_KERNEL(T)                            \
    330   REGISTER_KERNEL_BUILDER(Name("ResizeBilinear")      \
    331                               .Device(DEVICE_CPU)     \
    332                               .TypeConstraint<T>("T") \
    333                               .HostMemory("size"),    \
    334                           ResizeBilinearOp<CPUDevice, T>);
    335 
    336 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
    337 
    338 #undef REGISTER_KERNEL
    339 
    340 #define REGISTER_GRAD_KERNEL(T)                                             \
    341   REGISTER_KERNEL_BUILDER(                                                  \
    342       Name("ResizeBilinearGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    343       ResizeBilinearOpGrad<CPUDevice, T>);
    344 
    345 TF_CALL_half(REGISTER_GRAD_KERNEL);
    346 TF_CALL_float(REGISTER_GRAD_KERNEL);
    347 TF_CALL_double(REGISTER_GRAD_KERNEL);
    348 
    349 #undef REGISTER_GRAD_KERNEL
    350 
    351 #if GOOGLE_CUDA
    352 
    353 #define REGISTER_KERNEL(T)                            \
    354   REGISTER_KERNEL_BUILDER(Name("ResizeBilinear")      \
    355                               .Device(DEVICE_GPU)     \
    356                               .TypeConstraint<T>("T") \
    357                               .HostMemory("size"),    \
    358                           ResizeBilinearOp<GPUDevice, T>);
    359 
    360 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_KERNEL);
    361 
    362 #undef REGISTER_KERNEL
    363 
    364 #define REGISTER_GRAD_KERNEL(T)                                             \
    365   REGISTER_KERNEL_BUILDER(                                                  \
    366       Name("ResizeBilinearGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    367       ResizeBilinearOpGrad<GPUDevice, T>);
    368 
    369 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_GRAD_KERNEL);
    370 
    371 #undef REGISTER_GRAD_KERNEL
    372 
    373 #endif  // GOOGLE_CUDA
    374 
    375 }  // namespace tensorflow
    376