Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include <memory>
     21 #include <string>
     22 #include <utility>
     23 
     24 #include "tensorflow/core/kernels/depthtospace_op.h"
     25 
     26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     27 #include "tensorflow/core/framework/op.h"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/register_types.h"
     30 #include "tensorflow/core/framework/tensor.h"
     31 #include "tensorflow/core/framework/tensor_shape.h"
     32 #include "tensorflow/core/framework/tensor_types.h"
     33 #include "tensorflow/core/framework/types.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/types.h"
     36 #include "tensorflow/core/util/tensor_format.h"
     37 
     38 namespace tensorflow {
     39 
     40 typedef Eigen::ThreadPoolDevice CPUDevice;
     41 typedef Eigen::GpuDevice GPUDevice;
     42 
     43 template <typename Device, typename T>
     44 class DepthToSpaceOp : public OpKernel {
     45  public:
     46   explicit DepthToSpaceOp(OpKernelConstruction* context) : OpKernel(context) {
     47     string data_format_str;
     48     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
     49     OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
     50                 errors::InvalidArgument("Invalid data format"));
     51 
     52     OP_REQUIRES_OK(context, context->GetAttr("block_size", &block_size_));
     53     OP_REQUIRES(context, block_size_ > 1,
     54                 errors::InvalidArgument("Block size should be > 1, but was: ",
     55                                         block_size_));
     56 
     57     if (std::is_same<Device, CPUDevice>::value) {
     58       OP_REQUIRES(
     59           context, data_format_ == FORMAT_NHWC,
     60           errors::InvalidArgument(
     61               "Only NHWC data_format supported on CPU. Got ", data_format_str));
     62     }
     63   }
     64 
     65   void Compute(OpKernelContext* context) override {
     66     const Tensor& input = context->input(0);
     67     const int dims = input.dims();
     68 
     69     // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here.
     70     constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
     71     OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)),
     72                 errors::InvalidArgument(
     73                     "qint8 should be used with data_format NCHW_VECT_C."));
     74 
     75     constexpr int kVect = is_int8x4 ? 4 : 1;
     76     constexpr int kDims = is_int8x4 ? 5 : 4;
     77     OP_REQUIRES(context, kDims == dims,
     78                 errors::InvalidArgument("Input rank should be: ", kDims,
     79                                         " instead of: ", dims));
     80 
     81     constexpr int kNumSpatialDims = 2;
     82     const int batch_size =
     83         input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'N'));
     84     const int input_height =
     85         input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'H'));
     86     const int input_width =
     87         input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'W'));
     88     const int input_depth =
     89         input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'C')) *
     90         kVect;
     91 
     92     const int block_size_sq = block_size_ * block_size_;
     93 
     94     // The depth must be divisible by block_size_ * block_size_
     95     OP_REQUIRES(
     96         context, input_depth % block_size_sq == 0,
     97         errors::InvalidArgument("Input depth dimension ", input_depth,
     98                                 " should be divisible by: ", block_size_sq));
     99 
    100     const int output_depth = input_depth / block_size_sq;
    101     const int output_width = input_width * block_size_;
    102     const int output_height = input_height * block_size_;
    103 
    104     // Allocate output tensor.
    105     Tensor* outputs_tensor = nullptr;
    106     OP_REQUIRES_OK(context,
    107                    context->allocate_output(
    108                        0,
    109                        ShapeFromFormat(data_format_, batch_size, output_height,
    110                                        output_width, output_depth),
    111                        &outputs_tensor));
    112     auto Tinput = input.tensor<T, kDims>();
    113     auto Toutput = outputs_tensor->tensor<T, kDims>();
    114 
    115     if (std::is_same<Device, GPUDevice>::value) {
    116       if (is_int8x4) {
    117         // NCHW_VECT_C with 4 x qint8 can be treated as NCHW int32.
    118         auto Tinput_v = input.template reinterpret_last_dimension<int32, 4>();
    119         auto Toutput_v = outputs_tensor->reinterpret_last_dimension<int32, 4>();
    120         functor::DepthToSpaceOpFunctor<GPUDevice, int32, FORMAT_NCHW> functor;
    121         functor(context->eigen_device<GPUDevice>(), Tinput_v, block_size_,
    122                 Toutput_v);
    123         return;
    124       } else if (data_format_ == FORMAT_NCHW) {
    125         functor::DepthToSpaceOpFunctor<GPUDevice, T, FORMAT_NCHW> functor;
    126         functor(context->eigen_device<GPUDevice>(), Tinput, block_size_,
    127                 Toutput);
    128         return;
    129       }
    130     }
    131 
    132     // NOTE: Assumes data_format_ == FORMAT_NHWC here, since we have rejected
    133     // (CPU && data_format_ != FORMAT_NHWC) in the constructor.
    134 
    135     if (!is_int8x4) {
    136       functor::DepthToSpaceOpFunctor<Device, T, FORMAT_NHWC> functor;
    137       functor(context->eigen_device<Device>(), Tinput, block_size_, Toutput);
    138     }
    139   };
    140 
    141  private:
    142   int block_size_;
    143   TensorFormat data_format_;
    144 };
    145 
    146 // Partial specialization of DepthToSpaceOpFunctor for a CPUDevice
    147 // with FORMAT_NHWC.
    148 namespace functor {
    149 template <typename T>
    150 struct DepthToSpaceOpFunctor<CPUDevice, T, FORMAT_NHWC> {
    151   void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
    152                   int block_size, typename TTypes<T, 4>::Tensor output) {
    153     const int batch_size = output.dimension(0);
    154     const int output_height = output.dimension(1);
    155     const int output_width = output.dimension(2);
    156     const int output_depth = output.dimension(3);
    157 
    158     for (int b = 0; b < batch_size; ++b) {
    159       for (int h = 0; h < output_height; ++h) {
    160         const int in_h = h / block_size;
    161         const int offset_h = (h % block_size);
    162         for (int w = 0; w < output_width; ++w) {
    163           const int in_w = w / block_size;
    164           const int offset_w = (w % block_size);
    165           const int offset_d =
    166               (offset_h * block_size + offset_w) * output_depth;
    167           for (int d = 0; d < output_depth; ++d) {
    168             const int in_d = d + offset_d;
    169             output(b, h, w, d) = input(b, in_h, in_w, in_d);
    170           }
    171         }
    172       }
    173     }
    174   }
    175 };
    176 }  // namespace functor
    177 
    178 #define REGISTER(type)                                                   \
    179   REGISTER_KERNEL_BUILDER(                                               \
    180       Name("DepthToSpace").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    181       DepthToSpaceOp<CPUDevice, type>);
    182 
    183 TF_CALL_ALL_TYPES(REGISTER);
    184 #undef REGISTER
    185 
    186 #if GOOGLE_CUDA
    187 REGISTER_KERNEL_BUILDER(
    188     Name("DepthToSpace").Device(DEVICE_GPU).TypeConstraint<float>("T"),
    189     DepthToSpaceOp<GPUDevice, float>);
    190 REGISTER_KERNEL_BUILDER(
    191     Name("DepthToSpace").Device(DEVICE_GPU).TypeConstraint<qint8>("T"),
    192     DepthToSpaceOp<GPUDevice, qint8>);
    193 #endif  // GOOGLE_CUDA
    194 
    195 }  // end namespace tensorflow
    196