Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/image_ops.cc.
     17 
     18 #if GOOGLE_CUDA
     19 
     20 #define EIGEN_USE_GPU
     21 
     22 #include "tensorflow/core/kernels/resize_bilinear_op.h"
     23 
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/framework/tensor_types.h"
     26 #include "tensorflow/core/platform/types.h"
     27 #include "tensorflow/core/util/cuda_kernel_helper.h"
     28 
     29 namespace tensorflow {
     30 
     31 typedef Eigen::GpuDevice GPUDevice;
     32 
     33 namespace {
     34 
     35 template <typename T>
     36 __global__ void ResizeBilinearKernel(const int32 nthreads, const T* images,
     37                                      float height_scale, float width_scale,
     38                                      int batch, int in_height, int in_width,
     39                                      int channels, int out_height,
     40                                      int out_width, float* output) {
     41   CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
     42     // out_idx = c + channels * (x + out_width * (y + out_height * b))
     43     int idx = out_idx;
     44     const int c = idx % channels;
     45     idx /= channels;
     46     const int x = idx % out_width;
     47     idx /= out_width;
     48     const int y = idx % out_height;
     49     const int b = idx / out_height;
     50 
     51     const float in_y = y * height_scale;
     52     const int top_y_index = floorf(in_y);
     53     const int bottom_y_index =
     54         (in_y < in_height - 1) ? ceilf(in_y) : in_height - 1;
     55     const float y_lerp = in_y - top_y_index;
     56 
     57     const float in_x = x * width_scale;
     58     const int left_x_index = floorf(in_x);
     59     const int right_x_index =
     60         (in_x < in_width - 1) ? ceilf(in_x) : in_width - 1;
     61     const float x_lerp = in_x - left_x_index;
     62 
     63     const float top_left(
     64         images[((b * in_height + top_y_index) * in_width + left_x_index) *
     65                    channels +
     66                c]);
     67     const float top_right(
     68         images[((b * in_height + top_y_index) * in_width + right_x_index) *
     69                    channels +
     70                c]);
     71     const float bottom_left(
     72         images[((b * in_height + bottom_y_index) * in_width + left_x_index) *
     73                    channels +
     74                c]);
     75     const float bottom_right(
     76         images[((b * in_height + bottom_y_index) * in_width + right_x_index) *
     77                    channels +
     78                c]);
     79 
     80     const float top = top_left + (top_right - top_left) * x_lerp;
     81     const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
     82     output[out_idx] = top + (bottom - top) * y_lerp;
     83   }
     84 }
     85 
     86 template <typename T>
     87 __global__ void ResizeBilinearGradKernel(
     88     const int32 nthreads, const float* input_grad, float height_scale,
     89     float width_scale, int batch, int original_height, int original_width,
     90     int channels, int resized_height, int resized_width, T* output_grad) {
     91   CUDA_1D_KERNEL_LOOP(in_idx, nthreads) {
     92     // in_idx = c + channels * (x + resized_width * (y + resized_height * b))
     93     int idx = in_idx;
     94     const int c = idx % channels;
     95     idx /= channels;
     96     const int x = idx % resized_width;
     97     idx /= resized_width;
     98     const int y = idx % resized_height;
     99     const int b = idx / resized_height;
    100 
    101     const float original_y = y * height_scale;
    102     const int top_y_index = floorf(original_y);
    103     const int bottom_y_index = (original_y < original_height - 1)
    104                                    ? ceilf(original_y)
    105                                    : original_height - 1;
    106     const float y_lerp = original_y - top_y_index;
    107 
    108     const float original_x = x * width_scale;
    109     const int left_x_index = floorf(original_x);
    110     const int right_x_index = (original_x < original_width - 1)
    111                                   ? ceilf(original_x)
    112                                   : original_width - 1;
    113     const float x_lerp = original_x - left_x_index;
    114 
    115     const float dtop = (1 - y_lerp) * input_grad[in_idx];
    116     CudaAtomicAdd(output_grad +
    117                       ((b * original_height + top_y_index) * original_width +
    118                        left_x_index) *
    119                           channels +
    120                       c,
    121                   static_cast<T>((1 - x_lerp) * dtop));
    122     CudaAtomicAdd(output_grad +
    123                       ((b * original_height + top_y_index) * original_width +
    124                        right_x_index) *
    125                           channels +
    126                       c,
    127                   static_cast<T>(x_lerp * dtop));
    128 
    129     const float dbottom = y_lerp * input_grad[in_idx];
    130     CudaAtomicAdd(output_grad +
    131                       ((b * original_height + bottom_y_index) * original_width +
    132                        left_x_index) *
    133                           channels +
    134                       c,
    135                   static_cast<T>((1 - x_lerp) * dbottom));
    136     CudaAtomicAdd(output_grad +
    137                       ((b * original_height + bottom_y_index) * original_width +
    138                        right_x_index) *
    139                           channels +
    140                       c,
    141                   static_cast<T>(x_lerp * dbottom));
    142   }
    143 }
    144 
    145 }  // namespace
    146 
    147 namespace functor {
    148 
    149 // Partial specialization of ResizeBilinear functor for a GPUDevice.
    150 template <typename T>
    151 struct ResizeBilinear<GPUDevice, T> {
    152   void operator()(const GPUDevice& d, typename TTypes<T, 4>::ConstTensor images,
    153                   const float height_scale, const float width_scale,
    154                   typename TTypes<float, 4>::Tensor output) {
    155     const int batch = images.dimension(0);
    156     const int in_height = images.dimension(1);
    157     const int in_width = images.dimension(2);
    158     const int channels = images.dimension(3);
    159 
    160     const int out_height = output.dimension(1);
    161     const int out_width = output.dimension(2);
    162 
    163     const int total_count = batch * out_height * out_width * channels;
    164     if (total_count == 0) return;
    165 
    166     CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
    167     ResizeBilinearKernel<T>
    168         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    169             config.virtual_thread_count, images.data(), height_scale,
    170             width_scale, batch, in_height, in_width, channels, out_height,
    171             out_width, output.data());
    172   }
    173 };
    174 
    175 // Partial specialization of ResizeBilinearGrad functor for a GPUDevice.
    176 template <typename T>
    177 struct ResizeBilinearGrad<GPUDevice, T> {
    178   void operator()(const GPUDevice& d,
    179                   typename TTypes<float, 4>::ConstTensor input_grad,
    180                   const float height_scale, const float width_scale,
    181                   typename TTypes<T, 4>::Tensor output_grad) {
    182     const int batch = output_grad.dimension(0);
    183     const int original_height = output_grad.dimension(1);
    184     const int original_width = output_grad.dimension(2);
    185     const int channels = output_grad.dimension(3);
    186 
    187     const int resized_height = input_grad.dimension(1);
    188     const int resized_width = input_grad.dimension(2);
    189 
    190     int total_count;
    191     CudaLaunchConfig config;
    192 
    193     // Initialize output_grad with all zeros.
    194     total_count = batch * original_height * original_width * channels;
    195     if (total_count == 0) return;
    196     config = GetCudaLaunchConfig(total_count, d);
    197     SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    198         config.virtual_thread_count, output_grad.data());
    199 
    200     // Accumulate.
    201     total_count = batch * resized_height * resized_width * channels;
    202     config = GetCudaLaunchConfig(total_count, d);
    203     ResizeBilinearGradKernel<T>
    204         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    205             config.virtual_thread_count, input_grad.data(), height_scale,
    206             width_scale, batch, original_height, original_width, channels,
    207             resized_height, resized_width, output_grad.data());
    208   }
    209 };
    210 
    211 #define DEFINE_GPU_SPECS(T)                     \
    212   template struct ResizeBilinear<GPUDevice, T>; \
    213   template struct ResizeBilinearGrad<GPUDevice, T>;
    214 
    215 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DEFINE_GPU_SPECS);
    216 
    217 #undef DEFINE_GPU_SPECS
    218 
    219 }  // namespace functor
    220 }  // namespace tensorflow
    221 
    222 #endif  // GOOGLE_CUDA
    223