Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/kernels/data_format_ops.h"
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 
     26 namespace tensorflow {
     27 
     28 typedef Eigen::ThreadPoolDevice CPUDevice;
     29 typedef Eigen::GpuDevice GPUDevice;
     30 
     31 template <typename Device, typename T>
     32 class DataFormatDimMapOp : public OpKernel {
     33  public:
     34   explicit DataFormatDimMapOp(OpKernelConstruction* context)
     35       : OpKernel(context) {
     36     string src_format;
     37     OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
     38     string dst_format;
     39     OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
     40     OP_REQUIRES(
     41         context, src_format == "NHWC",
     42         errors::InvalidArgument(strings::StrCat(
     43             "Current implementation doesn't support source data format ",
     44             src_format)));
     45     OP_REQUIRES(context, dst_format == "NCHW",
     46                 errors::InvalidArgument(strings::StrCat(
     47                     "Current implementation doesn't support dst data format ",
     48                     dst_format)));
     49   }
     50 
     51   void Compute(OpKernelContext* context) override {
     52     const Tensor& input = context->input(0);
     53     Tensor* output = nullptr;
     54     OP_REQUIRES_OK(context,
     55                    context->allocate_output(0, input.shape(), &output));
     56     functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(),
     57                                            input.flat<T>(), output->flat<T>());
     58   }
     59 };
     60 
     61 template <typename Device, typename T>
     62 class DataFormatVecPermuteOp : public OpKernel {
     63  public:
     64   explicit DataFormatVecPermuteOp(OpKernelConstruction* context)
     65       : OpKernel(context) {
     66     string src_format;
     67     OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
     68     string dst_format;
     69     OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
     70     OP_REQUIRES(context,
     71                 (src_format == "NHWC" && dst_format == "NCHW") ||
     72                     (src_format == "NCHW" && dst_format == "NHWC"),
     73                 errors::InvalidArgument(strings::StrCat(
     74                     "Current implementation only supports NCHW-to-NHWC and "
     75                     "NHWC-to-NCHW format conversion; got source format ",
     76                     src_format, " and destination format ", dst_format)));
     77     nhwc_to_nchw_ = (src_format == "NHWC") ? true : false;
     78   }
     79 
     80   void Compute(OpKernelContext* context) override {
     81     const Tensor& input = context->input(0);
     82     OP_REQUIRES(context, input.dims() == 1 || input.dims() == 2,
     83                 errors::InvalidArgument(
     84                     "input must be a vector or 2D tensor, but got shape ",
     85                     input.shape().DebugString()));
     86     if (input.dims() == 1) {
     87       OP_REQUIRES(
     88           context, input.NumElements() == 4,
     89           errors::InvalidArgument("1D input must be of size 4, but got shape ",
     90                                   input.shape().DebugString()));
     91     } else if (input.dims() == 2) {
     92       OP_REQUIRES(
     93           context, input.dim_size(0) == 4,
     94           errors::InvalidArgument(
     95               "First dimension of 2D input must be of size 4, but got shape ",
     96               input.shape().DebugString()));
     97       OP_REQUIRES(
     98           context, input.dim_size(1) == 2,
     99           errors::InvalidArgument(
    100               "Second dimension of 2D input must be of size 2, but got shape ",
    101               input.shape().DebugString()));
    102     }
    103 
    104     Tensor* output = nullptr;
    105     OP_REQUIRES_OK(context,
    106                    context->allocate_output(0, input.shape(), &output));
    107     functor::DataFormatVecPermute<Device, T>()(
    108         context->eigen_device<Device>(), input.flat<T>(), output->flat<T>(),
    109         nhwc_to_nchw_);
    110   }
    111 
    112  private:
    113   bool nhwc_to_nchw_;
    114 };
    115 
    116 #define REGISTER_KERNEL(T)                                                \
    117   REGISTER_KERNEL_BUILDER(                                                \
    118       Name("DataFormatDimMap").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    119       DataFormatDimMapOp<CPUDevice, T>);
    120 TF_CALL_int32(REGISTER_KERNEL);
    121 TF_CALL_int64(REGISTER_KERNEL);
    122 #undef REGISTER_KERNEL
    123 
    124 #define REGISTER_KERNEL(T)                                                    \
    125   REGISTER_KERNEL_BUILDER(                                                    \
    126       Name("DataFormatVecPermute").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    127       DataFormatVecPermuteOp<CPUDevice, T>);
    128 TF_CALL_int32(REGISTER_KERNEL);
    129 TF_CALL_int64(REGISTER_KERNEL);
    130 #undef REGISTER_KERNEL
    131 
    132 #if GOOGLE_CUDA
    133 // Forward declarations of the functor specializations for GPU.
    134 namespace functor {
    135 #define DECLARE_GPU_SPEC(T)                                \
    136   template <>                                              \
    137   void DataFormatDimMap<GPUDevice, T>::operator()(         \
    138       const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
    139       typename TTypes<T>::Flat y);                         \
    140   extern template struct DataFormatDimMap<GPUDevice, T>;
    141 #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
    142 TF_CALL_int32(DECLARE_GPU_SPECS);
    143 TF_CALL_int64(DECLARE_GPU_SPECS);
    144 #undef DECLARE_GPU_SPEC
    145 
    146 #define DECLARE_GPU_SPEC(T)                                \
    147   template <>                                              \
    148   void DataFormatVecPermute<GPUDevice, T>::operator()(     \
    149       const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
    150       typename TTypes<T>::Vec y, bool nhwc_to_nchw);       \
    151   extern template struct DataFormatVecPermute<GPUDevice, T>;
    152 #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
    153 TF_CALL_int32(DECLARE_GPU_SPECS);
    154 TF_CALL_int64(DECLARE_GPU_SPECS);
    155 #undef DECLARE_GPU_SPEC
    156 }  // namespace functor
    157 
    158 // Registration of the GPU implementations.
    159 #define REGISTER_GPU_KERNEL(T)                                            \
    160   REGISTER_KERNEL_BUILDER(                                                \
    161       Name("DataFormatDimMap").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    162       DataFormatDimMapOp<GPUDevice, T>);
    163 TF_CALL_int32(REGISTER_GPU_KERNEL);
    164 TF_CALL_int64(REGISTER_GPU_KERNEL);
    165 #undef REGISTER_GPU_KERNEL
    166 
    167 #define REGISTER_GPU_KERNEL(T)                                                \
    168   REGISTER_KERNEL_BUILDER(                                                    \
    169       Name("DataFormatVecPermute").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    170       DataFormatVecPermuteOp<GPUDevice, T>);
    171 TF_CALL_int32(REGISTER_GPU_KERNEL);
    172 TF_CALL_int64(REGISTER_GPU_KERNEL);
    173 #undef REGISTER_GPU_KERNEL
    174 #endif  // GOOGLE_CUDA
    175 
    176 }  // namespace tensorflow
    177