1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific Concat Ops. 17 18 #include <limits> 19 #include <vector> 20 21 #include "tensorflow/compiler/tf2xla/type_util.h" 22 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 23 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 25 #include "tensorflow/compiler/xla/literal_util.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/tensor_types.h" 30 #include "tensorflow/core/framework/types.h" 31 #include "tensorflow/core/kernels/bounds_check.h" 32 #include "tensorflow/core/kernels/concat_lib.h" 33 #include "tensorflow/core/lib/core/status.h" 34 #include "tensorflow/core/platform/types.h" 35 36 namespace tensorflow { 37 namespace { 38 39 // -------------------------------------------------------------------------- 40 class ConcatBaseOp : public XlaOpKernel { 41 public: 42 ConcatBaseOp(OpKernelConstruction* c, int axis_index) 43 : XlaOpKernel(c), axis_index_(axis_index) {} 44 45 void Compile(XlaOpKernelContext* ctx) override { 46 const TensorShape concat_dim_tensor_shape = ctx->InputShape(axis_index_); 47 OP_REQUIRES( 48 ctx, IsLegacyScalar(concat_dim_tensor_shape), 49 errors::InvalidArgument( 50 "Concat dim tensor should be a scalar integer, but got shape ", 51 concat_dim_tensor_shape.DebugString())); 52 xla::Literal literal; 53 OP_REQUIRES_OK(ctx, ctx->ConstantInput(axis_index_, &literal)); 54 // TODO(annarev): add a helper to support int64 input. 55 const int32 concat_dim = literal.Get<int>({}); 56 57 std::vector<xla::ComputationDataHandle> values; 58 std::vector<TensorShape> shapes; 59 OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes)); 60 const int N = values.size(); 61 const int input_dims = shapes[0].dims(); 62 const TensorShape& input_shape = shapes[0]; 63 64 int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; 65 OP_REQUIRES(ctx, 66 (0 <= axis && axis < input_dims) || 67 (allow_legacy_scalars() && concat_dim == 0), 68 errors::InvalidArgument( 69 "ConcatOp : Expected concatenating dimensions in the range " 70 "[", 71 -input_dims, ", ", input_dims, "), but got ", concat_dim)); 72 73 // Make a vector holding the ComputationDataHandles for each of 74 // the inputs that has non-zero elements. 75 std::vector<xla::ComputationDataHandle> input_data; 76 int output_concat_dim = 0; 77 const bool input_is_scalar = IsLegacyScalar(input_shape); 78 for (int i = 0; i < N; ++i) { 79 xla::ComputationDataHandle handle = values[i]; 80 const TensorShape& in_shape = shapes[i]; 81 const bool in_is_scalar = IsLegacyScalar(in_shape); 82 OP_REQUIRES( 83 ctx, 84 in_shape.dims() == input_dims || (input_is_scalar && in_is_scalar), 85 errors::InvalidArgument( 86 "ConcatOp : Ranks of all input tensors should match: shape[0] = ", 87 input_shape.DebugString(), " vs. shape[", i, 88 "] = ", in_shape.DebugString())); 89 if (in_shape.dims() == 0) { 90 // Inputs that come in as scalars must be reshaped to 1-vectors. 91 input_data.push_back(ctx->builder()->Reshape(handle, {1})); 92 } else { 93 input_data.push_back(handle); 94 } 95 output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1; 96 } 97 98 VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; 99 ctx->SetOutput(0, ctx->builder()->ConcatInDim(input_data, axis)); 100 } 101 102 private: 103 int axis_index_; 104 }; 105 106 class ConcatOp : public ConcatBaseOp { 107 public: 108 explicit ConcatOp(OpKernelConstruction* c) 109 : ConcatBaseOp(c, /* axis_index */ 0) {} 110 }; 111 112 // ConcatV2 operation is the same as Concat except 'concat_dim' 113 // is the last input instead of the first and renamed to 'axis'. 114 class ConcatV2Op : public ConcatBaseOp { 115 public: 116 explicit ConcatV2Op(OpKernelConstruction* c) 117 : ConcatBaseOp(c, /* axis_index */ c->num_inputs() - 1) {} 118 }; 119 120 REGISTER_XLA_OP(Name("Concat").CompileTimeConstInput("concat_dim"), ConcatOp); 121 REGISTER_XLA_OP(Name("ConcatV2") 122 .TypeConstraint("Tidx", DT_INT32) 123 .CompileTimeConstInput("axis"), 124 ConcatV2Op); 125 126 class ConcatOffsetOp : public XlaOpKernel { 127 public: 128 explicit ConcatOffsetOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 129 130 void Compile(XlaOpKernelContext* ctx) override { 131 const TensorShape concat_dim_shape = ctx->InputShape(0); 132 OP_REQUIRES( 133 ctx, IsLegacyScalar(concat_dim_shape), 134 errors::InvalidArgument( 135 "Concat dim tensor should be a scalar integer, but got shape ", 136 concat_dim_shape.DebugString())); 137 for (int i = 1; i < ctx->num_inputs(); ++i) { 138 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ctx->InputShape(i)), 139 errors::InvalidArgument("input ", i, 140 " should be a vector, but got shape ", 141 ctx->InputShape(i).DebugString())); 142 } 143 // Suppose a Concat() op needs to Concatenate N tensors, each of 144 // which has the same number of dimensions. Their shapes match 145 // except the concat dimension. 146 // 147 // E.g., say, we want to concatenate 3 tensors in the 2nd 148 // dimension, and their shapes are: 149 // 150 // [2, 2, 5, 7] 151 // [2, 3, 5, 7] 152 // [2, 4, 5, 7] 153 // 154 // Here, N=3, cdim=1, dims=4. The concatenated tensor has shape 155 // [2,9,5,7]. We will compute the cumulative sum along the 2nd 156 // dimension to figure out each input's offset in the concatenated 157 // output: 158 // [0, 0, 0, 0] 159 // [0, 2, 0, 0] 160 // [0, 5, 0, 0] 161 const int32 N = ctx->num_inputs() - 1; 162 const TensorShape inp0_shape = ctx->InputShape(1); 163 xla::Literal inp0_literal; 164 OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &inp0_literal)); 165 const int64 dims = inp0_shape.num_elements(); 166 167 xla::Literal concat_dim_literal; 168 OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &concat_dim_literal)); 169 const int64 cdim = concat_dim_literal.Get<int>({}); 170 171 VLOG(1) << "ConcatOffset " << cdim << "," << dims; 172 int32 axis = cdim < 0 ? cdim + dims : cdim; 173 OP_REQUIRES(ctx, FastBoundsCheck(axis, dims), 174 errors::InvalidArgument("Concat dim is out of range: ", axis, 175 " vs. ", dims)); 176 int32 offset = 0; 177 for (int i = 0; i < N; ++i) { 178 const TensorShape inp_shape = ctx->InputShape(1 + i); 179 OP_REQUIRES(ctx, dims == inp_shape.num_elements(), 180 errors::InvalidArgument("input ", i, " should contain ", dims, 181 " elements, but got ", 182 inp_shape.num_elements())); 183 xla::Literal inp_literal; 184 OP_REQUIRES_OK(ctx, ctx->ConstantInput(1 + i, &inp_literal)); 185 186 Tensor out_constant(DT_INT32, TensorShape({dims})); 187 auto out_vec = out_constant.vec<int32>(); 188 for (int64 j = 0; j < dims; ++j) { 189 if (j == axis) { 190 out_vec(j) = offset; 191 offset += inp_literal.Get<int>({j}); 192 } else { 193 const int32 inp0_element = inp0_literal.Get<int>({j}); 194 const int32 inp_element = inp_literal.Get<int>({j}); 195 OP_REQUIRES(ctx, (inp0_element == inp_element), 196 errors::InvalidArgument("input[", i, ",", j, 197 "] mismatch: ", inp0_element, 198 " vs. ", inp_element)); 199 out_vec(j) = 0; 200 } 201 } 202 203 ctx->SetConstantOutput(i, out_constant); 204 } 205 } 206 }; 207 208 REGISTER_XLA_OP(Name("ConcatOffset") 209 .CompileTimeConstInput("concat_dim") 210 .CompileTimeConstInput("shape"), 211 ConcatOffsetOp); 212 213 } // namespace 214 } // namespace tensorflow 215