1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifdef INTEL_MKL 17 18 #include <algorithm> 19 #include <vector> 20 #include "tensorflow/core/framework/numeric_op.h" 21 #include "tensorflow/core/framework/op.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/kernels/ops_util.h" 27 #include "tensorflow/core/platform/cpu_info.h" 28 #include "tensorflow/core/platform/macros.h" 29 #include "tensorflow/core/util/tensor_format.h" 30 31 #include "tensorflow/core/kernels/mkl_tfconv_op.h" 32 #include "tensorflow/core/util/mkl_util.h" 33 34 #ifndef INTEL_MKL_ML 35 #include "mkldnn.hpp" 36 37 using mkldnn::stream; 38 #endif 39 40 namespace tensorflow { 41 typedef Eigen::ThreadPoolDevice CPUDevice; 42 43 /////////////////////////////////////////////////////////// 44 // Op kernel 45 // Checks and ensures that the 2 inputs are compatible for mkl binary ops. 46 // Here's the basic logic: 47 // 48 // if both inputs are in TF format: 49 // pass the inputs through to the output 50 // else if both inputs are in mkl format: 51 // if both have the same shape: 52 // pass the inputs through to the output 53 // else: 54 // convert both to TF 55 // else if one is TF and one is MKL: 56 // if broadcast is needed: 57 // convert the MKL format input to TF format 58 // else: 59 // convert the TF format input to MKL format 60 /////////////////////////////////////////////////////////// 61 62 #ifdef INTEL_MKL_ML 63 template <typename Device, typename T> 64 class MklInputConversionOp : public OpKernel { 65 public: 66 explicit MklInputConversionOp(OpKernelConstruction* context) 67 : OpKernel(context) { 68 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); 69 OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type)); 70 has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F); 71 } 72 73 private: 74 void Compute(OpKernelContext* context) override { 75 // Check if input tensors are in MKL format. 76 const Tensor& input_tensor_0 = MklGetInput(context, 0); 77 MklShape input_shape_0; 78 GetMklShape(context, 0, &input_shape_0); 79 80 const Tensor& input_tensor_1 = MklGetInput(context, 1); 81 MklShape input_shape_1; 82 GetMklShape(context, 1, &input_shape_1); 83 84 bool tf_shapes_are_same = MklCompareShapes(&context->input(0).shape(), 85 &context->input(1).shape()); 86 87 VLOG(1) << "MklInputConversionOp: Input shapes are " 88 << (tf_shapes_are_same ? "*same*" : "*different*") << ": " 89 << context->input(0).shape().DebugString() << " and " 90 << context->input(1).shape().DebugString(); 91 92 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 93 // if both inputs are in TF format, just copy input tensors to output. 94 if (!input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) { 95 VLOG(1) << "MklInputConversionOp: No conversion needed, " 96 << "copying TF inputs to output"; 97 98 ForwardTfTensorInToOut(context, 0, 0); 99 ForwardTfTensorInToOut(context, 1, 1); 100 return; 101 } 102 103 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 104 // If both inputs are in MKL format 105 if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) { 106 // If both have the same shape, pass them through 107 if (tf_shapes_are_same) { 108 VLOG(1) << "MklInputConversionOp: No conversion needed, " 109 << "copying MKL inputs with identical shapes to output"; 110 111 ForwardMklTensorInToOut(context, 0, 0); 112 ForwardMklTensorInToOut(context, 1, 1); 113 return; 114 } 115 116 // Sanity check 117 bool mkl_shapes_are_same = 118 MklCompareShapes(&input_shape_0, &input_shape_1); 119 if (mkl_shapes_are_same) { 120 CHECK(false) << "MklInputConversionOp: Unexpected: TF shapes are " 121 "different but MKL shapes are same"; 122 } 123 124 // Both have different shapes, so broadcast will be necessary. 125 // Convert to TF and pass both tensors through (we can't do broadcast 126 // with MKL tensors) 127 VLOG(1) << "MklInputConversionOp: Broadcast needed, " 128 << "converted MKL inputs to TF format"; 129 130 MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, 131 op_data_type, has_avx512f_, 0); 132 MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, 133 op_data_type, has_avx512f_, 1); 134 SetDummyMklShapeOutput(context, 0); 135 SetDummyMklShapeOutput(context, 1); 136 return; 137 } 138 139 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 140 // One input is MKL and one is TF. If no broadcast is needed, convert 141 // the TF tensor to MKL, otherwise convert the MKL tensor to TF format 142 VLOG(1) << "MklInputConversionOp: Inputs in different formats (MKL/TF)"; 143 144 const Tensor* mkl_tensor; 145 const MklShape* mkl_shape; 146 const Tensor* tf_tensor; 147 MklShape* tf_mkl_shape; 148 uint32 mkl_tensor_index; 149 uint32 tf_tensor_index; 150 if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) { 151 mkl_tensor = &input_tensor_0; 152 mkl_shape = &input_shape_0; 153 mkl_tensor_index = 0; 154 tf_tensor = &input_tensor_1; 155 tf_mkl_shape = &input_shape_1; 156 tf_tensor_index = 1; 157 } else if (!input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) { 158 mkl_tensor = &input_tensor_1; 159 mkl_shape = &input_shape_1; 160 mkl_tensor_index = 1; 161 tf_tensor = &input_tensor_0; 162 tf_mkl_shape = &input_shape_0; 163 tf_tensor_index = 0; 164 } else { 165 CHECK(false) << "MklInputConversionOp: Unexpected combination of input " 166 "shapes for MKL " 167 << "element-wise op"; 168 } 169 170 // Broadcast is needed if the shapes are not the same 171 bool broadcast_needed; 172 173 size_t in0_size = 1; 174 for (size_t i = 0; i < mkl_shape->GetDimension(); ++i) 175 in0_size *= mkl_shape->tf_dim_size(i); 176 177 size_t in1_size = 1; 178 for (size_t i = 0; i < tf_tensor->shape().dims(); ++i) 179 in1_size *= tf_tensor->shape().dim_size(i); 180 181 broadcast_needed = (in0_size != in1_size); 182 183 if (!broadcast_needed) { 184 // Both shapes are same, convert the TF input to MKL 185 VLOG(1) << "MklInputConversionOp: No broadcast needed."; 186 VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index 187 << " to MKL format"; 188 189 // Create MklShape 190 Tensor* tensor_out; 191 MklShape mkl_output_mkl_shape; 192 mkl_output_mkl_shape.SetMklTensor(true); 193 mkl_output_mkl_shape.SetTfLayout(mkl_shape->GetDimension(), 194 mkl_shape->GetSizes(), 195 mkl_shape->GetStrides()); 196 mkl_output_mkl_shape.SetTfDimOrder(mkl_shape->GetDimension()); 197 198 // ** Temporarily borrow the layout from the MKL input ** 199 mkl_output_mkl_shape.SetMklLayout(mkl_shape->GetCurLayout()); 200 201 // Create output tensor 202 AllocateOutputSetMklShape(context, tf_tensor_index, &tensor_out, 203 mkl_tensor->shape(), mkl_output_mkl_shape); 204 205 // Since the shapes are the same, use information from the other tensor 206 tf_mkl_shape->SetTfLayout(mkl_shape->GetDimension(), 207 mkl_shape->GetSizes(), mkl_shape->GetStrides()); 208 // Convert the data format 209 tf_mkl_shape->GetConvertedFlatData( 210 mkl_shape->GetCurLayout(), 211 const_cast<T*>(tf_tensor->flat<T>().data()), 212 const_cast<T*>(tensor_out->flat<T>().data())); 213 214 // ** Release the borrowed layout to avoid double deletion 215 // in the destructor call ** 216 mkl_output_mkl_shape.SetMklLayout(nullptr); 217 218 // -- The tensor in MKL format passes through -- 219 ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index); 220 } else { 221 // Broadcast is needed, so convert the MKL input to TF 222 VLOG(1) << "MklInputConversionOp: Broadcast needed."; 223 VLOG(1) << "MklInputConversionOp: Converting input " << mkl_tensor_index 224 << " to TF format"; 225 MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, 226 op_data_type, has_avx512f_, 227 mkl_tensor_index); 228 SetDummyMklShapeOutput(context, mkl_tensor_index); 229 230 // The tensor in TF format passes through 231 ForwardTfTensorInToOut(context, tf_tensor_index, tf_tensor_index); 232 } 233 234 VLOG(1) << "MklInputConversionOp: Shapes (output): " 235 << context->mutable_output(0)->shape().DebugString() << " and " 236 << context->mutable_output(1)->shape().DebugString(); 237 238 VLOG(1) << "MklInputConversion completed successfully."; 239 } 240 241 private: 242 /// Data format of the operation 243 string data_format_str; 244 245 /// Data type of the operation 246 DataType op_data_type; 247 248 /// CPUIDInfo 249 bool has_avx512f_ = false; 250 }; 251 252 #else 253 254 template <typename Device, typename T> 255 class MklInputConversionOp : public OpKernel { 256 public: 257 explicit MklInputConversionOp(OpKernelConstruction* context) 258 : OpKernel(context) { 259 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); 260 OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type)); 261 has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F); 262 } 263 264 private: 265 void Compute(OpKernelContext* context) override { 266 const Tensor& input_tensor_0 = MklGetInput(context, 0); 267 MklDnnShape input_shape_0; 268 GetMklShape(context, 0, &input_shape_0); 269 270 const Tensor& input_tensor_1 = MklGetInput(context, 1); 271 MklDnnShape input_shape_1; 272 GetMklShape(context, 1, &input_shape_1); 273 274 bool tf_shapes_are_same = 275 context->input(0).shape() == context->input(1).shape(); 276 277 VLOG(1) << "MklInputConversionOp: Input shapes are " 278 << (tf_shapes_are_same ? "*same*" : "*different*") << ": " 279 << context->input(0).shape().DebugString() << " and " 280 << context->input(1).shape().DebugString(); 281 282 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 283 // if both inputs are in TF format, just copy input tensors to output. 284 if (!input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) { 285 VLOG(1) << "MklInputConversionOp: No conversion needed, " 286 << "copying TF inputs to output"; 287 288 ForwardTfTensorInToOut(context, 0, 0); 289 ForwardTfTensorInToOut(context, 1, 1); 290 return; 291 } 292 293 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 294 // If both inputs are in MKL format 295 if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) { 296 if (tf_shapes_are_same) { 297 auto input0_md = input_shape_0.GetMklLayout(); 298 auto input1_md = input_shape_1.GetMklLayout(); 299 300 // If both have the same shape and same format, pass them through 301 if (input0_md.data.format == input1_md.data.format) { 302 VLOG(1) << "MklInputConversionOp: No conversion needed, " 303 << "copying MKL inputs with identical shapes to output"; 304 305 ForwardMklTensorInToOut(context, 0, 0); 306 ForwardMklTensorInToOut(context, 1, 1); 307 return; 308 } else { 309 VLOG(1) << "MklInputConversionOp: Shape is same, but format is " 310 "different, " 311 << "need to convert to same format"; 312 313 // Convert input0, and keep input1 unchanged 314 // Create MklDnnShape for output mkl tensor based on input0 315 Tensor* tensor_out; 316 MklDnnShape mkl_output_mkl_shape; 317 mkl_output_mkl_shape.SetMklTensor(true); 318 mkl_output_mkl_shape.SetElemType(MklDnnType<T>()); 319 mkl_output_mkl_shape.SetTfLayout(input_shape_0.GetDimension(), 320 input_shape_0.GetSizesAsMklDnnDims(), 321 input_shape_0.GetTfDataFormat()); 322 323 // Get MKL layout from input1 as destination layout 324 mkl_output_mkl_shape.SetMklLayout(&input1_md); 325 326 // Create output Mkl tensor for index 0 327 AllocateOutputSetMklShape(context, 0, &tensor_out, 328 input_tensor_0.shape(), 329 mkl_output_mkl_shape); 330 331 // Create MklDnnData object for input0 tesnsor 332 auto cpu_engine = engine(engine::cpu, 0); 333 MklDnnData<T> input(&cpu_engine); 334 input.SetUsrMem(input0_md, &input_tensor_0); 335 336 // Create reorder from input0's layout to input1's layout 337 std::vector<primitive> net; 338 CHECK_EQ(input.CheckReorderToOpMem( 339 memory::primitive_desc(input1_md, cpu_engine), 340 tensor_out, &net), 341 true); 342 stream(stream::kind::eager).submit(net).wait(); 343 344 // Input1 will be passed through 345 ForwardMklTensorInToOut(context, 1, 1); 346 return; 347 } 348 } 349 350 // Sanity check 351 bool mkl_shapes_are_same = input_shape_0 == input_shape_1; 352 if (mkl_shapes_are_same) { 353 CHECK(false) << "MklInputConversionOp: Unexpected: TF shapes are " 354 "different but MKL shapes are same"; 355 } 356 357 // Both have different shapes, so broadcast will be necessary. 358 // Convert to TF and pass both tensors through (we can't do broadcast 359 // with MKL tensors) 360 VLOG(1) << "MklInputConversionOp: Broadcast needed, " 361 << "converted MKL inputs to TF format"; 362 363 MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, 364 op_data_type, has_avx512f_, 0); 365 MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, 366 op_data_type, has_avx512f_, 1); 367 SetDummyMklShapeOutput(context, 0); 368 SetDummyMklShapeOutput(context, 1); 369 return; 370 } 371 372 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 373 // One input is MKL and one is TF. If no broadcast is needed, convert 374 // the TF tensor to MKL, otherwise convert the MKL tensor to TF format 375 VLOG(1) << "MklInputConversionOp: Inputs in different formats (MKL/TF)"; 376 377 const Tensor* mkl_tensor; 378 const MklDnnShape* mkl_shape; 379 const Tensor* tf_tensor; 380 MklDnnShape* tf_mkl_shape; 381 uint mkl_tensor_index; 382 uint tf_tensor_index; 383 if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) { 384 mkl_tensor = &input_tensor_0; 385 mkl_shape = &input_shape_0; 386 mkl_tensor_index = 0; 387 tf_tensor = &input_tensor_1; 388 tf_mkl_shape = &input_shape_1; 389 tf_tensor_index = 1; 390 } else if (!input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) { 391 mkl_tensor = &input_tensor_1; 392 mkl_shape = &input_shape_1; 393 mkl_tensor_index = 1; 394 tf_tensor = &input_tensor_0; 395 tf_mkl_shape = &input_shape_0; 396 tf_tensor_index = 0; 397 } else { 398 CHECK(false) << "MklInputConversionOp: Unexpected combination of input " 399 "shapes for MKL " 400 << "element-wise op"; 401 } 402 403 // Broadcast is needed if the shapes are not the same 404 bool broadcast_needed; 405 406 size_t in0_size = 1; 407 for (size_t i = 0; i < mkl_shape->GetDimension(); ++i) 408 in0_size *= mkl_shape->TfDimSize(i); 409 410 size_t in1_size = 1; 411 for (size_t i = 0; i < tf_tensor->shape().dims(); ++i) 412 in1_size *= tf_tensor->shape().dim_size(i); 413 414 broadcast_needed = (in0_size != in1_size); 415 416 if (!broadcast_needed) { 417 // Both shapes are same, convert the TF input to MKL 418 VLOG(1) << "MklInputConversionOp: No broadcast needed."; 419 VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index 420 << " to MKL format"; 421 422 // Create MklDnnShape for output Mkl tensor. 423 Tensor* tensor_out; 424 MklDnnShape mkl_output_mkl_shape; 425 mkl_output_mkl_shape.SetMklTensor(true); 426 mkl_output_mkl_shape.SetElemType(MklDnnType<T>()); 427 mkl_output_mkl_shape.SetTfLayout(mkl_shape->GetDimension(), 428 mkl_shape->GetSizesAsMklDnnDims(), 429 mkl_shape->GetTfDataFormat()); 430 // ** Temporarily borrow the layout from the MKL input ** 431 auto output_mkl_md = mkl_shape->GetMklLayout(); 432 mkl_output_mkl_shape.SetMklLayout(&output_mkl_md); 433 434 // Create output Mkl tensor 435 AllocateOutputSetMklShape(context, tf_tensor_index, &tensor_out, 436 mkl_tensor->shape(), mkl_output_mkl_shape); 437 438 // Create MklDnnData object for input tensor. Input tensor is in 439 // Tensorflow layout. 440 auto cpu_engine = engine(engine::cpu, 0); 441 MklDnnData<T> tf_input(&cpu_engine); 442 auto input_tf_md = mkl_output_mkl_shape.GetTfLayout(); 443 tf_input.SetUsrMem(input_tf_md, tf_tensor); 444 445 // Create reorder between tensorflow layout and Mkl layout. 446 std::vector<primitive> net; 447 CHECK_EQ(tf_input.CheckReorderToOpMem( 448 memory::primitive_desc(output_mkl_md, cpu_engine), 449 tensor_out, &net), 450 true); 451 stream(stream::kind::eager).submit(net).wait(); 452 453 // -- The tensor in MKL format passes through -- 454 ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index); 455 } else { 456 // Broadcast is needed, so convert the MKL input to TF 457 VLOG(1) << "MklInputConversionOp: Broadcast needed."; 458 VLOG(1) << "MklInputConversionOp: Converting input " << mkl_tensor_index 459 << " to TF format"; 460 MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, 461 op_data_type, has_avx512f_, 462 mkl_tensor_index); 463 SetDummyMklShapeOutput(context, mkl_tensor_index); 464 465 // The tensor in TF format passes through 466 ForwardTfTensorInToOut(context, tf_tensor_index, tf_tensor_index); 467 } 468 469 VLOG(1) << "MklInputConversionOp: Shapes (output): " 470 << context->mutable_output(0)->shape().DebugString() << " and " 471 << context->mutable_output(1)->shape().DebugString(); 472 473 VLOG(1) << "MklInputConversion completed successfully."; 474 } 475 476 private: 477 /// Data format of the operation 478 string data_format_str; 479 480 /// Data type of the operation 481 DataType op_data_type; 482 483 /// CPUIDInfo 484 bool has_avx512f_ = false; 485 }; 486 487 #endif 488 489 /////////////////////////////////////////////////////////// 490 // Register kernel 491 /////////////////////////////////////////////////////////// 492 493 #define REGISTER_CPU(T) \ 494 REGISTER_KERNEL_BUILDER(Name("_MklInputConversion") \ 495 .Device(DEVICE_CPU) \ 496 .TypeConstraint<T>("T") \ 497 .Label(mkl_op_registry::kMklOpLabel), \ 498 MklInputConversionOp<CPUDevice, T>); 499 500 // TODO(nhasabni): We cannot support all number types since MklDnn does 501 // not support types. 502 // TF_CALL_NUMBER_TYPES(REGISTER_CPU); 503 TF_CALL_float(REGISTER_CPU); 504 #undef REGISTER_CPU 505 } // namespace tensorflow 506 #endif // INTEL_MKL 507