1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc. 17 18 #include <limits> 19 #include <vector> 20 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_types.h" 26 #include "tensorflow/core/framework/types.h" 27 #include "tensorflow/core/kernels/bounds_check.h" 28 #include "tensorflow/core/kernels/concat_lib.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/platform/types.h" 31 32 namespace tensorflow { 33 34 typedef Eigen::ThreadPoolDevice CPUDevice; 35 #if GOOGLE_CUDA 36 typedef Eigen::GpuDevice GPUDevice; 37 #endif // GOOGLE_CUDA 38 #ifdef TENSORFLOW_USE_SYCL 39 typedef Eigen::SyclDevice SYCLDevice; 40 #endif // TENSORFLOW_USE_SYCL 41 42 enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM }; 43 44 // -------------------------------------------------------------------------- 45 template <typename Device, typename T, AxisArgumentName AxisArgName> 46 class ConcatBaseOp : public OpKernel { 47 public: 48 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> 49 ConstMatrixVector; 50 51 explicit ConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {} 52 53 void Compute(OpKernelContext* c) override { 54 const Tensor* concat_dim_tensor; 55 const char* axis_attribute_name = 56 AxisArgName == NAME_IS_AXIS 57 ? "axis" 58 : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>"; 59 OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor)); 60 OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()), 61 errors::InvalidArgument( 62 axis_attribute_name, 63 " tensor should be a scalar integer, but got shape ", 64 concat_dim_tensor->shape().DebugString())); 65 const int32 concat_dim = 66 internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()()); 67 OpInputList values; 68 OP_REQUIRES_OK(c, c->input_list("values", &values)); 69 const int N = values.size(); 70 const int input_dims = values[0].dims(); 71 const TensorShape& input_shape = values[0].shape(); 72 73 int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; 74 OP_REQUIRES(c, 75 (0 <= axis && axis < input_dims) || 76 (allow_legacy_scalars() && concat_dim == 0), 77 errors::InvalidArgument( 78 "ConcatOp : Expected concatenating dimensions in the range " 79 "[", 80 -input_dims, ", ", input_dims, "), but got ", concat_dim)); 81 // Note that we reduce the concat of n-dimensional tensors into a two 82 // dimensional concat. Assuming the dimensions of any input/output 83 // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along 84 // the dimension indicated with size y0, we flatten it to {x, y}, where y = 85 // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1). 86 ConstMatrixVector inputs_flat; 87 inputs_flat.reserve(N); 88 int64 inputs_flat_dim0 = 1; 89 for (int d = 0; d < axis; ++d) { 90 inputs_flat_dim0 *= input_shape.dim_size(d); 91 } 92 int64 output_concat_dim = 0; 93 const bool input_is_scalar = IsLegacyScalar(input_shape); 94 for (int i = 0; i < N; ++i) { 95 const auto in = values[i]; 96 const bool in_is_scalar = IsLegacyScalar(in.shape()); 97 OP_REQUIRES( 98 c, in.dims() == input_dims || (input_is_scalar && in_is_scalar), 99 errors::InvalidArgument( 100 "ConcatOp : Ranks of all input tensors should match: shape[0] = ", 101 input_shape.DebugString(), " vs. shape[", i, 102 "] = ", in.shape().DebugString())); 103 for (int j = 0; j < input_dims; ++j) { 104 if (j == axis) { 105 continue; 106 } 107 OP_REQUIRES( 108 c, in.dim_size(j) == input_shape.dim_size(j), 109 errors::InvalidArgument( 110 "ConcatOp : Dimensions of inputs should match: shape[0] = ", 111 input_shape.DebugString(), " vs. shape[", i, 112 "] = ", in.shape().DebugString())); 113 } 114 if (in.NumElements() > 0) { 115 int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0; 116 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( 117 in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1}))); 118 } 119 // TODO(irving): Remove check once !allow_legacy_scalars(). 120 output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1; 121 } 122 123 TensorShape output_shape(input_shape); 124 // TODO(irving): Remove rank 0 case once !allow_legacy_scalars(). 125 if (output_shape.dims() == 0) { 126 output_shape.AddDim(output_concat_dim); 127 } else { 128 output_shape.set_dim(axis, output_concat_dim); 129 } 130 Tensor* output = nullptr; 131 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); 132 if (output->NumElements() > 0) { 133 int64 output_dim1 = output->NumElements() / inputs_flat_dim0; 134 auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1}); 135 #if GOOGLE_CUDA 136 if (std::is_same<Device, GPUDevice>::value) { 137 ConcatGPU<T>(c, inputs_flat, output, &output_flat); 138 return; 139 } 140 #endif // GOOGLE_CUDA 141 #ifdef TENSORFLOW_USE_SYCL 142 if (std::is_same<Device, SYCLDevice>::value) { 143 ConcatSYCL<T>(c->eigen_sycl_device(), inputs_flat, &output_flat); 144 return; 145 } 146 #endif // TENSORFLOW_USE_SYCL 147 ConcatCPU<T>(c->device(), inputs_flat, &output_flat); 148 } 149 } 150 }; 151 152 template <typename Device, typename T> 153 using ConcatOp = ConcatBaseOp<Device, T, NAME_IS_CONCAT_DIM>; 154 template <typename Device, typename T> 155 using ConcatV2Op = ConcatBaseOp<Device, T, NAME_IS_AXIS>; 156 157 #define REGISTER_CONCAT(type) \ 158 REGISTER_KERNEL_BUILDER(Name("Concat") \ 159 .Device(DEVICE_CPU) \ 160 .TypeConstraint<type>("T") \ 161 .HostMemory("concat_dim"), \ 162 ConcatOp<CPUDevice, type>) \ 163 REGISTER_KERNEL_BUILDER(Name("ConcatV2") \ 164 .Device(DEVICE_CPU) \ 165 .TypeConstraint<type>("T") \ 166 .TypeConstraint<int32>("Tidx") \ 167 .HostMemory("axis"), \ 168 ConcatV2Op<CPUDevice, type>) 169 170 TF_CALL_POD_STRING_TYPES(REGISTER_CONCAT); 171 REGISTER_CONCAT(quint8); 172 REGISTER_CONCAT(qint8); 173 REGISTER_CONCAT(quint16); 174 REGISTER_CONCAT(qint16); 175 REGISTER_CONCAT(qint32); 176 177 #undef REGISTER_CONCAT 178 179 #if GOOGLE_CUDA 180 181 #define REGISTER_GPU(type) \ 182 REGISTER_KERNEL_BUILDER(Name("Concat") \ 183 .Device(DEVICE_GPU) \ 184 .TypeConstraint<type>("T") \ 185 .HostMemory("concat_dim"), \ 186 ConcatOp<GPUDevice, type>) \ 187 REGISTER_KERNEL_BUILDER(Name("ConcatV2") \ 188 .Device(DEVICE_GPU) \ 189 .TypeConstraint<type>("T") \ 190 .TypeConstraint<int32>("Tidx") \ 191 .HostMemory("axis"), \ 192 ConcatV2Op<GPUDevice, type>) 193 194 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); 195 REGISTER_GPU(bfloat16); 196 TF_CALL_complex64(REGISTER_GPU); 197 TF_CALL_complex128(REGISTER_GPU); 198 TF_CALL_int64(REGISTER_GPU); 199 REGISTER_GPU(bool); 200 #undef REGISTER_GPU 201 202 // A special GPU kernel for int32. 203 // TODO(b/25387198): Also enable int32 in device memory. This kernel 204 // registration requires all int32 inputs and outputs to be in host memory. 205 REGISTER_KERNEL_BUILDER(Name("Concat") 206 .Device(DEVICE_GPU) 207 .TypeConstraint<int32>("T") 208 .HostMemory("concat_dim") 209 .HostMemory("values") 210 .HostMemory("output"), 211 ConcatOp<CPUDevice, int32>); 212 REGISTER_KERNEL_BUILDER(Name("ConcatV2") 213 .Device(DEVICE_GPU) 214 .TypeConstraint<int32>("T") 215 .TypeConstraint<int32>("Tidx") 216 .HostMemory("values") 217 .HostMemory("axis") 218 .HostMemory("output"), 219 ConcatV2Op<CPUDevice, int32>); 220 221 #endif // GOOGLE_CUDA 222 223 #ifdef TENSORFLOW_USE_SYCL 224 #define REGISTER_SYCL(type) \ 225 REGISTER_KERNEL_BUILDER(Name("Concat") \ 226 .Device(DEVICE_SYCL) \ 227 .TypeConstraint<type>("T") \ 228 .HostMemory("concat_dim"), \ 229 ConcatOp<SYCLDevice, type>) \ 230 REGISTER_KERNEL_BUILDER(Name("ConcatV2") \ 231 .Device(DEVICE_SYCL) \ 232 .TypeConstraint<type>("T") \ 233 .TypeConstraint<int32>("Tidx") \ 234 .HostMemory("axis"), \ 235 ConcatV2Op<SYCLDevice, type>) 236 237 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL); 238 239 REGISTER_KERNEL_BUILDER(Name("Concat") 240 .Device(DEVICE_SYCL) 241 .TypeConstraint<int32>("T") 242 .HostMemory("concat_dim") 243 .HostMemory("values") 244 .HostMemory("output"), 245 ConcatOp<CPUDevice, int32>); 246 REGISTER_KERNEL_BUILDER(Name("ConcatV2") 247 .Device(DEVICE_SYCL) 248 .TypeConstraint<int32>("T") 249 .TypeConstraint<int32>("Tidx") 250 .HostMemory("values") 251 .HostMemory("axis") 252 .HostMemory("output"), 253 ConcatV2Op<CPUDevice, int32>); 254 255 #undef REGISTER_SYCL 256 #endif // TENSORFLOW_USE_SYCL 257 258 class ConcatOffsetOp : public OpKernel { 259 public: 260 explicit ConcatOffsetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 261 262 void Compute(OpKernelContext* ctx) override { 263 const Tensor& concat_dim = ctx->input(0); 264 OP_REQUIRES( 265 ctx, IsLegacyScalar(concat_dim.shape()), 266 errors::InvalidArgument( 267 "Concat dim tensor should be a scalar integer, but got shape ", 268 concat_dim.shape().DebugString())); 269 for (int i = 1; i < ctx->num_inputs(); ++i) { 270 const Tensor& inp = ctx->input(i); 271 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(inp.shape()), 272 errors::InvalidArgument("input ", i, 273 " should be a vector, but got shape ", 274 inp.shape().DebugString())); 275 } 276 // Suppose a Concat() op needs to Concatenate N tensors, each of 277 // which has the same number of dimensions. Their shapes match 278 // except the concat dimension. 279 // 280 // E.g., say, we want to concatenate 3 tensors in the 2nd 281 // dimension, and their shapes are: 282 // 283 // [2, 2, 5, 7] 284 // [2, 3, 5, 7] 285 // [2, 4, 5, 7] 286 // 287 // Here, N=3, cdim=1, dims=4. The concatenated tensor has shape 288 // [2,9,5,7]. We will compute the cumulative sum along the 2nd 289 // dimension to figure out each input's offset in the concatenated 290 // output: 291 // [0, 0, 0, 0] 292 // [0, 2, 0, 0] 293 // [0, 5, 0, 0] 294 const int32 N = ctx->num_inputs() - 1; 295 const Tensor& inp0 = ctx->input(1); 296 auto inp0_vec = inp0.vec<int32>(); 297 const int64 cdim = internal::SubtleMustCopy(concat_dim.scalar<int32>()()); 298 const int64 dims = inp0.NumElements(); 299 int32 axis = cdim < 0 ? cdim + dims : cdim; 300 OP_REQUIRES(ctx, FastBoundsCheck(axis, dims), 301 errors::InvalidArgument("Concat dim is out of range: ", cdim, 302 " vs. ", dims)); 303 int32 offset = 0; 304 for (int i = 0; i < N; ++i) { 305 const Tensor& inp = ctx->input(1 + i); 306 OP_REQUIRES( 307 ctx, dims == inp.NumElements(), 308 errors::InvalidArgument("input ", i, " should contain ", dims, 309 " elements, but got ", inp.NumElements())); 310 auto inp_vec = inp.vec<int32>(); 311 Tensor* out = nullptr; 312 OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out)); 313 auto out_vec = out->vec<int32>(); 314 for (int64 j = 0; j < dims; ++j) { 315 if (j == axis) { 316 out_vec(j) = offset; 317 offset += inp_vec(j); 318 } else { 319 OP_REQUIRES(ctx, (inp0_vec(j) == inp_vec(j)), 320 errors::InvalidArgument( 321 "All dimensions except ", axis, " must match. Input ", 322 i, " has shape [", inp.SummarizeValue(10), 323 "] and doesn't match input 0 with shape [", 324 inp0.SummarizeValue(10), "].")); 325 out_vec(j) = 0; 326 } 327 } 328 } 329 } 330 331 bool IsExpensive() override { return false; } 332 }; 333 334 REGISTER_KERNEL_BUILDER(Name("ConcatOffset").Device(DEVICE_CPU), 335 ConcatOffsetOp); 336 337 REGISTER_KERNEL_BUILDER(Name("ConcatOffset") 338 .Device(DEVICE_GPU) 339 .HostMemory("concat_dim") 340 .HostMemory("shape") 341 .HostMemory("offset"), 342 ConcatOffsetOp); 343 344 #ifdef TENSORFLOW_USE_SYCL 345 REGISTER_KERNEL_BUILDER(Name("ConcatOffset") 346 .Device(DEVICE_SYCL) 347 .HostMemory("concat_dim") 348 .HostMemory("shape") 349 .HostMemory("offset"), 350 ConcatOffsetOp); 351 #endif // TENSORFLOW_USE_SYCL 352 } // namespace tensorflow 353