1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/type_util.h" 17 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 18 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 19 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 20 #include "tensorflow/compiler/xla/array4d.h" 21 #include "tensorflow/core/framework/kernel_def_builder.h" 22 #include "tensorflow/core/framework/register_types.h" 23 #include "tensorflow/core/lib/math/math_util.h" 24 25 namespace tensorflow { 26 namespace { 27 28 // We implement bilinear interpolation by upsampling followed by convolution. 29 // The basic idea is as follows. To scale from NxN to RxR: 30 // 31 // 1. S := (N - 1) / gcd(N-1, R-1) 32 // 2. k := (R - 1) / gcd(N-1, R-1) 33 // 3. Convolution(kxk, stride=S, lhs_dilation=k, padding=k-1) 34 // 35 // For example, to Scale from 7x7 -> 15x15: 36 // 37 // 1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3 38 // 2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7 39 // 3. Convolution(7x7, stride=3, lhs_dilation=3, padding=2) 40 // 41 // 42 // The 7x7 -> 15x15 case is much too large to write out in full as an 43 // example. The smallest interesting example is 3x3 -> 4x4. 44 // 45 // S := 2 46 // k := 3 47 // 48 // 00 03 06 00 00 00 00 00 00 00 00 00 00 00 00 02 04 06 49 // 09 12 15 -> 00 00 00 00 00 00 00 00 00 00 00 -> 06 08 10 12 50 // 18 21 24 00 00 00 00 00 03 00 00 06 00 00 12 14 16 18 51 // 00 00 00 00 00 00 00 00 00 00 00 18 20 22 24 52 // 00 00 00 00 00 00 00 00 00 00 00 53 // 00 00 09 00 00 12 00 00 15 00 00 54 // 00 00 00 00 00 00 00 00 00 00 00 55 // 00 00 00 00 00 00 00 00 00 00 00 56 // 00 00 18 00 00 21 00 00 24 00 00 57 // 00 00 00 00 00 00 00 00 00 00 00 58 // 00 00 00 00 00 00 00 00 00 00 00 59 // 60 // with the following convolutional kernel, with stride [2, 2]: 61 // 1 2 3 2 1 62 // 2 4 6 4 2 63 // 1/9 * 3 6 9 6 3 64 // 2 4 6 4 2 65 // 1 2 3 2 1 66 67 // Computes the size of the convolutional kernel and stride to use when resizing 68 // from in_size to out_size. 69 struct ResizeConvolutionDims { 70 // Size of the kernel to use. 71 std::vector<int64> kernel_size; 72 73 // Stride of the convolution to use. 74 std::vector<int64> stride; 75 }; 76 ResizeConvolutionDims ComputeResizeConvolutionParameters( 77 gtl::ArraySlice<int64> in_size, gtl::ArraySlice<int64> out_size) { 78 CHECK_EQ(in_size.size(), out_size.size()); 79 int num_spatial_dims = in_size.size(); 80 ResizeConvolutionDims dims; 81 dims.kernel_size.resize(num_spatial_dims); 82 dims.stride.resize(num_spatial_dims); 83 for (int i = 0; i < num_spatial_dims; ++i) { 84 if (in_size[i] == 1) { 85 // We must handle input size 1 specially because XLA convolution does 86 // not allow stride 0. 87 dims.stride[i] = dims.kernel_size[i] = 1; 88 } else if (out_size[i] == 1) { 89 // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first 90 // entry before resizing. 91 dims.stride[i] = dims.kernel_size[i] = 1; 92 } else { 93 int64 gcd = MathUtil::GCD(static_cast<uint64>(in_size[i] - 1), 94 static_cast<uint64>(out_size[i] - 1)); 95 dims.stride[i] = (in_size[i] - 1) / gcd; 96 dims.kernel_size[i] = (out_size[i] - 1) / gcd; 97 } 98 } 99 return dims; 100 } 101 102 xla::ComputationDataHandle MakeBilinearResizeKernel( 103 xla::ComputationBuilder* builder, gtl::ArraySlice<int64> kernel_size, 104 int64 channels) { 105 // Form a 2D convolution kernel like: 106 // 1 2 3 2 1 107 // 2 4 6 4 2 108 // 1/9 * 3 6 9 6 3 109 // 2 4 6 4 2 110 // 1 2 3 2 1 111 // by multiplying two 1D kernels of the form: 112 // 1/3 * [1 2 3 2 1] 113 auto make_1d_kernel = [](int64 n) { 114 std::vector<float> kernel(n * 2 - 1); 115 for (int64 i = 0; i < n; ++i) { 116 float v = (i + 1.0f) / n; 117 kernel[i] = v; 118 kernel[n * 2 - 2 - i] = v; 119 } 120 return kernel; 121 }; 122 123 xla::ComputationDataHandle channels_iota; 124 // DT_INT32 Iota will always return status::OK(). 125 TF_CHECK_OK( 126 XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); 127 128 auto diag = builder->ConvertElementType( 129 builder->Eq( 130 builder->Broadcast(channels_iota, {2 * kernel_size[0] - 1, 131 2 * kernel_size[1] - 1, channels}), 132 channels_iota, /*broadcast_dimensions=*/{2}), 133 xla::PrimitiveType::F32); 134 return builder->Mul( 135 builder->Mul(diag, 136 builder->ConstantR1<float>(make_1d_kernel(kernel_size[1])), 137 /*broadcast_dimensions=*/{1}), 138 builder->ConstantR1<float>(make_1d_kernel(kernel_size[0])), 139 /*broadcast_dimensions=*/{0}); 140 } 141 142 xla::ComputationDataHandle ResizeUsingDilationAndConvolution( 143 xla::ComputationBuilder* builder, const xla::ComputationDataHandle& input, 144 const int num_spatial_dims, std::vector<int64> in_size, 145 std::vector<int64> out_size, const int64 channels) { 146 // Picture for a 1x3 to 1x4 resize: 147 // stride = 2, kernel size = 3 148 // Input: 149 // 3 6 9 150 // Input with dilation and padding: 151 // 0 0 3 0 0 6 0 0 9 0 0 152 // Convolution kernel: 153 // 1/3 * [1 2 3 2 1] 154 // Output: 155 // 3 5 7 9 156 xla::ConvolutionDimensionNumbers dimension_numbers; 157 dimension_numbers.set_input_batch_dimension(0); 158 dimension_numbers.set_output_batch_dimension(0); 159 dimension_numbers.set_input_feature_dimension(3); 160 dimension_numbers.set_output_feature_dimension(3); 161 for (int i = 0; i < num_spatial_dims; ++i) { 162 dimension_numbers.add_input_spatial_dimensions(1 + i); 163 dimension_numbers.add_output_spatial_dimensions(1 + i); 164 dimension_numbers.add_kernel_spatial_dimensions(i); 165 } 166 dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); 167 dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); 168 169 ResizeConvolutionDims dims = 170 ComputeResizeConvolutionParameters(in_size, out_size); 171 xla::ComputationDataHandle kernel = 172 MakeBilinearResizeKernel(builder, dims.kernel_size, channels); 173 xla::ComputationDataHandle output = builder->ConvGeneralDilated( 174 input, kernel, dims.stride, 175 /*padding=*/ 176 {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, 177 {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, 178 /*lhs_dilation=*/dims.kernel_size, 179 /*rhs_dilation=*/{1, 1}, dimension_numbers); 180 181 // Add broadcasts to handle expanding from a size == 1 dimension to a 182 // size > 1 dimension. 183 for (int i = 0; i < num_spatial_dims; ++i) { 184 if (in_size[i] == 1 && out_size[i] > 1) { 185 output = builder->Add(output, builder->ConstantR1<float>(out_size[i], 0), 186 /*broadcast_dimensions=*/{1 + i}); 187 } 188 } 189 return output; 190 } 191 192 xla::ComputationDataHandle ResizeUsingDilationAndConvolutionGradOp( 193 xla::ComputationBuilder* builder, const xla::ComputationDataHandle& grad, 194 const int num_spatial_dims, std::vector<int64> in_size, 195 std::vector<int64> grad_size, const int64 channels) { 196 ResizeConvolutionDims dims = 197 ComputeResizeConvolutionParameters(in_size, grad_size); 198 199 // To form the backward convolution, we keep the kernel unchanged (it is 200 // already symmetric) and swap the roles of strides and LHS dilation. 201 xla::ConvolutionDimensionNumbers dimension_numbers; 202 dimension_numbers.set_input_batch_dimension(0); 203 dimension_numbers.set_output_batch_dimension(0); 204 dimension_numbers.set_input_feature_dimension(3); 205 dimension_numbers.set_output_feature_dimension(3); 206 for (int i = 0; i < num_spatial_dims; ++i) { 207 dimension_numbers.add_input_spatial_dimensions(1 + i); 208 dimension_numbers.add_output_spatial_dimensions(1 + i); 209 dimension_numbers.add_kernel_spatial_dimensions(i); 210 } 211 dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); 212 dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); 213 xla::ComputationDataHandle kernel = 214 MakeBilinearResizeKernel(builder, dims.kernel_size, channels); 215 216 // Broadcast the input kernel where the forward op expanded from a size == 1 217 // dimension to a size > 1 dimension. This has the effect of summing the 218 // gradient contributions in that dimension. 219 for (int i = 0; i < num_spatial_dims; ++i) { 220 if (in_size[i] == 1 && grad_size[i] > 1) { 221 kernel = builder->Add(kernel, builder->ConstantR1<float>(grad_size[i], 0), 222 /*broadcast_dimensions=*/{i}); 223 } 224 } 225 226 xla::ComputationDataHandle output = builder->ConvGeneralDilated( 227 grad, kernel, /*window_strides=*/dims.kernel_size, 228 /*padding=*/ 229 {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, 230 {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, 231 /*lhs_dilation=*/dims.stride, 232 /*rhs_dilation=*/{1, 1}, dimension_numbers); 233 234 // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i. 235 // Opposite of the slice performed by the forward op. 236 xla::PaddingConfig padding = xla::MakeNoPaddingConfig(4); 237 bool pad_output = false; 238 for (int i = 0; i < num_spatial_dims; ++i) { 239 if (in_size[i] > 1 && grad_size[i] == 1) { 240 pad_output = true; 241 padding.mutable_dimensions(1 + i)->set_edge_padding_high(in_size[i] - 1); 242 } 243 } 244 if (pad_output) { 245 output = builder->Pad(output, builder->ConstantR0<float>(0.0f), padding); 246 } 247 return output; 248 } 249 250 class ResizeBilinearOp : public XlaOpKernel { 251 public: 252 explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 253 OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); 254 OP_REQUIRES( 255 ctx, align_corners_ == true, 256 errors::Unimplemented( 257 "ResizeBilinear with align_corners=False is not yet implemented")); 258 } 259 260 void Compile(XlaOpKernelContext* ctx) override { 261 xla::ComputationBuilder* b = ctx->builder(); 262 263 TensorShape input_shape = ctx->InputShape(0); 264 OP_REQUIRES(ctx, input_shape.dims() == 4, 265 errors::InvalidArgument("input must be 4-dimensional", 266 input_shape.DebugString())); 267 const int64 batch = input_shape.dim_size(0); 268 std::vector<int64> in_size = {input_shape.dim_size(1), 269 input_shape.dim_size(2)}; 270 const int64 channels = input_shape.dim_size(3); 271 OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, 272 errors::InvalidArgument("input size must be positive, got [", 273 in_size[0], ",", in_size[1], "]")); 274 275 std::vector<int64> out_size; 276 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size)); 277 OP_REQUIRES(ctx, out_size.size() == 2, 278 errors::InvalidArgument("output size must be length 2, got ", 279 out_size.size())); 280 OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0, 281 errors::InvalidArgument("output size must be positive, got [", 282 out_size[0], ",", out_size[1], "]")); 283 284 const int num_spatial_dims = 2; 285 286 xla::ComputationDataHandle input = ctx->Input(0); 287 288 // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in 289 // dimension i. 290 std::vector<int64> slice_size = in_size; 291 bool slice_input = false; 292 for (int i = 0; i < num_spatial_dims; ++i) { 293 if (in_size[i] > 1 && out_size[i] == 1) { 294 // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first 295 // entry before resizing. 296 slice_input = true; 297 slice_size[i] = 1; 298 } 299 } 300 if (slice_input) { 301 input = b->Slice(input, {0, 0, 0, 0}, 302 {batch, slice_size[0], slice_size[1], channels}, 303 {1, 1, 1, 1}); 304 } 305 306 // Output is always type float. 307 input = b->ConvertElementType(input, xla::F32); 308 309 // Special Case: 310 // Instead of doing a ResizeUsingDilationAndConvolution directly, 311 // while (out_size[0]-1) = c * 2^x * (in_size[0]-1) for x>1 c>1, resize the 312 // image to 2*(in_size[0]-1)+1 x-times and then resize by scale c(int here). 313 // Instead of resizing directly we resize it iteratively. 314 // 315 // Since bilinear resize can be broken down as 2 sequential linear 316 // operations along different dimensions. 317 // Given sufficient numerical stability and a<e<c and b<f<d, bilinear resize 318 // from image of size axb -> cxd is same as resizing axb -> exf -> cxd. 319 // 320 // This makes the convolutions kernels smaller and the operation faster. 321 xla::ComputationDataHandle output = input; 322 while (in_size != out_size) { 323 if (in_size[0] != 1 && in_size[1] != 1) { 324 std::vector<float> k = { 325 (static_cast<float>(out_size[0]) - 1) / ((in_size[0] - 1) * 2), 326 (static_cast<float>(out_size[1]) - 1) / ((in_size[1] - 1) * 2)}; 327 if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) && 328 k[0] > 1 && k[1] > 1) { 329 std::vector<int64> next_out_size = {(in_size[0] - 1) * 2 + 1, 330 (in_size[1] - 1) * 2 + 1}; 331 output = ResizeUsingDilationAndConvolution( 332 b, input, num_spatial_dims, in_size, next_out_size, channels); 333 input = output; 334 in_size = next_out_size; 335 } else { 336 output = ResizeUsingDilationAndConvolution( 337 b, input, num_spatial_dims, in_size, out_size, channels); 338 in_size = out_size; 339 } 340 } else { 341 output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, 342 in_size, out_size, channels); 343 in_size = out_size; 344 } 345 } 346 347 ctx->SetOutput(0, output); 348 } 349 350 private: 351 bool align_corners_; 352 }; 353 354 REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstInput("size"), 355 ResizeBilinearOp); 356 357 class ResizeBilinearGradOp : public XlaOpKernel { 358 public: 359 explicit ResizeBilinearGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 360 OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); 361 OP_REQUIRES( 362 ctx, align_corners_ == true, 363 errors::Unimplemented("ResizeBilinearGrad with align_corners=False is " 364 "not yet implemented")); 365 366 DataType output_dtype; 367 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype)); 368 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_dtype, &output_type_)); 369 } 370 371 void Compile(XlaOpKernelContext* ctx) override { 372 xla::ComputationBuilder* b = ctx->builder(); 373 374 TensorShape input_shape = ctx->InputShape(1); 375 OP_REQUIRES(ctx, input_shape.dims() == 4, 376 errors::InvalidArgument("input must be 4-dimensional", 377 input_shape.DebugString())); 378 const int64 batch = input_shape.dim_size(0); 379 std::vector<int64> in_size = {input_shape.dim_size(1), 380 input_shape.dim_size(2)}; 381 const int64 channels = input_shape.dim_size(3); 382 OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, 383 errors::InvalidArgument("input size must be positive, got [", 384 in_size[0], ",", in_size[1], "]")); 385 386 TensorShape grad_shape = ctx->InputShape(0); 387 OP_REQUIRES(ctx, grad_shape.dims() == 4, 388 errors::InvalidArgument("gradient must be 4-dimensional", 389 grad_shape.DebugString())); 390 const int64 grad_batch = grad_shape.dim_size(0); 391 const std::vector<int64> grad_size = {grad_shape.dim_size(1), 392 grad_shape.dim_size(2)}; 393 const int64 grad_channels = grad_shape.dim_size(3); 394 OP_REQUIRES(ctx, batch == grad_batch, 395 errors::InvalidArgument( 396 "activations and gradients must have the same batch size (", 397 batch, " vs. ", grad_batch, ")")); 398 OP_REQUIRES(ctx, grad_size[0] > 0 && grad_size[1] > 0, 399 errors::InvalidArgument("gradient size must be positive, got [", 400 grad_size[0], ",", grad_size[1], "]")); 401 OP_REQUIRES( 402 ctx, channels == grad_channels, 403 errors::InvalidArgument( 404 "activations and gradients must have the same number of channels (", 405 channels, " vs. ", grad_channels, ")")); 406 407 const int num_spatial_dims = 2; 408 409 xla::ComputationDataHandle grad = ctx->Input(0); 410 411 xla::ComputationDataHandle output = grad; 412 while (in_size != grad_size) { 413 if (in_size[0] != 1 && in_size[1] != 1) { 414 std::vector<float> k = { 415 (static_cast<float>(grad_size[0]) - 1) / ((in_size[0] - 1) * 2), 416 (static_cast<float>(grad_size[1]) - 1) / ((in_size[1] - 1) * 2)}; 417 if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) && 418 k[0] > 1 && k[1] > 1) { 419 std::vector<int64> next_grad_size = {(in_size[0] - 1) * 2 + 1, 420 (in_size[1] - 1) * 2 + 1}; 421 output = ResizeUsingDilationAndConvolutionGradOp( 422 b, grad, num_spatial_dims, in_size, next_grad_size, channels); 423 grad = output; 424 in_size = next_grad_size; 425 } else { 426 output = ResizeUsingDilationAndConvolutionGradOp( 427 b, grad, num_spatial_dims, in_size, grad_size, channels); 428 in_size = grad_size; 429 } 430 } else { 431 output = ResizeUsingDilationAndConvolutionGradOp( 432 b, grad, num_spatial_dims, in_size, grad_size, channels); 433 in_size = grad_size; 434 } 435 } 436 437 output = b->ConvertElementType(output, output_type_); 438 ctx->SetOutput(0, output); 439 } 440 441 private: 442 bool align_corners_; 443 xla::PrimitiveType output_type_; 444 }; 445 446 REGISTER_XLA_OP(Name("ResizeBilinearGrad"), ResizeBilinearGradOp); 447 448 } // namespace 449 } // namespace tensorflow 450