Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include <memory>
     21 #include <string>
     22 #include <utility>
     23 
     24 #include "tensorflow/core/kernels/spacetobatch_functor.h"
     25 
     26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     27 #include "tensorflow/core/framework/op.h"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/register_types.h"
     30 #include "tensorflow/core/framework/tensor.h"
     31 #include "tensorflow/core/framework/tensor_shape.h"
     32 #include "tensorflow/core/framework/tensor_types.h"
     33 #include "tensorflow/core/framework/types.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/types.h"
     36 
     37 namespace tensorflow {
     38 
     39 typedef Eigen::ThreadPoolDevice CPUDevice;
     40 typedef Eigen::GpuDevice GPUDevice;
     41 
     42 template <typename Device, typename T>
     43 static void BatchToSpaceOpCompute(OpKernelContext* context,
     44                                   const Tensor& orig_input_tensor,
     45                                   const Tensor& orig_block_shape,
     46                                   const Tensor& orig_crops) {
     47   const int input_dims = orig_input_tensor.dims();
     48   OP_REQUIRES(
     49       context, TensorShapeUtils::IsVector(orig_block_shape.shape()),
     50       errors::InvalidArgument("block_shape rank should be 1 instead of ",
     51                               orig_block_shape.dims()));
     52 
     53   const int block_dims = orig_block_shape.dim_size(0);
     54   OP_REQUIRES(
     55       context, orig_input_tensor.dims() >= 1 + block_dims,
     56       errors::InvalidArgument("input rank should be >= ", 1 + block_dims,
     57                               " instead of ", orig_input_tensor.dims()));
     58 
     59   OP_REQUIRES(context,
     60               TensorShapeUtils::IsMatrix(orig_crops.shape()) &&
     61                   block_dims == orig_crops.dim_size(0) &&
     62                   2 == orig_crops.dim_size(1),
     63               errors::InvalidArgument("crops should have shape [", block_dims,
     64                                       ", 2] instead of ",
     65                                       orig_crops.shape().DebugString()));
     66   // To avoid out-of-bounds access in the case that the block_shape and/or
     67   // crops tensors are concurrently modified, we must copy the values.
     68   gtl::InlinedVector<int64, 4> block_shape;
     69   gtl::InlinedVector<int64, 8> crops;
     70   internal::spacetobatch::SubtleMustCopyFlat(orig_block_shape, &block_shape);
     71   internal::spacetobatch::SubtleMustCopyFlat(orig_crops, &crops);
     72 
     73   // Determine the length of the prefix of block dims that can be combined
     74   // into the batch dimension due to having no padding and block_shape=1.
     75   int removed_prefix_block_dims = 0;
     76   for (; removed_prefix_block_dims < block_dims; ++removed_prefix_block_dims) {
     77     const int dim = removed_prefix_block_dims;
     78     if (crops[2 * dim] != 0 || crops[2 * dim + 1] != 0 ||
     79         block_shape[dim] != 1) {
     80       break;
     81     }
     82   }
     83 
     84   // Determine the length of the suffix of block dims that can be combined
     85   // into the depth dimension due to having no padding and block_shape=1.
     86   int removed_suffix_block_dims = 0;
     87   for (; removed_suffix_block_dims < block_dims - removed_prefix_block_dims;
     88        ++removed_suffix_block_dims) {
     89     const int dim = block_dims - 1 - removed_suffix_block_dims;
     90     if (crops[2 * dim] != 0 || crops[2 * dim + 1] != 0 ||
     91         block_shape[dim] != 1) {
     92       break;
     93     }
     94   }
     95 
     96   // Compute the product of the block_shape values.
     97   int64 block_shape_product = 1;
     98   for (int block_dim = 0; block_dim < block_dims; ++block_dim) {
     99     block_shape_product *= block_shape[block_dim];
    100   }
    101   OP_REQUIRES(
    102       context, block_shape_product > 0,
    103       errors::InvalidArgument("Product of block sizes must be positive, got ",
    104                               block_shape_product));
    105 
    106   const int64 orig_input_batch_size = orig_input_tensor.dim_size(0);
    107   OP_REQUIRES(
    108       context, orig_input_batch_size % block_shape_product == 0,
    109       errors::InvalidArgument("Input batch dimension (", orig_input_batch_size,
    110                               ") is not divisible by product of block sizes (",
    111                               block_shape_product, ")"));
    112 
    113   const int internal_block_dims =
    114       block_dims - removed_prefix_block_dims - removed_suffix_block_dims;
    115   OP_REQUIRES(context, internal_block_dims <= kMaxSpaceToBatchBlockDims,
    116               errors::InvalidArgument(
    117                   "Maximum number of non-combined block dimensions is ",
    118                   internal_block_dims, " but must not exceed ",
    119                   kMaxSpaceToBatchBlockDims));
    120 
    121   if (internal_block_dims == 0) {
    122     context->set_output(0, orig_input_tensor);
    123     return;
    124   }
    125 
    126   // For the purpose of computing the result, the input will be treated as
    127   // having this shape, of rank 2 + internal_block_dims.
    128   TensorShape internal_input_shape;
    129 
    130   // For the purpose of computing the result, the output will be treated as
    131   // having this shape, of rank 2 + internal_block_dims.
    132   TensorShape internal_output_shape;
    133 
    134   // The actual output shape exposed to callers.
    135   TensorShape external_output_shape;
    136 
    137   external_output_shape.AddDim(orig_input_batch_size / block_shape_product);
    138 
    139   int64 input_batch_size = orig_input_batch_size;
    140   for (int block_dim = 0; block_dim < removed_prefix_block_dims; ++block_dim) {
    141     const int64 size = orig_input_tensor.dim_size(block_dim + 1);
    142     input_batch_size *= size;
    143     external_output_shape.AddDim(size);
    144   }
    145   internal_input_shape.AddDim(input_batch_size);
    146   internal_output_shape.AddDim(input_batch_size / block_shape_product);
    147 
    148   for (int block_dim = removed_prefix_block_dims;
    149        block_dim < block_dims - removed_suffix_block_dims; ++block_dim) {
    150     const int64 crop_start = crops[2 * block_dim],
    151                 crop_end = crops[2 * block_dim + 1];
    152     OP_REQUIRES(context, crop_start >= 0 && crop_end >= 0,
    153                 errors::InvalidArgument("Crops must be non-negative"));
    154     const int64 input_size = orig_input_tensor.dim_size(block_dim + 1);
    155     const int64 block_shape_value = block_shape[block_dim];
    156     const int64 cropped_size =
    157         input_size * block_shape_value - crop_start - crop_end;
    158     OP_REQUIRES(context, cropped_size >= 0,
    159                 errors::InvalidArgument("cropped_shape[", block_dim, "]=",
    160                                         cropped_size, " must be non-negative"));
    161     internal_input_shape.AddDim(input_size);
    162     internal_output_shape.AddDim(cropped_size);
    163     external_output_shape.AddDim(cropped_size);
    164   }
    165 
    166   int64 depth = 1;
    167   for (int dim = block_dims - removed_suffix_block_dims + 1; dim < input_dims;
    168        ++dim) {
    169     const int64 size = orig_input_tensor.dim_size(dim);
    170     external_output_shape.AddDim(size);
    171     depth *= size;
    172   }
    173   internal_input_shape.AddDim(depth);
    174   internal_output_shape.AddDim(depth);
    175 
    176   // Allocate output tensor.
    177   Tensor* output_tensor = nullptr;
    178   OP_REQUIRES_OK(context, context->allocate_output(0, external_output_shape,
    179                                                    &output_tensor));
    180 
    181   const int64* internal_crops = &crops[2 * removed_prefix_block_dims];
    182   const int64* internal_block_shape = &block_shape[removed_prefix_block_dims];
    183 
    184   switch (internal_block_dims) {
    185 #define TF_BATCHTOSPACE_BLOCK_DIMS_CASE(NUM_BLOCK_DIMS)                   \
    186   case NUM_BLOCK_DIMS: {                                                  \
    187     OP_REQUIRES_OK(                                                       \
    188         context,                                                          \
    189         (functor::SpaceToBatchFunctor<Device, T, NUM_BLOCK_DIMS, true>()( \
    190             context->eigen_device<Device>(),                              \
    191             output_tensor->shaped<T, NUM_BLOCK_DIMS + 2>(                 \
    192                 internal_output_shape.dim_sizes()),                       \
    193             internal_block_shape, internal_crops,                         \
    194             orig_input_tensor.shaped<T, NUM_BLOCK_DIMS + 2>(              \
    195                 internal_input_shape.dim_sizes()))));                     \
    196   } break;                                                                \
    197     /**/
    198     TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(TF_BATCHTOSPACE_BLOCK_DIMS_CASE)
    199 #undef TF_BATCHTOSPACE_BLOCK_DIMS_CASE
    200   }
    201 }
    202 
    203 template <typename Device, typename T>
    204 class BatchToSpaceNDOp : public OpKernel {
    205  public:
    206   explicit BatchToSpaceNDOp(OpKernelConstruction* context)
    207       : OpKernel(context) {}
    208 
    209   void Compute(OpKernelContext* context) override {
    210     const Tensor& orig_input_tensor = context->input(0);
    211     const Tensor& orig_block_shape = context->input(1);
    212     const Tensor& orig_crops = context->input(2);
    213     BatchToSpaceOpCompute<Device, T>(context, orig_input_tensor,
    214                                      orig_block_shape, orig_crops);
    215   }
    216 };
    217 
    218 template <typename Device, typename T>
    219 class BatchToSpaceOp : public OpKernel {
    220  public:
    221   explicit BatchToSpaceOp(OpKernelConstruction* context) : OpKernel(context) {
    222     OP_REQUIRES_OK(context, context->GetAttr("block_size", &block_size_));
    223     OP_REQUIRES(
    224         context, block_size_ > 1,
    225         errors::InvalidArgument("Block size should be > 1: ", block_size_));
    226     // We don't use context->allocate_persistent because the allocation must
    227     // happen on the CPU regardless of Device.
    228     block_shape_ = Tensor(tensorflow::DT_INT64, TensorShape({2}));
    229     auto block_shape_vec = block_shape_.vec<int64>();
    230     block_shape_vec(0) = block_size_;
    231     block_shape_vec(1) = block_size_;
    232   }
    233 
    234   void Compute(OpKernelContext* context) override {
    235     const Tensor& in0 = context->input(0);
    236     const Tensor& in1 = context->input(1);
    237     const int dims = in0.dims();
    238 
    239     // Check on the input dimensions first.
    240     // The input is presumed to be [batch, height, width, depth]
    241     static const int kRequiredDims = 4;
    242     OP_REQUIRES(context, kRequiredDims == dims,
    243                 errors::InvalidArgument("Input rank should be: ", kRequiredDims,
    244                                         "instead of: ", dims));
    245     BatchToSpaceOpCompute<Device, T>(context, in0, block_shape_, in1);
    246   }
    247 
    248  private:
    249   int block_size_;
    250   Tensor block_shape_;
    251 };
    252 
    253 #define REGISTER(T)                                        \
    254   REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND")           \
    255                               .Device(DEVICE_CPU)          \
    256                               .TypeConstraint<T>("T")      \
    257                               .HostMemory("block_shape")   \
    258                               .HostMemory("crops"),        \
    259                           BatchToSpaceNDOp<CPUDevice, T>); \
    260   REGISTER_KERNEL_BUILDER(Name("BatchToSpace")             \
    261                               .Device(DEVICE_CPU)          \
    262                               .TypeConstraint<T>("T")      \
    263                               .HostMemory("crops"),        \
    264                           BatchToSpaceOp<CPUDevice, T>);
    265 
    266 TF_CALL_REAL_NUMBER_TYPES(REGISTER);
    267 #undef REGISTER
    268 
    269 #if GOOGLE_CUDA
    270 #define REGISTER(T)                                        \
    271   REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND")           \
    272                               .Device(DEVICE_GPU)          \
    273                               .TypeConstraint<T>("T")      \
    274                               .HostMemory("block_shape")   \
    275                               .HostMemory("crops"),        \
    276                           BatchToSpaceNDOp<GPUDevice, T>); \
    277   REGISTER_KERNEL_BUILDER(Name("BatchToSpace")             \
    278                               .Device(DEVICE_GPU)          \
    279                               .TypeConstraint<T>("T")      \
    280                               .HostMemory("crops"),        \
    281                           BatchToSpaceOp<GPUDevice, T>);
    282 
    283 TF_CALL_GPU_NUMBER_TYPES(REGISTER);
    284 #undef REGISTER
    285 #endif  // GOOGLE_CUDA
    286 
    287 }  // end namespace tensorflow
    288