1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific Concat Ops. 17 18 #include <limits> 19 #include <vector> 20 21 #include "tensorflow/compiler/tf2xla/type_util.h" 22 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 23 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 25 #include "tensorflow/compiler/xla/client/xla_builder.h" 26 #include "tensorflow/compiler/xla/literal_util.h" 27 #include "tensorflow/core/framework/bounds_check.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/register_types.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_shape.h" 32 #include "tensorflow/core/framework/tensor_types.h" 33 #include "tensorflow/core/framework/types.h" 34 #include "tensorflow/core/lib/core/status.h" 35 #include "tensorflow/core/platform/types.h" 36 37 namespace tensorflow { 38 namespace { 39 40 // -------------------------------------------------------------------------- 41 class ConcatBaseOp : public XlaOpKernel { 42 public: 43 ConcatBaseOp(OpKernelConstruction* c, int axis_index) 44 : XlaOpKernel(c), axis_index_(axis_index) {} 45 46 void Compile(XlaOpKernelContext* ctx) override { 47 const TensorShape concat_dim_tensor_shape = ctx->InputShape(axis_index_); 48 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(concat_dim_tensor_shape), 49 errors::InvalidArgument( 50 "Concat dim tensor should be a scalar, but got shape ", 51 concat_dim_tensor_shape.DebugString())); 52 int64 concat_dim; 53 OP_REQUIRES_OK(ctx, 54 ctx->ConstantInputAsIntScalar(axis_index_, &concat_dim)); 55 56 std::vector<xla::XlaOp> values; 57 std::vector<TensorShape> shapes; 58 OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes)); 59 const int N = values.size(); 60 const int input_dims = shapes[0].dims(); 61 const TensorShape& input_shape = shapes[0]; 62 63 int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; 64 OP_REQUIRES(ctx, 0 <= axis && axis < input_dims, 65 errors::InvalidArgument( 66 "ConcatOp : Expected concatenating dimensions in the range " 67 "[", 68 -input_dims, ", ", input_dims, "), but got ", concat_dim)); 69 70 // Make a vector holding the XlaOp for each of the inputs that has non-zero 71 // elements. 72 std::vector<xla::XlaOp> input_data; 73 int output_concat_dim = 0; 74 for (int i = 0; i < N; ++i) { 75 xla::XlaOp handle = values[i]; 76 const TensorShape& in_shape = shapes[i]; 77 OP_REQUIRES( 78 ctx, in_shape.dims() == input_dims, 79 errors::InvalidArgument( 80 "ConcatOp : Ranks of all input tensors should match: shape[0] = ", 81 input_shape.DebugString(), " vs. shape[", i, 82 "] = ", in_shape.DebugString())); 83 if (in_shape.dims() == 0) { 84 // Inputs that come in as scalars must be reshaped to 1-vectors. 85 input_data.push_back(xla::Reshape(handle, {1})); 86 } else { 87 input_data.push_back(handle); 88 } 89 output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1; 90 } 91 92 VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; 93 ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis)); 94 } 95 96 private: 97 int axis_index_; 98 }; 99 100 class ConcatOp : public ConcatBaseOp { 101 public: 102 explicit ConcatOp(OpKernelConstruction* c) 103 : ConcatBaseOp(c, /* axis_index */ 0) {} 104 }; 105 106 // ConcatV2 operation is the same as Concat except 'concat_dim' 107 // is the last input instead of the first and renamed to 'axis'. 108 class ConcatV2Op : public ConcatBaseOp { 109 public: 110 explicit ConcatV2Op(OpKernelConstruction* c) 111 : ConcatBaseOp(c, /* axis_index */ c->num_inputs() - 1) {} 112 }; 113 114 REGISTER_XLA_OP(Name("Concat").CompileTimeConstantInput("concat_dim"), 115 ConcatOp); 116 REGISTER_XLA_OP(Name("ConcatV2") 117 .TypeConstraint("Tidx", DT_INT32) 118 .CompileTimeConstantInput("axis"), 119 ConcatV2Op); 120 121 class ConcatOffsetOp : public XlaOpKernel { 122 public: 123 explicit ConcatOffsetOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 124 125 void Compile(XlaOpKernelContext* ctx) override { 126 const TensorShape concat_dim_shape = ctx->InputShape(0); 127 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(concat_dim_shape), 128 errors::InvalidArgument( 129 "Concat dim tensor should be a scalar, but got shape ", 130 concat_dim_shape.DebugString())); 131 for (int i = 1; i < ctx->num_inputs(); ++i) { 132 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ctx->InputShape(i)), 133 errors::InvalidArgument("input ", i, 134 " should be a vector, but got shape ", 135 ctx->InputShape(i).DebugString())); 136 } 137 // Suppose a Concat() op needs to Concatenate N tensors, each of 138 // which has the same number of dimensions. Their shapes match 139 // except the concat dimension. 140 // 141 // E.g., say, we want to concatenate 3 tensors in the 2nd 142 // dimension, and their shapes are: 143 // 144 // [2, 2, 5, 7] 145 // [2, 3, 5, 7] 146 // [2, 4, 5, 7] 147 // 148 // Here, N=3, cdim=1, dims=4. The concatenated tensor has shape 149 // [2,9,5,7]. We will compute the cumulative sum along the 2nd 150 // dimension to figure out each input's offset in the concatenated 151 // output: 152 // [0, 0, 0, 0] 153 // [0, 2, 0, 0] 154 // [0, 5, 0, 0] 155 const int32 N = ctx->num_inputs() - 1; 156 const TensorShape inp0_shape = ctx->InputShape(1); 157 std::vector<int64> inp0_dims; 158 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &inp0_dims)); 159 const int64 inp0_rank = inp0_shape.num_elements(); 160 161 int64 cdim; 162 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &cdim)); 163 164 VLOG(1) << "ConcatOffset " << cdim << "," << inp0_rank; 165 int32 axis = cdim < 0 ? cdim + inp0_rank : cdim; 166 OP_REQUIRES(ctx, FastBoundsCheck(axis, inp0_rank), 167 errors::InvalidArgument("Concat dim is out of range: ", axis, 168 " vs. ", inp0_rank)); 169 int32 offset = 0; 170 for (int i = 0; i < N; ++i) { 171 const TensorShape inp_shape = ctx->InputShape(1 + i); 172 OP_REQUIRES(ctx, inp0_rank == inp_shape.num_elements(), 173 errors::InvalidArgument("input ", i, " should contain ", 174 inp0_rank, " elements, but got ", 175 inp_shape.num_elements())); 176 std::vector<int64> inp_dims; 177 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1 + i, &inp_dims)); 178 179 Tensor out_constant(DT_INT32, TensorShape({inp0_rank})); 180 auto out_vec = out_constant.vec<int32>(); 181 for (int64 j = 0; j < inp0_rank; ++j) { 182 if (j == axis) { 183 out_vec(j) = offset; 184 offset += inp_dims[j]; 185 } else { 186 const int32 inp0_element = inp0_dims[j]; 187 const int32 inp_element = inp_dims[j]; 188 OP_REQUIRES(ctx, inp0_element == inp_element, 189 errors::InvalidArgument("input[", i, ",", j, 190 "] mismatch: ", inp0_element, 191 " vs. ", inp_element)); 192 out_vec(j) = 0; 193 } 194 } 195 196 ctx->SetConstantOutput(i, out_constant); 197 } 198 } 199 }; 200 201 REGISTER_XLA_OP(Name("ConcatOffset") 202 .CompileTimeConstantInput("concat_dim") 203 .CompileTimeConstantInput("shape"), 204 ConcatOffsetOp); 205 206 } // namespace 207 } // namespace tensorflow 208