1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 Licensed under the Apache License, Version 2.0 (the "License"); 3 you may not use this file except in compliance with the License. 4 You may obtain a copy of the License at 5 http://www.apache.org/licenses/LICENSE-2.0 6 Unless required by applicable law or agreed to in writing, software 7 distributed under the License is distributed on an "AS IS" BASIS, 8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 See the License for the specific language governing permissions and 10 limitations under the License. 11 ==============================================================================*/ 12 13 #ifdef INTEL_MKL 14 15 #include <limits> 16 #include <vector> 17 18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/framework/register_types.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/framework/tensor_types.h" 23 #include "tensorflow/core/framework/types.h" 24 #include "tensorflow/core/kernels/bounds_check.h" 25 #include "tensorflow/core/kernels/concat_lib.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/platform/types.h" 28 29 #include "mkl_dnn.h" 30 #include "mkl_dnn_types.h" 31 #include "tensorflow/core/util/mkl_util.h" 32 33 #ifndef INTEL_MKL_ML 34 #include "mkldnn.hpp" 35 36 using mkldnn::concat; 37 using mkldnn::stream; 38 #endif 39 40 namespace tensorflow { 41 typedef Eigen::ThreadPoolDevice CPUDevice; 42 43 // List of TensorShape objects. Used in Concat/Split layers. 44 typedef std::vector<TensorShape> TensorShapeList; 45 46 enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM }; 47 48 // TODO(intelft) Check if we can reuse existing EigenConcatOp using Mutable 49 // reference inputs. 50 // -------------------------------------------------------------------------- 51 // Eigen Concat Op 52 // -------------------------------------------------------------------------- 53 template <typename Device, typename T, AxisArgumentName AxisArgName> 54 class EigenConcatBaseOp : public OpKernel { 55 public: 56 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> 57 ConstMatrixVector; 58 59 explicit EigenConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {} 60 61 // Although, we modify Compute for this call to accept one extra param, 62 // we need to have empty Compute because Compute is pure virtual function. 63 void Compute(OpKernelContext* c) {} 64 65 #ifdef INTEL_MKL_ML 66 67 void Compute(OpKernelContext* c, const std::vector<Tensor>& values) { 68 const Tensor* concat_dim_tensor; 69 const char* axis_attribute_name = 70 AxisArgName == NAME_IS_AXIS 71 ? "axis" 72 : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>"; 73 OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor)); 74 OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()), 75 errors::InvalidArgument( 76 axis_attribute_name, 77 " tensor should be a scalar integer, but got shape ", 78 concat_dim_tensor->shape().DebugString())); 79 const int32 concat_dim = 80 internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()()); 81 // Instead of accessing values from context, we use input to Compute. 82 const int N = values.size(); 83 const int input_dims = values[0].dims(); 84 const TensorShape& input_shape = values[0].shape(); 85 86 int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; 87 OP_REQUIRES(c, 88 (0 <= axis && axis < input_dims) || 89 (allow_legacy_scalars() && concat_dim == 0), 90 errors::InvalidArgument( 91 "ConcatOp : Expected concatenating dimensions in the range " 92 "[", 93 -input_dims, ", ", input_dims, "), but got ", concat_dim)); 94 // Note that we reduce the concat of n-dimensional tensors into a two 95 // dimensional concat. Assuming the dimensions of any input/output 96 // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along 97 // the dimension indicated with size y0, we flatten it to {x, y}, where y = 98 // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1). 99 ConstMatrixVector inputs_flat; 100 inputs_flat.reserve(N); 101 int64 inputs_flat_dim0 = 1; 102 for (int d = 0; d < axis; ++d) { 103 inputs_flat_dim0 *= input_shape.dim_size(d); 104 } 105 int64 output_concat_dim = 0; 106 const bool input_is_scalar = IsLegacyScalar(input_shape); 107 for (int i = 0; i < N; ++i) { 108 const auto in = values[i]; 109 const bool in_is_scalar = IsLegacyScalar(in.shape()); 110 OP_REQUIRES( 111 c, in.dims() == input_dims || (input_is_scalar && in_is_scalar), 112 errors::InvalidArgument( 113 "ConcatOp : Ranks of all input tensors should match: shape[0] = ", 114 input_shape.DebugString(), " vs. shape[", i, 115 "] = ", in.shape().DebugString())); 116 for (int j = 0; j < input_dims; ++j) { 117 if (j == axis) { 118 continue; 119 } 120 OP_REQUIRES( 121 c, in.dim_size(j) == input_shape.dim_size(j), 122 errors::InvalidArgument( 123 "ConcatOp : Dimensions of inputs should match: shape[0] = ", 124 input_shape.DebugString(), " vs. shape[", i, 125 "] = ", in.shape().DebugString())); 126 } 127 if (in.NumElements() > 0) { 128 int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0; 129 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( 130 in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1}))); 131 } 132 // TODO(irving): Remove check once !allow_legacy_scalars(). 133 output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1; 134 } 135 136 TensorShape output_shape(input_shape); 137 // TODO(irving): Remove rank 0 case once !allow_legacy_scalars(). 138 if (output_shape.dims() == 0) { 139 output_shape.AddDim(output_concat_dim); 140 } else { 141 output_shape.set_dim(axis, output_concat_dim); 142 } 143 Tensor* output = nullptr; 144 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); 145 if (output->NumElements() > 0) { 146 int64 output_dim1 = output->NumElements() / inputs_flat_dim0; 147 auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1}); 148 ConcatCPU<T>(c->device(), inputs_flat, &output_flat); 149 } 150 } 151 152 #else // MKL_DNN 153 154 void Compute(OpKernelContext* c, const std::vector<Tensor>& values, 155 const TensorShapeList& input_shapes) { 156 const Tensor* concat_dim_tensor; 157 const char* axis_attribute_name = 158 AxisArgName == NAME_IS_AXIS 159 ? "axis" 160 : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>"; 161 OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor)); 162 OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()), 163 errors::InvalidArgument( 164 axis_attribute_name, 165 " tensor should be a scalar integer, but got shape ", 166 concat_dim_tensor->shape().DebugString())); 167 const int32 concat_dim = 168 internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()()); 169 // Instead of accessing values from context, we use input to Compute. 170 const int N = values.size(); 171 const int input_dims = input_shapes[0].dims(); 172 const TensorShape& input_shape = input_shapes[0]; 173 174 int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; 175 OP_REQUIRES(c, 176 (0 <= axis && axis < input_dims) || 177 (allow_legacy_scalars() && concat_dim == 0), 178 errors::InvalidArgument( 179 "ConcatOp : Expected concatenating dimensions in the range " 180 "[", 181 -input_dims, ", ", input_dims, "), but got ", concat_dim)); 182 // Note that we reduce the concat of n-dimensional tensors into a two 183 // dimensional concat. Assuming the dimensions of any input/output 184 // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along 185 // the dimension indicated with size y0, we flatten it to {x, y}, where y = 186 // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1). 187 ConstMatrixVector inputs_flat; 188 inputs_flat.reserve(N); 189 int64 inputs_flat_dim0 = 1; 190 for (int d = 0; d < axis; ++d) { 191 inputs_flat_dim0 *= input_shape.dim_size(d); 192 } 193 int64 output_concat_dim = 0; 194 const bool input_is_scalar = IsLegacyScalar(input_shape); 195 for (int i = 0; i < N; ++i) { 196 const auto in = values[i]; 197 const bool in_is_scalar = IsLegacyScalar(input_shapes[i]); 198 OP_REQUIRES( 199 c, 200 (input_shapes[i].dims() == input_dims) || 201 (input_is_scalar && in_is_scalar), 202 errors::InvalidArgument( 203 "ConcatOp : Ranks of all input tensors should match: shape[0] = ", 204 input_shape.DebugString(), " vs. shape[", i, 205 "] = ", input_shapes[i].DebugString())); 206 if (in.NumElements() > 0) { 207 int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0; 208 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( 209 in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1}))); 210 } 211 output_concat_dim += 212 input_shapes[i].dims() > 0 ? input_shapes[i].dim_size(axis) : 1; 213 } 214 215 TensorShape output_shape(input_shape); 216 if (output_shape.dims() == 0) { 217 output_shape.AddDim(output_concat_dim); 218 } else { 219 output_shape.set_dim(axis, output_concat_dim); 220 } 221 Tensor* output = nullptr; 222 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); 223 if (output->NumElements() > 0) { 224 int64 output_dim1 = output->NumElements() / inputs_flat_dim0; 225 auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1}); 226 ConcatCPU<T>(c->device(), inputs_flat, &output_flat); 227 } 228 } 229 230 #endif 231 }; 232 233 #ifdef INTEL_MKL_ML 234 235 // -------------------------------------------------------------------------- 236 // Mkl Concat Op 237 // -------------------------------------------------------------------------- 238 239 template <typename Device, typename T, AxisArgumentName AxisArgName> 240 class MklConcatOp : public OpKernel { 241 private: 242 TensorFormat data_format_; 243 EigenConcatBaseOp<Device, T, AxisArgName> eigen_concat_op_; 244 245 public: 246 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> 247 ConstMatrixVector; 248 249 explicit MklConcatOp(OpKernelConstruction* c) 250 : OpKernel(c), eigen_concat_op_(c) {} 251 252 void Compute(OpKernelContext* context) override { 253 MklConcatOpContext mkl_context; 254 255 // Get input tensors. 256 OpInputList input_tensors; 257 GetMklInputList(context, "values", &input_tensors); 258 const int N = input_tensors.size(); 259 // Get MKL shapes. 260 MklShapeList input_shapes(N); 261 GetMklShapeList(context, "values", &input_shapes); 262 263 // If this is Concat, then concat_dim is 0th input. 264 // If this is ConcatV2, then axis is Nth input. 265 const Tensor& concat_dim_tensor = AxisArgName == NAME_IS_CONCAT_DIM 266 ? MklGetInput(context, 0) 267 : MklGetInput(context, N); 268 269 // Sanity checks 270 OP_REQUIRES( 271 context, IsLegacyScalar(concat_dim_tensor.shape()), 272 errors::InvalidArgument( 273 "Concat dim tensor should be a scalar integer, but got shape ", 274 concat_dim_tensor.shape().DebugString())); 275 int32 concat_dim = 276 internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()()); 277 278 MklShape& inpshape0 = input_shapes[0]; 279 280 // Check that all tensors are Mkl, if not we call Eigen version. 281 bool invoke_eigen = false; 282 bool is_concat_dim_channel = true; 283 if (!AreAllMklTensors(input_shapes)) { 284 invoke_eigen = true; 285 } 286 287 // Check that total number of dimensions is 4, if not call Eigen. 288 if (!invoke_eigen) { 289 for (auto& s : input_shapes) { 290 if (s.GetDimension() != 4) { 291 invoke_eigen = true; 292 break; 293 } 294 } 295 } 296 297 // check that concat_dim is channel, if not call Eigen version. 298 if (!invoke_eigen) { 299 for (auto& s : input_shapes) { 300 if (!s.IsMklChannelDim(concat_dim)) { 301 invoke_eigen = true; 302 is_concat_dim_channel = false; 303 break; 304 } 305 } 306 } 307 308 if (invoke_eigen) { 309 string msg = std::string("Invoking Eigen version of Concat. Reason:") + 310 (!is_concat_dim_channel 311 ? std::string("Concat dimension is not channel") 312 : std::string("Not all tensors are in Mkl layout")); 313 VLOG(1) << "_MklConcatOp: " << msg; 314 CallEigenVersion(context, input_tensors, input_shapes); 315 return; 316 } 317 318 // For MKL format, the channel is dimension number 2. 319 // So if we are concating over channel and _all_ inputs are in MKL 320 // format, then we set concat_dim to 2. 321 // Since we have reached till here, it means we are concating 322 // over channel. 323 concat_dim = MklDims::C; 324 325 // One more sanity check: check that ranks of all tensors match 326 // and that their shapes match except for concat_dim. 327 int i = 0; 328 for (auto& s : input_shapes) { 329 size_t exp_dims = inpshape0.GetDimension(); 330 OP_REQUIRES(context, s.GetDimension() == exp_dims, 331 errors::InvalidArgument( 332 "_MklConcatOp : Ranks of all input tensors should match:" 333 " input dimensions = ", 334 s.GetDimension(), " vs. expected rank = ", exp_dims)); 335 336 for (int d = 0; d < exp_dims; ++d) { 337 if (d == concat_dim) { 338 continue; 339 } 340 341 size_t exp_size = inpshape0.GetSizes()[d]; 342 OP_REQUIRES( 343 context, exp_size == s.GetSizes()[d], 344 errors::InvalidArgument("_MklConcatOp : Dimensions of inputs" 345 "should match: shape[0][", 346 d, "]= ", exp_size, " vs. shape[", i, "][", 347 d, "] = ", s.GetSizes()[d])); 348 } 349 ++i; 350 } 351 352 // Use input MKL layout instead of creating new layouts. 353 int64 output_concat_dim_size = 0; 354 for (auto& s : input_shapes) { 355 output_concat_dim_size += 356 s.GetDimension() > 0 ? s.GetSizes()[concat_dim] : 1; 357 } 358 mkl_context.MklCreateInputLayouts(context, input_shapes); 359 OP_REQUIRES_OK(context, context->status()); 360 361 CHECK_EQ(dnnConcatCreate_F32(&mkl_context.prim_concat, NULL, N, 362 &mkl_context.lt_inputs[0]), 363 E_SUCCESS); 364 365 // Calculate output sizes and strides 366 TensorFormat data_format; 367 if (inpshape0.IsTensorInNHWCFormat()) { 368 data_format = FORMAT_NHWC; 369 } else { 370 OP_REQUIRES( 371 context, inpshape0.IsTensorInNCHWFormat(), 372 errors::InvalidArgument( 373 "_MklConcat only supports all inputs in NCHW or NHWC format ")); 374 data_format = FORMAT_NCHW; 375 } 376 377 // Since all tensors are in Mkl layout, we copy sizes from input tensor. 378 mkl_context.out_sizes[MklDims::W] = inpshape0.GetSizes()[MklDims::W]; 379 mkl_context.out_sizes[MklDims::H] = inpshape0.GetSizes()[MklDims::H]; 380 mkl_context.out_sizes[MklDims::C] = output_concat_dim_size; 381 mkl_context.out_sizes[MklDims::N] = inpshape0.GetSizes()[MklDims::N]; 382 GetStridesFromSizes(data_format, mkl_context.out_strides, 383 mkl_context.out_sizes); 384 385 // Set output Mkl shape. 386 int64 dim = 4; 387 MklShape mkl_output_mkl_shape; 388 mkl_output_mkl_shape.SetMklTensor(true); 389 mkl_output_mkl_shape.SetMklLayout(mkl_context.prim_concat, dnnResourceDst); 390 mkl_output_mkl_shape.SetTfLayout(dim, mkl_context.out_sizes, 391 mkl_context.out_strides); 392 mkl_output_mkl_shape.SetTfDimOrder(dim, inpshape0.GetTfToMklDimMap()); 393 394 TensorShape mkl_output_tf_shape; 395 mkl_output_tf_shape.AddDim(1); 396 mkl_output_tf_shape.AddDim( 397 dnnLayoutGetMemorySize_F32( 398 static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) / 399 sizeof(T)); 400 401 Tensor* output = nullptr; 402 AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape, 403 mkl_output_mkl_shape); 404 405 // Set destination resource. 406 mkl_context.concat_res[dnnResourceDst] = 407 const_cast<void*>(static_cast<const void*>(output->flat<T>().data())); 408 409 mkl_context.mkl_tmp_tensors.resize(N); 410 mkl_context.MklPrepareConcatInputs(context, input_tensors); 411 OP_REQUIRES_OK(context, context->status()); 412 413 // Execute primitive. 414 CHECK_EQ(dnnExecute_F32(mkl_context.prim_concat, mkl_context.concat_res), 415 E_SUCCESS); 416 417 mkl_context.MklCleanup(); 418 OP_REQUIRES_OK(context, context->status()); 419 } 420 421 private: 422 typedef struct { 423 TensorFormat data_format; 424 size_t out_sizes[4]; 425 size_t out_strides[4]; 426 dnnPrimitive_t prim_concat; 427 void* concat_res[dnnResourceNumber]; 428 std::vector<dnnLayout_t> lt_inputs; 429 std::vector<Tensor> mkl_tmp_tensors; 430 431 // Create MKL dnnLayout_t objects for tensors coming into the layer 432 // We only support case where input tensors are all in Mkl layout. 433 void MklCreateInputLayouts(OpKernelContext* context, 434 MklShapeList& input_shapes) { 435 for (auto& is : input_shapes) { 436 CHECK_EQ(is.IsMklTensor(), true); 437 lt_inputs.push_back((dnnLayout_t)is.GetCurLayout()); 438 } 439 } 440 441 void MklPrepareConcatInputs(OpKernelContext* context, 442 OpInputList& input_tensors) { 443 CHECK_EQ(lt_inputs.size(), mkl_tmp_tensors.size()); 444 445 for (int i = 0; i < lt_inputs.size(); ++i) { 446 dnnPrimitive_t mkl_prim_convert_input; 447 dnnLayout_t mkl_lt_internal_input; 448 void* mkl_buf_convert_input = nullptr; 449 450 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( 451 &mkl_lt_internal_input, prim_concat, 452 (dnnResourceType_t)(dnnResourceMultipleSrc + i)), 453 E_SUCCESS); 454 455 if (!dnnLayoutCompare_F32(lt_inputs[i], mkl_lt_internal_input)) { 456 CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, 457 lt_inputs[i], mkl_lt_internal_input), 458 E_SUCCESS); 459 460 AllocTmpBuffer(context, &mkl_tmp_tensors[i], mkl_lt_internal_input, 461 &mkl_buf_convert_input); 462 463 CHECK_EQ(dnnConversionExecute_F32( 464 mkl_prim_convert_input, 465 const_cast<void*>(static_cast<const void*>( 466 input_tensors[i].flat<T>().data())), 467 mkl_buf_convert_input), 468 E_SUCCESS); 469 470 concat_res[dnnResourceMultipleSrc + i] = mkl_buf_convert_input; 471 CHECK_EQ(dnnDelete_F32(mkl_prim_convert_input), E_SUCCESS); 472 } else { 473 concat_res[dnnResourceMultipleSrc + i] = const_cast<void*>( 474 static_cast<const void*>(input_tensors[i].flat<T>().data())); 475 } 476 477 CHECK_EQ(dnnLayoutDelete_F32(mkl_lt_internal_input), E_SUCCESS); 478 } 479 } 480 481 void MklCleanup() { 482 for (auto& lt : lt_inputs) { 483 lt = nullptr; 484 } 485 CHECK_EQ(dnnDelete_F32(prim_concat), E_SUCCESS); 486 } 487 } MklConcatOpContext; 488 489 void CallEigenVersion(OpKernelContext* context, const OpInputList& values, 490 const MklShapeList& input_shapes) { 491 // Before calling Eigen version, we need to convert Mkl tensors to TF. 492 // First check that the number of input tensors and the number of Mkl 493 // shapes match. 494 CHECK_EQ(values.size(), input_shapes.size()); 495 496 std::vector<Tensor> converted_values; 497 for (int i = 0; i < input_shapes.size(); i++) { 498 if (input_shapes[i].IsMklTensor()) { 499 // If input tensor is Mkl, then do the conversion. 500 Tensor tmp_tensor = 501 ConvertMklToTF<T>(context, values[i], input_shapes[i]); 502 converted_values.push_back(tmp_tensor); 503 } else { 504 // If input tensor is TF already, then we do not need any conversion. 505 converted_values.push_back(values[i]); 506 } 507 } 508 509 // Call Eigen concat. 510 eigen_concat_op_.Compute(context, converted_values); 511 512 // Set dummy Mkl tensor as output Mkl tensor for this op. 513 MklShape mkl_tensor_mkl_shape; 514 mkl_tensor_mkl_shape.SetMklTensor(false); 515 mkl_tensor_mkl_shape.SetDimensions(4); 516 mkl_tensor_mkl_shape.SetTfDimOrder(4); // Dimensions 517 Tensor* mkl_tensor = nullptr; 518 TensorShape mkl_tensor_tf_shape; 519 mkl_tensor_tf_shape.AddDim( 520 SIZE_OF_MKL_SERIAL_DATA(mkl_tensor_mkl_shape.GetDimension())); 521 int tf_output_index = 0; 522 context->allocate_output( 523 GetTensorMetaDataIndex(tf_output_index, context->num_outputs()), 524 mkl_tensor_tf_shape, &mkl_tensor); 525 mkl_tensor_mkl_shape.SerializeMklShape( 526 mkl_tensor->flat<uint8>().data(), 527 mkl_tensor->flat<uint8>().size() * sizeof(uint8)); 528 } 529 530 // overloading methods with input shapes as a list of TensorShape's 531 void CallEigenVersion(OpKernelContext* context, const OpInputList& values, 532 const TensorShapeList& input_shapes) { 533 CHECK_EQ(values.size(), input_shapes.size()); 534 535 std::vector<Tensor> converted_values; 536 for (int i = 0; i < input_shapes.size(); i++) { 537 converted_values.push_back(values[i]); 538 } 539 540 // Call Eigen concat. 541 eigen_concat_op_.Compute(context, converted_values); 542 543 // Set dummy Mkl tensor as output Mkl tensor for this op. 544 MklShape mkl_tensor_mkl_shape; 545 mkl_tensor_mkl_shape.SetMklTensor(false); 546 mkl_tensor_mkl_shape.SetDimensions(4); 547 Tensor* mkl_tensor = nullptr; 548 TensorShape mkl_tensor_tf_shape; 549 mkl_tensor_tf_shape.AddDim( 550 SIZE_OF_MKL_SERIAL_DATA(mkl_tensor_mkl_shape.GetDimension())); 551 int tf_output_index = 0; 552 context->allocate_output( 553 GetTensorMetaDataIndex(tf_output_index, context->num_outputs()), 554 mkl_tensor_tf_shape, &mkl_tensor); 555 mkl_tensor_mkl_shape.SerializeMklShape( 556 mkl_tensor->flat<uint8>().data(), 557 mkl_tensor->flat<uint8>().size() * sizeof(uint8)); 558 } 559 }; 560 561 #else 562 563 // -------------------------------------------------------------------------- 564 // Mkl Concat Op 565 // -------------------------------------------------------------------------- 566 567 template <typename Device, typename T, AxisArgumentName AxisArgName> 568 class MklConcatOp : public OpKernel { 569 private: 570 TensorFormat data_format_; 571 EigenConcatBaseOp<Device, T, AxisArgName> eigen_concat_op_; 572 573 public: 574 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> 575 ConstMatrixVector; 576 577 explicit MklConcatOp(OpKernelConstruction* c) 578 : OpKernel(c), eigen_concat_op_(c) {} 579 580 void Compute(OpKernelContext* context) override { 581 try { 582 auto cpu_engine = engine(engine::cpu, 0); 583 OpInputList input_tensors; 584 GetMklInputList(context, "values", &input_tensors); 585 const int N = input_tensors.size(); 586 587 // Get Tensor shapes. 588 std::vector<MklDnnShape> input_shapes(N); 589 GetMklShapeList(context, "values", &input_shapes); 590 591 const Tensor& concat_dim_tensor = (AxisArgName == NAME_IS_CONCAT_DIM) 592 ? MklGetInput(context, 0) 593 : MklGetInput(context, N); 594 // Sanity checks 595 OP_REQUIRES( 596 context, IsLegacyScalar(concat_dim_tensor.shape()), 597 errors::InvalidArgument( 598 "Concat dim tensor should be a scalar integer, but got shape ", 599 concat_dim_tensor.shape().DebugString())); 600 int32 concat_dim = 601 internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()()); 602 603 // check that ranks of all tensors match 604 // and that their shapes match except for concat_dim. 605 int i = 0; 606 bool invoke_eigen = false; 607 bool are_all_mkl_inputs = true, are_all_tf_inputs = true; 608 const TensorShape expected_shape = input_shapes[0].IsMklTensor() 609 ? input_shapes[0].GetTfShape() 610 : input_tensors[0].shape(); 611 size_t expected_dims = expected_shape.dims(); 612 613 if (concat_dim < 0) concat_dim = expected_dims + concat_dim; 614 615 for (auto& s : input_shapes) { 616 if (s == expected_shape) { 617 ++i; 618 continue; 619 } 620 621 TensorShape s_shape = 622 s.IsMklTensor() ? s.GetTfShape() : input_tensors[i].shape(); 623 size_t s_dims = s_shape.dims(); 624 625 OP_REQUIRES( 626 context, s_dims == expected_dims, 627 errors::InvalidArgument( 628 "_MklConcatOp : Ranks of all input tensors should match:" 629 " input dimensions = ", 630 s_dims, " vs. expected rank = ", expected_dims)); 631 632 for (int d = 0; d < expected_dims; ++d) { 633 if (d == concat_dim) continue; 634 635 size_t expected_size = expected_shape.dim_size(d); 636 size_t s_size = s_shape.dim_size(d); 637 OP_REQUIRES( 638 context, expected_size == s_size, 639 errors::InvalidArgument("_MklConcatOp : Dimensions of inputs " 640 "should match: shape[0][", 641 d, "]= ", expected_size, " vs. shape[", i, 642 "][", d, "] = ", s_size)); 643 } 644 645 if (s.IsMklTensor()) 646 are_all_tf_inputs = false; 647 else 648 are_all_mkl_inputs = false; 649 650 if (s_dims != 4) invoke_eigen = true; 651 ++i; 652 } 653 654 // All inputs are not in one format (TF or MKL). This is mixed input case. 655 // We can potentially optimize this case by converting all TF inputs 656 // to Mkl format. But currently, we fall to Eigen for this case. 657 // It may be possible to convert inputs that in TF format to Mkl 658 // format and avoid calling eigen version. 659 if (!are_all_tf_inputs && !are_all_mkl_inputs) invoke_eigen = true; 660 661 // Call Eigen library 662 if (invoke_eigen) { 663 TensorShapeList tf_input_shapes; 664 i = 0; 665 for (auto& s : input_shapes) { 666 TensorShape s_shape = 667 s.IsMklTensor() ? s.GetTfShape() : input_tensors[i].shape(); 668 tf_input_shapes.push_back(s_shape); 669 ++i; 670 } 671 CallEigenVersion(context, input_tensors, tf_input_shapes); 672 return; 673 } 674 675 memory::dims dst_dims; 676 if (are_all_mkl_inputs) 677 dst_dims = TFShapeToMklDnnDims(input_shapes[0].GetTfShape()); 678 else 679 // When all the inputs are in Tensorflow format, we don't know 680 // what is the input data format. In that case, we just use 681 // output format that is same as input formats. 682 dst_dims = TFShapeToMklDnnDims(input_tensors[0].shape()); 683 684 std::vector<memory::primitive_desc> srcs_pd; 685 std::vector<MklDnnData<T>> srcs(N, MklDnnData<T>(&cpu_engine)); 686 int64 dst_concat_dim_size = 0; 687 for (int k = 0; k < N; k++) { 688 bool is_mkl_tensor = input_shapes[k].IsMklTensor(); 689 memory::dims src_dims; 690 691 // Same comment as dst_dims for src_dims. 692 src_dims = (is_mkl_tensor) 693 ? TFShapeToMklDnnDims(input_shapes[k].GetTfShape()) 694 : TFShapeToMklDnnDims(input_tensors[k].shape()); 695 696 dst_concat_dim_size += src_dims[concat_dim]; 697 auto src_md = 698 is_mkl_tensor ? input_shapes[k].GetMklLayout() : 699 // It does not matter what data format we use here 700 // (NHWC or NCHW). We just need to ensure that output 701 // of Concat uses same data format as input. 702 memory::desc(src_dims, MklDnnType<T>(), memory::format::nchw); 703 704 srcs[k].SetUsrMem(src_md, &input_tensors[k]); 705 auto src_mpd = srcs[k].GetUsrMemPrimDesc(); 706 srcs_pd.push_back(src_mpd); 707 } 708 dst_dims[concat_dim] = dst_concat_dim_size; 709 710 MklDnnData<T> dst(&cpu_engine); 711 memory::desc dst_md({}, memory::data_undef, memory::format_undef); 712 memory::dims dst_dims_in_nchw; 713 if (are_all_mkl_inputs) { 714 // Since we are passing a specific format for destination, 715 // we need to have dst_dims in MklDnn order (NCHW). 716 auto orig_tf_format = input_shapes[0].GetTfDataFormat(); 717 dst_dims_in_nchw = MklDnnDimsInNCHW( 718 dst_dims, MklDnnDataFormatToTFDataFormat(orig_tf_format)); 719 // We will set the output in the same format as input to avoid layout 720 // conversions. 721 // Currently we are setting dst format same as input format. 722 // See if we can make this choice in a better way. 723 dst_md = memory::desc( 724 dst_dims_in_nchw, MklDnnType<T>(), 725 (memory::format)input_shapes[0].GetMklLayout().data.format); 726 } else { 727 // Again, format does not matter here. We just need to make it same as 728 // input format. 729 dst_md = memory::desc(dst_dims, MklDnnType<T>(), memory::format::nchw); 730 } 731 732 std::vector<primitive::at> inputs; 733 for (int k = 0; k < input_tensors.size(); k++) 734 inputs.push_back(srcs[k].GetOpMem()); 735 736 // If all inputs are in MKL format, then meaning of concat_dim needs to 737 // change. Value of concat_dim is tied to input Tensorflow data format 738 // (NHWC or NCHW). MklDnn dimensions are in NCHW order. So if Tensorflow 739 // tensors are in NCHW order, then concat_dim semantics is preserved. 740 // But ifinput tensors are in NHWC order, then semantics need to change. 741 // E.g., if we are concatinating over Channel (dimension 3 for NHWC), 742 // then since MklDnn order is NCHW, concat_dim needs to be 1. 743 if (are_all_mkl_inputs) concat_dim = input_shapes[0].TfDimIdx(concat_dim); 744 745 auto concat_pd = concat::primitive_desc(dst_md, concat_dim, srcs_pd); 746 747 MklDnnShape dnn_shape_dst; 748 TensorShape tf_shape_dst; 749 Tensor* dst_tensor = nullptr; 750 if (are_all_mkl_inputs) { 751 dnn_shape_dst.SetMklTensor(true); 752 auto dst_pd = concat_pd.dst_primitive_desc(); 753 dnn_shape_dst.SetMklLayout(&dst_pd); 754 dnn_shape_dst.SetElemType(MklDnnType<T>()); 755 dnn_shape_dst.SetTfLayout(dst_dims.size(), dst_dims_in_nchw, 756 input_shapes[0].GetTfDataFormat()); 757 tf_shape_dst.AddDim((dst_pd.get_size() / sizeof(T))); 758 } else { 759 dnn_shape_dst.SetMklTensor(false); 760 tf_shape_dst = MklDnnDimsToTFShape(dst_dims); 761 } 762 AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst, 763 dnn_shape_dst); 764 CHECK_NOTNULL(dst_tensor); 765 766 dst_md = 767 dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout() : dst_md; 768 dst.SetUsrMem(dst_md, dst_tensor); 769 770 auto concat_op = concat(concat_pd, inputs, dst.GetOpMem()); 771 std::vector<primitive> net; 772 net.push_back(concat_op); 773 stream(stream::kind::eager).submit(net).wait(); 774 } catch (mkldnn::error& e) { 775 string error_msg = "Status: " + std::to_string(e.status) + 776 ", message: " + string(e.message) + ", in file " + 777 string(__FILE__) + ":" + std::to_string(__LINE__); 778 OP_REQUIRES_OK( 779 context, 780 errors::Aborted("Operation received an exception:", error_msg)); 781 } 782 } 783 784 void CallEigenVersion(OpKernelContext* context, const OpInputList& values, 785 const TensorShapeList& input_shapes) { 786 CHECK_EQ(values.size(), input_shapes.size()); 787 788 std::vector<Tensor> converted_values; 789 for (int i = 0; i < input_shapes.size(); i++) 790 converted_values.push_back(values[i]); 791 792 // Call Eigen concat. 793 eigen_concat_op_.Compute(context, converted_values, input_shapes); 794 795 // Set output Mkl tensor for this op. 796 MklDnnShape dnn_shape_output; 797 dnn_shape_output.SetMklTensor(false); 798 dnn_shape_output.SetDimensions(4); 799 Tensor* output_tensor = nullptr; 800 TensorShape tf_shape_output; 801 tf_shape_output.AddDim(dnn_shape_output.GetSerializeBufferSize()); 802 context->allocate_output(GetTensorMetaDataIndex(0, context->num_outputs()), 803 tf_shape_output, &output_tensor); 804 dnn_shape_output.SerializeMklDnnShape( 805 output_tensor->flat<uint8>().data(), 806 output_tensor->flat<uint8>().size() * sizeof(uint8)); 807 } 808 }; 809 810 #endif 811 812 /* Use optimized concat for float type only */ 813 #define REGISTER_MKL_CPU(type) \ 814 REGISTER_KERNEL_BUILDER(Name("_MklConcat") \ 815 .Device(DEVICE_CPU) \ 816 .TypeConstraint<type>("T") \ 817 .HostMemory("concat_dim") \ 818 .Label(mkl_op_registry::kMklOpLabel), \ 819 MklConcatOp<CPUDevice, type, NAME_IS_CONCAT_DIM>) \ 820 REGISTER_KERNEL_BUILDER(Name("_MklConcatV2") \ 821 .Device(DEVICE_CPU) \ 822 .TypeConstraint<type>("T") \ 823 .TypeConstraint<int32>("Tidx") \ 824 .HostMemory("axis") \ 825 .Label(mkl_op_registry::kMklOpLabel), \ 826 MklConcatOp<CPUDevice, type, NAME_IS_AXIS>) 827 828 TF_CALL_float(REGISTER_MKL_CPU); 829 830 #undef REGISTER_CONCAT_MKL 831 } // namespace tensorflow 832 833 #endif // INTEL_MKL 834