1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 19 20 namespace tensorflow { 21 namespace { 22 23 void SpaceToBatch(XlaOpKernelContext* ctx, 24 const xla::ComputationDataHandle& input, DataType input_dtype, 25 const TensorShape& input_tensor_shape, 26 gtl::ArraySlice<int64> block_shape, 27 const xla::Literal& paddings) { 28 const int input_rank = input_tensor_shape.dims(); 29 const gtl::InlinedVector<int64, 4> input_shape = 30 input_tensor_shape.dim_sizes(); 31 const int block_rank = block_shape.size(); 32 33 OP_REQUIRES( 34 ctx, input_rank >= 1 + block_rank, 35 errors::InvalidArgument("input rank should be >= ", 1 + block_rank, 36 " instead of ", input_rank)); 37 gtl::ArraySlice<int64> remainder_shape(input_shape); 38 remainder_shape.remove_prefix(1 + block_rank); 39 40 OP_REQUIRES( 41 ctx, 42 xla::ShapeUtil::Rank(paddings.shape()) == 2 && 43 block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) && 44 2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1), 45 errors::InvalidArgument("paddings should have shape [", block_rank, 46 ", 2] instead of ", 47 xla::ShapeUtil::HumanString(paddings.shape()))); 48 49 xla::ComputationBuilder* b = ctx->builder(); 50 51 // 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the 52 // input according to `paddings` to produce `padded` of shape `padded_shape`. 53 xla::PaddingConfig padding_config; 54 std::vector<int64> padded_shape(input_shape.begin(), input_shape.end()); 55 int64 block_num_elems = 1LL; 56 padding_config.add_dimensions(); // Don't pad the batch dimension. 57 for (int i = 0; i < block_rank; ++i) { 58 auto* dim = padding_config.add_dimensions(); 59 int64 pad_start = paddings.Get<int64>({i, 0}); 60 int64 pad_end = paddings.Get<int64>({i, 1}); 61 OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0, 62 errors::InvalidArgument("Paddings must be non-negative")); 63 dim->set_edge_padding_low(pad_start); 64 dim->set_edge_padding_high(pad_end); 65 padded_shape[1 + i] += pad_start + pad_end; 66 block_num_elems *= block_shape[i]; 67 } 68 // Don't pad the remainder dimensions. 69 for (int i = 0; i < remainder_shape.size(); ++i) { 70 padding_config.add_dimensions(); 71 } 72 OP_REQUIRES(ctx, block_num_elems > 0, 73 errors::InvalidArgument( 74 "The product of the block dimensions must be positive")); 75 76 xla::ComputationDataHandle padded = 77 b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config); 78 79 // 2. Reshape `padded` to `reshaped_padded` of shape: 80 // 81 // [batch] + 82 // [padded_shape[1] / block_shape[0], 83 // block_shape[0], 84 // ..., 85 // padded_shape[M] / block_shape[M-1], 86 // block_shape[M-1]] + 87 // remaining_shape 88 const int64 batch_size = input_shape[0]; 89 std::vector<int64> reshaped_padded_shape(input_rank + block_rank); 90 reshaped_padded_shape[0] = batch_size; 91 for (int i = 0; i < block_rank; ++i) { 92 OP_REQUIRES(ctx, padded_shape[1 + i] % block_shape[i] == 0, 93 errors::InvalidArgument("padded_shape[", 1 + i, 94 "]=", padded_shape[1 + i], 95 " is not divisible by block_shape[", i, 96 "]=", block_shape[i])); 97 98 reshaped_padded_shape[1 + i * 2] = padded_shape[1 + i] / block_shape[i]; 99 reshaped_padded_shape[1 + i * 2 + 1] = block_shape[i]; 100 } 101 std::copy(remainder_shape.begin(), remainder_shape.end(), 102 reshaped_padded_shape.begin() + 1 + 2 * block_rank); 103 104 xla::ComputationDataHandle reshaped_padded = 105 b->Reshape(padded, reshaped_padded_shape); 106 107 // 3. Permute dimensions of `reshaped_padded` to produce 108 // `permuted_reshaped_padded` of shape: 109 // 110 // block_shape + 111 // [batch] + 112 // [padded_shape[1] / block_shape[0], 113 // ..., 114 // padded_shape[M] / block_shape[M-1]] + 115 // remaining_shape 116 std::vector<int64> permutation(reshaped_padded_shape.size()); 117 for (int i = 0; i < block_rank; ++i) { 118 permutation[i] = 1 + 2 * i + 1; 119 permutation[block_rank + 1 + i] = 1 + 2 * i; 120 } 121 permutation[block_rank] = 0; 122 std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), 123 1 + block_rank * 2); 124 xla::ComputationDataHandle permuted_reshaped_padded = 125 b->Transpose(reshaped_padded, permutation); 126 127 // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the 128 // batch dimension, producing an output tensor of shape: 129 // 130 // [batch * prod(block_shape)] + 131 // [padded_shape[1] / block_shape[0], 132 // ..., 133 // padded_shape[M] / block_shape[M-1]] + 134 // remaining_shape 135 // Determine the length of the prefix of block dims that can be combined 136 // into the batch dimension due to having no padding and block_shape=1. 137 std::vector<int64> output_shape(input_rank); 138 output_shape[0] = batch_size * block_num_elems; 139 for (int i = 0; i < block_rank; ++i) { 140 output_shape[1 + i] = padded_shape[1 + i] / block_shape[i]; 141 } 142 std::copy(remainder_shape.begin(), remainder_shape.end(), 143 output_shape.begin() + 1 + block_rank); 144 145 xla::ComputationDataHandle output = 146 b->Reshape(permuted_reshaped_padded, output_shape); 147 ctx->SetOutput(0, output); 148 } 149 150 class SpaceToBatchNDOp : public XlaOpKernel { 151 public: 152 explicit SpaceToBatchNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 153 154 void Compile(XlaOpKernelContext* ctx) override { 155 std::vector<int64> block_shape; 156 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape)); 157 158 xla::Literal paddings; 159 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &paddings)); 160 161 SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), 162 block_shape, paddings); 163 } 164 }; 165 REGISTER_XLA_OP(Name("SpaceToBatchND") 166 .CompileTimeConstInput("paddings") 167 .CompileTimeConstInput("block_shape"), 168 SpaceToBatchNDOp); 169 170 class SpaceToBatchOp : public XlaOpKernel { 171 public: 172 explicit SpaceToBatchOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 173 OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); 174 OP_REQUIRES( 175 ctx, block_size_ > 1, 176 errors::InvalidArgument("Block size should be > 1: ", block_size_)); 177 } 178 179 void Compile(XlaOpKernelContext* ctx) override { 180 xla::Literal paddings; 181 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &paddings)); 182 183 SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), 184 {block_size_, block_size_}, paddings); 185 } 186 187 private: 188 int block_size_; 189 }; 190 REGISTER_XLA_OP(Name("SpaceToBatch").CompileTimeConstInput("paddings"), 191 SpaceToBatchOp); 192 193 } // namespace 194 } // namespace tensorflow 195