Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include <cfloat>
     21 #include <vector>
     22 
     23 #include "tensorflow/core/kernels/dilation_ops.h"
     24 
     25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     26 #include "tensorflow/core/common_runtime/device.h"
     27 #include "tensorflow/core/framework/numeric_op.h"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/register_types.h"
     30 #include "tensorflow/core/framework/tensor.h"
     31 #include "tensorflow/core/framework/tensor_shape.h"
     32 #include "tensorflow/core/framework/tensor_slice.h"
     33 #include "tensorflow/core/kernels/ops_util.h"
     34 #include "tensorflow/core/lib/core/errors.h"
     35 #include "tensorflow/core/lib/gtl/array_slice.h"
     36 #include "tensorflow/core/util/padding.h"
     37 
     38 namespace tensorflow {
     39 
     40 typedef Eigen::ThreadPoolDevice CPUDevice;
     41 typedef Eigen::GpuDevice GPUDevice;
     42 
     43 void ParseAttributes(OpKernelConstruction* context, std::vector<int32>* strides,
     44                      std::vector<int32>* rates, Padding* padding) {
     45   OP_REQUIRES_OK(context, context->GetAttr("strides", strides));
     46   OP_REQUIRES(context, strides->size() == 4,
     47               errors::InvalidArgument("Sliding window stride field must "
     48                                       "specify 4 dimensions"));
     49   OP_REQUIRES(context, (*strides)[0] == 1 && (*strides)[3] == 1,
     50               errors::Unimplemented(
     51                   "Stride is only supported across spatial dimensions."));
     52 
     53   OP_REQUIRES_OK(context, context->GetAttr("rates", rates));
     54   OP_REQUIRES(context, rates->size() == 4,
     55               errors::InvalidArgument("Input stride (atrous rate) field "
     56                                       "must specify 4 dimensions"));
     57   OP_REQUIRES(context, (*rates)[0] == 1 && (*rates)[3] == 1,
     58               errors::Unimplemented(
     59                   "Rate is only supported across spatial dimensions."));
     60 
     61   OP_REQUIRES_OK(context, context->GetAttr("padding", padding));
     62 }
     63 
     64 void ParseSizes(OpKernelContext* context, const std::vector<int32>& strides,
     65                 const std::vector<int32>& rates, const Padding& padding,
     66                 int* stride_rows, int* stride_cols, int* rate_rows,
     67                 int* rate_cols, int64* pad_top, int64* pad_left,
     68                 int64* out_rows, int64* out_cols) {
     69   // Input tensor is of the following dimensions:
     70   // [ batch, input_rows, input_cols, depth ]
     71   const Tensor& input = context->input(0);
     72   OP_REQUIRES(context, input.dims() == 4,
     73               errors::InvalidArgument("input must be 4-dimensional",
     74                                       input.shape().DebugString()));
     75   const int input_rows = input.dim_size(1);
     76   const int input_cols = input.dim_size(2);
     77   const int depth = input.dim_size(3);
     78 
     79   // For now we take the stride and rate from the second and third dimensions
     80   // only (we do not support striding on the batch or depth dimension).
     81   *stride_rows = strides[1];
     82   *stride_cols = strides[2];
     83   *rate_rows = rates[1];
     84   *rate_cols = rates[2];
     85 
     86   // Input filter is of the following dimensions:
     87   // [ filter_rows, filter_cols, depth ]
     88   const Tensor& filter = context->input(1);
     89   OP_REQUIRES(context, filter.dims() == 3,
     90               errors::InvalidArgument("filter must be 3-dimensional: ",
     91                                       filter.shape().DebugString()));
     92   const int filter_rows = filter.dim_size(0);
     93   const int filter_cols = filter.dim_size(1);
     94   OP_REQUIRES(context, depth == filter.dim_size(2),
     95               errors::InvalidArgument(
     96                   "input and filter must have the same depth: ", depth, " vs ",
     97                   filter.dim_size(2)));
     98 
     99   // Effective filter size, after introducing rate - 1 zeros between each
    100   // non-zero filter element.
    101   const int filter_rows_eff =
    102       filter_rows + (filter_rows - 1) * (*rate_rows - 1);
    103   const int filter_cols_eff =
    104       filter_cols + (filter_cols - 1) * (*rate_cols - 1);
    105 
    106   OP_REQUIRES_OK(
    107       context, GetWindowedOutputSize(input_rows, filter_rows_eff, *stride_rows,
    108                                      padding, out_rows, pad_top));
    109   OP_REQUIRES_OK(
    110       context, GetWindowedOutputSize(input_cols, filter_cols_eff, *stride_cols,
    111                                      padding, out_cols, pad_left));
    112 }
    113 
    114 template <typename Device, typename T>
    115 class DilationOp : public OpKernel {
    116  public:
    117   explicit DilationOp(OpKernelConstruction* context) : OpKernel(context) {
    118     ParseAttributes(context, &strides_, &rates_, &padding_);
    119   }
    120 
    121   void Compute(OpKernelContext* context) override {
    122     const Tensor& input = context->input(0);
    123     const Tensor& filter = context->input(1);
    124 
    125     // Determine relevant sizes from input and filters.
    126     int stride_rows = 0, stride_cols = 0;
    127     int rate_rows = 0, rate_cols = 0;
    128     int64 pad_top = 0, pad_left = 0;
    129     int64 out_rows = 0, out_cols = 0;
    130     ParseSizes(context, strides_, rates_, padding_, &stride_rows, &stride_cols,
    131                &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows,
    132                &out_cols);
    133 
    134     // Output tensor is of the following dimensions:
    135     // [ batch, out_rows, out_cols, depth ]
    136     const int batch = input.dim_size(0);
    137     const int depth = input.dim_size(3);
    138     const std::vector<int64> out_sizes = {batch, out_rows, out_cols, depth};
    139     TensorShape out_shape(out_sizes);
    140 
    141     Tensor* output = nullptr;
    142     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
    143 
    144     // If there is nothing to compute, return.
    145     if (out_shape.num_elements() == 0) {
    146       return;
    147     }
    148 
    149     functor::Dilation<Device, T>()(
    150         context->eigen_device<Device>(), input.tensor<T, 4>(),
    151         filter.tensor<T, 3>(), stride_rows, stride_cols, rate_rows, rate_cols,
    152         pad_top, pad_left, output->tensor<T, 4>());
    153   }
    154 
    155   std::vector<int32> strides_;
    156   std::vector<int32> rates_;
    157   Padding padding_;
    158 };
    159 
    160 // Partial specialization of Dilation functor for a CPUDevice.
    161 namespace functor {
    162 template <typename T>
    163 struct Dilation<CPUDevice, T> {
    164   void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
    165                   typename TTypes<T, 3>::ConstTensor filter, int stride_rows,
    166                   int stride_cols, int rate_rows, int rate_cols, int pad_top,
    167                   int pad_left, typename TTypes<T, 4>::Tensor output) {
    168     const int batch = input.dimension(0);
    169     const int input_rows = input.dimension(1);
    170     const int input_cols = input.dimension(2);
    171     const int depth = input.dimension(3);
    172 
    173     const int filter_rows = filter.dimension(0);
    174     const int filter_cols = filter.dimension(1);
    175 
    176     const int output_rows = output.dimension(1);
    177     const int output_cols = output.dimension(2);
    178 
    179     // This is a reference implementation, likely to be slow.
    180     // TODO(gpapan): Write multi-threaded implementation.
    181     for (int b = 0; b < batch; ++b) {
    182       for (int h_out = 0; h_out < output_rows; ++h_out) {
    183         int h_beg = h_out * stride_rows - pad_top;
    184         for (int w_out = 0; w_out < output_cols; ++w_out) {
    185           int w_beg = w_out * stride_cols - pad_left;
    186           for (int d = 0; d < depth; ++d) {
    187             T cur_val = Eigen::NumTraits<T>::lowest();
    188             for (int h = 0; h < filter_rows; ++h) {
    189               const int h_in = h_beg + h * rate_rows;
    190               if (h_in >= 0 && h_in < input_rows) {
    191                 for (int w = 0; w < filter_cols; ++w) {
    192                   const int w_in = w_beg + w * rate_cols;
    193                   if (w_in >= 0 && w_in < input_cols) {
    194                     const T val = input(b, h_in, w_in, d) + filter(h, w, d);
    195                     if (val > cur_val) {
    196                       cur_val = val;
    197                     }
    198                   }
    199                 }
    200               }
    201             }
    202             output(b, h_out, w_out, d) = cur_val;
    203           }
    204         }
    205       }
    206     }
    207   }
    208 };
    209 }  // namespace functor
    210 
    211 template <typename Device, typename T>
    212 class DilationBackpropInputOp : public OpKernel {
    213  public:
    214   explicit DilationBackpropInputOp(OpKernelConstruction* context)
    215       : OpKernel(context) {
    216     ParseAttributes(context, &strides_, &rates_, &padding_);
    217   }
    218 
    219   void Compute(OpKernelContext* context) override {
    220     const Tensor& input = context->input(0);
    221     const Tensor& filter = context->input(1);
    222     const Tensor& out_backprop = context->input(2);
    223 
    224     // Determine relevant sizes from input and filters.
    225     int stride_rows = 0, stride_cols = 0;
    226     int rate_rows = 0, rate_cols = 0;
    227     int64 pad_top = 0, pad_left = 0;
    228     int64 out_rows = 0, out_cols = 0;
    229     ParseSizes(context, strides_, rates_, padding_, &stride_rows, &stride_cols,
    230                &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows,
    231                &out_cols);
    232 
    233     // Verify that the incoming gradient tensor has the expected size
    234     // [ batch, out_rows, out_cols, depth ]
    235     const int batch = input.dim_size(0);
    236     const int depth = input.dim_size(3);
    237     OP_REQUIRES(context,
    238                 batch == out_backprop.dim_size(0) &&
    239                     out_rows == out_backprop.dim_size(1) &&
    240                     out_cols == out_backprop.dim_size(2) &&
    241                     depth == out_backprop.dim_size(3),
    242                 errors::InvalidArgument("out_backprop has incompatible size."));
    243 
    244     // The computed in_backprop has the same dimensions as the input:
    245     // [ batch, input_rows, input_cols, depth ]
    246     Tensor* in_backprop = nullptr;
    247     OP_REQUIRES_OK(context,
    248                    context->allocate_output(0, input.shape(), &in_backprop));
    249 
    250     // If there is nothing to compute, return.
    251     if (input.shape().num_elements() == 0) {
    252       return;
    253     }
    254 
    255     functor::DilationBackpropInput<Device, T>()(
    256         context->eigen_device<Device>(), input.tensor<T, 4>(),
    257         filter.tensor<T, 3>(), out_backprop.tensor<T, 4>(), stride_rows,
    258         stride_cols, rate_rows, rate_cols, pad_top, pad_left,
    259         in_backprop->tensor<T, 4>());
    260   }
    261 
    262   std::vector<int32> strides_;
    263   std::vector<int32> rates_;
    264   Padding padding_;
    265 };
    266 
    267 // Partial specialization of DilationBackpropInput functor for a CPUDevice.
    268 namespace functor {
    269 template <typename T>
    270 struct DilationBackpropInput<CPUDevice, T> {
    271   void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
    272                   typename TTypes<T, 3>::ConstTensor filter,
    273                   typename TTypes<T, 4>::ConstTensor out_backprop,
    274                   int stride_rows, int stride_cols, int rate_rows,
    275                   int rate_cols, int pad_top, int pad_left,
    276                   typename TTypes<T, 4>::Tensor in_backprop) {
    277     const int batch = input.dimension(0);
    278     const int input_rows = input.dimension(1);
    279     const int input_cols = input.dimension(2);
    280     const int depth = input.dimension(3);
    281 
    282     const int filter_rows = filter.dimension(0);
    283     const int filter_cols = filter.dimension(1);
    284 
    285     const int output_rows = out_backprop.dimension(1);
    286     const int output_cols = out_backprop.dimension(2);
    287 
    288     // Initialize gradient with all zeros.
    289     in_backprop.setZero();
    290 
    291     // This is a reference implementation, likely to be slow.
    292     // TODO(gpapan): Write multi-threaded implementation.
    293     // In the case of multiple argmax branches, we only back-propagate along the
    294     // last branch, i.e., the one with largest value of `h * filter_cols + w`,
    295     // similarly to the max-pooling backward routines.
    296     for (int b = 0; b < batch; ++b) {
    297       for (int h_out = 0; h_out < output_rows; ++h_out) {
    298         int h_beg = h_out * stride_rows - pad_top;
    299         for (int w_out = 0; w_out < output_cols; ++w_out) {
    300           int w_beg = w_out * stride_cols - pad_left;
    301           for (int d = 0; d < depth; ++d) {
    302             T cur_val = Eigen::NumTraits<T>::lowest();
    303             int h_in_max = (h_beg < 0) ? 0 : h_beg;
    304             int w_in_max = (w_beg < 0) ? 0 : w_beg;
    305             for (int h = 0; h < filter_rows; ++h) {
    306               const int h_in = h_beg + h * rate_rows;
    307               if (h_in >= 0 && h_in < input_rows) {
    308                 for (int w = 0; w < filter_cols; ++w) {
    309                   const int w_in = w_beg + w * rate_cols;
    310                   if (w_in >= 0 && w_in < input_cols) {
    311                     const T val = input(b, h_in, w_in, d) + filter(h, w, d);
    312                     if (val > cur_val) {
    313                       cur_val = val;
    314                       h_in_max = h_in;
    315                       w_in_max = w_in;
    316                     }
    317                   }
    318                 }
    319               }
    320             }
    321             in_backprop(b, h_in_max, w_in_max, d) +=
    322                 out_backprop(b, h_out, w_out, d);
    323           }
    324         }
    325       }
    326     }
    327   }
    328 };
    329 }  // namespace functor
    330 
    331 template <typename Device, typename T>
    332 class DilationBackpropFilterOp : public OpKernel {
    333  public:
    334   explicit DilationBackpropFilterOp(OpKernelConstruction* context)
    335       : OpKernel(context) {
    336     ParseAttributes(context, &strides_, &rates_, &padding_);
    337   }
    338 
    339   void Compute(OpKernelContext* context) override {
    340     const Tensor& input = context->input(0);
    341     const Tensor& filter = context->input(1);
    342     const Tensor& out_backprop = context->input(2);
    343 
    344     // Determine relevant sizes from input and filters.
    345     int stride_rows = 0, stride_cols = 0;
    346     int rate_rows = 0, rate_cols = 0;
    347     int64 pad_top = 0, pad_left = 0;
    348     int64 out_rows = 0, out_cols = 0;
    349     ParseSizes(context, strides_, rates_, padding_, &stride_rows, &stride_cols,
    350                &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows,
    351                &out_cols);
    352 
    353     // Verify that the incoming gradient tensor has the expected size
    354     // [ batch, out_rows, out_cols, depth ]
    355     const int batch = input.dim_size(0);
    356     const int depth = input.dim_size(3);
    357     OP_REQUIRES(context,
    358                 batch == out_backprop.dim_size(0) &&
    359                     out_rows == out_backprop.dim_size(1) &&
    360                     out_cols == out_backprop.dim_size(2) &&
    361                     depth == out_backprop.dim_size(3),
    362                 errors::InvalidArgument("out_backprop has incompatible size."));
    363 
    364     // The computed filter_backprop has the same dimensions as the filter:
    365     // [ batch, input_rows, input_cols, depth ]
    366     Tensor* filter_backprop = nullptr;
    367     OP_REQUIRES_OK(
    368         context, context->allocate_output(0, filter.shape(), &filter_backprop));
    369 
    370     // If there is nothing to compute, return.
    371     if (filter.shape().num_elements() == 0) {
    372       return;
    373     }
    374 
    375     functor::DilationBackpropFilter<Device, T>()(
    376         context->eigen_device<Device>(), input.tensor<T, 4>(),
    377         filter.tensor<T, 3>(), out_backprop.tensor<T, 4>(), stride_rows,
    378         stride_cols, rate_rows, rate_cols, pad_top, pad_left,
    379         filter_backprop->tensor<T, 3>());
    380   }
    381 
    382   std::vector<int32> strides_;
    383   std::vector<int32> rates_;
    384   Padding padding_;
    385 };
    386 
    387 // Partial specialization of DilationBackpropFilter functor for a CPUDevice.
    388 namespace functor {
    389 template <typename T>
    390 struct DilationBackpropFilter<CPUDevice, T> {
    391   void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
    392                   typename TTypes<T, 3>::ConstTensor filter,
    393                   typename TTypes<T, 4>::ConstTensor out_backprop,
    394                   int stride_rows, int stride_cols, int rate_rows,
    395                   int rate_cols, int pad_top, int pad_left,
    396                   typename TTypes<T, 3>::Tensor filter_backprop) {
    397     const int batch = input.dimension(0);
    398     const int input_rows = input.dimension(1);
    399     const int input_cols = input.dimension(2);
    400     const int depth = input.dimension(3);
    401 
    402     const int filter_rows = filter.dimension(0);
    403     const int filter_cols = filter.dimension(1);
    404 
    405     const int output_rows = out_backprop.dimension(1);
    406     const int output_cols = out_backprop.dimension(2);
    407 
    408     // Initialize gradient with all zeros.
    409     filter_backprop.setZero();
    410 
    411     // This is a reference implementation, likely to be slow.
    412     // TODO(gpapan): Write multi-threaded implementation.
    413     // In the case of multiple argmax branches, we only back-propagate along the
    414     // last branch, i.e., the one with largest value of `h * filter_cols + w`,
    415     // similarly to the max-pooling backward routines.
    416     for (int b = 0; b < batch; ++b) {
    417       for (int h_out = 0; h_out < output_rows; ++h_out) {
    418         int h_beg = h_out * stride_rows - pad_top;
    419         for (int w_out = 0; w_out < output_cols; ++w_out) {
    420           int w_beg = w_out * stride_cols - pad_left;
    421           for (int d = 0; d < depth; ++d) {
    422             T cur_val = Eigen::NumTraits<T>::lowest();
    423             int h_max = 0;
    424             int w_max = 0;
    425             for (int h = 0; h < filter_rows; ++h) {
    426               const int h_in = h_beg + h * rate_rows;
    427               if (h_in >= 0 && h_in < input_rows) {
    428                 for (int w = 0; w < filter_cols; ++w) {
    429                   const int w_in = w_beg + w * rate_cols;
    430                   if (w_in >= 0 && w_in < input_cols) {
    431                     const T val = input(b, h_in, w_in, d) + filter(h, w, d);
    432                     if (val > cur_val) {
    433                       cur_val = val;
    434                       h_max = h;
    435                       w_max = w;
    436                     }
    437                   }
    438                 }
    439               }
    440             }
    441             filter_backprop(h_max, w_max, d) +=
    442                 out_backprop(b, h_out, w_out, d);
    443           }
    444         }
    445       }
    446     }
    447   }
    448 };
    449 }  // namespace functor
    450 
    451 #define REGISTER(T)                                                 \
    452   REGISTER_KERNEL_BUILDER(                                          \
    453       Name("Dilation2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    454       DilationOp<CPUDevice, T>);                                    \
    455                                                                     \
    456   REGISTER_KERNEL_BUILDER(Name("Dilation2DBackpropInput")           \
    457                               .Device(DEVICE_CPU)                   \
    458                               .TypeConstraint<T>("T"),              \
    459                           DilationBackpropInputOp<CPUDevice, T>);   \
    460                                                                     \
    461   REGISTER_KERNEL_BUILDER(Name("Dilation2DBackpropFilter")          \
    462                               .Device(DEVICE_CPU)                   \
    463                               .TypeConstraint<T>("T"),              \
    464                           DilationBackpropFilterOp<CPUDevice, T>);
    465 
    466 TF_CALL_REAL_NUMBER_TYPES(REGISTER);
    467 
    468 #undef REGISTER
    469 
    470 #if GOOGLE_CUDA
    471 
    472 #define REGISTER(T)                                                 \
    473   REGISTER_KERNEL_BUILDER(                                          \
    474       Name("Dilation2D").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    475       DilationOp<GPUDevice, T>);                                    \
    476                                                                     \
    477   REGISTER_KERNEL_BUILDER(Name("Dilation2DBackpropInput")           \
    478                               .Device(DEVICE_GPU)                   \
    479                               .TypeConstraint<T>("T"),              \
    480                           DilationBackpropInputOp<GPUDevice, T>);   \
    481                                                                     \
    482   REGISTER_KERNEL_BUILDER(Name("Dilation2DBackpropFilter")          \
    483                               .Device(DEVICE_GPU)                   \
    484                               .TypeConstraint<T>("T"),              \
    485                           DilationBackpropFilterOp<GPUDevice, T>);
    486 
    487 TF_CALL_GPU_NUMBER_TYPES(REGISTER);
    488 
    489 #undef REGISTER
    490 
    491 #endif  // GOOGLE_CUDA
    492 
    493 }  // namespace tensorflow
    494