1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include <memory> 21 #include <string> 22 #include <utility> 23 24 #include "tensorflow/core/kernels/spacetobatch_functor.h" 25 26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 27 #include "tensorflow/core/framework/op.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/register_types.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_shape.h" 32 #include "tensorflow/core/framework/tensor_types.h" 33 #include "tensorflow/core/framework/types.h" 34 #include "tensorflow/core/platform/logging.h" 35 #include "tensorflow/core/platform/types.h" 36 37 namespace tensorflow { 38 39 typedef Eigen::ThreadPoolDevice CPUDevice; 40 typedef Eigen::GpuDevice GPUDevice; 41 42 template <typename Device, typename T> 43 static void BatchToSpaceOpCompute(OpKernelContext* context, 44 const Tensor& orig_input_tensor, 45 const Tensor& orig_block_shape, 46 const Tensor& orig_crops) { 47 const int input_dims = orig_input_tensor.dims(); 48 OP_REQUIRES( 49 context, TensorShapeUtils::IsVector(orig_block_shape.shape()), 50 errors::InvalidArgument("block_shape rank should be 1 instead of ", 51 orig_block_shape.dims())); 52 53 const int block_dims = orig_block_shape.dim_size(0); 54 OP_REQUIRES( 55 context, orig_input_tensor.dims() >= 1 + block_dims, 56 errors::InvalidArgument("input rank should be >= ", 1 + block_dims, 57 " instead of ", orig_input_tensor.dims())); 58 59 OP_REQUIRES(context, 60 TensorShapeUtils::IsMatrix(orig_crops.shape()) && 61 block_dims == orig_crops.dim_size(0) && 62 2 == orig_crops.dim_size(1), 63 errors::InvalidArgument("crops should have shape [", block_dims, 64 ", 2] instead of ", 65 orig_crops.shape().DebugString())); 66 // To avoid out-of-bounds access in the case that the block_shape and/or 67 // crops tensors are concurrently modified, we must copy the values. 68 gtl::InlinedVector<int64, 4> block_shape; 69 gtl::InlinedVector<int64, 8> crops; 70 internal::spacetobatch::SubtleMustCopyFlat(orig_block_shape, &block_shape); 71 internal::spacetobatch::SubtleMustCopyFlat(orig_crops, &crops); 72 73 // Determine the length of the prefix of block dims that can be combined 74 // into the batch dimension due to having no padding and block_shape=1. 75 int removed_prefix_block_dims = 0; 76 for (; removed_prefix_block_dims < block_dims; ++removed_prefix_block_dims) { 77 const int dim = removed_prefix_block_dims; 78 if (crops[2 * dim] != 0 || crops[2 * dim + 1] != 0 || 79 block_shape[dim] != 1) { 80 break; 81 } 82 } 83 84 // Determine the length of the suffix of block dims that can be combined 85 // into the depth dimension due to having no padding and block_shape=1. 86 int removed_suffix_block_dims = 0; 87 for (; removed_suffix_block_dims < block_dims - removed_prefix_block_dims; 88 ++removed_suffix_block_dims) { 89 const int dim = block_dims - 1 - removed_suffix_block_dims; 90 if (crops[2 * dim] != 0 || crops[2 * dim + 1] != 0 || 91 block_shape[dim] != 1) { 92 break; 93 } 94 } 95 96 // Compute the product of the block_shape values. 97 int64 block_shape_product = 1; 98 for (int block_dim = 0; block_dim < block_dims; ++block_dim) { 99 block_shape_product *= block_shape[block_dim]; 100 } 101 OP_REQUIRES( 102 context, block_shape_product > 0, 103 errors::InvalidArgument("Product of block sizes must be positive, got ", 104 block_shape_product)); 105 106 const int64 orig_input_batch_size = orig_input_tensor.dim_size(0); 107 OP_REQUIRES( 108 context, orig_input_batch_size % block_shape_product == 0, 109 errors::InvalidArgument("Input batch dimension (", orig_input_batch_size, 110 ") is not divisible by product of block sizes (", 111 block_shape_product, ")")); 112 113 const int internal_block_dims = 114 block_dims - removed_prefix_block_dims - removed_suffix_block_dims; 115 OP_REQUIRES(context, internal_block_dims <= kMaxSpaceToBatchBlockDims, 116 errors::InvalidArgument( 117 "Maximum number of non-combined block dimensions is ", 118 internal_block_dims, " but must not exceed ", 119 kMaxSpaceToBatchBlockDims)); 120 121 if (internal_block_dims == 0) { 122 context->set_output(0, orig_input_tensor); 123 return; 124 } 125 126 // For the purpose of computing the result, the input will be treated as 127 // having this shape, of rank 2 + internal_block_dims. 128 TensorShape internal_input_shape; 129 130 // For the purpose of computing the result, the output will be treated as 131 // having this shape, of rank 2 + internal_block_dims. 132 TensorShape internal_output_shape; 133 134 // The actual output shape exposed to callers. 135 TensorShape external_output_shape; 136 137 external_output_shape.AddDim(orig_input_batch_size / block_shape_product); 138 139 int64 input_batch_size = orig_input_batch_size; 140 for (int block_dim = 0; block_dim < removed_prefix_block_dims; ++block_dim) { 141 const int64 size = orig_input_tensor.dim_size(block_dim + 1); 142 input_batch_size *= size; 143 external_output_shape.AddDim(size); 144 } 145 internal_input_shape.AddDim(input_batch_size); 146 internal_output_shape.AddDim(input_batch_size / block_shape_product); 147 148 for (int block_dim = removed_prefix_block_dims; 149 block_dim < block_dims - removed_suffix_block_dims; ++block_dim) { 150 const int64 crop_start = crops[2 * block_dim], 151 crop_end = crops[2 * block_dim + 1]; 152 OP_REQUIRES(context, crop_start >= 0 && crop_end >= 0, 153 errors::InvalidArgument("Crops must be non-negative")); 154 const int64 input_size = orig_input_tensor.dim_size(block_dim + 1); 155 const int64 block_shape_value = block_shape[block_dim]; 156 const int64 cropped_size = 157 input_size * block_shape_value - crop_start - crop_end; 158 OP_REQUIRES(context, cropped_size >= 0, 159 errors::InvalidArgument("cropped_shape[", block_dim, "]=", 160 cropped_size, " must be non-negative")); 161 internal_input_shape.AddDim(input_size); 162 internal_output_shape.AddDim(cropped_size); 163 external_output_shape.AddDim(cropped_size); 164 } 165 166 int64 depth = 1; 167 for (int dim = block_dims - removed_suffix_block_dims + 1; dim < input_dims; 168 ++dim) { 169 const int64 size = orig_input_tensor.dim_size(dim); 170 external_output_shape.AddDim(size); 171 depth *= size; 172 } 173 internal_input_shape.AddDim(depth); 174 internal_output_shape.AddDim(depth); 175 176 // Allocate output tensor. 177 Tensor* output_tensor = nullptr; 178 OP_REQUIRES_OK(context, context->allocate_output(0, external_output_shape, 179 &output_tensor)); 180 181 const int64* internal_crops = &crops[2 * removed_prefix_block_dims]; 182 const int64* internal_block_shape = &block_shape[removed_prefix_block_dims]; 183 184 switch (internal_block_dims) { 185 #define TF_BATCHTOSPACE_BLOCK_DIMS_CASE(NUM_BLOCK_DIMS) \ 186 case NUM_BLOCK_DIMS: { \ 187 OP_REQUIRES_OK( \ 188 context, \ 189 (functor::SpaceToBatchFunctor<Device, T, NUM_BLOCK_DIMS, true>()( \ 190 context->eigen_device<Device>(), \ 191 output_tensor->shaped<T, NUM_BLOCK_DIMS + 2>( \ 192 internal_output_shape.dim_sizes()), \ 193 internal_block_shape, internal_crops, \ 194 orig_input_tensor.shaped<T, NUM_BLOCK_DIMS + 2>( \ 195 internal_input_shape.dim_sizes())))); \ 196 } break; \ 197 /**/ 198 TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(TF_BATCHTOSPACE_BLOCK_DIMS_CASE) 199 #undef TF_BATCHTOSPACE_BLOCK_DIMS_CASE 200 } 201 } 202 203 template <typename Device, typename T> 204 class BatchToSpaceNDOp : public OpKernel { 205 public: 206 explicit BatchToSpaceNDOp(OpKernelConstruction* context) 207 : OpKernel(context) {} 208 209 void Compute(OpKernelContext* context) override { 210 const Tensor& orig_input_tensor = context->input(0); 211 const Tensor& orig_block_shape = context->input(1); 212 const Tensor& orig_crops = context->input(2); 213 BatchToSpaceOpCompute<Device, T>(context, orig_input_tensor, 214 orig_block_shape, orig_crops); 215 } 216 }; 217 218 template <typename Device, typename T> 219 class BatchToSpaceOp : public OpKernel { 220 public: 221 explicit BatchToSpaceOp(OpKernelConstruction* context) : OpKernel(context) { 222 OP_REQUIRES_OK(context, context->GetAttr("block_size", &block_size_)); 223 OP_REQUIRES( 224 context, block_size_ > 1, 225 errors::InvalidArgument("Block size should be > 1: ", block_size_)); 226 // We don't use context->allocate_persistent because the allocation must 227 // happen on the CPU regardless of Device. 228 block_shape_ = Tensor(tensorflow::DT_INT64, TensorShape({2})); 229 auto block_shape_vec = block_shape_.vec<int64>(); 230 block_shape_vec(0) = block_size_; 231 block_shape_vec(1) = block_size_; 232 } 233 234 void Compute(OpKernelContext* context) override { 235 const Tensor& in0 = context->input(0); 236 const Tensor& in1 = context->input(1); 237 const int dims = in0.dims(); 238 239 // Check on the input dimensions first. 240 // The input is presumed to be [batch, height, width, depth] 241 static const int kRequiredDims = 4; 242 OP_REQUIRES(context, kRequiredDims == dims, 243 errors::InvalidArgument("Input rank should be: ", kRequiredDims, 244 "instead of: ", dims)); 245 BatchToSpaceOpCompute<Device, T>(context, in0, block_shape_, in1); 246 } 247 248 private: 249 int block_size_; 250 Tensor block_shape_; 251 }; 252 253 #define REGISTER(T) \ 254 REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND") \ 255 .Device(DEVICE_CPU) \ 256 .TypeConstraint<T>("T") \ 257 .HostMemory("block_shape") \ 258 .HostMemory("crops"), \ 259 BatchToSpaceNDOp<CPUDevice, T>); \ 260 REGISTER_KERNEL_BUILDER(Name("BatchToSpace") \ 261 .Device(DEVICE_CPU) \ 262 .TypeConstraint<T>("T") \ 263 .HostMemory("crops"), \ 264 BatchToSpaceOp<CPUDevice, T>); 265 266 TF_CALL_REAL_NUMBER_TYPES(REGISTER); 267 #undef REGISTER 268 269 #if GOOGLE_CUDA 270 #define REGISTER(T) \ 271 REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND") \ 272 .Device(DEVICE_GPU) \ 273 .TypeConstraint<T>("T") \ 274 .HostMemory("block_shape") \ 275 .HostMemory("crops"), \ 276 BatchToSpaceNDOp<GPUDevice, T>); \ 277 REGISTER_KERNEL_BUILDER(Name("BatchToSpace") \ 278 .Device(DEVICE_GPU) \ 279 .TypeConstraint<T>("T") \ 280 .HostMemory("crops"), \ 281 BatchToSpaceOp<GPUDevice, T>); 282 283 TF_CALL_GPU_NUMBER_TYPES(REGISTER); 284 #undef REGISTER 285 #endif // GOOGLE_CUDA 286 287 } // end namespace tensorflow 288