Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/type_util.h"
     17 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     18 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     19 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     20 #include "tensorflow/compiler/xla/array4d.h"
     21 #include "tensorflow/core/framework/kernel_def_builder.h"
     22 #include "tensorflow/core/framework/register_types.h"
     23 #include "tensorflow/core/lib/math/math_util.h"
     24 
     25 namespace tensorflow {
     26 namespace {
     27 
     28 // We implement bilinear interpolation by upsampling followed by convolution.
     29 // The basic idea is as follows. To scale from NxN to RxR:
     30 //
     31 //    1. S := (N - 1) /  gcd(N-1, R-1)
     32 //    2. k := (R - 1) /  gcd(N-1, R-1)
     33 //    3. Convolution(kxk, stride=S, lhs_dilation=k, padding=k-1)
     34 //
     35 // For example, to Scale from 7x7 -> 15x15:
     36 //
     37 //    1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3
     38 //    2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7
     39 //    3. Convolution(7x7, stride=3, lhs_dilation=3, padding=2)
     40 //
     41 //
     42 // The 7x7 -> 15x15 case is much too large to write out in full as an
     43 // example. The smallest interesting example is 3x3 -> 4x4.
     44 //
     45 // S := 2
     46 // k := 3
     47 //
     48 // 00 03 06    00 00 00 00 00 00 00 00 00 00 00      00 02 04 06
     49 // 09 12 15 -> 00 00 00 00 00 00 00 00 00 00 00   -> 06 08 10 12
     50 // 18 21 24    00 00 00 00 00 03 00 00 06 00 00      12 14 16 18
     51 //             00 00 00 00 00 00 00 00 00 00 00      18 20 22 24
     52 //             00 00 00 00 00 00 00 00 00 00 00
     53 //             00 00 09 00 00 12 00 00 15 00 00
     54 //             00 00 00 00 00 00 00 00 00 00 00
     55 //             00 00 00 00 00 00 00 00 00 00 00
     56 //             00 00 18 00 00 21 00 00 24 00 00
     57 //             00 00 00 00 00 00 00 00 00 00 00
     58 //             00 00 00 00 00 00 00 00 00 00 00
     59 //
     60 // with the following convolutional kernel, with stride [2, 2]:
     61 //       1 2 3 2 1
     62 //       2 4 6 4 2
     63 // 1/9 * 3 6 9 6 3
     64 //       2 4 6 4 2
     65 //       1 2 3 2 1
     66 
     67 // Computes the size of the convolutional kernel and stride to use when resizing
     68 // from in_size to out_size.
     69 struct ResizeConvolutionDims {
     70   // Size of the kernel to use.
     71   std::vector<int64> kernel_size;
     72 
     73   // Stride of the convolution to use.
     74   std::vector<int64> stride;
     75 };
     76 ResizeConvolutionDims ComputeResizeConvolutionParameters(
     77     gtl::ArraySlice<int64> in_size, gtl::ArraySlice<int64> out_size) {
     78   CHECK_EQ(in_size.size(), out_size.size());
     79   int num_spatial_dims = in_size.size();
     80   ResizeConvolutionDims dims;
     81   dims.kernel_size.resize(num_spatial_dims);
     82   dims.stride.resize(num_spatial_dims);
     83   for (int i = 0; i < num_spatial_dims; ++i) {
     84     if (in_size[i] == 1) {
     85       // We must handle input size 1 specially because XLA convolution does
     86       // not allow stride 0.
     87       dims.stride[i] = dims.kernel_size[i] = 1;
     88     } else if (out_size[i] == 1) {
     89       // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first
     90       // entry before resizing.
     91       dims.stride[i] = dims.kernel_size[i] = 1;
     92     } else {
     93       int64 gcd = MathUtil::GCD(static_cast<uint64>(in_size[i] - 1),
     94                                 static_cast<uint64>(out_size[i] - 1));
     95       dims.stride[i] = (in_size[i] - 1) / gcd;
     96       dims.kernel_size[i] = (out_size[i] - 1) / gcd;
     97     }
     98   }
     99   return dims;
    100 }
    101 
    102 xla::ComputationDataHandle MakeBilinearResizeKernel(
    103     xla::ComputationBuilder* builder, gtl::ArraySlice<int64> kernel_size,
    104     int64 channels) {
    105   // Form a 2D convolution kernel like:
    106   //       1 2 3 2 1
    107   //       2 4 6 4 2
    108   // 1/9 * 3 6 9 6 3
    109   //       2 4 6 4 2
    110   //       1 2 3 2 1
    111   // by multiplying two 1D kernels of the form:
    112   // 1/3 * [1 2 3 2 1]
    113   auto make_1d_kernel = [](int64 n) {
    114     std::vector<float> kernel(n * 2 - 1);
    115     for (int64 i = 0; i < n; ++i) {
    116       float v = (i + 1.0f) / n;
    117       kernel[i] = v;
    118       kernel[n * 2 - 2 - i] = v;
    119     }
    120     return kernel;
    121   };
    122 
    123   xla::ComputationDataHandle channels_iota;
    124   // DT_INT32 Iota will always return status::OK().
    125   TF_CHECK_OK(
    126       XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota));
    127 
    128   auto diag = builder->ConvertElementType(
    129       builder->Eq(
    130           builder->Broadcast(channels_iota, {2 * kernel_size[0] - 1,
    131                                              2 * kernel_size[1] - 1, channels}),
    132           channels_iota, /*broadcast_dimensions=*/{2}),
    133       xla::PrimitiveType::F32);
    134   return builder->Mul(
    135       builder->Mul(diag,
    136                    builder->ConstantR1<float>(make_1d_kernel(kernel_size[1])),
    137                    /*broadcast_dimensions=*/{1}),
    138       builder->ConstantR1<float>(make_1d_kernel(kernel_size[0])),
    139       /*broadcast_dimensions=*/{0});
    140 }
    141 
    142 xla::ComputationDataHandle ResizeUsingDilationAndConvolution(
    143     xla::ComputationBuilder* builder, const xla::ComputationDataHandle& input,
    144     const int num_spatial_dims, std::vector<int64> in_size,
    145     std::vector<int64> out_size, const int64 channels) {
    146   // Picture for a 1x3 to 1x4 resize:
    147   // stride = 2, kernel size = 3
    148   // Input:
    149   // 3 6 9
    150   // Input with dilation and padding:
    151   // 0 0 3 0 0 6 0 0 9 0 0
    152   // Convolution kernel:
    153   // 1/3 * [1 2 3 2 1]
    154   // Output:
    155   // 3 5 7 9
    156   xla::ConvolutionDimensionNumbers dimension_numbers;
    157   dimension_numbers.set_input_batch_dimension(0);
    158   dimension_numbers.set_output_batch_dimension(0);
    159   dimension_numbers.set_input_feature_dimension(3);
    160   dimension_numbers.set_output_feature_dimension(3);
    161   for (int i = 0; i < num_spatial_dims; ++i) {
    162     dimension_numbers.add_input_spatial_dimensions(1 + i);
    163     dimension_numbers.add_output_spatial_dimensions(1 + i);
    164     dimension_numbers.add_kernel_spatial_dimensions(i);
    165   }
    166   dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims);
    167   dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1);
    168 
    169   ResizeConvolutionDims dims =
    170       ComputeResizeConvolutionParameters(in_size, out_size);
    171   xla::ComputationDataHandle kernel =
    172       MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
    173   xla::ComputationDataHandle output = builder->ConvGeneralDilated(
    174       input, kernel, dims.stride,
    175       /*padding=*/
    176       {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
    177        {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
    178       /*lhs_dilation=*/dims.kernel_size,
    179       /*rhs_dilation=*/{1, 1}, dimension_numbers);
    180 
    181   // Add broadcasts to handle expanding from a size == 1 dimension to a
    182   // size > 1 dimension.
    183   for (int i = 0; i < num_spatial_dims; ++i) {
    184     if (in_size[i] == 1 && out_size[i] > 1) {
    185       output = builder->Add(output, builder->ConstantR1<float>(out_size[i], 0),
    186                             /*broadcast_dimensions=*/{1 + i});
    187     }
    188   }
    189   return output;
    190 }
    191 
    192 xla::ComputationDataHandle ResizeUsingDilationAndConvolutionGradOp(
    193     xla::ComputationBuilder* builder, const xla::ComputationDataHandle& grad,
    194     const int num_spatial_dims, std::vector<int64> in_size,
    195     std::vector<int64> grad_size, const int64 channels) {
    196   ResizeConvolutionDims dims =
    197       ComputeResizeConvolutionParameters(in_size, grad_size);
    198 
    199   // To form the backward convolution, we keep the kernel unchanged (it is
    200   // already symmetric) and swap the roles of strides and LHS dilation.
    201   xla::ConvolutionDimensionNumbers dimension_numbers;
    202   dimension_numbers.set_input_batch_dimension(0);
    203   dimension_numbers.set_output_batch_dimension(0);
    204   dimension_numbers.set_input_feature_dimension(3);
    205   dimension_numbers.set_output_feature_dimension(3);
    206   for (int i = 0; i < num_spatial_dims; ++i) {
    207     dimension_numbers.add_input_spatial_dimensions(1 + i);
    208     dimension_numbers.add_output_spatial_dimensions(1 + i);
    209     dimension_numbers.add_kernel_spatial_dimensions(i);
    210   }
    211   dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims);
    212   dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1);
    213   xla::ComputationDataHandle kernel =
    214       MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
    215 
    216   // Broadcast the input kernel where the forward op expanded from a size == 1
    217   // dimension to a size > 1 dimension. This has the effect of summing the
    218   // gradient contributions in that dimension.
    219   for (int i = 0; i < num_spatial_dims; ++i) {
    220     if (in_size[i] == 1 && grad_size[i] > 1) {
    221       kernel = builder->Add(kernel, builder->ConstantR1<float>(grad_size[i], 0),
    222                             /*broadcast_dimensions=*/{i});
    223     }
    224   }
    225 
    226   xla::ComputationDataHandle output = builder->ConvGeneralDilated(
    227       grad, kernel, /*window_strides=*/dims.kernel_size,
    228       /*padding=*/
    229       {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
    230        {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
    231       /*lhs_dilation=*/dims.stride,
    232       /*rhs_dilation=*/{1, 1}, dimension_numbers);
    233 
    234   // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i.
    235   // Opposite of the slice performed by the forward op.
    236   xla::PaddingConfig padding = xla::MakeNoPaddingConfig(4);
    237   bool pad_output = false;
    238   for (int i = 0; i < num_spatial_dims; ++i) {
    239     if (in_size[i] > 1 && grad_size[i] == 1) {
    240       pad_output = true;
    241       padding.mutable_dimensions(1 + i)->set_edge_padding_high(in_size[i] - 1);
    242     }
    243   }
    244   if (pad_output) {
    245     output = builder->Pad(output, builder->ConstantR0<float>(0.0f), padding);
    246   }
    247   return output;
    248 }
    249 
    250 class ResizeBilinearOp : public XlaOpKernel {
    251  public:
    252   explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    253     OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
    254     OP_REQUIRES(
    255         ctx, align_corners_ == true,
    256         errors::Unimplemented(
    257             "ResizeBilinear with align_corners=False is not yet implemented"));
    258   }
    259 
    260   void Compile(XlaOpKernelContext* ctx) override {
    261     xla::ComputationBuilder* b = ctx->builder();
    262 
    263     TensorShape input_shape = ctx->InputShape(0);
    264     OP_REQUIRES(ctx, input_shape.dims() == 4,
    265                 errors::InvalidArgument("input must be 4-dimensional",
    266                                         input_shape.DebugString()));
    267     const int64 batch = input_shape.dim_size(0);
    268     std::vector<int64> in_size = {input_shape.dim_size(1),
    269                                   input_shape.dim_size(2)};
    270     const int64 channels = input_shape.dim_size(3);
    271     OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0,
    272                 errors::InvalidArgument("input size must be positive, got [",
    273                                         in_size[0], ",", in_size[1], "]"));
    274 
    275     std::vector<int64> out_size;
    276     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size));
    277     OP_REQUIRES(ctx, out_size.size() == 2,
    278                 errors::InvalidArgument("output size must be length 2, got ",
    279                                         out_size.size()));
    280     OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0,
    281                 errors::InvalidArgument("output size must be positive, got [",
    282                                         out_size[0], ",", out_size[1], "]"));
    283 
    284     const int num_spatial_dims = 2;
    285 
    286     xla::ComputationDataHandle input = ctx->Input(0);
    287 
    288     // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in
    289     // dimension i.
    290     std::vector<int64> slice_size = in_size;
    291     bool slice_input = false;
    292     for (int i = 0; i < num_spatial_dims; ++i) {
    293       if (in_size[i] > 1 && out_size[i] == 1) {
    294         // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first
    295         // entry before resizing.
    296         slice_input = true;
    297         slice_size[i] = 1;
    298       }
    299     }
    300     if (slice_input) {
    301       input = b->Slice(input, {0, 0, 0, 0},
    302                        {batch, slice_size[0], slice_size[1], channels},
    303                        {1, 1, 1, 1});
    304     }
    305 
    306     // Output is always type float.
    307     input = b->ConvertElementType(input, xla::F32);
    308 
    309     // Special Case:
    310     // Instead of doing a ResizeUsingDilationAndConvolution directly,
    311     // while (out_size[0]-1) = c * 2^x * (in_size[0]-1) for x>1 c>1, resize the
    312     // image to 2*(in_size[0]-1)+1 x-times and then resize by scale c(int here).
    313     // Instead of resizing directly we resize it iteratively.
    314     //
    315     // Since bilinear resize can be broken down as 2 sequential linear
    316     // operations along different dimensions.
    317     // Given sufficient numerical stability and a<e<c and b<f<d, bilinear resize
    318     // from image of size axb -> cxd is same as resizing axb -> exf -> cxd.
    319     //
    320     // This makes the convolutions kernels smaller and the operation faster.
    321     xla::ComputationDataHandle output = input;
    322     while (in_size != out_size) {
    323       if (in_size[0] != 1 && in_size[1] != 1) {
    324         std::vector<float> k = {
    325             (static_cast<float>(out_size[0]) - 1) / ((in_size[0] - 1) * 2),
    326             (static_cast<float>(out_size[1]) - 1) / ((in_size[1] - 1) * 2)};
    327         if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) &&
    328             k[0] > 1 && k[1] > 1) {
    329           std::vector<int64> next_out_size = {(in_size[0] - 1) * 2 + 1,
    330                                               (in_size[1] - 1) * 2 + 1};
    331           output = ResizeUsingDilationAndConvolution(
    332               b, input, num_spatial_dims, in_size, next_out_size, channels);
    333           input = output;
    334           in_size = next_out_size;
    335         } else {
    336           output = ResizeUsingDilationAndConvolution(
    337               b, input, num_spatial_dims, in_size, out_size, channels);
    338           in_size = out_size;
    339         }
    340       } else {
    341         output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims,
    342                                                    in_size, out_size, channels);
    343         in_size = out_size;
    344       }
    345     }
    346 
    347     ctx->SetOutput(0, output);
    348   }
    349 
    350  private:
    351   bool align_corners_;
    352 };
    353 
    354 REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstInput("size"),
    355                 ResizeBilinearOp);
    356 
    357 class ResizeBilinearGradOp : public XlaOpKernel {
    358  public:
    359   explicit ResizeBilinearGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    360     OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
    361     OP_REQUIRES(
    362         ctx, align_corners_ == true,
    363         errors::Unimplemented("ResizeBilinearGrad with align_corners=False is "
    364                               "not yet implemented"));
    365 
    366     DataType output_dtype;
    367     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype));
    368     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_dtype, &output_type_));
    369   }
    370 
    371   void Compile(XlaOpKernelContext* ctx) override {
    372     xla::ComputationBuilder* b = ctx->builder();
    373 
    374     TensorShape input_shape = ctx->InputShape(1);
    375     OP_REQUIRES(ctx, input_shape.dims() == 4,
    376                 errors::InvalidArgument("input must be 4-dimensional",
    377                                         input_shape.DebugString()));
    378     const int64 batch = input_shape.dim_size(0);
    379     std::vector<int64> in_size = {input_shape.dim_size(1),
    380                                   input_shape.dim_size(2)};
    381     const int64 channels = input_shape.dim_size(3);
    382     OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0,
    383                 errors::InvalidArgument("input size must be positive, got [",
    384                                         in_size[0], ",", in_size[1], "]"));
    385 
    386     TensorShape grad_shape = ctx->InputShape(0);
    387     OP_REQUIRES(ctx, grad_shape.dims() == 4,
    388                 errors::InvalidArgument("gradient must be 4-dimensional",
    389                                         grad_shape.DebugString()));
    390     const int64 grad_batch = grad_shape.dim_size(0);
    391     const std::vector<int64> grad_size = {grad_shape.dim_size(1),
    392                                           grad_shape.dim_size(2)};
    393     const int64 grad_channels = grad_shape.dim_size(3);
    394     OP_REQUIRES(ctx, batch == grad_batch,
    395                 errors::InvalidArgument(
    396                     "activations and gradients must have the same batch size (",
    397                     batch, " vs. ", grad_batch, ")"));
    398     OP_REQUIRES(ctx, grad_size[0] > 0 && grad_size[1] > 0,
    399                 errors::InvalidArgument("gradient size must be positive, got [",
    400                                         grad_size[0], ",", grad_size[1], "]"));
    401     OP_REQUIRES(
    402         ctx, channels == grad_channels,
    403         errors::InvalidArgument(
    404             "activations and gradients must have the same number of channels (",
    405             channels, " vs. ", grad_channels, ")"));
    406 
    407     const int num_spatial_dims = 2;
    408 
    409     xla::ComputationDataHandle grad = ctx->Input(0);
    410 
    411     xla::ComputationDataHandle output = grad;
    412     while (in_size != grad_size) {
    413       if (in_size[0] != 1 && in_size[1] != 1) {
    414         std::vector<float> k = {
    415             (static_cast<float>(grad_size[0]) - 1) / ((in_size[0] - 1) * 2),
    416             (static_cast<float>(grad_size[1]) - 1) / ((in_size[1] - 1) * 2)};
    417         if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) &&
    418             k[0] > 1 && k[1] > 1) {
    419           std::vector<int64> next_grad_size = {(in_size[0] - 1) * 2 + 1,
    420                                                (in_size[1] - 1) * 2 + 1};
    421           output = ResizeUsingDilationAndConvolutionGradOp(
    422               b, grad, num_spatial_dims, in_size, next_grad_size, channels);
    423           grad = output;
    424           in_size = next_grad_size;
    425         } else {
    426           output = ResizeUsingDilationAndConvolutionGradOp(
    427               b, grad, num_spatial_dims, in_size, grad_size, channels);
    428           in_size = grad_size;
    429         }
    430       } else {
    431         output = ResizeUsingDilationAndConvolutionGradOp(
    432             b, grad, num_spatial_dims, in_size, grad_size, channels);
    433         in_size = grad_size;
    434       }
    435     }
    436 
    437     output = b->ConvertElementType(output, output_type_);
    438     ctx->SetOutput(0, output);
    439   }
    440 
    441  private:
    442   bool align_corners_;
    443   xla::PrimitiveType output_type_;
    444 };
    445 
    446 REGISTER_XLA_OP(Name("ResizeBilinearGrad"), ResizeBilinearGradOp);
    447 
    448 }  // namespace
    449 }  // namespace tensorflow
    450