Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     19 #include "tensorflow/core/util/tensor_format.h"
     20 
     21 namespace tensorflow {
     22 namespace {
     23 
     24 class SpaceToDepthOp : public XlaOpKernel {
     25  public:
     26   explicit SpaceToDepthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
     27     string data_format_str;
     28     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
     29     OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
     30                 errors::InvalidArgument("Invalid data format"));
     31 
     32     OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC,
     33                 errors::InvalidArgument("Unsupported data format ",
     34                                         ToString(data_format_),
     35                                         "; expected formats NHWC or NCHW"));
     36 
     37     OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
     38     OP_REQUIRES(
     39         ctx, block_size_ > 1,
     40         errors::InvalidArgument("Block size should be > 1: ", block_size_));
     41   }
     42 
     43   void Compile(XlaOpKernelContext* ctx) override {
     44     const TensorShape input_tensor_shape = ctx->InputShape(0);
     45     int input_rank = input_tensor_shape.dims();
     46     static const int kRequiredDims = 4;
     47     OP_REQUIRES(ctx, kRequiredDims == input_rank,
     48                 errors::InvalidArgument("Input rank should be ", kRequiredDims,
     49                                         "; got ", input_rank));
     50     const gtl::InlinedVector<int64, 4> input_shape =
     51         input_tensor_shape.dim_sizes();
     52 
     53     xla::ComputationBuilder* b = ctx->builder();
     54     xla::ComputationDataHandle input = ctx->Input(0);
     55 
     56     int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_);
     57     int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_);
     58 
     59     std::vector<int64> reshaped_shape;
     60     std::vector<int64> transpose_order;
     61     std::vector<int64> output_shape;
     62     reshaped_shape.reserve(input_rank);
     63     transpose_order.reserve(input_rank);
     64     output_shape.reserve(input_rank);
     65     if (data_format_ == FORMAT_NHWC) {
     66       int64 block_elems = 1;
     67       for (int i = 0; i < num_spatial_dims; ++i) {
     68         OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0,
     69                     errors::InvalidArgument(
     70                         "input shape[", 1 + i, "]=", input_shape[1 + i],
     71                         " is not divisible by block_size=", block_size_));
     72         block_elems *= block_size_;
     73       }
     74 
     75       reshaped_shape.push_back(input_shape[0]);
     76       for (int i = 0; i < num_spatial_dims; ++i) {
     77         reshaped_shape.push_back(input_shape[1 + i] / block_size_);
     78         reshaped_shape.push_back(block_size_);
     79       }
     80       reshaped_shape.push_back(input_shape[feature_dim]);
     81 
     82       transpose_order.push_back(0);
     83       for (int i = 0; i < num_spatial_dims; ++i) {
     84         transpose_order.push_back(i * 2 + 1);
     85       }
     86       for (int i = 0; i < num_spatial_dims; ++i) {
     87         transpose_order.push_back(i * 2 + 2);
     88       }
     89       transpose_order.push_back(feature_dim + num_spatial_dims);
     90 
     91       output_shape.push_back(input_shape[0]);
     92       for (int i = 0; i < num_spatial_dims; ++i) {
     93         output_shape.push_back(input_shape[1 + i] / block_size_);
     94       }
     95       output_shape.push_back(input_shape[feature_dim] * block_elems);
     96     } else {
     97       // FORMAT_NCHW
     98       int64 block_elems = 1;
     99       for (int i = 0; i < num_spatial_dims; ++i) {
    100         OP_REQUIRES(ctx, input_shape[2 + i] % block_size_ == 0,
    101                     errors::InvalidArgument(
    102                         "input shape[", 2 + i, "]=", input_shape[2 + i],
    103                         " is not divisible by block_size=", block_size_));
    104         block_elems *= block_size_;
    105       }
    106 
    107       reshaped_shape.push_back(input_shape[0]);
    108       reshaped_shape.push_back(input_shape[feature_dim]);
    109       for (int i = 0; i < num_spatial_dims; ++i) {
    110         reshaped_shape.push_back(input_shape[2 + i] / block_size_);
    111         reshaped_shape.push_back(block_size_);
    112       }
    113 
    114       transpose_order.push_back(0);
    115       for (int i = 0; i < num_spatial_dims; ++i) {
    116         transpose_order.push_back(i * 2 + 3);
    117       }
    118       transpose_order.push_back(feature_dim);
    119       for (int i = 0; i < num_spatial_dims; ++i) {
    120         transpose_order.push_back(i * 2 + 2);
    121       }
    122 
    123       output_shape.push_back(input_shape[0]);
    124       output_shape.push_back(input_shape[feature_dim] * block_elems);
    125       for (int i = 0; i < num_spatial_dims; ++i) {
    126         output_shape.push_back(input_shape[2 + i] / block_size_);
    127       }
    128     }
    129 
    130     // Note: comments are given in NHWC format; NCHW is similar with a different
    131     // dimension order.
    132     // 1. Reshape `input` to `reshaped` of shape:
    133     //
    134     //      [batch,
    135     //       input_shape[1] / block_size_, block_size_,
    136     //       input_shape[2] / block_size_, block_size_,
    137     //       depth]
    138     xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape);
    139 
    140     // 2. Permute dimensions of `reshaped` to produce
    141     //    `permuted_reshaped` of shape:
    142     //
    143     //      [batch,
    144     //       input_shape[1] / block_size_,
    145     //       input_shape[2] / block_size_,
    146     //       block_size_, block_size_,
    147     //       depth]
    148     xla::ComputationDataHandle permuted_reshaped =
    149         b->Transpose(reshaped, transpose_order);
    150 
    151     // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the
    152     //    batch dimension, producing an output tensor of shape:
    153     //
    154     //      [batch,
    155     //       input_shape[1] / block_size_,
    156     //       input_shape[2] / block_size_,
    157     //       block_size_ * block_size_ * depth]
    158     //
    159     xla::ComputationDataHandle output =
    160         b->Reshape(permuted_reshaped, output_shape);
    161 
    162     ctx->SetOutput(0, output);
    163   }
    164 
    165  private:
    166   TensorFormat data_format_;
    167   int block_size_;
    168 };
    169 REGISTER_XLA_OP(Name("SpaceToDepth"), SpaceToDepthOp);
    170 
    171 }  // namespace
    172 }  // namespace tensorflow
    173