Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/kernels/mirror_pad_op.h"
     21 #include <string>
     22 
     23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     24 
     25 #include "tensorflow/core/framework/op.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/register_types.h"
     28 #include "tensorflow/core/framework/tensor.h"
     29 #include "tensorflow/core/framework/tensor_shape.h"
     30 #include "tensorflow/core/framework/tensor_types.h"
     31 #include "tensorflow/core/framework/types.h"
     32 #include "tensorflow/core/platform/logging.h"
     33 #include "tensorflow/core/platform/types.h"
     34 #include "tensorflow/core/util/mirror_pad_mode.h"
     35 
     36 namespace tensorflow {
     37 
     38 template <typename Device, typename T, typename Tpaddings>
     39 class MirrorPadOp : public OpKernel {
     40  public:
     41   explicit MirrorPadOp(OpKernelConstruction* context) : OpKernel(context) {
     42     MirrorPadMode mode;
     43     OP_REQUIRES_OK(context, context->GetAttr("mode", &mode));
     44 
     45     switch (mode) {
     46       case MirrorPadMode::SYMMETRIC: {
     47         offset_ = 0;
     48         break;
     49       }
     50       case MirrorPadMode::REFLECT: {
     51         offset_ = 1;
     52         break;
     53       }
     54       default:
     55         OP_REQUIRES(context, false,
     56                     errors::InvalidArgument(
     57                         "mode must be either REFLECT or SYMMETRIC."));
     58     }
     59   }
     60 
     61   ~MirrorPadOp() override = default;
     62 
     63   void Compute(OpKernelContext* context) override {
     64     const Tensor& in0 = context->input(0);
     65     const Tensor& in1 = context->input(1);
     66     const int dims = in0.dims();
     67     constexpr int kMinDims = 0;
     68     constexpr int kMaxDims = 5;
     69     OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims,
     70                 errors::Unimplemented("inputs rank not in [", kMinDims, ",",
     71                                       kMaxDims, "]: ", dims));
     72     OP_REQUIRES(
     73         context,
     74         TensorShapeUtils::IsMatrix(in1.shape()) && in1.dim_size(1) == 2,
     75         errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
     76                                 in1.shape().DebugString()));
     77     OP_REQUIRES(
     78         context, dims == in1.dim_size(0),
     79         errors::InvalidArgument(
     80             "The first dimension of paddings must be the rank of inputs",
     81             in1.shape().DebugString(), ", ", in0.shape().DebugString()));
     82 
     83     // Compute the shape of the output tensor, and allocate it.
     84     TensorShape output_shape;
     85     typename TTypes<Tpaddings>::ConstMatrix paddings = in1.matrix<Tpaddings>();
     86     for (int d = 0; d < dims; ++d) {
     87       const Tpaddings before = paddings(d, 0);  // Pad before existing elements.
     88       const Tpaddings after = paddings(d, 1);   // Pad after existing elements.
     89       OP_REQUIRES(context, before >= 0 && after >= 0,
     90                   errors::InvalidArgument(
     91                       "paddings must be non-negative: ", before, " ", after));
     92       if (offset_ == 0) {  // SYMMETRIC mode.
     93         OP_REQUIRES(context,
     94                     before <= in0.dim_size(d) && after <= in0.dim_size(d),
     95                     errors::InvalidArgument("paddings must be no greater "
     96                                             "than the dimension size: ",
     97                                             before, ", ", after,
     98                                             " greater than ", in0.dim_size(d)));
     99       } else if (offset_ == 1) {  // REFLECT mode.
    100         OP_REQUIRES(
    101             context, before < in0.dim_size(d) && after < in0.dim_size(d),
    102             errors::InvalidArgument("paddings must be less than"
    103                                     " the dimension size: ",
    104                                     before, ", ", after, " not less than ",
    105                                     in0.dim_size(d)));
    106       }
    107 
    108       output_shape.AddDim(before + in0.dim_size(d) + after);
    109     }
    110 
    111     if (output_shape.num_elements() == in0.NumElements()) {
    112       // When num_elements == 0, shape may have changed.
    113       Tensor out;
    114       CHECK(out.CopyFrom(in0, output_shape));
    115       context->set_output(0, out);
    116       return;
    117     }
    118 
    119     Tensor* output = nullptr;
    120     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    121 
    122 #define MIRROR_PAD_CASE(i)                                                \
    123   case i: {                                                               \
    124     functor::MirrorPad<Device, T, Tpaddings, i>()(                        \
    125         context->eigen_device<Device>(), To32Bit(output->tensor<T, i>()), \
    126         To32Bit(in0.tensor<T, i>()), paddings, offset_);                  \
    127     break;                                                                \
    128   }
    129 
    130     // Invoke the dims-specific implementation.
    131     switch (dims) {
    132       MIRROR_PAD_CASE(1)
    133       MIRROR_PAD_CASE(2)
    134       MIRROR_PAD_CASE(3)
    135       MIRROR_PAD_CASE(4)
    136       MIRROR_PAD_CASE(5)
    137       default:
    138         OP_REQUIRES(context, false,
    139                     errors::InvalidArgument("Unsupported rank: ",
    140                                             in0.shape().DebugString()));
    141     }
    142 #undef MIRROR_PAD_CASE
    143   }
    144 
    145  private:
    146   int offset_;
    147 };
    148 
    149 using CpuDevice = Eigen::ThreadPoolDevice;
    150 using GpuDevice = Eigen::GpuDevice;
    151 
    152 namespace functor {
    153 // Forward declarations of the functor specializations defined in the sharded
    154 // files.
    155 #define DECLARE_CPU_SPEC(T, Tpaddings, i)                     \
    156   template <>                                                 \
    157   void MirrorPad<CpuDevice, T, Tpaddings, i>::operator()(     \
    158       const CpuDevice&, typename TTypes<T, i, int32>::Tensor, \
    159       typename TTypes<T, i, int32>::ConstTensor,              \
    160       TTypes<Tpaddings>::ConstMatrix, int);                   \
    161   extern template struct MirrorPad<CpuDevice, T, Tpaddings, i>;
    162 
    163 #define DECLARE_CPU_SPECS(T)     \
    164   DECLARE_CPU_SPEC(T, int32, 1); \
    165   DECLARE_CPU_SPEC(T, int32, 2); \
    166   DECLARE_CPU_SPEC(T, int32, 3); \
    167   DECLARE_CPU_SPEC(T, int32, 4); \
    168   DECLARE_CPU_SPEC(T, int32, 5); \
    169   DECLARE_CPU_SPEC(T, int64, 1); \
    170   DECLARE_CPU_SPEC(T, int64, 2); \
    171   DECLARE_CPU_SPEC(T, int64, 3); \
    172   DECLARE_CPU_SPEC(T, int64, 4); \
    173   DECLARE_CPU_SPEC(T, int64, 5);
    174 
    175 TF_CALL_POD_TYPES(DECLARE_CPU_SPECS);
    176 
    177 #undef DECLARE_CPU_SPEC
    178 #undef DECLARE_CPU_SPECS
    179 }  // namespace functor
    180 
    181 #define REGISTER_KERNEL(type)                                     \
    182   REGISTER_KERNEL_BUILDER(Name("MirrorPad")                       \
    183                               .Device(DEVICE_CPU)                 \
    184                               .TypeConstraint<type>("T")          \
    185                               .TypeConstraint<int32>("Tpaddings") \
    186                               .HostMemory("paddings"),            \
    187                           MirrorPadOp<CpuDevice, type, int32>);   \
    188   REGISTER_KERNEL_BUILDER(Name("MirrorPad")                       \
    189                               .Device(DEVICE_CPU)                 \
    190                               .TypeConstraint<type>("T")          \
    191                               .TypeConstraint<int64>("Tpaddings") \
    192                               .HostMemory("paddings"),            \
    193                           MirrorPadOp<CpuDevice, type, int64>);
    194 
    195 // Note that we do register for bool type, but not in the gradient op.
    196 TF_CALL_POD_TYPES(REGISTER_KERNEL);
    197 #undef REGISTER_KERNEL
    198 
    199 #if GOOGLE_CUDA
    200 namespace functor {
    201 // Forward declarations of the functor specializations for GPU.
    202 #define DECLARE_GPU_SPEC(T, Tpaddings, i)                     \
    203   template <>                                                 \
    204   void MirrorPad<GpuDevice, T, Tpaddings, i>::operator()(     \
    205       const GpuDevice&, typename TTypes<T, i, int32>::Tensor, \
    206       typename TTypes<T, i, int32>::ConstTensor,              \
    207       TTypes<Tpaddings>::ConstMatrix, int);                   \
    208   extern template struct MirrorPad<GpuDevice, T, Tpaddings, i>;
    209 
    210 #define DECLARE_GPU_SPECS(T)     \
    211   DECLARE_GPU_SPEC(T, int32, 1); \
    212   DECLARE_GPU_SPEC(T, int32, 2); \
    213   DECLARE_GPU_SPEC(T, int32, 3); \
    214   DECLARE_GPU_SPEC(T, int32, 4); \
    215   DECLARE_GPU_SPEC(T, int32, 5); \
    216   DECLARE_GPU_SPEC(T, int64, 1); \
    217   DECLARE_GPU_SPEC(T, int64, 2); \
    218   DECLARE_GPU_SPEC(T, int64, 3); \
    219   DECLARE_GPU_SPEC(T, int64, 4); \
    220   DECLARE_GPU_SPEC(T, int64, 5);
    221 
    222 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
    223 #undef DECLARE_GPU_SPECS
    224 #undef DECLARE_GPU_SPEC
    225 }  // namespace functor
    226 
    227 // Registration of the GPU implementations.
    228 #define REGISTER_GPU_KERNEL(T)                                    \
    229   REGISTER_KERNEL_BUILDER(Name("MirrorPad")                       \
    230                               .Device(DEVICE_GPU)                 \
    231                               .TypeConstraint<T>("T")             \
    232                               .TypeConstraint<int32>("Tpaddings") \
    233                               .HostMemory("paddings"),            \
    234                           MirrorPadOp<GpuDevice, T, int32>);      \
    235   REGISTER_KERNEL_BUILDER(Name("MirrorPad")                       \
    236                               .Device(DEVICE_GPU)                 \
    237                               .TypeConstraint<T>("T")             \
    238                               .TypeConstraint<int64>("Tpaddings") \
    239                               .HostMemory("paddings"),            \
    240                           MirrorPadOp<GpuDevice, T, int64>);
    241 
    242 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
    243 #undef REGISTER_GPU_KERNEL
    244 #endif  // GOOGLE_CUDA
    245 
    246 // Gradient op.
    247 template <typename Device, typename T, typename Tpaddings>
    248 class MirrorPadGradOp : public OpKernel {
    249  public:
    250   explicit MirrorPadGradOp(OpKernelConstruction* context) : OpKernel(context) {
    251     MirrorPadMode mode;
    252     OP_REQUIRES_OK(context, context->GetAttr("mode", &mode));
    253 
    254     switch (mode) {
    255       case MirrorPadMode::SYMMETRIC: {
    256         offset_ = 0;
    257         break;
    258       }
    259       case MirrorPadMode::REFLECT: {
    260         offset_ = 1;
    261         break;
    262       }
    263       default:
    264         OP_REQUIRES(context, false,
    265                     errors::InvalidArgument(
    266                         "mode must be either REFLECT or SYMMETRIC."));
    267     }
    268   }
    269 
    270   ~MirrorPadGradOp() override = default;
    271 
    272   void Compute(OpKernelContext* context) override {
    273     const Tensor& in0 = context->input(0);
    274     const Tensor& in1 = context->input(1);
    275     const int dims = in0.dims();
    276     constexpr int kMinDims = 0;
    277     constexpr int kMaxDims = 5;
    278     OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims,
    279                 errors::Unimplemented("inputs rank not in [", kMinDims, ",",
    280                                       kMaxDims, "]: ", dims));
    281     OP_REQUIRES(
    282         context,
    283         TensorShapeUtils::IsMatrix(in1.shape()) && in1.dim_size(1) == 2,
    284         errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
    285                                 in1.shape().DebugString()));
    286     OP_REQUIRES(
    287         context, dims == in1.dim_size(0),
    288         errors::InvalidArgument(
    289             "The first dimension of paddings must be the rank of inputs",
    290             in1.shape().DebugString(), " ", in0.shape().DebugString()));
    291 
    292     // Compute the shape of the output tensor, and allocate it.
    293     TensorShape output_shape;
    294     typename TTypes<Tpaddings>::ConstMatrix paddings = in1.matrix<Tpaddings>();
    295     for (int d = 0; d < dims; ++d) {
    296       const Tpaddings before = paddings(d, 0);  // Pad before existing elements.
    297       const Tpaddings after = paddings(d, 1);   // Pad after existing elements.
    298       OP_REQUIRES(context, before >= 0 && after >= 0,
    299                   errors::InvalidArgument(
    300                       "Paddings must be non-negative: ", before, ", ", after));
    301 
    302       const int64 out_size = in0.dim_size(d) - (before + after);
    303       if (offset_ == 0) {  // SYMMETRIC mode.
    304         OP_REQUIRES(context, before <= out_size && after <= out_size,
    305                     errors::InvalidArgument("paddings must be no greater "
    306                                             "than the output dimension size: ",
    307                                             before, ", ", after,
    308                                             " greater than ", out_size));
    309       } else if (offset_ == 1) {  // REFLECT mode.
    310         OP_REQUIRES(context, before < out_size && after < out_size,
    311                     errors::InvalidArgument("paddings must be less than"
    312                                             " the output dimension size: ",
    313                                             before, ", ", after,
    314                                             " not less than ", out_size));
    315       }
    316       output_shape.AddDim(out_size);
    317     }
    318 
    319     if (output_shape == in0.shape()) {
    320       context->set_output(0, in0);
    321       return;
    322     }
    323 
    324     Tensor scratch;
    325     OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
    326                                                    in0.shape(), &scratch));
    327 
    328     Tensor* output = nullptr;
    329     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    330 
    331 #define MIRROR_PAD_GRAD_CASE(k)                                           \
    332   case k: {                                                               \
    333     functor::MirrorPadGrad<Device, T, Tpaddings, k>()(                    \
    334         context->eigen_device<Device>(), To32Bit(output->tensor<T, k>()), \
    335         To32Bit(in0.tensor<T, k>()), paddings, offset_,                   \
    336         To32Bit(scratch.tensor<T, k>()));                                 \
    337     break;                                                                \
    338   }
    339 
    340     // Invoke the dims-specific implementation.
    341     switch (dims) {
    342       MIRROR_PAD_GRAD_CASE(1);
    343       MIRROR_PAD_GRAD_CASE(2);
    344       MIRROR_PAD_GRAD_CASE(3);
    345       MIRROR_PAD_GRAD_CASE(4);
    346       MIRROR_PAD_GRAD_CASE(5);
    347       default:
    348         OP_REQUIRES(context, false,
    349                     errors::InvalidArgument("Unsupported rank: ",
    350                                             in0.shape().DebugString()));
    351     }
    352 #undef MIRROR_PAD_GRAD_CASE
    353   }
    354 
    355  private:
    356   int offset_;
    357 };
    358 
    359 namespace functor {
    360 // Forward declarations of the functor specializations defined in the sharded
    361 // files.
    362 #define DECLARE_CPU_SPEC(T, Tpaddings, k)                     \
    363   template <>                                                 \
    364   void MirrorPadGrad<CpuDevice, T, Tpaddings, k>::operator()( \
    365       const CpuDevice&, typename TTypes<T, k, int32>::Tensor, \
    366       typename TTypes<T, k, int32>::ConstTensor,              \
    367       TTypes<Tpaddings>::ConstMatrix, int,                    \
    368       typename TTypes<T, k, int32>::Tensor);                  \
    369   extern template struct MirrorPadGrad<CpuDevice, T, Tpaddings, k>;
    370 
    371 #define DECLARE_CPU_SPECS(T)     \
    372   DECLARE_CPU_SPEC(T, int32, 1); \
    373   DECLARE_CPU_SPEC(T, int32, 2); \
    374   DECLARE_CPU_SPEC(T, int32, 3); \
    375   DECLARE_CPU_SPEC(T, int32, 4); \
    376   DECLARE_CPU_SPEC(T, int32, 5); \
    377   DECLARE_CPU_SPEC(T, int64, 1); \
    378   DECLARE_CPU_SPEC(T, int64, 2); \
    379   DECLARE_CPU_SPEC(T, int64, 3); \
    380   DECLARE_CPU_SPEC(T, int64, 4); \
    381   DECLARE_CPU_SPEC(T, int64, 5);
    382 
    383 TF_CALL_NUMBER_TYPES(DECLARE_CPU_SPECS);
    384 #undef DECLARE_CPU_SPECS
    385 #undef DECLARE_CPU_SPEC
    386 }  // namespace functor
    387 
    388 #define REGISTER_KERNEL(type)                                       \
    389   REGISTER_KERNEL_BUILDER(Name("MirrorPadGrad")                     \
    390                               .Device(DEVICE_CPU)                   \
    391                               .TypeConstraint<type>("T")            \
    392                               .TypeConstraint<int32>("Tpaddings")   \
    393                               .HostMemory("paddings"),              \
    394                           MirrorPadGradOp<CpuDevice, type, int32>); \
    395   REGISTER_KERNEL_BUILDER(Name("MirrorPadGrad")                     \
    396                               .Device(DEVICE_CPU)                   \
    397                               .TypeConstraint<type>("T")            \
    398                               .TypeConstraint<int64>("Tpaddings")   \
    399                               .HostMemory("paddings"),              \
    400                           MirrorPadGradOp<CpuDevice, type, int64>);
    401 
    402 TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
    403 #undef REGISTER_KERNEL
    404 
    405 #if GOOGLE_CUDA
    406 namespace functor {
    407 // Forward declarations of the functor specializations for GPU.
    408 #define DECLARE_GPU_SPEC(T, Tpaddings, k)                     \
    409   template <>                                                 \
    410   void MirrorPadGrad<GpuDevice, T, Tpaddings, k>::operator()( \
    411       const GpuDevice&, typename TTypes<T, k, int32>::Tensor, \
    412       typename TTypes<T, k, int32>::ConstTensor,              \
    413       TTypes<Tpaddings>::ConstMatrix, int,                    \
    414       typename TTypes<T, k, int32>::Tensor);                  \
    415   extern template struct MirrorPadGrad<GpuDevice, T, Tpaddings, k>;
    416 
    417 #define DECLARE_GPU_SPECS(T)     \
    418   DECLARE_GPU_SPEC(T, int32, 1); \
    419   DECLARE_GPU_SPEC(T, int32, 2); \
    420   DECLARE_GPU_SPEC(T, int32, 3); \
    421   DECLARE_GPU_SPEC(T, int32, 4); \
    422   DECLARE_GPU_SPEC(T, int32, 5); \
    423   DECLARE_GPU_SPEC(T, int64, 1); \
    424   DECLARE_GPU_SPEC(T, int64, 2); \
    425   DECLARE_GPU_SPEC(T, int64, 3); \
    426   DECLARE_GPU_SPEC(T, int64, 4); \
    427   DECLARE_GPU_SPEC(T, int64, 5);
    428 
    429 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
    430 #undef DECLARE_GPU_SPECS
    431 #undef DECLARE_GPU_SPEC
    432 }  // namespace functor
    433 
    434 // Registration of the GPU implementations.
    435 #define REGISTER_GPU_KERNEL(T)                                    \
    436   REGISTER_KERNEL_BUILDER(Name("MirrorPadGrad")                   \
    437                               .Device(DEVICE_GPU)                 \
    438                               .TypeConstraint<T>("T")             \
    439                               .TypeConstraint<int32>("Tpaddings") \
    440                               .HostMemory("paddings"),            \
    441                           MirrorPadGradOp<GpuDevice, T, int32>);  \
    442   REGISTER_KERNEL_BUILDER(Name("MirrorPadGrad")                   \
    443                               .Device(DEVICE_GPU)                 \
    444                               .TypeConstraint<T>("T")             \
    445                               .TypeConstraint<int64>("Tpaddings") \
    446                               .HostMemory("paddings"),            \
    447                           MirrorPadGradOp<GpuDevice, T, int64>);
    448 
    449 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
    450 #undef REGISTER_GPU_KERNEL
    451 #endif  // GOOGLE_CUDA
    452 
    453 }  // namespace tensorflow
    454