1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/image_ops.cc 17 #define EIGEN_USE_THREADS 18 19 #include "tensorflow/core/kernels/resize_nearest_neighbor_op.h" 20 21 #include <memory> 22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/tensor_shape.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/kernels/image_resizer_state.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/platform/logging.h" 31 32 namespace tensorflow { 33 34 typedef Eigen::ThreadPoolDevice CPUDevice; 35 typedef Eigen::GpuDevice GPUDevice; 36 37 template <typename Device, typename T> 38 class ResizeNearestNeighborOp : public OpKernel { 39 public: 40 explicit ResizeNearestNeighborOp(OpKernelConstruction* context) 41 : OpKernel(context) { 42 OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_)); 43 } 44 45 void Compute(OpKernelContext* context) override { 46 const Tensor& input = context->input(0); 47 ImageResizerState st(align_corners_); 48 st.ValidateAndCreateOutput(context, input); 49 50 if (!context->status().ok()) return; 51 52 OP_REQUIRES(context, st.in_height < (1 << 24) && st.in_width < (1 << 24), 53 errors::InvalidArgument("nearest neighbor requires max height " 54 "& width of 2^24")); 55 56 // Return if the output is empty. 57 if (st.output->NumElements() == 0) return; 58 59 typename TTypes<T, 4>::ConstTensor input_data(input.tensor<T, 4>()); 60 typename TTypes<T, 4>::Tensor output_data(st.output->tensor<T, 4>()); 61 62 bool status; 63 if (align_corners_) { 64 status = 65 functor::ResizeNearestNeighbor<Device, T, /*align_corners=*/true>()( 66 context->eigen_device<Device>(), input_data, st.height_scale, 67 st.width_scale, output_data); 68 } else { 69 status = 70 functor::ResizeNearestNeighbor<Device, T, /*align_corners=*/false>()( 71 context->eigen_device<Device>(), input_data, st.height_scale, 72 st.width_scale, output_data); 73 } 74 if (!status) { 75 context->SetStatus( 76 errors::Internal("Failed launching ResizeNearestNeighbor")); 77 } 78 } 79 80 private: 81 bool align_corners_; 82 }; 83 84 // Partial specialization of ResizeNearestNeighbor functor for a CPUDevice. 85 namespace functor { 86 template <typename T, bool align_corners> 87 struct ResizeNearestNeighbor<CPUDevice, T, align_corners> { 88 bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input, 89 const float height_scale, const float width_scale, 90 typename TTypes<T, 4>::Tensor output) { 91 const int batch_size = input.dimension(0); 92 const int64 in_height = input.dimension(1); 93 const int64 in_width = input.dimension(2); 94 const int channels = input.dimension(3); 95 96 const int64 out_height = output.dimension(1); 97 const int64 out_width = output.dimension(2); 98 99 for (int b = 0; b < batch_size; ++b) { 100 for (int y = 0; y < out_height; ++y) { 101 const int64 in_y = std::min( 102 (align_corners) ? static_cast<int64>(roundf(y * height_scale)) 103 : static_cast<int64>(floorf(y * height_scale)), 104 in_height - 1); 105 for (int x = 0; x < out_width; ++x) { 106 const int64 in_x = std::min( 107 (align_corners) ? static_cast<int64>(roundf(x * width_scale)) 108 : static_cast<int64>(floorf(x * width_scale)), 109 in_width - 1); 110 std::copy_n(&input(b, in_y, in_x, 0), channels, &output(b, y, x, 0)); 111 } 112 } 113 } 114 return true; 115 } 116 }; 117 } // namespace functor 118 119 template <typename Device, typename T> 120 class ResizeNearestNeighborOpGrad : public OpKernel { 121 public: 122 explicit ResizeNearestNeighborOpGrad(OpKernelConstruction* context) 123 : OpKernel(context) { 124 OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_)); 125 } 126 127 void Compute(OpKernelContext* context) override { 128 // Grab and validate the input: 129 const Tensor& input = context->input(0); 130 OP_REQUIRES(context, input.dims() == 4, 131 errors::InvalidArgument("input must be 4-dimensional", 132 input.shape().DebugString())); 133 134 // Grab and validate the output shape: 135 const Tensor& shape_t = context->input(1); 136 OP_REQUIRES(context, shape_t.dims() == 1, 137 errors::InvalidArgument("shape_t must be 1-dimensional", 138 shape_t.shape().DebugString())); 139 OP_REQUIRES(context, shape_t.NumElements() == 2, 140 errors::InvalidArgument("shape_t must have two elements", 141 shape_t.shape().DebugString())); 142 143 auto sizes = shape_t.vec<int32>(); 144 OP_REQUIRES(context, sizes(0) > 0 && sizes(1) > 0, 145 errors::InvalidArgument("shape_t's elements must be positive")); 146 147 const int64 batch_size = input.dim_size(0); 148 const int64 in_height = input.dim_size(1); 149 const int64 in_width = input.dim_size(2); 150 const int64 channels = input.dim_size(3); 151 152 const int64 out_height = sizes(0); 153 const int64 out_width = sizes(1); 154 155 Tensor* output = nullptr; 156 OP_REQUIRES_OK( 157 context, 158 context->allocate_output( 159 0, TensorShape({batch_size, out_height, out_width, channels}), 160 &output)); 161 162 // Return if the output is empty. 163 if (output->NumElements() == 0) return; 164 165 typename TTypes<T, 4>::ConstTensor input_data(input.tensor<T, 4>()); 166 typename TTypes<T, 4>::Tensor output_data(output->tensor<T, 4>()); 167 168 const float height_scale = 169 CalculateResizeScale(out_height, in_height, align_corners_); 170 const float width_scale = 171 CalculateResizeScale(out_width, in_width, align_corners_); 172 173 bool status; 174 if (align_corners_) { 175 status = functor::ResizeNearestNeighborGrad<Device, T, 176 /*align_corners=*/true>()( 177 context->eigen_device<Device>(), input_data, height_scale, 178 width_scale, output_data); 179 } else { 180 status = functor::ResizeNearestNeighborGrad<Device, T, 181 /*align_corners=*/false>()( 182 context->eigen_device<Device>(), input_data, height_scale, 183 width_scale, output_data); 184 } 185 if (!status) { 186 context->SetStatus( 187 errors::Internal("Failed launching ResizeNearestNeighborGrad")); 188 } 189 } 190 191 private: 192 bool align_corners_; 193 }; 194 195 // Partial specialization of ResizeNearestNeighborGrad functor for a CPUDevice. 196 namespace functor { 197 template <typename T, bool align_corners> 198 struct ResizeNearestNeighborGrad<CPUDevice, T, align_corners> { 199 bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input, 200 const float height_scale, const float width_scale, 201 typename TTypes<T, 4>::Tensor output) { 202 const int batch_size = input.dimension(0); 203 const int64 in_height = input.dimension(1); 204 const int64 in_width = input.dimension(2); 205 const int channels = input.dimension(3); 206 207 const int64 out_height = output.dimension(1); 208 const int64 out_width = output.dimension(2); 209 210 output.setZero(); 211 212 for (int y = 0; y < in_height; ++y) { 213 const int64 out_y = std::min( 214 (align_corners) ? static_cast<int64>(roundf(y * height_scale)) 215 : static_cast<int64>(floorf(y * height_scale)), 216 out_height - 1); 217 for (int x = 0; x < in_width; ++x) { 218 const int64 out_x = std::min( 219 (align_corners) ? static_cast<int64>(roundf(x * width_scale)) 220 : static_cast<int64>(floorf(x * width_scale)), 221 out_width - 1); 222 for (int b = 0; b < batch_size; ++b) { 223 for (int c = 0; c < channels; ++c) { 224 output(b, out_y, out_x, c) += input(b, y, x, c); 225 } 226 } 227 } 228 } 229 return true; 230 } 231 }; 232 } // namespace functor 233 234 #define REGISTER_KERNEL(T) \ 235 REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighbor") \ 236 .Device(DEVICE_CPU) \ 237 .TypeConstraint<T>("T") \ 238 .HostMemory("size"), \ 239 ResizeNearestNeighborOp<CPUDevice, T>); \ 240 REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighborGrad") \ 241 .Device(DEVICE_CPU) \ 242 .TypeConstraint<T>("T") \ 243 .HostMemory("size"), \ 244 ResizeNearestNeighborOpGrad<CPUDevice, T>); 245 246 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL); 247 248 #undef REGISTER_KERNEL 249 250 #if GOOGLE_CUDA 251 252 #define REGISTER_KERNEL(T) \ 253 REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighbor") \ 254 .Device(DEVICE_GPU) \ 255 .TypeConstraint<T>("T") \ 256 .HostMemory("size"), \ 257 ResizeNearestNeighborOp<GPUDevice, T>); \ 258 REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighborGrad") \ 259 .Device(DEVICE_GPU) \ 260 .TypeConstraint<T>("T") \ 261 .HostMemory("size"), \ 262 ResizeNearestNeighborOpGrad<GPUDevice, T>); 263 264 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL); 265 266 #undef REGISTER_KERNEL 267 268 #endif // GOOGLE_CUDA 269 270 } // namespace tensorflow 271