Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include <memory>
     21 #include <string>
     22 #include <utility>
     23 
     24 #include "tensorflow/core/kernels/spacetobatch_functor.h"
     25 
     26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     27 #include "tensorflow/core/framework/op.h"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/register_types.h"
     30 #include "tensorflow/core/framework/tensor.h"
     31 #include "tensorflow/core/framework/tensor_shape.h"
     32 #include "tensorflow/core/framework/tensor_types.h"
     33 #include "tensorflow/core/framework/types.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/types.h"
     36 
     37 namespace tensorflow {
     38 
     39 typedef Eigen::ThreadPoolDevice CPUDevice;
     40 typedef Eigen::GpuDevice GPUDevice;
     41 
     42 namespace {
     43 
     44 template <typename Device, typename T>
     45 void SpaceToBatchOpCompute(OpKernelContext* context,
     46                            const Tensor& orig_input_tensor,
     47                            const Tensor& orig_block_shape,
     48                            const Tensor& orig_paddings) {
     49   const int input_dims = orig_input_tensor.dims();
     50   OP_REQUIRES(
     51       context, TensorShapeUtils::IsVector(orig_block_shape.shape()),
     52       errors::InvalidArgument("block_shape rank should be 1 instead of ",
     53                               orig_block_shape.dims()));
     54 
     55   const int block_dims = orig_block_shape.dim_size(0);
     56   OP_REQUIRES(
     57       context, orig_input_tensor.dims() >= 1 + block_dims,
     58       errors::InvalidArgument("input rank should be >= ", 1 + block_dims,
     59                               " instead of ", orig_input_tensor.dims()));
     60 
     61   OP_REQUIRES(context,
     62               TensorShapeUtils::IsMatrix(orig_paddings.shape()) &&
     63                   block_dims == orig_paddings.dim_size(0) &&
     64                   2 == orig_paddings.dim_size(1),
     65               errors::InvalidArgument("paddings should have shape [",
     66                                       block_dims, ", 2] instead of ",
     67                                       orig_paddings.shape().DebugString()));
     68 
     69   // To avoid out-of-bounds access in the case that the block_shape and/or
     70   // paddings tensors are concurrently modified, we must copy the values.
     71   gtl::InlinedVector<int64, 4> block_shape;
     72   gtl::InlinedVector<int64, 8> paddings;
     73   internal::spacetobatch::SubtleMustCopyFlat(orig_block_shape, &block_shape);
     74   internal::spacetobatch::SubtleMustCopyFlat(orig_paddings, &paddings);
     75 
     76   // Determine the length of the prefix of block dims that can be combined
     77   // into the batch dimension due to having no padding and block_shape=1.
     78   int removed_prefix_block_dims = 0;
     79   for (; removed_prefix_block_dims < block_dims; ++removed_prefix_block_dims) {
     80     const int dim = removed_prefix_block_dims;
     81     if (paddings[2 * dim] != 0 || paddings[2 * dim + 1] != 0 ||
     82         block_shape[dim] != 1) {
     83       break;
     84     }
     85   }
     86 
     87   // Determine the length of the suffix of block dims that can be combined
     88   // into the depth dimension due to having no padding and block_shape=1.
     89   int removed_suffix_block_dims = 0;
     90   for (; removed_suffix_block_dims < block_dims - removed_prefix_block_dims;
     91        ++removed_suffix_block_dims) {
     92     const int dim = block_dims - 1 - removed_suffix_block_dims;
     93     if (paddings[dim * 2] != 0 || paddings[dim * 2 + 1] != 0 ||
     94         block_shape[dim] != 1) {
     95       break;
     96     }
     97   }
     98 
     99   // Compute the product of the block_shape values.
    100   int64 block_shape_product = 1;
    101   for (int block_dim = 0; block_dim < block_dims; ++block_dim) {
    102     block_shape_product *= block_shape[block_dim];
    103   }
    104   OP_REQUIRES(
    105       context, block_shape_product > 0,
    106       errors::InvalidArgument("Product of block sizes must be positive, got ",
    107                               block_shape_product));
    108 
    109   const int internal_block_dims =
    110       block_dims - removed_prefix_block_dims - removed_suffix_block_dims;
    111   OP_REQUIRES(context, internal_block_dims <= kMaxSpaceToBatchBlockDims,
    112               errors::InvalidArgument(
    113                   "Maximum number of non-combined block dimensions is ",
    114                   internal_block_dims, " but must not exceed ",
    115                   kMaxSpaceToBatchBlockDims));
    116 
    117   if (internal_block_dims == 0) {
    118     context->set_output(0, orig_input_tensor);
    119     return;
    120   }
    121 
    122   // For the purpose of computing the result, the input will be treated as
    123   // having this shape, of rank 2 + internal_block_dims.
    124   TensorShape internal_input_shape;
    125 
    126   // For the purpose of computing the result, the output will be treated as
    127   // having this shape, of rank 2 + internal_block_dims.
    128   TensorShape internal_output_shape;
    129 
    130   // The actual output shape exposed to callers.
    131   TensorShape external_output_shape;
    132 
    133   external_output_shape.AddDim(orig_input_tensor.dim_size(0) *
    134                                block_shape_product);
    135 
    136   int64 input_batch_size = orig_input_tensor.dim_size(0);
    137   for (int block_dim = 0; block_dim < removed_prefix_block_dims; ++block_dim) {
    138     const int64 size = orig_input_tensor.dim_size(block_dim + 1);
    139     input_batch_size *= size;
    140     external_output_shape.AddDim(size);
    141   }
    142   internal_input_shape.AddDim(input_batch_size);
    143   internal_output_shape.AddDim(input_batch_size * block_shape_product);
    144 
    145   for (int block_dim = removed_prefix_block_dims;
    146        block_dim < block_dims - removed_suffix_block_dims; ++block_dim) {
    147     const int64 pad_start = paddings[2 * block_dim],
    148                 pad_end = paddings[2 * block_dim + 1];
    149     OP_REQUIRES(context, pad_start >= 0 && pad_end >= 0,
    150                 errors::InvalidArgument("Paddings must be non-negative"));
    151     const int64 input_size = orig_input_tensor.dim_size(block_dim + 1);
    152     const int64 block_shape_value = block_shape[block_dim];
    153     const int64 padded_size = input_size + pad_start + pad_end;
    154     OP_REQUIRES(
    155         context, padded_size % block_shape_value == 0,
    156         errors::InvalidArgument("padded_shape[", block_dim, "]=", padded_size,
    157                                 " is not divisible by block_shape[", block_dim,
    158                                 "]=", block_shape_value));
    159     internal_input_shape.AddDim(input_size);
    160     const int64 output_size = padded_size / block_shape_value;
    161     internal_output_shape.AddDim(output_size);
    162     external_output_shape.AddDim(output_size);
    163   }
    164 
    165   int64 depth = 1;
    166   for (int dim = block_dims - removed_suffix_block_dims + 1; dim < input_dims;
    167        ++dim) {
    168     const int64 size = orig_input_tensor.dim_size(dim);
    169     external_output_shape.AddDim(size);
    170     depth *= size;
    171   }
    172   internal_input_shape.AddDim(depth);
    173   internal_output_shape.AddDim(depth);
    174 
    175   // Allocate output tensor.
    176   Tensor* output_tensor = nullptr;
    177   OP_REQUIRES_OK(context, context->allocate_output(0, external_output_shape,
    178                                                    &output_tensor));
    179 
    180   const int64* internal_paddings = &paddings[2 * removed_prefix_block_dims];
    181   const int64* internal_block_shape = &block_shape[removed_prefix_block_dims];
    182 
    183   switch (internal_block_dims) {
    184 #define TF_SPACETOBATCH_BLOCK_DIMS_CASE(NUM_BLOCK_DIMS)                    \
    185   case NUM_BLOCK_DIMS: {                                                   \
    186     OP_REQUIRES_OK(                                                        \
    187         context,                                                           \
    188         (functor::SpaceToBatchFunctor<Device, T, NUM_BLOCK_DIMS, false>()( \
    189             context->eigen_device<Device>(),                               \
    190             orig_input_tensor.shaped<T, NUM_BLOCK_DIMS + 2>(               \
    191                 internal_input_shape.dim_sizes()),                         \
    192             internal_block_shape, internal_paddings,                       \
    193             output_tensor->shaped<T, NUM_BLOCK_DIMS + 2>(                  \
    194                 internal_output_shape.dim_sizes()))));                     \
    195   } break;                                                                 \
    196     /**/
    197     TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(TF_SPACETOBATCH_BLOCK_DIMS_CASE)
    198 #undef TF_SPACETOBATCH_BLOCK_DIMS_CASE
    199   }
    200 }
    201 
    202 }  // namespace
    203 
    204 template <typename Device, typename T>
    205 class SpaceToBatchNDOp : public OpKernel {
    206  public:
    207   explicit SpaceToBatchNDOp(OpKernelConstruction* context)
    208       : OpKernel(context) {}
    209 
    210   void Compute(OpKernelContext* context) override {
    211     const Tensor& orig_input_tensor = context->input(0);
    212     const Tensor& orig_block_shape = context->input(1);
    213     const Tensor& orig_paddings = context->input(2);
    214     SpaceToBatchOpCompute<Device, T>(context, orig_input_tensor,
    215                                      orig_block_shape, orig_paddings);
    216   }
    217 };
    218 
    219 template <typename Device, typename T>
    220 class SpaceToBatchOp : public OpKernel {
    221  public:
    222   explicit SpaceToBatchOp(OpKernelConstruction* context) : OpKernel(context) {
    223     OP_REQUIRES_OK(context, context->GetAttr("block_size", &block_size_));
    224     OP_REQUIRES(
    225         context, block_size_ > 1,
    226         errors::InvalidArgument("Block size should be > 1: ", block_size_));
    227     // We don't use context->allocate_persistent because the allocation must
    228     // happen on the CPU regardless of Device.
    229     block_shape_ = Tensor(tensorflow::DT_INT64, TensorShape({2}));
    230     auto block_shape_vec = block_shape_.vec<int64>();
    231     block_shape_vec(0) = block_size_;
    232     block_shape_vec(1) = block_size_;
    233   }
    234 
    235   void Compute(OpKernelContext* context) override {
    236     const Tensor& in0 = context->input(0);
    237     const Tensor& in1 = context->input(1);
    238     const int dims = in0.dims();
    239 
    240     static const int kRequiredDims = 4;
    241     OP_REQUIRES(context, kRequiredDims == dims,
    242                 errors::InvalidArgument("Input rank should be: ", kRequiredDims,
    243                                         "instead of: ", dims));
    244     SpaceToBatchOpCompute<Device, T>(context, in0, block_shape_, in1);
    245   }
    246 
    247  private:
    248   int block_size_;
    249   Tensor block_shape_;
    250 };
    251 
    252 #define REGISTER(T)                                        \
    253   REGISTER_KERNEL_BUILDER(Name("SpaceToBatchND")           \
    254                               .Device(DEVICE_CPU)          \
    255                               .TypeConstraint<T>("T")      \
    256                               .HostMemory("block_shape")   \
    257                               .HostMemory("paddings"),     \
    258                           SpaceToBatchNDOp<CPUDevice, T>); \
    259   REGISTER_KERNEL_BUILDER(Name("SpaceToBatch")             \
    260                               .Device(DEVICE_CPU)          \
    261                               .TypeConstraint<T>("T")      \
    262                               .HostMemory("paddings"),     \
    263                           SpaceToBatchOp<CPUDevice, T>);
    264 
    265 TF_CALL_REAL_NUMBER_TYPES(REGISTER);
    266 #undef REGISTER
    267 
    268 #if GOOGLE_CUDA
    269 #define REGISTER(T)                                        \
    270   REGISTER_KERNEL_BUILDER(Name("SpaceToBatchND")           \
    271                               .Device(DEVICE_GPU)          \
    272                               .TypeConstraint<T>("T")      \
    273                               .HostMemory("block_shape")   \
    274                               .HostMemory("paddings"),     \
    275                           SpaceToBatchNDOp<GPUDevice, T>); \
    276   REGISTER_KERNEL_BUILDER(Name("SpaceToBatch")             \
    277                               .Device(DEVICE_GPU)          \
    278                               .TypeConstraint<T>("T")      \
    279                               .HostMemory("paddings"),     \
    280                           SpaceToBatchOp<GPUDevice, T>);
    281 
    282 TF_CALL_GPU_NUMBER_TYPES(REGISTER);
    283 #undef REGISTER
    284 #endif  // GOOGLE_CUDA
    285 
    286 }  // end namespace tensorflow
    287