1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 #ifdef INTEL_MKL 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 #include "tensorflow/core/framework/numeric_op.h" 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/register_types.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/lib/core/errors.h" 25 26 #include "mkl_dnn.h" 27 #include "mkl_dnn_types.h" 28 #include "tensorflow/core/platform/default/logging.h" 29 #include "tensorflow/core/util/mkl_util.h" 30 31 #ifndef INTEL_MKL_ML 32 #include "mkldnn.hpp" 33 34 using mkldnn::algorithm; 35 using mkldnn::eltwise_elu; 36 using mkldnn::eltwise_relu; 37 using mkldnn::eltwise_tanh; 38 using mkldnn::prop_kind; 39 using mkldnn::relu_backward; 40 using mkldnn::relu_forward; 41 using mkldnn::stream; 42 #endif 43 44 namespace tensorflow { 45 46 typedef Eigen::ThreadPoolDevice CPUDevice; 47 48 struct MklReluHelpers { 49 static void ValidateSameSizeHelper(OpKernelContext* context, const Tensor& g, 50 const Tensor& a) { 51 OP_REQUIRES(context, a.IsSameSize(g), 52 errors::InvalidArgument("g and a must be the same size")); 53 } 54 static bool ValidateSameSize(OpKernelContext* context, const Tensor& g, 55 const Tensor& a) { 56 ValidateSameSizeHelper(context, g, a); 57 return context->status().ok(); 58 } 59 }; 60 61 #ifdef INTEL_MKL_ML 62 63 template <typename Device, typename T> 64 class MklReluOp : public OpKernel { 65 public: 66 ~MklReluOp() {} 67 68 explicit MklReluOp(OpKernelConstruction* context) : OpKernel(context) {} 69 70 void Compute(OpKernelContext* context) override { 71 MklReluOpContext mkl_context; 72 73 const Tensor& input = MklGetInput(context, 0); 74 GetMklShape(context, 0, &mkl_context.input_shape); 75 void* user_i = static_cast<void*>(const_cast<T*>(input.flat<T>().data())); 76 bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); 77 78 if (!input_in_mkl_format && !input.dims()) { // handle the case of a scalar 79 const TensorShape& o_shape = input.shape(); 80 Tensor* out_tensor = nullptr; 81 mkl_context.output_shape.SetMklTensor(false); 82 AllocateOutputSetMklShape(context, 0, &out_tensor, o_shape, 83 mkl_context.output_shape); 84 void* out_o = static_cast<void*>(out_tensor->flat<T>().data()); 85 (static_cast<T*>(out_o))[0] = 86 std::max((static_cast<T*>(user_i))[0], static_cast<T>(0)); 87 return; 88 } 89 90 // Generate size, stride for input if input is in MKL format. 91 if (input_in_mkl_format) { 92 mkl_context.in_dims = mkl_context.input_shape.GetDimension(); 93 mkl_context.in_sizes = new size_t[mkl_context.in_dims]; 94 mkl_context.in_strides = new size_t[mkl_context.in_dims]; 95 for (int i = 0; i < mkl_context.in_dims; i++) { 96 mkl_context.in_sizes[i] = mkl_context.input_shape.GetSizes()[i]; 97 mkl_context.in_strides[i] = mkl_context.input_shape.GetStrides()[i]; 98 } 99 } else { 100 mkl_context.in_dims = input.dims(); 101 mkl_context.in_sizes = new size_t[mkl_context.in_dims]; 102 mkl_context.in_strides = new size_t[mkl_context.in_dims]; 103 for (int i = 0; i < mkl_context.in_dims; i++) { 104 mkl_context.in_sizes[i] = input.dim_size((mkl_context.in_dims - 1) - i); 105 } 106 mkl_context.in_strides[0] = 1; 107 for (int i = 1; i < mkl_context.in_dims; i++) { 108 mkl_context.in_strides[i] = 109 mkl_context.in_strides[i - 1] * mkl_context.in_sizes[i - 1]; 110 } 111 } 112 113 float negative_slope = 0.0; 114 mkl_context.MklCreateInputLayouts(context); 115 CHECK_EQ(dnnReLUCreateForward_F32(&mkl_context.prim_relu_fwd, NULL, 116 mkl_context.lt_input, negative_slope), 117 E_SUCCESS); 118 119 Tensor* output = nullptr; 120 121 if (input_in_mkl_format) { 122 TensorShape tf_shape; 123 mkl_context.output_shape.SetMklTensor(true); 124 mkl_context.output_shape.SetMklLayout(mkl_context.prim_relu_fwd, 125 dnnResourceDst); 126 mkl_context.output_shape.SetTfLayout( 127 mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides); 128 mkl_context.output_shape.SetTfDimOrder( 129 mkl_context.in_dims, mkl_context.input_shape.GetTfToMklDimMap()); 130 tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>( 131 mkl_context.output_shape.GetMklLayout())) / 132 sizeof(T)); 133 AllocateOutputSetMklShape(context, 0, &output, tf_shape, 134 mkl_context.output_shape); 135 } else { 136 const TensorShape& o_shape = input.shape(); 137 mkl_context.output_shape.SetMklTensor(false); 138 AllocateOutputSetMklShape(context, 0, &output, o_shape, 139 mkl_context.output_shape); 140 } 141 142 void* user_o = static_cast<void*>(const_cast<T*>(output->flat<T>().data())); 143 144 mkl_context.relu_res[dnnResourceDst] = user_o; 145 mkl_context.relu_res[dnnResourceSrc] = user_i; 146 CHECK_EQ(dnnExecute_F32(mkl_context.prim_relu_fwd, mkl_context.relu_res), 147 E_SUCCESS); 148 mkl_context.MklCleanup(); 149 } 150 151 private: 152 typedef struct { 153 int in_dims; 154 size_t* in_sizes; 155 size_t* in_strides; 156 MklShape input_shape, output_shape; 157 dnnPrimitive_t prim_relu_fwd = nullptr; 158 void* relu_res[dnnResourceNumber]; 159 dnnLayout_t lt_input = nullptr; 160 161 void MklCleanup() { 162 bool input_in_mkl_format = input_shape.IsMklTensor(); 163 if (!input_in_mkl_format) { 164 dnnLayoutDelete_F32(lt_input); 165 free(in_sizes); 166 free(in_strides); 167 } 168 dnnDelete_F32(prim_relu_fwd); 169 } 170 171 void MklCreateInputLayouts(OpKernelContext* context) { 172 bool input_in_mkl_format = input_shape.IsMklTensor(); 173 if (!input_in_mkl_format) { 174 CHECK_EQ(dnnLayoutCreate_F32(<_input, in_dims, in_sizes, in_strides), 175 E_SUCCESS); 176 } else { 177 lt_input = static_cast<dnnLayout_t>(input_shape.GetCurLayout()); 178 } 179 } 180 } MklReluOpContext; 181 }; 182 183 template <typename Device, typename T> 184 class MklReluGradOp : public OpKernel { 185 public: 186 ~MklReluGradOp() {} 187 188 explicit MklReluGradOp(OpKernelConstruction* context) : OpKernel(context) {} 189 190 void Compute(OpKernelContext* context) override; 191 192 private: 193 typedef struct { 194 int in_dims; 195 size_t* in_sizes; 196 size_t* in_strides; 197 MklShape input_shape, grad_shape, output_shape; 198 void* relu_res[dnnResourceNumber]; 199 dnnPrimitive_t prim_relu_bwd; 200 dnnLayout_t lt_input, lt_grad; 201 202 void MklPrepareReluGradInputs(OpKernelContext* context, 203 Tensor* mkl_tmp_input_buf_tensor) { 204 const Tensor& g = MklGetInput(context, 0); 205 const Tensor& a = MklGetInput(context, 1); 206 void* buf_input = static_cast<void*>(const_cast<T*>(a.flat<T>().data())); 207 void* mkl_buffer_convert = nullptr; 208 209 dnnPrimitive_t cv_input_to_grad = nullptr; 210 211 // if input and grad are not in the same layout, 212 // do a conversion between them. 213 if (!dnnLayoutCompare_F32(lt_input, lt_grad)) { 214 AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_grad, 215 &mkl_buffer_convert); 216 CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, lt_grad), 217 E_SUCCESS); 218 CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, buf_input, 219 mkl_buffer_convert), 220 E_SUCCESS); 221 relu_res[dnnResourceSrc] = mkl_buffer_convert; 222 dnnDelete_F32(cv_input_to_grad); 223 } else { 224 relu_res[dnnResourceSrc] = buf_input; 225 } 226 227 void* buf_grad = static_cast<void*>(const_cast<T*>(g.flat<T>().data())); 228 relu_res[dnnResourceDiffDst] = buf_grad; 229 } 230 231 void MklCreateInputLayouts(OpKernelContext* context) { 232 bool grad_is_mkl = grad_shape.IsMklTensor(); 233 bool input_is_mkl = input_shape.IsMklTensor(); 234 if (!input_is_mkl) { 235 CHECK_EQ(dnnLayoutCreate_F32(<_input, in_dims, in_sizes, in_strides), 236 E_SUCCESS); 237 } else { 238 lt_input = static_cast<dnnLayout_t>(input_shape.GetCurLayout()); 239 } 240 241 if (!grad_is_mkl) { 242 CHECK_EQ(dnnLayoutCreate_F32(<_grad, in_dims, in_sizes, in_strides), 243 E_SUCCESS); 244 } else { 245 lt_grad = static_cast<dnnLayout_t>(grad_shape.GetCurLayout()); 246 } 247 } 248 249 void MklCleanup() { 250 bool grad_is_mkl = grad_shape.IsMklTensor(); 251 bool input_is_mkl = input_shape.IsMklTensor(); 252 dnnDelete_F32(prim_relu_bwd); 253 if (!input_is_mkl) { 254 dnnLayoutDelete_F32(lt_input); 255 free(in_sizes); 256 free(in_strides); 257 } 258 if (!grad_is_mkl) { 259 dnnLayoutDelete_F32(lt_grad); 260 } 261 } 262 } MklReluGradOpContext; 263 }; 264 265 template <typename Device, typename T> 266 void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) { 267 MklReluGradOpContext mkl_context; 268 const Tensor& g = MklGetInput(context, 0); 269 const Tensor& a = MklGetInput(context, 1); 270 271 void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data())); 272 void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data())); 273 274 GetMklShape(context, 0, &mkl_context.grad_shape); 275 GetMklShape(context, 1, &mkl_context.input_shape); 276 277 bool grad_is_mkl = mkl_context.grad_shape.IsMklTensor(); 278 bool input_is_mkl = mkl_context.input_shape.IsMklTensor(); 279 if (!input_is_mkl && !grad_is_mkl && 280 !MklReluHelpers::ValidateSameSize(context, g, a)) 281 return; 282 Tensor* output = nullptr; 283 284 if (!input_is_mkl && !grad_is_mkl && !a.dims()) { 285 // handle the scalar case 286 const TensorShape& g_shape = g.shape(); 287 mkl_context.output_shape.SetMklTensor(false); 288 AllocateOutputSetMklShape(context, 0, &output, g_shape, 289 mkl_context.output_shape); 290 291 void* out_o = static_cast<void*>(output->flat<T>().data()); 292 (static_cast<T*>(out_o))[0] = 293 (static_cast<T*>(user_g))[0] * ((static_cast<T*>(user_i))[0] > 0); 294 return; 295 } 296 297 // generate size, stride for input if input/grad is in mkl format. 298 if (grad_is_mkl || input_is_mkl) { 299 const MklShape* tmp_mkl_shape = 300 (grad_is_mkl) ? &mkl_context.grad_shape : &mkl_context.input_shape; 301 302 mkl_context.in_dims = tmp_mkl_shape->GetDimension(); 303 mkl_context.in_strides = new size_t[mkl_context.in_dims]; 304 mkl_context.in_sizes = new size_t[mkl_context.in_dims]; 305 for (int i = 0; i < mkl_context.in_dims; i++) { 306 mkl_context.in_sizes[i] = tmp_mkl_shape->GetSizes()[i]; 307 mkl_context.in_strides[i] = tmp_mkl_shape->GetStrides()[i]; 308 } 309 } else { 310 mkl_context.in_dims = g.dims(); 311 mkl_context.in_strides = new size_t[mkl_context.in_dims]; 312 mkl_context.in_sizes = new size_t[mkl_context.in_dims]; 313 314 for (int i = 0; i < mkl_context.in_dims; i++) { 315 mkl_context.in_sizes[i] = g.dim_size((mkl_context.in_dims - 1) - i); 316 } 317 mkl_context.in_strides[0] = 1; 318 for (int i = 1; i < mkl_context.in_dims; i++) { 319 mkl_context.in_strides[i] = 320 mkl_context.in_strides[i - 1] * mkl_context.in_sizes[i - 1]; 321 } 322 } 323 324 mkl_context.MklCreateInputLayouts(context); 325 float negative_slope = 0.0; 326 CHECK_EQ(dnnReLUCreateBackward_F32(&mkl_context.prim_relu_bwd, NULL, 327 mkl_context.lt_grad, mkl_context.lt_grad, 328 negative_slope), 329 E_SUCCESS); 330 Tensor mkl_tmp_input_buf_tensor; 331 mkl_context.MklPrepareReluGradInputs(context, &mkl_tmp_input_buf_tensor); 332 333 if (input_is_mkl || 334 grad_is_mkl) { /*if grad or input are mkl leave it in mkl*/ 335 TensorShape tf_shape; 336 mkl_context.output_shape.SetMklTensor(true); 337 mkl_context.output_shape.SetMklLayout(mkl_context.prim_relu_bwd, 338 dnnResourceDiffSrc); 339 mkl_context.output_shape.SetTfLayout( 340 mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides); 341 // if input_is_mkl or grad_is_mkl, then we copy strides and sizes from mkl 342 // shape of one that is in mkl layout. 343 if (grad_is_mkl == true) { 344 mkl_context.output_shape.SetTfDimOrder( 345 mkl_context.in_dims, mkl_context.grad_shape.GetTfToMklDimMap()); 346 } else { 347 mkl_context.output_shape.SetTfDimOrder( 348 mkl_context.in_dims, mkl_context.input_shape.GetTfToMklDimMap()); 349 } 350 351 tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>( 352 mkl_context.output_shape.GetMklLayout())) / 353 sizeof(T)); 354 AllocateOutputSetMklShape(context, 0, &output, tf_shape, 355 mkl_context.output_shape); 356 } else { 357 const TensorShape& o_shape = g.shape(); 358 mkl_context.output_shape.SetMklTensor(false); 359 AllocateOutputSetMklShape(context, 0, &output, o_shape, 360 mkl_context.output_shape); 361 } 362 363 mkl_context.relu_res[dnnResourceDiffSrc] = 364 static_cast<void*>(output->flat<T>().data()); 365 366 CHECK_EQ(dnnExecute_F32(mkl_context.prim_relu_bwd, mkl_context.relu_res), 367 E_SUCCESS); 368 mkl_context.MklCleanup(); 369 } 370 371 #else // INTEL_MKL_ML 372 373 template <typename Device, typename T, algorithm alg_kind> 374 class MklReluOpBase : public OpKernel { 375 public: 376 ~MklReluOpBase() {} 377 378 explicit MklReluOpBase(OpKernelConstruction* context) : OpKernel(context) {} 379 380 virtual void Compute_Scalar(OpKernelContext* context) = 0; 381 382 void Compute(OpKernelContext* context) override { 383 try { 384 auto cpu_engine = engine(engine::cpu, 0); 385 const size_t src_index = 0; // index of src input tensor 386 const size_t dst_index = 0; // index of dst output tensor 387 const Tensor& src_tensor = MklGetInput(context, src_index); 388 MklDnnShape dnn_shape_src; 389 GetMklShape(context, src_index, &dnn_shape_src); 390 391 Tensor* dst_tensor = nullptr; 392 if (src_tensor.dims() == 0) { 393 Compute_Scalar(context); 394 return; 395 } 396 397 // Create relu primitive. 398 MklDnnData<T> src(&cpu_engine); 399 MklDnnData<T> dst(&cpu_engine); 400 401 // Set DNN primitive - src 402 memory::desc src_md({}, memory::data_undef, memory::format_undef); 403 if (dnn_shape_src.IsMklTensor()) { 404 src_md = dnn_shape_src.GetMklLayout(); 405 } else { 406 auto src_dims = TFShapeToMklDnnDims(src_tensor.shape()); 407 auto src_strides = CalculateTFStrides(src_dims); 408 // Create blocked memory descriptor 409 src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides); 410 } 411 src.SetUsrMem(src_md, &src_tensor); 412 413 T alpha = 0, beta = 0; 414 std::shared_ptr<relu_forward::primitive_desc> relu_fwd_pd; 415 auto relu_fwd_desc = relu_forward::desc( 416 prop_kind::forward_training, 417 // Operator memory descriptor is same as user memory descriptor. 418 alg_kind, src.GetUsrMemDesc(), alpha, beta); 419 relu_fwd_pd.reset( 420 new relu_forward::primitive_desc(relu_fwd_desc, cpu_engine)); 421 422 // allocate dst tensor 423 MklDnnShape dnn_shape_dst; 424 TensorShape tf_shape_dst; 425 if (dnn_shape_src.IsMklTensor()) { 426 dnn_shape_dst.SetMklTensor(true); 427 auto dst_pd = relu_fwd_pd->dst_primitive_desc(); 428 dnn_shape_dst.SetMklLayout(&dst_pd); 429 dnn_shape_dst.SetElemType(MklDnnType<T>()); 430 dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), 431 dnn_shape_src.GetSizesAsMklDnnDims(), 432 dnn_shape_src.GetTfDataFormat()); 433 tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); 434 } else { 435 dnn_shape_dst.SetMklTensor(false); 436 tf_shape_dst = src_tensor.shape(); 437 } 438 AllocateOutputSetMklShape(context, dst_index, &dst_tensor, tf_shape_dst, 439 dnn_shape_dst); 440 441 // Destination memory descriptor is same as source memory descriptor. 442 auto dst_md = src_md; 443 dst.SetUsrMem(dst_md, dst_tensor); 444 445 // execute net 446 std::vector<primitive> net; 447 auto relu_fwd = 448 relu_forward(*relu_fwd_pd, src.GetOpMem(), dst.GetOpMem()); 449 net.push_back(relu_fwd); 450 stream(stream::kind::eager).submit(net).wait(); 451 } catch (mkldnn::error& e) { 452 string error_msg = "Status: " + std::to_string(e.status) + 453 ", message: " + string(e.message) + ", in file " + 454 string(__FILE__) + ":" + std::to_string(__LINE__); 455 OP_REQUIRES_OK( 456 context, 457 errors::Aborted("Operation received an exception:", error_msg)); 458 } 459 } 460 }; 461 462 template <typename Device, typename T, algorithm alg_kind> 463 class MklReluGradOpBase : public OpKernel { 464 public: 465 ~MklReluGradOpBase() {} 466 467 explicit MklReluGradOpBase(OpKernelConstruction* context) 468 : OpKernel(context) {} 469 470 virtual void Compute_Scalar(OpKernelContext* context) = 0; 471 472 void Compute(OpKernelContext* context) { 473 try { 474 auto cpu_engine = engine(engine::cpu, 0); 475 MklDnnData<T> src(&cpu_engine); 476 MklDnnData<T> diff_dst(&cpu_engine); 477 MklDnnData<T> diff_src(&cpu_engine); 478 479 const size_t diff_dst_index = 0; // index of diff_dst input tensor 480 const size_t src_index = 1; // index of src input tensor 481 const size_t diff_src_index = 0; // index of diff_src output tensor 482 483 const Tensor& src_tensor = MklGetInput(context, src_index); 484 const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); 485 Tensor* diff_src_tensor = nullptr; 486 487 MklDnnShape dnn_shape_src, dnn_shape_diff_dst; 488 GetMklShape(context, src_index, &dnn_shape_src); 489 GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); 490 491 int src_dims_size = src_tensor.dims(); 492 if (src_dims_size == 0) { 493 Compute_Scalar(context); 494 return; 495 } 496 497 // Set DNN primitives for src & diff_dst 498 memory::desc src_md({}, memory::data_undef, memory::format_undef); 499 memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef); 500 501 // For creating Sum primitive, we need to ensure that all inputs are in 502 // same format. What that means is if we have a mixed input case - where 503 // one input is in Tensorflow format and one input is in MKL format -, 504 // then we need to ensure that all inputs are in same format for 505 // primitive construction. For performance reason, we say that all inputs 506 // are in MKL format in such case, and insert reorder for input that is 507 // in Tensorflow format into MKL format. On the other hand, if both the 508 // inputs are in MKL format or both are in Tensorflow format, then we 509 // dont need reorder. 510 if (!dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) { 511 // If both the inputs are in Tensorflow format, we create blocked memory 512 // descriptor. 513 auto src_dims = TFShapeToMklDnnDims(src_tensor.shape()); 514 auto src_strides = CalculateTFStrides(src_dims); 515 src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides); 516 diff_dst_md = src_md; 517 } else if (dnn_shape_src.IsMklTensor() && 518 !dnn_shape_diff_dst.IsMklTensor()) { 519 // If one input is in MKL format and other is in Tensorflow, then 520 // create respective descriptors describing the actual case. For input 521 // in Mkl format, we just get Mkl layout from MklDnnShape. For input in 522 // Tensorflow format, we create memory descriptor using data format. 523 src_md = dnn_shape_src.GetMklLayout(); 524 525 memory::format src_mkl_data_format = dnn_shape_src.GetTfDataFormat(); 526 auto src_tf_data_format = 527 MklDnnDataFormatToTFDataFormat(src_mkl_data_format); 528 auto diff_dst_dims = TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), 529 src_tf_data_format); 530 diff_dst_md = 531 memory::desc(diff_dst_dims, MklDnnType<T>(), src_mkl_data_format); 532 } else if (!dnn_shape_src.IsMklTensor() && 533 dnn_shape_diff_dst.IsMklTensor()) { 534 // Same comment as above. 535 diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); 536 537 memory::format diff_dst_mkl_data_format = 538 dnn_shape_diff_dst.GetTfDataFormat(); 539 auto diff_dst_tf_data_format = 540 MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format); 541 auto src_dims = TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), 542 diff_dst_tf_data_format); 543 src_md = 544 memory::desc(src_dims, MklDnnType<T>(), diff_dst_mkl_data_format); 545 } else { 546 // If both the inputs are in MKL format, we use Mkl layout of the input 547 // tensors. 548 src_md = dnn_shape_src.GetMklLayout(); 549 diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); 550 } 551 552 src.SetUsrMem(src_md, &src_tensor); 553 diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); 554 555 // As per comment above, we tell MKLDNN that both the inputs are in same 556 // format. So we set common memory descriptor in MKL format, if any of the 557 // inputs are in MKL format. Let's get memory descriptor that we will use 558 // for both the inputs. 559 memory::desc common_md({}, memory::data_undef, memory::format_undef); 560 if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) { 561 common_md = dnn_shape_src.IsMklTensor() ? src_md : diff_dst_md; 562 } else { 563 // Since both the inputs are in Tensorflow format, and have 564 // same shape, we can get memory descriptor from any input. 565 common_md = src_md; 566 } 567 568 T alpha = 0, beta = 0; 569 std::shared_ptr<relu_forward::primitive_desc> relu_fwd_pd; 570 auto relu_fwd_desc = relu_forward::desc(prop_kind::forward_training, 571 alg_kind, src_md, alpha, beta); 572 relu_fwd_pd.reset( 573 new relu_forward::primitive_desc(relu_fwd_desc, cpu_engine)); 574 auto relu_bwd_desc = 575 relu_backward::desc(alg_kind, common_md, common_md, alpha, beta); 576 auto relu_bwd_pd = relu_backward::primitive_desc( 577 relu_bwd_desc, cpu_engine, *relu_fwd_pd); 578 579 // allocate diff_src tensor 580 MklDnnShape dnn_shape_diff_src; 581 TensorShape tf_shape_diff_src; 582 if (dnn_shape_src.IsMklTensor()) { 583 dnn_shape_diff_src.SetMklTensor(true); 584 auto diff_src_pd = relu_bwd_pd.diff_src_primitive_desc(); 585 dnn_shape_diff_src.SetMklLayout(&diff_src_pd); 586 dnn_shape_diff_src.SetElemType(MklDnnType<T>()); 587 dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), 588 dnn_shape_src.GetSizesAsMklDnnDims(), 589 dnn_shape_src.GetTfDataFormat()); 590 tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); 591 } else { 592 dnn_shape_diff_src.SetMklTensor(false); 593 tf_shape_diff_src = src_tensor.shape(); 594 } 595 AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, 596 tf_shape_diff_src, dnn_shape_diff_src); 597 598 // diff_src memory descriptor is same as memory descriptor for both 599 // inputs. 600 diff_src.SetUsrMem(common_md, diff_src_tensor); 601 602 PrepareAndExecuteNet(relu_bwd_pd, &src, &diff_src, &diff_dst); 603 } catch (mkldnn::error& e) { 604 string error_msg = "Status: " + std::to_string(e.status) + 605 ", message: " + string(e.message) + ", in file " + 606 string(__FILE__) + ":" + std::to_string(__LINE__); 607 OP_REQUIRES_OK( 608 context, 609 errors::Aborted("Operation received an exception:", error_msg)); 610 } 611 } 612 613 void PrepareAndExecuteNet(const relu_backward::primitive_desc& relu_prim_desc, 614 MklDnnData<T>* src, MklDnnData<T>* diff_src, 615 MklDnnData<T>* diff_dst) { 616 std::vector<primitive> net; 617 618 // Check if we need to reorder original input tensors into common_md layout 619 // that we set for primitive creation. diff_src_primitive_desc is same as 620 // common_md. 621 src->CheckReorderToOpMem(relu_prim_desc.diff_src_primitive_desc(), &net); 622 diff_dst->CheckReorderToOpMem(relu_prim_desc.diff_src_primitive_desc(), 623 &net); 624 625 net.push_back(relu_backward(relu_prim_desc, src->GetOpMem(), 626 diff_dst->GetOpMem(), diff_src->GetOpMem())); 627 stream(stream::kind::eager).submit(net).wait(); 628 } 629 }; 630 631 template <typename Device, typename T> 632 class MklReluOp : public MklReluOpBase<Device, T, eltwise_relu> { 633 public: 634 ~MklReluOp() {} 635 636 explicit MklReluOp(OpKernelConstruction* context) 637 : MklReluOpBase<Device, T, eltwise_relu>(context) {} 638 639 virtual void Compute_Scalar(OpKernelContext* context) { 640 const size_t src_index = 0; // index of src input tensor 641 const size_t dst_index = 0; // index of dst output tensor 642 const Tensor& src_tensor = MklGetInput(context, src_index); 643 MklDnnShape dnn_shape_src; 644 GetMklShape(context, src_index, &dnn_shape_src); 645 646 Tensor* dst_tensor = nullptr; 647 void* user_i = 648 static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data())); 649 MklDnnShape dnn_shape_dst; 650 dnn_shape_dst.SetMklTensor(false); 651 AllocateOutputSetMklShape(context, dst_index, &dst_tensor, 652 src_tensor.shape(), dnn_shape_dst); 653 void* out_o = static_cast<void*>(dst_tensor->flat<T>().data()); 654 (static_cast<T*>(out_o))[0] = 655 std::max((static_cast<T*>(user_i))[0], static_cast<T>(0)); 656 return; 657 } 658 }; 659 660 template <typename Device, typename T> 661 class MklReluGradOp : public MklReluGradOpBase<Device, T, eltwise_relu> { 662 public: 663 ~MklReluGradOp() {} 664 665 explicit MklReluGradOp(OpKernelConstruction* context) 666 : MklReluGradOpBase<Device, T, eltwise_relu>(context) {} 667 668 virtual void Compute_Scalar(OpKernelContext* context) { 669 const size_t diff_dst_index = 0; // index of diff_dst input tensor 670 const size_t src_index = 1; // index of src input tensor 671 const size_t diff_src_index = 0; // index of diff_src output tensor 672 const Tensor& src_tensor = MklGetInput(context, src_index); 673 const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); 674 Tensor* diff_src_tensor = nullptr; 675 676 MklDnnShape dnn_shape_diff_dst; 677 GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); 678 679 MklDnnShape dnn_shape_diff_src; 680 dnn_shape_diff_src.SetMklTensor(false); 681 AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, 682 diff_dst_tensor.shape(), dnn_shape_diff_src); 683 void* out_o = static_cast<void*>(diff_src_tensor->flat<T>().data()); 684 void* user_i = 685 static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data())); 686 void* user_g = 687 static_cast<void*>(const_cast<T*>(diff_dst_tensor.flat<T>().data())); 688 (static_cast<T*>(out_o))[0] = 689 (static_cast<T*>(user_g))[0] * ((static_cast<T*>(user_i))[0] > 0); 690 return; 691 } 692 }; 693 694 template <typename Device, typename T> 695 class MklEluOp : public MklReluOpBase<Device, T, eltwise_elu> { 696 public: 697 ~MklEluOp() {} 698 699 explicit MklEluOp(OpKernelConstruction* context) 700 : MklReluOpBase<Device, T, eltwise_elu>(context) {} 701 702 virtual void Compute_Scalar(OpKernelContext* context) { 703 const size_t src_index = 0; // index of src input tensor 704 const size_t dst_index = 0; // index of dst output tensor 705 const Tensor& src_tensor = MklGetInput(context, src_index); 706 MklDnnShape dnn_shape_src; 707 GetMklShape(context, src_index, &dnn_shape_src); 708 709 Tensor* dst_tensor = nullptr; 710 void* user_i = 711 static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data())); 712 MklDnnShape dnn_shape_dst; 713 dnn_shape_dst.SetMklTensor(false); 714 AllocateOutputSetMklShape(context, dst_index, &dst_tensor, 715 src_tensor.shape(), dnn_shape_dst); 716 void* out_o = static_cast<void*>(dst_tensor->flat<T>().data()); 717 // return exp(feature) - 1 if feature > 0; feature otherwise 718 T feature = (static_cast<T*>(user_i))[0]; 719 if (feature < 0) 720 (static_cast<T*>(out_o))[0] = std::exp(feature); 721 else 722 (static_cast<T*>(out_o))[0] = feature; 723 return; 724 } 725 }; 726 727 template <typename Device, typename T> 728 class MklEluGradOp : public MklReluGradOpBase<Device, T, eltwise_elu> { 729 public: 730 ~MklEluGradOp() {} 731 732 explicit MklEluGradOp(OpKernelConstruction* context) 733 : MklReluGradOpBase<Device, T, eltwise_elu>(context) {} 734 735 virtual void Compute_Scalar(OpKernelContext* context) { 736 const size_t diff_dst_index = 0; // index of diff_dst input tensor 737 const size_t src_index = 1; // index of src input tensor 738 const size_t diff_src_index = 0; // index of diff_src output tensor 739 const Tensor& src_tensor = MklGetInput(context, src_index); 740 const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); 741 Tensor* diff_src_tensor = nullptr; 742 743 MklDnnShape dnn_shape_diff_dst; 744 GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); 745 746 MklDnnShape dnn_shape_diff_src; 747 dnn_shape_diff_src.SetMklTensor(false); 748 AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, 749 diff_dst_tensor.shape(), dnn_shape_diff_src); 750 void* out_o = static_cast<void*>(diff_src_tensor->flat<T>().data()); 751 void* user_i = 752 static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data())); 753 void* user_g = 754 static_cast<void*>(const_cast<T*>(diff_dst_tensor.flat<T>().data())); 755 // gradient of elu(x) = 1 if x > 0; elu(x) + 1 otherwise 756 T feature = (static_cast<T*>(user_i))[0]; 757 if (feature > 0) { 758 (static_cast<T*>(out_o))[0] = (static_cast<T*>(user_g))[0]; 759 } else { 760 T elu = std::exp(feature) - 1; 761 (static_cast<T*>(out_o))[0] = (static_cast<T*>(user_g))[0] * (elu + 1); 762 } 763 } 764 }; 765 766 template <typename Device, typename T> 767 class MklTanhOp : public MklReluOpBase<Device, T, eltwise_tanh> { 768 public: 769 ~MklTanhOp() {} 770 771 explicit MklTanhOp(OpKernelConstruction* context) 772 : MklReluOpBase<Device, T, eltwise_tanh>(context) {} 773 774 virtual void Compute_Scalar(OpKernelContext* context) { 775 const size_t src_index = 0; // index of src input tensor 776 const size_t dst_index = 0; // index of dst output tensor 777 const Tensor& src_tensor = MklGetInput(context, src_index); 778 MklDnnShape dnn_shape_src; 779 GetMklShape(context, src_index, &dnn_shape_src); 780 781 Tensor* dst_tensor = nullptr; 782 void* user_i = 783 static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data())); 784 MklDnnShape dnn_shape_dst; 785 dnn_shape_dst.SetMklTensor(false); 786 AllocateOutputSetMklShape(context, dst_index, &dst_tensor, 787 src_tensor.shape(), dnn_shape_dst); 788 void* out_o = static_cast<void*>(dst_tensor->flat<T>().data()); 789 // tanh(x) = (e^x - e^(-x))/ (e^x + e^(-x)) 790 T feature = (static_cast<T*>(user_i))[0]; 791 T e1 = std::exp(feature); 792 T e2 = std::exp(-feature); 793 (static_cast<T*>(out_o))[0] = (e1 - e2) / (e1 + e2); 794 return; 795 } 796 }; 797 798 template <typename Device, typename T> 799 class MklTanhGradOp : public MklReluGradOpBase<Device, T, eltwise_tanh> { 800 public: 801 ~MklTanhGradOp() {} 802 803 explicit MklTanhGradOp(OpKernelConstruction* context) 804 : MklReluGradOpBase<Device, T, eltwise_tanh>(context) {} 805 806 virtual void Compute_Scalar(OpKernelContext* context) { 807 const size_t diff_dst_index = 0; // index of diff_dst input tensor 808 const size_t src_index = 1; // index of src input tensor 809 const size_t diff_src_index = 0; // index of diff_src output tensor 810 const Tensor& src_tensor = MklGetInput(context, src_index); 811 const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); 812 Tensor* diff_src_tensor = nullptr; 813 814 MklDnnShape dnn_shape_diff_dst; 815 GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); 816 817 MklDnnShape dnn_shape_diff_src; 818 dnn_shape_diff_src.SetMklTensor(false); 819 AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, 820 diff_dst_tensor.shape(), dnn_shape_diff_src); 821 void* out_o = static_cast<void*>(diff_src_tensor->flat<T>().data()); 822 void* user_i = 823 static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data())); 824 // gradient of tanh(x) = 1 - tanh(x)^2 825 T feature = (static_cast<T*>(user_i))[0]; 826 T e1 = std::exp(feature); 827 T e2 = std::exp(-feature); 828 T tanh = (e1 - e2) / (e1 + e2); 829 void* user_g = 830 static_cast<void*>(const_cast<T*>(diff_dst_tensor.flat<T>().data())); 831 (static_cast<T*>(out_o))[0] = 832 (static_cast<T*>(user_g))[0] * (1 - tanh * tanh); 833 } 834 }; 835 836 #endif 837 838 // register dnn kernels for supported operations and supported types 839 #define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \ 840 REGISTER_KERNEL_BUILDER(Name("_MklRelu") \ 841 .Device(DEVICE_CPU) \ 842 .TypeConstraint<type>("T") \ 843 .Label(mkl_op_registry::kMklOpLabel), \ 844 MklReluOp<CPUDevice, type>); \ 845 REGISTER_KERNEL_BUILDER(Name("_MklReluGrad") \ 846 .Device(DEVICE_CPU) \ 847 .TypeConstraint<type>("T") \ 848 .Label(mkl_op_registry::kMklOpLabel), \ 849 MklReluGradOp<CPUDevice, type>); 850 TF_CALL_float(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES); 851 852 #ifndef INTEL_MKL_ML 853 854 // register dnn kernels for supported operations and supported types 855 #define REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES(type) \ 856 REGISTER_KERNEL_BUILDER(Name("_MklElu") \ 857 .Device(DEVICE_CPU) \ 858 .TypeConstraint<type>("T") \ 859 .Label(mkl_op_registry::kMklOpLabel), \ 860 MklEluOp<CPUDevice, type>); \ 861 REGISTER_KERNEL_BUILDER(Name("_MklEluGrad") \ 862 .Device(DEVICE_CPU) \ 863 .TypeConstraint<type>("T") \ 864 .Label(mkl_op_registry::kMklOpLabel), \ 865 MklEluGradOp<CPUDevice, type>); 866 TF_CALL_float(REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES); 867 868 #define REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES(type) \ 869 REGISTER_KERNEL_BUILDER(Name("_MklTanh") \ 870 .Device(DEVICE_CPU) \ 871 .TypeConstraint<type>("T") \ 872 .Label(mkl_op_registry::kMklOpLabel), \ 873 MklTanhOp<CPUDevice, type>); \ 874 REGISTER_KERNEL_BUILDER(Name("_MklTanhGrad") \ 875 .Device(DEVICE_CPU) \ 876 .TypeConstraint<type>("T") \ 877 .Label(mkl_op_registry::kMklOpLabel), \ 878 MklTanhGradOp<CPUDevice, type>); 879 TF_CALL_float(REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES); 880 881 #endif 882 883 } // namespace tensorflow 884 885 #endif // INTEL_MKL 886