1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific Ops for split. 17 18 #include "tensorflow/compiler/tf2xla/type_util.h" 19 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 22 #include "tensorflow/compiler/xla/literal_util.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/tensor_shape.h" 27 28 namespace tensorflow { 29 namespace { 30 31 class SplitOp : public XlaOpKernel { 32 public: 33 explicit SplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 34 35 void Compile(XlaOpKernelContext* ctx) override { 36 const int32 num_split = num_outputs(); 37 const TensorShape index_shape = ctx->InputShape(0); 38 const TensorShape input_shape = ctx->InputShape(1); 39 40 xla::Literal literal_index; 41 OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal_index)); 42 43 int32 split_dim_orig; 44 if (index_shape.dims() == 0) { 45 split_dim_orig = literal_index.Get<int>({}); 46 } else { 47 OP_REQUIRES( 48 ctx, index_shape.dims() == 1, 49 errors::InvalidArgument("split_index input to Split Op must be a " 50 "scalar or a vector with 1 element")); 51 OP_REQUIRES( 52 ctx, index_shape.dim_size(0) == 1, 53 errors::InvalidArgument("split_index input to Split Op must be a " 54 "scalar or a vector with 1 element")); 55 split_dim_orig = literal_index.Get<int>({0}); 56 } 57 int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims() 58 : split_dim_orig; 59 OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(), 60 errors::InvalidArgument("-input rank(-", input_shape.dims(), 61 ") <= split_dim < input rank (", 62 input_shape.dims(), "), but got ", 63 split_dim_orig)); 64 65 OP_REQUIRES( 66 ctx, num_split > 0, 67 errors::InvalidArgument( 68 "Number of ways to split should be > 0, but got ", num_split)); 69 70 OP_REQUIRES( 71 ctx, input_shape.dim_size(split_dim) % num_split == 0, 72 errors::InvalidArgument( 73 "Number of ways to split should evenly divide the split " 74 "dimension, but got split_dim ", 75 split_dim_orig, " (size = ", input_shape.dim_size(split_dim), ") ", 76 "and num_split ", num_split)); 77 78 // All the slices are the same size: this is the size along the 79 // split dimension. 80 const int32 slice_size = input_shape.dim_size(split_dim) / num_split; 81 82 // The vectors we will use to define the slice. The entry for the 83 // split dimensions varies for each output. 84 std::vector<int64> begin(input_shape.dims(), 0); 85 std::vector<int64> limits(input_shape.dims()); 86 std::vector<int64> strides(input_shape.dims(), 1); 87 for (int i = 0; i < input_shape.dims(); ++i) { 88 // Initially set up the limits to be the full size of the input: 89 // the split dimension is filled in below. 90 int64 dim = input_shape.dim_size(i); 91 limits[i] = dim; 92 } 93 94 auto input = ctx->Input(1); 95 96 // Create each of the outputs. 97 for (int i = 0; i < num_split; ++i) { 98 // Slice out the ith split from the split dimension. 99 begin[split_dim] = i * slice_size; 100 limits[split_dim] = (i + 1) * slice_size; 101 ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides)); 102 } 103 } 104 }; 105 106 REGISTER_XLA_OP(Name("Split").CompileTimeConstInput("split_dim"), SplitOp); 107 108 class SplitVOp : public XlaOpKernel { 109 public: 110 explicit SplitVOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 111 112 void Compile(XlaOpKernelContext* ctx) override { 113 const int32 num_split = num_outputs(); 114 const TensorShape index_shape = ctx->InputShape(2); 115 xla::Literal literal_index; 116 OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &literal_index)); 117 118 int32 split_dim; 119 OP_REQUIRES(ctx, index_shape.dims() == 0, 120 errors::InvalidArgument("split_dim input to Split Op must be a " 121 "scalar")); 122 split_dim = literal_index.Get<int>({}); 123 124 xla::ComputationDataHandle input = ctx->Input(0); 125 const TensorShape input_shape = ctx->InputShape(0); 126 127 OP_REQUIRES(ctx, input_shape.dims() > 0, 128 errors::InvalidArgument("Can't split a 0 dimensional input")); 129 130 OP_REQUIRES( 131 ctx, 0 <= split_dim && split_dim < input_shape.dims(), 132 errors::InvalidArgument("0 <= split_dim < number of input dimensions (", 133 input_shape.dims(), "), but got ", split_dim)); 134 135 OP_REQUIRES( 136 ctx, num_split > 0, 137 errors::InvalidArgument( 138 "Number of ways to split should be > 0, but got ", num_split)); 139 140 // check that sizes are correct 141 int total_split_size = 0; 142 int neg_one_dim = -1; 143 std::vector<int64> split_sizes_vec(num_split, -1); 144 const TensorShape split_size_shape = ctx->InputShape(1); 145 OP_REQUIRES(ctx, 146 split_size_shape.dims() == 1 && 147 split_size_shape.num_elements() == num_split, 148 errors::InvalidArgument( 149 "shape of tensor describing " 150 " the output must have dimension 1 and the same " 151 " number of elements as the output. Got ", 152 split_size_shape.dims(), "-D and ", 153 split_size_shape.num_elements(), " elements")); 154 // get the dimension of this split 155 xla::Literal split_size_literal; 156 OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal)); 157 158 for (int i = 0; i < num_split; ++i) { 159 int slice_size; 160 slice_size = split_size_literal.Get<int>({i}); 161 if (slice_size == -1) { 162 OP_REQUIRES( 163 ctx, neg_one_dim == -1, 164 errors::InvalidArgument("Only one dimensions can have a value of" 165 "-1. Second one found at dimension ", 166 i)); 167 neg_one_dim = i; 168 } else { 169 split_sizes_vec[i] = slice_size; 170 total_split_size += slice_size; 171 } 172 } 173 174 OP_REQUIRES( 175 ctx, 176 (neg_one_dim == -1 && 177 total_split_size == input_shape.dim_size(split_dim)) || 178 (neg_one_dim >= 0 && 179 total_split_size <= input_shape.dim_size(split_dim)), 180 errors::InvalidArgument("Determined shape must either match " 181 "input shape along split_dim exactly if " 182 "fully specified, or be less than the size of " 183 "the input along split_dim if not fully " 184 "specified. Got: ", 185 total_split_size)); 186 187 if (neg_one_dim >= 0) { 188 split_sizes_vec[neg_one_dim] = 189 input_shape.dim_size(split_dim) - total_split_size; 190 } 191 192 // The vectors we will use to define the slice. The entry for the 193 // split dimensions varies for each output. 194 std::vector<int64> begin(input_shape.dims(), 0); 195 auto dim_sizes = input_shape.dim_sizes(); 196 std::vector<int64> limits(dim_sizes.begin(), dim_sizes.end()); 197 std::vector<int64> strides(input_shape.dims(), 1); 198 for (int i = 0; i < num_split; ++i) { 199 TensorShape output_shape(input_shape); 200 int slice_size = split_sizes_vec[i]; 201 output_shape.set_dim(split_dim, slice_size); 202 203 // Slice out the ith split from the split dimension. 204 limits[split_dim] = begin[split_dim] + slice_size; 205 ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides)); 206 begin[split_dim] = limits[split_dim]; 207 } 208 } 209 }; 210 211 REGISTER_XLA_OP(Name("SplitV") 212 .CompileTimeConstInput("split_dim") 213 .CompileTimeConstInput("size_splits"), 214 SplitVOp); 215 216 } // namespace 217 } // namespace tensorflow 218