1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include "tensorflow/core/kernels/data_format_ops.h" 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/tensor.h" 25 26 namespace tensorflow { 27 28 typedef Eigen::ThreadPoolDevice CPUDevice; 29 typedef Eigen::GpuDevice GPUDevice; 30 31 template <typename Device, typename T> 32 class DataFormatDimMapOp : public OpKernel { 33 public: 34 explicit DataFormatDimMapOp(OpKernelConstruction* context) 35 : OpKernel(context) { 36 string src_format; 37 OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format)); 38 string dst_format; 39 OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format)); 40 OP_REQUIRES( 41 context, src_format == "NHWC", 42 errors::InvalidArgument(strings::StrCat( 43 "Current implementation doesn't support source data format ", 44 src_format))); 45 OP_REQUIRES(context, dst_format == "NCHW", 46 errors::InvalidArgument(strings::StrCat( 47 "Current implementation doesn't support dst data format ", 48 dst_format))); 49 } 50 51 void Compute(OpKernelContext* context) override { 52 const Tensor& input = context->input(0); 53 Tensor* output = nullptr; 54 OP_REQUIRES_OK(context, 55 context->allocate_output(0, input.shape(), &output)); 56 functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(), 57 input.flat<T>(), output->flat<T>()); 58 } 59 }; 60 61 template <typename Device, typename T> 62 class DataFormatVecPermuteOp : public OpKernel { 63 public: 64 explicit DataFormatVecPermuteOp(OpKernelConstruction* context) 65 : OpKernel(context) { 66 string src_format; 67 OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format)); 68 string dst_format; 69 OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format)); 70 OP_REQUIRES(context, 71 (src_format == "NHWC" && dst_format == "NCHW") || 72 (src_format == "NCHW" && dst_format == "NHWC"), 73 errors::InvalidArgument(strings::StrCat( 74 "Current implementation only supports NCHW-to-NHWC and " 75 "NHWC-to-NCHW format conversion; got source format ", 76 src_format, " and destination format ", dst_format))); 77 nhwc_to_nchw_ = (src_format == "NHWC") ? true : false; 78 } 79 80 void Compute(OpKernelContext* context) override { 81 const Tensor& input = context->input(0); 82 OP_REQUIRES(context, input.dims() == 1 || input.dims() == 2, 83 errors::InvalidArgument( 84 "input must be a vector or 2D tensor, but got shape ", 85 input.shape().DebugString())); 86 if (input.dims() == 1) { 87 OP_REQUIRES( 88 context, input.NumElements() == 4, 89 errors::InvalidArgument("1D input must be of size 4, but got shape ", 90 input.shape().DebugString())); 91 } else if (input.dims() == 2) { 92 OP_REQUIRES( 93 context, input.dim_size(0) == 4, 94 errors::InvalidArgument( 95 "First dimension of 2D input must be of size 4, but got shape ", 96 input.shape().DebugString())); 97 OP_REQUIRES( 98 context, input.dim_size(1) == 2, 99 errors::InvalidArgument( 100 "Second dimension of 2D input must be of size 2, but got shape ", 101 input.shape().DebugString())); 102 } 103 104 Tensor* output = nullptr; 105 OP_REQUIRES_OK(context, 106 context->allocate_output(0, input.shape(), &output)); 107 functor::DataFormatVecPermute<Device, T>()( 108 context->eigen_device<Device>(), input.flat<T>(), output->flat<T>(), 109 nhwc_to_nchw_); 110 } 111 112 private: 113 bool nhwc_to_nchw_; 114 }; 115 116 #define REGISTER_KERNEL(T) \ 117 REGISTER_KERNEL_BUILDER( \ 118 Name("DataFormatDimMap").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 119 DataFormatDimMapOp<CPUDevice, T>); 120 TF_CALL_int32(REGISTER_KERNEL); 121 TF_CALL_int64(REGISTER_KERNEL); 122 #undef REGISTER_KERNEL 123 124 #define REGISTER_KERNEL(T) \ 125 REGISTER_KERNEL_BUILDER( \ 126 Name("DataFormatVecPermute").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 127 DataFormatVecPermuteOp<CPUDevice, T>); 128 TF_CALL_int32(REGISTER_KERNEL); 129 TF_CALL_int64(REGISTER_KERNEL); 130 #undef REGISTER_KERNEL 131 132 #if GOOGLE_CUDA 133 // Forward declarations of the functor specializations for GPU. 134 namespace functor { 135 #define DECLARE_GPU_SPEC(T) \ 136 template <> \ 137 void DataFormatDimMap<GPUDevice, T>::operator()( \ 138 const GPUDevice& d, typename TTypes<T>::ConstFlat x, \ 139 typename TTypes<T>::Flat y); \ 140 extern template struct DataFormatDimMap<GPUDevice, T>; 141 #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); 142 TF_CALL_int32(DECLARE_GPU_SPECS); 143 TF_CALL_int64(DECLARE_GPU_SPECS); 144 #undef DECLARE_GPU_SPEC 145 146 #define DECLARE_GPU_SPEC(T) \ 147 template <> \ 148 void DataFormatVecPermute<GPUDevice, T>::operator()( \ 149 const GPUDevice& d, typename TTypes<T>::ConstFlat x, \ 150 typename TTypes<T>::Vec y, bool nhwc_to_nchw); \ 151 extern template struct DataFormatVecPermute<GPUDevice, T>; 152 #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); 153 TF_CALL_int32(DECLARE_GPU_SPECS); 154 TF_CALL_int64(DECLARE_GPU_SPECS); 155 #undef DECLARE_GPU_SPEC 156 } // namespace functor 157 158 // Registration of the GPU implementations. 159 #define REGISTER_GPU_KERNEL(T) \ 160 REGISTER_KERNEL_BUILDER( \ 161 Name("DataFormatDimMap").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 162 DataFormatDimMapOp<GPUDevice, T>); 163 TF_CALL_int32(REGISTER_GPU_KERNEL); 164 TF_CALL_int64(REGISTER_GPU_KERNEL); 165 #undef REGISTER_GPU_KERNEL 166 167 #define REGISTER_GPU_KERNEL(T) \ 168 REGISTER_KERNEL_BUILDER( \ 169 Name("DataFormatVecPermute").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 170 DataFormatVecPermuteOp<GPUDevice, T>); 171 TF_CALL_int32(REGISTER_GPU_KERNEL); 172 TF_CALL_int64(REGISTER_GPU_KERNEL); 173 #undef REGISTER_GPU_KERNEL 174 #endif // GOOGLE_CUDA 175 176 } // namespace tensorflow 177