Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/image_ops.cc
     17 #define EIGEN_USE_THREADS
     18 
     19 #include "tensorflow/core/kernels/resize_nearest_neighbor_op.h"
     20 
     21 #include <memory>
     22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/framework/tensor.h"
     26 #include "tensorflow/core/framework/tensor_shape.h"
     27 #include "tensorflow/core/framework/types.h"
     28 #include "tensorflow/core/kernels/image_resizer_state.h"
     29 #include "tensorflow/core/lib/core/status.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 
     32 namespace tensorflow {
     33 
     34 typedef Eigen::ThreadPoolDevice CPUDevice;
     35 typedef Eigen::GpuDevice GPUDevice;
     36 
     37 template <typename Device, typename T>
     38 class ResizeNearestNeighborOp : public OpKernel {
     39  public:
     40   explicit ResizeNearestNeighborOp(OpKernelConstruction* context)
     41       : OpKernel(context) {
     42     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
     43   }
     44 
     45   void Compute(OpKernelContext* context) override {
     46     const Tensor& input = context->input(0);
     47     ImageResizerState st(align_corners_);
     48     st.ValidateAndCreateOutput(context, input);
     49 
     50     if (!context->status().ok()) return;
     51 
     52     OP_REQUIRES(context, st.in_height < (1 << 24) && st.in_width < (1 << 24),
     53                 errors::InvalidArgument("nearest neighbor requires max height "
     54                                         "& width of 2^24"));
     55 
     56     // Return if the output is empty.
     57     if (st.output->NumElements() == 0) return;
     58 
     59     typename TTypes<T, 4>::ConstTensor input_data(input.tensor<T, 4>());
     60     typename TTypes<T, 4>::Tensor output_data(st.output->tensor<T, 4>());
     61 
     62     bool status;
     63     if (align_corners_) {
     64       status =
     65           functor::ResizeNearestNeighbor<Device, T, /*align_corners=*/true>()(
     66               context->eigen_device<Device>(), input_data, st.height_scale,
     67               st.width_scale, output_data);
     68     } else {
     69       status =
     70           functor::ResizeNearestNeighbor<Device, T, /*align_corners=*/false>()(
     71               context->eigen_device<Device>(), input_data, st.height_scale,
     72               st.width_scale, output_data);
     73     }
     74     if (!status) {
     75       context->SetStatus(
     76           errors::Internal("Failed launching ResizeNearestNeighbor"));
     77     }
     78   }
     79 
     80  private:
     81   bool align_corners_;
     82 };
     83 
     84 // Partial specialization of ResizeNearestNeighbor functor for a CPUDevice.
     85 namespace functor {
     86 template <typename T, bool align_corners>
     87 struct ResizeNearestNeighbor<CPUDevice, T, align_corners> {
     88   bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
     89                   const float height_scale, const float width_scale,
     90                   typename TTypes<T, 4>::Tensor output) {
     91     const int batch_size = input.dimension(0);
     92     const int64 in_height = input.dimension(1);
     93     const int64 in_width = input.dimension(2);
     94     const int channels = input.dimension(3);
     95 
     96     const int64 out_height = output.dimension(1);
     97     const int64 out_width = output.dimension(2);
     98 
     99     for (int b = 0; b < batch_size; ++b) {
    100       for (int y = 0; y < out_height; ++y) {
    101         const int64 in_y = std::min(
    102             (align_corners) ? static_cast<int64>(roundf(y * height_scale))
    103                             : static_cast<int64>(floorf(y * height_scale)),
    104             in_height - 1);
    105         for (int x = 0; x < out_width; ++x) {
    106           const int64 in_x = std::min(
    107               (align_corners) ? static_cast<int64>(roundf(x * width_scale))
    108                               : static_cast<int64>(floorf(x * width_scale)),
    109               in_width - 1);
    110           std::copy_n(&input(b, in_y, in_x, 0), channels, &output(b, y, x, 0));
    111         }
    112       }
    113     }
    114     return true;
    115   }
    116 };
    117 }  // namespace functor
    118 
    119 template <typename Device, typename T>
    120 class ResizeNearestNeighborOpGrad : public OpKernel {
    121  public:
    122   explicit ResizeNearestNeighborOpGrad(OpKernelConstruction* context)
    123       : OpKernel(context) {
    124     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
    125   }
    126 
    127   void Compute(OpKernelContext* context) override {
    128     // Grab and validate the input:
    129     const Tensor& input = context->input(0);
    130     OP_REQUIRES(context, input.dims() == 4,
    131                 errors::InvalidArgument("input must be 4-dimensional",
    132                                         input.shape().DebugString()));
    133 
    134     // Grab and validate the output shape:
    135     const Tensor& shape_t = context->input(1);
    136     OP_REQUIRES(context, shape_t.dims() == 1,
    137                 errors::InvalidArgument("shape_t must be 1-dimensional",
    138                                         shape_t.shape().DebugString()));
    139     OP_REQUIRES(context, shape_t.NumElements() == 2,
    140                 errors::InvalidArgument("shape_t must have two elements",
    141                                         shape_t.shape().DebugString()));
    142 
    143     auto sizes = shape_t.vec<int32>();
    144     OP_REQUIRES(context, sizes(0) > 0 && sizes(1) > 0,
    145                 errors::InvalidArgument("shape_t's elements must be positive"));
    146 
    147     const int64 batch_size = input.dim_size(0);
    148     const int64 in_height = input.dim_size(1);
    149     const int64 in_width = input.dim_size(2);
    150     const int64 channels = input.dim_size(3);
    151 
    152     const int64 out_height = sizes(0);
    153     const int64 out_width = sizes(1);
    154 
    155     Tensor* output = nullptr;
    156     OP_REQUIRES_OK(
    157         context,
    158         context->allocate_output(
    159             0, TensorShape({batch_size, out_height, out_width, channels}),
    160             &output));
    161 
    162     // Return if the output is empty.
    163     if (output->NumElements() == 0) return;
    164 
    165     typename TTypes<T, 4>::ConstTensor input_data(input.tensor<T, 4>());
    166     typename TTypes<T, 4>::Tensor output_data(output->tensor<T, 4>());
    167 
    168     const float height_scale =
    169         CalculateResizeScale(out_height, in_height, align_corners_);
    170     const float width_scale =
    171         CalculateResizeScale(out_width, in_width, align_corners_);
    172 
    173     bool status;
    174     if (align_corners_) {
    175       status = functor::ResizeNearestNeighborGrad<Device, T,
    176                                                   /*align_corners=*/true>()(
    177           context->eigen_device<Device>(), input_data, height_scale,
    178           width_scale, output_data);
    179     } else {
    180       status = functor::ResizeNearestNeighborGrad<Device, T,
    181                                                   /*align_corners=*/false>()(
    182           context->eigen_device<Device>(), input_data, height_scale,
    183           width_scale, output_data);
    184     }
    185     if (!status) {
    186       context->SetStatus(
    187           errors::Internal("Failed launching ResizeNearestNeighborGrad"));
    188     }
    189   }
    190 
    191  private:
    192   bool align_corners_;
    193 };
    194 
    195 // Partial specialization of ResizeNearestNeighborGrad functor for a CPUDevice.
    196 namespace functor {
    197 template <typename T, bool align_corners>
    198 struct ResizeNearestNeighborGrad<CPUDevice, T, align_corners> {
    199   bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
    200                   const float height_scale, const float width_scale,
    201                   typename TTypes<T, 4>::Tensor output) {
    202     const int batch_size = input.dimension(0);
    203     const int64 in_height = input.dimension(1);
    204     const int64 in_width = input.dimension(2);
    205     const int channels = input.dimension(3);
    206 
    207     const int64 out_height = output.dimension(1);
    208     const int64 out_width = output.dimension(2);
    209 
    210     output.setZero();
    211 
    212     for (int y = 0; y < in_height; ++y) {
    213       const int64 out_y = std::min(
    214           (align_corners) ? static_cast<int64>(roundf(y * height_scale))
    215                           : static_cast<int64>(floorf(y * height_scale)),
    216           out_height - 1);
    217       for (int x = 0; x < in_width; ++x) {
    218         const int64 out_x = std::min(
    219             (align_corners) ? static_cast<int64>(roundf(x * width_scale))
    220                             : static_cast<int64>(floorf(x * width_scale)),
    221             out_width - 1);
    222         for (int b = 0; b < batch_size; ++b) {
    223           for (int c = 0; c < channels; ++c) {
    224             output(b, out_y, out_x, c) += input(b, y, x, c);
    225           }
    226         }
    227       }
    228     }
    229     return true;
    230   }
    231 };
    232 }  // namespace functor
    233 
    234 #define REGISTER_KERNEL(T)                                        \
    235   REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighbor")           \
    236                               .Device(DEVICE_CPU)                 \
    237                               .TypeConstraint<T>("T")             \
    238                               .HostMemory("size"),                \
    239                           ResizeNearestNeighborOp<CPUDevice, T>); \
    240   REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighborGrad")       \
    241                               .Device(DEVICE_CPU)                 \
    242                               .TypeConstraint<T>("T")             \
    243                               .HostMemory("size"),                \
    244                           ResizeNearestNeighborOpGrad<CPUDevice, T>);
    245 
    246 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
    247 
    248 #undef REGISTER_KERNEL
    249 
    250 #if GOOGLE_CUDA
    251 
    252 #define REGISTER_KERNEL(T)                                        \
    253   REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighbor")           \
    254                               .Device(DEVICE_GPU)                 \
    255                               .TypeConstraint<T>("T")             \
    256                               .HostMemory("size"),                \
    257                           ResizeNearestNeighborOp<GPUDevice, T>); \
    258   REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighborGrad")       \
    259                               .Device(DEVICE_GPU)                 \
    260                               .TypeConstraint<T>("T")             \
    261                               .HostMemory("size"),                \
    262                           ResizeNearestNeighborOpGrad<GPUDevice, T>);
    263 
    264 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
    265 
    266 #undef REGISTER_KERNEL
    267 
    268 #endif  // GOOGLE_CUDA
    269 
    270 }  // namespace tensorflow
    271