Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/kernels/pad_op.h"
     21 
     22 #include <memory>
     23 #include <string>
     24 #include <utility>
     25 
     26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     27 #include "tensorflow/core/framework/op.h"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/register_types.h"
     30 #include "tensorflow/core/framework/tensor.h"
     31 #include "tensorflow/core/framework/tensor_shape.h"
     32 #include "tensorflow/core/framework/tensor_types.h"
     33 #include "tensorflow/core/framework/types.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/types.h"
     36 
     37 namespace tensorflow {
     38 
     39 typedef Eigen::ThreadPoolDevice CPUDevice;
     40 typedef Eigen::GpuDevice GPUDevice;
     41 #ifdef TENSORFLOW_USE_SYCL
     42 typedef Eigen::SyclDevice SYCLDevice;
     43 #endif  // TENSORFLOW_USE_SYCL
     44 
     45 template <typename Device, typename T, typename Tpadding>
     46 class PadOp : public OpKernel {
     47  public:
     48   explicit PadOp(OpKernelConstruction* context) : OpKernel(context) {}
     49 
     50   void Compute(OpKernelContext* context) override {
     51     const Tensor& in0 = context->input(0);
     52     const Tensor& in1 = context->input(1);
     53     const int dims = in0.dims();
     54     static const int kMinDims = 0;
     55     static const int kMaxDims = 6;
     56     OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims,
     57                 errors::Unimplemented("inputs rank not in [", kMinDims, ",",
     58                                       kMaxDims, "]: ", dims));
     59     OP_REQUIRES(
     60         context,
     61         TensorShapeUtils::IsMatrix(in1.shape()) && in1.dim_size(1) == 2,
     62         errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
     63                                 in1.shape().DebugString()));
     64     const int fixed_dims =
     65         (allow_legacy_scalars() && dims == 0 && in1.dim_size(0) == 1) ? 1
     66                                                                       : dims;
     67     OP_REQUIRES(
     68         context, fixed_dims == in1.dim_size(0),
     69         errors::InvalidArgument(
     70             "The first dimension of paddings must be the rank of inputs",
     71             in1.shape().DebugString(), " ", in0.shape().DebugString()));
     72 
     73     T pad_value(0);
     74     if (context->num_inputs() == 3) {
     75       const Tensor& constant_values = context->input(2);
     76       OP_REQUIRES(
     77           context, TensorShapeUtils::IsScalar(constant_values.shape()),
     78           errors::InvalidArgument("constant_values must be a scalar. Found: ",
     79                                   constant_values.shape().DebugString()));
     80       pad_value = context->input(2).scalar<T>()();
     81     }
     82 
     83     // Compute the shape of the output tensor, and allocate it.
     84     TensorShape output_shape;
     85     typename TTypes<Tpadding>::ConstMatrix paddings = in1.matrix<Tpadding>();
     86     for (int d = 0; d < fixed_dims; ++d) {
     87       const Tpadding before_d =
     88           paddings(d, 0);                       // Pad before existing elements.
     89       const Tpadding after_d = paddings(d, 1);  // Pad after existing elements.
     90       OP_REQUIRES(context, before_d >= 0 && after_d >= 0,
     91                   errors::InvalidArgument("Paddings must be non-negative: ",
     92                                           before_d, " ", after_d));
     93       const int64 size_d =
     94           (allow_legacy_scalars() && d == in0.dims()) ? 1 : in0.dim_size(d);
     95       output_shape.AddDim(before_d + size_d + after_d);
     96     }
     97 
     98     // If there is no padding to be done, forward the input to output.
     99     if (output_shape.num_elements() == in0.NumElements()) {
    100       // When num_elements == 0, shape may have changed.
    101       Tensor out;
    102       CHECK(out.CopyFrom(in0, output_shape));
    103       context->set_output(0, out);
    104       return;
    105     }
    106 
    107     Tensor* output = nullptr;
    108     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    109 
    110     // Invoke the dims-specific implementation.
    111     switch (fixed_dims) {
    112       case 0:
    113         Operate<0>(context, in0.tensor<T, 0>(), paddings, pad_value, output);
    114         break;
    115       case 1:
    116         // TODO(irving): Once Pad doesn't need a scalar special case,
    117         // change flat to tensor.  That is, once !allow_legacy_scalars().
    118         Operate<1>(context, in0.flat<T>(), paddings, pad_value, output);
    119         break;
    120       case 2:
    121         Operate<2>(context, in0.tensor<T, 2>(), paddings, pad_value, output);
    122         break;
    123       case 3:
    124         Operate<3>(context, in0.tensor<T, 3>(), paddings, pad_value, output);
    125         break;
    126       case 4:
    127         Operate<4>(context, in0.tensor<T, 4>(), paddings, pad_value, output);
    128         break;
    129       case 5:
    130         Operate<5>(context, in0.tensor<T, 5>(), paddings, pad_value, output);
    131         break;
    132       case 6:
    133         Operate<6>(context, in0.tensor<T, 6>(), paddings, pad_value, output);
    134         break;
    135       default:
    136         OP_REQUIRES(context, false,
    137                     errors::InvalidArgument("Only ranks up to 6 supported: ",
    138                                             in0.shape().DebugString()));
    139     }
    140   }
    141 
    142  private:
    143   template <int Dims>
    144   void Operate(OpKernelContext* context,
    145                typename TTypes<T, Dims>::ConstTensor input,
    146                typename TTypes<Tpadding>::ConstMatrix paddings, T pad_value,
    147                Tensor* output) {
    148     CHECK_EQ(Dims, paddings.dimension(0));
    149     CHECK_EQ(2, paddings.dimension(1));
    150     Eigen::array<Eigen::IndexPair<Tpadding>, Dims> paddings_array;
    151     for (int i = 0; i < Dims; ++i) {
    152       paddings_array[i] = {paddings(i, 0), paddings(i, 1)};
    153     }
    154     functor::Pad<Device, T, Tpadding, Dims> functor;
    155     functor(context->eigen_device<Device>(), output->tensor<T, Dims>(), input,
    156             paddings_array, pad_value);
    157   }
    158 };
    159 
    160 #define REGISTER_KERNEL(type)                                     \
    161   REGISTER_KERNEL_BUILDER(Name("Pad")                             \
    162                               .Device(DEVICE_CPU)                 \
    163                               .TypeConstraint<type>("T")          \
    164                               .TypeConstraint<int32>("Tpaddings") \
    165                               .HostMemory("paddings"),            \
    166                           PadOp<CPUDevice, type, int32>);         \
    167   REGISTER_KERNEL_BUILDER(Name("Pad")                             \
    168                               .Device(DEVICE_CPU)                 \
    169                               .TypeConstraint<type>("T")          \
    170                               .TypeConstraint<int64>("Tpaddings") \
    171                               .HostMemory("paddings"),            \
    172                           PadOp<CPUDevice, type, int64>);         \
    173   REGISTER_KERNEL_BUILDER(Name("PadV2")                           \
    174                               .Device(DEVICE_CPU)                 \
    175                               .TypeConstraint<type>("T")          \
    176                               .TypeConstraint<int32>("Tpaddings") \
    177                               .HostMemory("paddings")             \
    178                               .HostMemory("constant_values"),     \
    179                           PadOp<CPUDevice, type, int32>);         \
    180   REGISTER_KERNEL_BUILDER(Name("PadV2")                           \
    181                               .Device(DEVICE_CPU)                 \
    182                               .TypeConstraint<type>("T")          \
    183                               .TypeConstraint<int64>("Tpaddings") \
    184                               .HostMemory("paddings")             \
    185                               .HostMemory("constant_values"),     \
    186                           PadOp<CPUDevice, type, int64>);
    187 
    188 TF_CALL_POD_TYPES(REGISTER_KERNEL);
    189 #undef REGISTER_KERNEL
    190 
    191 #if GOOGLE_CUDA
    192 // Forward declarations of the functor specializations for GPU.
    193 namespace functor {
    194 #define DECLARE_GPU_SPEC(T, Dims)                                         \
    195   template <>                                                             \
    196   void Pad<GPUDevice, T, int32, Dims>::operator()(                        \
    197       const GPUDevice& d, typename TTypes<T, Dims>::Tensor output,        \
    198       typename TTypes<T, Dims>::ConstTensor input,                        \
    199       Eigen::array<Eigen::IndexPair<int32>, Dims> paddings, T pad_value); \
    200   extern template struct Pad<GPUDevice, T, int32, Dims>;                  \
    201   template <>                                                             \
    202   void Pad<GPUDevice, T, int64, Dims>::operator()(                        \
    203       const GPUDevice& d, typename TTypes<T, Dims>::Tensor output,        \
    204       typename TTypes<T, Dims>::ConstTensor input,                        \
    205       Eigen::array<Eigen::IndexPair<int64>, Dims> paddings, T pad_value); \
    206   extern template struct Pad<GPUDevice, T, int64, Dims>;
    207 
    208 #define DECLARE_GPU_SPECS(T) \
    209   DECLARE_GPU_SPEC(T, 0);    \
    210   DECLARE_GPU_SPEC(T, 1);    \
    211   DECLARE_GPU_SPEC(T, 2);    \
    212   DECLARE_GPU_SPEC(T, 3);    \
    213   DECLARE_GPU_SPEC(T, 4);    \
    214   DECLARE_GPU_SPEC(T, 5);    \
    215   DECLARE_GPU_SPEC(T, 6);
    216 
    217 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
    218 }  // namespace functor
    219 
    220 // Registration of the GPU implementations.
    221 #define REGISTER_GPU_KERNEL(T)                                    \
    222   REGISTER_KERNEL_BUILDER(Name("Pad")                             \
    223                               .Device(DEVICE_GPU)                 \
    224                               .TypeConstraint<T>("T")             \
    225                               .TypeConstraint<int32>("Tpaddings") \
    226                               .HostMemory("paddings"),            \
    227                           PadOp<GPUDevice, T, int32>);            \
    228   REGISTER_KERNEL_BUILDER(Name("Pad")                             \
    229                               .Device(DEVICE_GPU)                 \
    230                               .TypeConstraint<T>("T")             \
    231                               .TypeConstraint<int64>("Tpaddings") \
    232                               .HostMemory("paddings"),            \
    233                           PadOp<GPUDevice, T, int64>);            \
    234   REGISTER_KERNEL_BUILDER(Name("PadV2")                           \
    235                               .Device(DEVICE_GPU)                 \
    236                               .TypeConstraint<T>("T")             \
    237                               .TypeConstraint<int32>("Tpaddings") \
    238                               .HostMemory("paddings")             \
    239                               .HostMemory("constant_values"),     \
    240                           PadOp<GPUDevice, T, int32>)             \
    241   REGISTER_KERNEL_BUILDER(Name("PadV2")                           \
    242                               .Device(DEVICE_GPU)                 \
    243                               .TypeConstraint<T>("T")             \
    244                               .TypeConstraint<int64>("Tpaddings") \
    245                               .HostMemory("paddings")             \
    246                               .HostMemory("constant_values"),     \
    247                           PadOp<GPUDevice, T, int64>)
    248 
    249 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
    250 
    251 // A special GPU kernel for int32.
    252 // TODO(b/25387198): Also enable int32 in device memory. This kernel
    253 // registration requires all int32 inputs and outputs to be in host memory.
    254 REGISTER_KERNEL_BUILDER(Name("Pad")
    255                             .Device(DEVICE_GPU)
    256                             .TypeConstraint<int32>("T")
    257                             .TypeConstraint<int32>("Tpaddings")
    258                             .HostMemory("input")
    259                             .HostMemory("paddings")
    260                             .HostMemory("output"),
    261                         PadOp<CPUDevice, int32, int32>);
    262 REGISTER_KERNEL_BUILDER(Name("Pad")
    263                             .Device(DEVICE_GPU)
    264                             .TypeConstraint<int32>("T")
    265                             .TypeConstraint<int64>("Tpaddings")
    266                             .HostMemory("input")
    267                             .HostMemory("paddings")
    268                             .HostMemory("output"),
    269                         PadOp<CPUDevice, int32, int64>);
    270 REGISTER_KERNEL_BUILDER(Name("PadV2")
    271                             .Device(DEVICE_GPU)
    272                             .TypeConstraint<int32>("T")
    273                             .TypeConstraint<int32>("Tpaddings")
    274                             .HostMemory("input")
    275                             .HostMemory("paddings")
    276                             .HostMemory("constant_values")
    277                             .HostMemory("output"),
    278                         PadOp<CPUDevice, int32, int32>);
    279 REGISTER_KERNEL_BUILDER(Name("PadV2")
    280                             .Device(DEVICE_GPU)
    281                             .TypeConstraint<int32>("T")
    282                             .TypeConstraint<int64>("Tpaddings")
    283                             .HostMemory("input")
    284                             .HostMemory("paddings")
    285                             .HostMemory("constant_values")
    286                             .HostMemory("output"),
    287                         PadOp<CPUDevice, int32, int64>);
    288 #endif
    289 
    290 #ifdef TENSORFLOW_USE_SYCL
    291 // Registration of the GPU implementations.
    292 #define REGISTER_SYCL_KERNEL(T)                                   \
    293   REGISTER_KERNEL_BUILDER(Name("Pad")                             \
    294                               .Device(DEVICE_SYCL)                \
    295                               .TypeConstraint<T>("T")             \
    296                               .TypeConstraint<int32>("Tpaddings") \
    297                               .HostMemory("paddings"),            \
    298                           PadOp<SYCLDevice, T, int32>);           \
    299   REGISTER_KERNEL_BUILDER(Name("Pad")                             \
    300                               .Device(DEVICE_SYCL)                \
    301                               .TypeConstraint<T>("T")             \
    302                               .TypeConstraint<int64>("Tpaddings") \
    303                               .HostMemory("paddings"),            \
    304                           PadOp<SYCLDevice, T, int64>);           \
    305   REGISTER_KERNEL_BUILDER(Name("PadV2")                           \
    306                               .Device(DEVICE_SYCL)                \
    307                               .TypeConstraint<T>("T")             \
    308                               .TypeConstraint<int32>("Tpaddings") \
    309                               .HostMemory("paddings")             \
    310                               .HostMemory("constant_values"),     \
    311                           PadOp<SYCLDevice, T, int32>)            \
    312   REGISTER_KERNEL_BUILDER(Name("PadV2")                           \
    313                               .Device(DEVICE_SYCL)                \
    314                               .TypeConstraint<T>("T")             \
    315                               .TypeConstraint<int64>("Tpaddings") \
    316                               .HostMemory("paddings")             \
    317                               .HostMemory("constant_values"),     \
    318                           PadOp<SYCLDevice, T, int64>)
    319 
    320 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
    321 REGISTER_KERNEL_BUILDER(Name("Pad")
    322                             .Device(DEVICE_SYCL)
    323                             .TypeConstraint<int32>("T")
    324                             .TypeConstraint<int32>("Tpaddings")
    325                             .HostMemory("input")
    326                             .HostMemory("paddings")
    327                             .HostMemory("output"),
    328                         PadOp<CPUDevice, int32, int32>);
    329 REGISTER_KERNEL_BUILDER(Name("Pad")
    330                             .Device(DEVICE_SYCL)
    331                             .TypeConstraint<int32>("T")
    332                             .TypeConstraint<int64>("Tpaddings")
    333                             .HostMemory("input")
    334                             .HostMemory("paddings")
    335                             .HostMemory("output"),
    336                         PadOp<CPUDevice, int32, int64>);
    337 REGISTER_KERNEL_BUILDER(Name("PadV2")
    338                             .Device(DEVICE_SYCL)
    339                             .TypeConstraint<int32>("T")
    340                             .TypeConstraint<int32>("Tpaddings")
    341                             .HostMemory("input")
    342                             .HostMemory("paddings")
    343                             .HostMemory("constant_values")
    344                             .HostMemory("output"),
    345                         PadOp<CPUDevice, int32, int32>);
    346 REGISTER_KERNEL_BUILDER(Name("PadV2")
    347                             .Device(DEVICE_SYCL)
    348                             .TypeConstraint<int32>("T")
    349                             .TypeConstraint<int64>("Tpaddings")
    350                             .HostMemory("input")
    351                             .HostMemory("paddings")
    352                             .HostMemory("constant_values")
    353                             .HostMemory("output"),
    354                         PadOp<CPUDevice, int32, int64>);
    355 #undef REGISTER_SYCL_KERNEL
    356 #endif  // TENSORFLOW_USE_SYCL
    357 
    358 }  // end namespace tensorflow
    359