Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     19 
     20 namespace tensorflow {
     21 namespace {
     22 
     23 void SpaceToBatch(XlaOpKernelContext* ctx,
     24                   const xla::ComputationDataHandle& input, DataType input_dtype,
     25                   const TensorShape& input_tensor_shape,
     26                   gtl::ArraySlice<int64> block_shape,
     27                   const xla::Literal& paddings) {
     28   const int input_rank = input_tensor_shape.dims();
     29   const gtl::InlinedVector<int64, 4> input_shape =
     30       input_tensor_shape.dim_sizes();
     31   const int block_rank = block_shape.size();
     32 
     33   OP_REQUIRES(
     34       ctx, input_rank >= 1 + block_rank,
     35       errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
     36                               " instead of ", input_rank));
     37   gtl::ArraySlice<int64> remainder_shape(input_shape);
     38   remainder_shape.remove_prefix(1 + block_rank);
     39 
     40   OP_REQUIRES(
     41       ctx,
     42       xla::ShapeUtil::Rank(paddings.shape()) == 2 &&
     43           block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) &&
     44           2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1),
     45       errors::InvalidArgument("paddings should have shape [", block_rank,
     46                               ", 2] instead of ",
     47                               xla::ShapeUtil::HumanString(paddings.shape())));
     48 
     49   xla::ComputationBuilder* b = ctx->builder();
     50 
     51   // 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the
     52   //  input according to `paddings` to produce `padded` of shape `padded_shape`.
     53   xla::PaddingConfig padding_config;
     54   std::vector<int64> padded_shape(input_shape.begin(), input_shape.end());
     55   int64 block_num_elems = 1LL;
     56   padding_config.add_dimensions();  // Don't pad the batch dimension.
     57   for (int i = 0; i < block_rank; ++i) {
     58     auto* dim = padding_config.add_dimensions();
     59     int64 pad_start = paddings.Get<int64>({i, 0});
     60     int64 pad_end = paddings.Get<int64>({i, 1});
     61     OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0,
     62                 errors::InvalidArgument("Paddings must be non-negative"));
     63     dim->set_edge_padding_low(pad_start);
     64     dim->set_edge_padding_high(pad_end);
     65     padded_shape[1 + i] += pad_start + pad_end;
     66     block_num_elems *= block_shape[i];
     67   }
     68   // Don't pad the remainder dimensions.
     69   for (int i = 0; i < remainder_shape.size(); ++i) {
     70     padding_config.add_dimensions();
     71   }
     72   OP_REQUIRES(ctx, block_num_elems > 0,
     73               errors::InvalidArgument(
     74                   "The product of the block dimensions must be positive"));
     75 
     76   xla::ComputationDataHandle padded =
     77       b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config);
     78 
     79   // 2. Reshape `padded` to `reshaped_padded` of shape:
     80   //
     81   //      [batch] +
     82   //      [padded_shape[1] / block_shape[0],
     83   //        block_shape[0],
     84   //       ...,
     85   //       padded_shape[M] / block_shape[M-1],
     86   //       block_shape[M-1]] +
     87   //      remaining_shape
     88   const int64 batch_size = input_shape[0];
     89   std::vector<int64> reshaped_padded_shape(input_rank + block_rank);
     90   reshaped_padded_shape[0] = batch_size;
     91   for (int i = 0; i < block_rank; ++i) {
     92     OP_REQUIRES(ctx, padded_shape[1 + i] % block_shape[i] == 0,
     93                 errors::InvalidArgument("padded_shape[", 1 + i,
     94                                         "]=", padded_shape[1 + i],
     95                                         " is not divisible by block_shape[", i,
     96                                         "]=", block_shape[i]));
     97 
     98     reshaped_padded_shape[1 + i * 2] = padded_shape[1 + i] / block_shape[i];
     99     reshaped_padded_shape[1 + i * 2 + 1] = block_shape[i];
    100   }
    101   std::copy(remainder_shape.begin(), remainder_shape.end(),
    102             reshaped_padded_shape.begin() + 1 + 2 * block_rank);
    103 
    104   xla::ComputationDataHandle reshaped_padded =
    105       b->Reshape(padded, reshaped_padded_shape);
    106 
    107   // 3. Permute dimensions of `reshaped_padded` to produce
    108   //    `permuted_reshaped_padded` of shape:
    109   //
    110   //      block_shape +
    111   //      [batch] +
    112   //      [padded_shape[1] / block_shape[0],
    113   //       ...,
    114   //       padded_shape[M] / block_shape[M-1]] +
    115   //      remaining_shape
    116   std::vector<int64> permutation(reshaped_padded_shape.size());
    117   for (int i = 0; i < block_rank; ++i) {
    118     permutation[i] = 1 + 2 * i + 1;
    119     permutation[block_rank + 1 + i] = 1 + 2 * i;
    120   }
    121   permutation[block_rank] = 0;
    122   std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
    123             1 + block_rank * 2);
    124   xla::ComputationDataHandle permuted_reshaped_padded =
    125       b->Transpose(reshaped_padded, permutation);
    126 
    127   // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the
    128   //    batch dimension, producing an output tensor of shape:
    129   //
    130   //      [batch * prod(block_shape)] +
    131   //      [padded_shape[1] / block_shape[0],
    132   //       ...,
    133   //       padded_shape[M] / block_shape[M-1]] +
    134   //      remaining_shape
    135   // Determine the length of the prefix of block dims that can be combined
    136   // into the batch dimension due to having no padding and block_shape=1.
    137   std::vector<int64> output_shape(input_rank);
    138   output_shape[0] = batch_size * block_num_elems;
    139   for (int i = 0; i < block_rank; ++i) {
    140     output_shape[1 + i] = padded_shape[1 + i] / block_shape[i];
    141   }
    142   std::copy(remainder_shape.begin(), remainder_shape.end(),
    143             output_shape.begin() + 1 + block_rank);
    144 
    145   xla::ComputationDataHandle output =
    146       b->Reshape(permuted_reshaped_padded, output_shape);
    147   ctx->SetOutput(0, output);
    148 }
    149 
    150 class SpaceToBatchNDOp : public XlaOpKernel {
    151  public:
    152   explicit SpaceToBatchNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
    153 
    154   void Compile(XlaOpKernelContext* ctx) override {
    155     std::vector<int64> block_shape;
    156     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape));
    157 
    158     xla::Literal paddings;
    159     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &paddings));
    160 
    161     SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
    162                  block_shape, paddings);
    163   }
    164 };
    165 REGISTER_XLA_OP(Name("SpaceToBatchND")
    166                     .CompileTimeConstInput("paddings")
    167                     .CompileTimeConstInput("block_shape"),
    168                 SpaceToBatchNDOp);
    169 
    170 class SpaceToBatchOp : public XlaOpKernel {
    171  public:
    172   explicit SpaceToBatchOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    173     OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
    174     OP_REQUIRES(
    175         ctx, block_size_ > 1,
    176         errors::InvalidArgument("Block size should be > 1: ", block_size_));
    177   }
    178 
    179   void Compile(XlaOpKernelContext* ctx) override {
    180     xla::Literal paddings;
    181     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &paddings));
    182 
    183     SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
    184                  {block_size_, block_size_}, paddings);
    185   }
    186 
    187  private:
    188   int block_size_;
    189 };
    190 REGISTER_XLA_OP(Name("SpaceToBatch").CompileTimeConstInput("paddings"),
    191                 SpaceToBatchOp);
    192 
    193 }  // namespace
    194 }  // namespace tensorflow
    195