1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // LRN = Local Response Normalization 17 // See docs in ../ops/nn_ops.cc. This opkernel uses MKL library, create MKL 18 // layout and primitives, use MKL dnn primitives to compute local 19 // response normalization 20 21 #ifdef INTEL_MKL 22 23 #define EIGEN_USE_THREADS 24 #include <vector> 25 #include "mkldnn.hpp" 26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 27 #include "tensorflow/core/framework/bounds_check.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/register_types.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/kernels/ops_util.h" 32 #include "tensorflow/core/lib/core/errors.h" 33 #include "tensorflow/core/util/mkl_util.h" 34 #include "tensorflow/core/util/tensor_format.h" 35 36 #if !defined(IS_MOBILE_PLATFORM) 37 #include "tensorflow/core/util/work_sharder.h" 38 #endif 39 40 using mkldnn::lrn_across_channels; 41 using mkldnn::lrn_backward; 42 using mkldnn::lrn_forward; 43 using mkldnn::prop_kind; 44 using mkldnn::stream; 45 46 namespace tensorflow { 47 48 namespace { 49 // Create a depth-by-depth band matrix with 1s along a swath of size (2 * 50 // depth_radius + 1) around the diagonal. 51 template <typename T> 52 void GetBandMatrix(int depth, int depth_radius, 53 Eigen::Tensor<T, 2, Eigen::RowMajor>* result) { 54 result->setZero(); 55 for (int row = 0; row < depth; ++row) { 56 const int begin = std::max<int>(0, row - depth_radius); 57 const int end = std::min<int>(depth, row + depth_radius + 1); 58 Eigen::DSizes<Eigen::DenseIndex, 2> start(row, begin); 59 Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, end - begin); 60 result->slice(start, sizes).setConstant(T(1)); 61 } 62 } 63 64 } // namespace 65 66 template <typename T> 67 class MklLRNOp : public OpKernel { 68 public: 69 ~MklLRNOp() {} 70 71 explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) { 72 int64 depth_radius64; 73 OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64)); 74 OP_REQUIRES( 75 context, 76 FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()), 77 errors::InvalidArgument("depth_radius = ", depth_radius64, 78 " larger than int max")); 79 depth_radius_ = static_cast<size_t>(depth_radius64); 80 81 OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_)); 82 OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_)); 83 OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_)); 84 workspace_enabled_ = false; 85 OP_REQUIRES_OK(context, 86 context->GetAttr("workspace_enabled", &workspace_enabled_)); 87 } 88 89 void Compute(OpKernelContext* context) override { 90 try { 91 SanityCheckInputs(context); 92 if (!context->status().ok()) return; 93 94 auto cpu_engine = engine(engine::cpu, 0); 95 const Tensor& src_tensor = MklGetInput(context, kIdxInput); 96 MklDnnShape src_dnn_shape; 97 GetMklShape(context, kIdxInput, &src_dnn_shape); 98 99 // MKL-DNN has a notion of kernel_size and not depth_radius. 100 int kernel_size = 2 * depth_radius_ + 1; 101 float new_alpha = alpha_ * kernel_size; 102 103 // if the input tensor is not an MKL Tensor, or if the last 104 // dimension is not channel, then just use Eigen. 105 // MKL only support normalization over the channel dimension. 106 if (!src_dnn_shape.IsMklTensor()) { 107 MklDefaultToEigen(context, src_tensor); 108 return; 109 } else if (!src_dnn_shape.IsMklChannelDim(src_dnn_shape.GetDimension() - 110 1)) { 111 Tensor converted_tensor = 112 ConvertMklToTF<T>(context, src_tensor, src_dnn_shape); 113 MklDefaultToEigen(context, converted_tensor); 114 return; 115 } 116 // At this point, we can assume that the src is an MklTensor 117 // and we can enable the workspace 118 workspace_enabled_ = true; 119 120 MklDnnData<T> src_dnn_data(&cpu_engine); 121 MklDnnData<T> dst_dnn_data(&cpu_engine); 122 MklDnnData<uint8> workspace_dnn_data(&cpu_engine); 123 124 TensorShape tf_output_shape = src_tensor.shape(); 125 126 memory::desc src_md = src_dnn_shape.GetCurLayout(); 127 memory::dims input_dims = src_dnn_shape.GetSizesAsMklDnnDims(); 128 129 // Create memory for user input. 130 // Since Tensorflow always performs normalization over last dimension, 131 // and MKL-DNN performs normalization over Channel, we tell MKL-DNN 132 // that input is in NHWC layout with Channel being the last dimension. 133 src_dnn_data.SetUsrMem(src_md, &src_tensor); 134 src_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc); 135 136 // output_dnn_data and workspace both have the same shape as input 137 dst_dnn_data.SetUsrMem(src_md); 138 dst_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc); 139 140 // Create LRN primitive descriptor. 141 // Tensorflow's normalization semantics is across channels. 142 // MKL-DNN also supports normalization within channel. 143 auto lrn_desc = lrn_forward::desc(prop_kind::forward, lrn_across_channels, 144 src_dnn_data.GetUsrMemDesc(), 145 kernel_size, new_alpha, beta_, bias_); 146 auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, cpu_engine); 147 148 // Allocate output_dnn_data tensor. 149 Tensor* output_tensor = nullptr; 150 memory::format input_format = src_dnn_shape.GetTfDataFormat(); 151 AllocateOutputTensor(context, lrn_prim_desc, input_dims, input_format, 152 &output_tensor); 153 OP_REQUIRES_OK(context, context->status()); 154 CHECK_NOTNULL(output_tensor); 155 dst_dnn_data.SetUsrMemDataHandle(output_tensor); 156 157 // Handle workspace required for MKL-DNN. 158 AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data); 159 OP_REQUIRES_OK(context, context->status()); 160 161 PrepareAndExecuteNet(lrn_prim_desc, &src_dnn_data, &dst_dnn_data, 162 &workspace_dnn_data); 163 } catch (mkldnn::error& e) { 164 string error_msg = "Status: " + std::to_string(e.status) + 165 ", message: " + string(e.message) + ", in file " + 166 string(__FILE__) + ":" + std::to_string(__LINE__); 167 OP_REQUIRES_OK( 168 context, 169 errors::Aborted("Operation received an exception:", error_msg)); 170 } 171 } 172 173 private: 174 void PrepareAndExecuteNet(const lrn_forward::primitive_desc& lrn_fwd_desc, 175 MklDnnData<T>* src_dnn_data, 176 MklDnnData<T>* dst_dnn_data, 177 MklDnnData<uint8>* wksp_dnn_data = nullptr) { 178 // Check for input reorder 179 src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc()); 180 181 // Create pooling primitive and add it to net 182 std::vector<primitive> net; 183 if (wksp_dnn_data != nullptr) { 184 net.push_back(lrn_forward(lrn_fwd_desc, src_dnn_data->GetOpMem(), 185 wksp_dnn_data->GetOpMem(), 186 dst_dnn_data->GetOpMem())); 187 } else { 188 net.push_back(lrn_forward(lrn_fwd_desc, src_dnn_data->GetOpMem(), 189 dst_dnn_data->GetOpMem())); 190 } 191 stream(stream::kind::eager).submit(net).wait(); 192 } 193 194 void AllocateOutputTensor( 195 OpKernelContext* context, 196 const lrn_forward::primitive_desc& lrn_fwd_prim_desc, 197 const memory::dims output_dims_mkl_order, 198 const memory::format& output_tf_format, Tensor** output_tensor) { 199 CHECK_NOTNULL(output_tensor); 200 memory::primitive_desc dst_pd = lrn_fwd_prim_desc.dst_primitive_desc(); 201 202 MklDnnShape output_mkl_shape; 203 // We only handle the case when the inputs and output are in Mkl format 204 // Any other case is handled by Eigen 205 output_mkl_shape.SetMklTensor(true); 206 output_mkl_shape.SetMklLayout(&dst_pd); 207 output_mkl_shape.SetElemType(MklDnnType<T>()); 208 output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), 209 output_dims_mkl_order, output_tf_format); 210 TensorShape output_tf_shape; 211 // only allocate enough space for the elements we need. 212 size_t num_bytes = dst_pd.get_size(); 213 CHECK_EQ(num_bytes % sizeof(T), 0); 214 output_tf_shape.AddDim(num_bytes / sizeof(T)); 215 AllocateOutputSetMklShape(context, kIdxOutput, output_tensor, 216 output_tf_shape, output_mkl_shape); 217 } 218 219 // Fallback implementation - Taken from lrn_op.cc 220 // TODO(inteltf) Check if we can use EigenLRNOp directly instead of making a 221 // copy. 222 void MklDefaultToEigen(OpKernelContext* context, const Tensor& input) { 223 const int batch = static_cast<int>(input.dim_size(0)); 224 const int rows = static_cast<int>(input.dim_size(1)); 225 const int cols = static_cast<int>(input.dim_size(2)); 226 const int depth = static_cast<int>(input.dim_size(3)); 227 const int nodes = cols * rows; 228 229 auto in_shaped = input.shaped<T, 2>({nodes * batch, depth}); 230 // Multiplying the input with the band matrix has the effect of reducing 231 // the 232 // correct patch along the depth. 233 Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth); 234 GetBandMatrix<T>(depth, depth_radius_, &multiplier); 235 236 Tensor* output_dnn_data = nullptr; 237 MklDnnShape mkl_output_mkl_shape; 238 mkl_output_mkl_shape.SetMklTensor(false); 239 mkl_output_mkl_shape.SetDimensions(4); 240 AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data, 241 input.shape(), mkl_output_mkl_shape); 242 CHECK_NOTNULL(output_dnn_data); 243 244 Tensor* workspace_tensor = nullptr; 245 MklDnnShape workspace_mkl_shape; 246 workspace_mkl_shape.SetMklTensor(false); 247 TensorShape workspace_tf_shape; 248 workspace_tf_shape.AddDim(0); 249 AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor, 250 workspace_tf_shape, workspace_mkl_shape); 251 CHECK_NOTNULL(workspace_tensor); 252 253 auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth}); 254 Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}}; 255 auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_; 256 if (beta_ == T(1)) { 257 out_shaped.device(context->eigen_cpu_device()) = 258 in_shaped * tmp.inverse(); 259 } else if (beta_ == T(0.5)) { 260 out_shaped.device(context->eigen_cpu_device()) = in_shaped * tmp.rsqrt(); 261 } else { 262 out_shaped.device(context->eigen_cpu_device()) = 263 in_shaped * (tmp.log() * -beta_).exp(); 264 } 265 } 266 267 void AllocateWorkspaceTensor( 268 OpKernelContext* context, 269 const lrn_forward::primitive_desc& lrn_fwd_prim_desc, 270 MklDnnData<uint8>* dnn_data_wksp) { 271 CHECK_NOTNULL(dnn_data_wksp); 272 Tensor* workspace_tensor = nullptr; 273 memory::primitive_desc workspace_pd = 274 lrn_fwd_prim_desc.workspace_primitive_desc(); 275 size_t workspace_bytes = workspace_pd.get_size(); 276 MklDnnShape workspace_mkl_shape; 277 // the workspace tensor is a uint8 tensor that has 278 // exactly the number of bytes necessary 279 workspace_mkl_shape.SetMklTensor(false); 280 TensorShape workspace_tf_shape; 281 workspace_tf_shape.AddDim(workspace_bytes); 282 AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor, 283 workspace_tf_shape, workspace_mkl_shape); 284 CHECK_NOTNULL(workspace_tensor); 285 dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor); 286 } 287 288 void SanityCheckInputs(OpKernelContext* context) { 289 const Tensor& src_tensor = MklGetInput(context, kIdxInput); 290 MklDnnShape src_dnn_shape; 291 GetMklShape(context, kIdxInput, &src_dnn_shape); 292 if (src_dnn_shape.IsMklTensor()) { 293 OP_REQUIRES(context, src_dnn_shape.GetDimension() == 4, 294 errors::InvalidArgument("input must be 4-dimensional")); 295 OP_REQUIRES(context, 296 FastBoundsCheck(src_tensor.NumElements(), 297 std::numeric_limits<int>::max()), 298 errors::InvalidArgument("argument to LRN too large")); 299 } else { 300 OP_REQUIRES(context, src_tensor.dims() == 4, 301 errors::InvalidArgument("input must be 4-dimensional")); 302 OP_REQUIRES(context, 303 FastBoundsCheck(src_tensor.NumElements(), 304 std::numeric_limits<int>::max()), 305 errors::InvalidArgument("argument to LRN too large")); 306 } 307 } 308 const int kIdxInput = 0, kIdxOutput = 0, kIdxWorkspace = 1; 309 310 typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair; 311 bool workspace_enabled_; 312 int depth_radius_; 313 float bias_; 314 float alpha_; 315 float beta_; 316 }; 317 318 template <typename T> 319 class MklLRNGradOp : public OpKernel { 320 public: 321 explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) { 322 int64 depth_radius64; 323 OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64)); 324 OP_REQUIRES( 325 context, 326 FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()), 327 errors::InvalidArgument("depth_radius = ", depth_radius64, 328 " larger than int max")); 329 depth_radius_ = static_cast<int>(depth_radius64); 330 OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_)); 331 OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_)); 332 OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_)); 333 workspace_enabled_ = false; 334 OP_REQUIRES_OK(context, 335 context->GetAttr("workspace_enabled", &workspace_enabled_)); 336 } 337 338 void Compute(OpKernelContext* context) override { 339 try { 340 SanityCheckInputs(context); 341 if (!context->status().ok()) return; 342 343 auto cpu_engine = engine(engine::cpu, 0); 344 MklDnnData<T> input_grad_dnn_data(&cpu_engine); 345 MklDnnData<T> orig_input_dnn_data(&cpu_engine); 346 MklDnnData<T> orig_output_dnn_data(&cpu_engine); 347 MklDnnData<T> output_dnn_data(&cpu_engine); 348 349 MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape, 350 orig_output_dnn_shape; 351 GetMklShape(context, kIdxGradient, &input_grad_dnn_shape); 352 GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape); 353 GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape); 354 355 // We only use MKLDNN if all of the necessary inputs are present 356 // in mkldnn format, and Channel is the last dimension 357 bool can_use_mkldnn = workspace_enabled_ && 358 input_grad_dnn_shape.IsMklTensor() && 359 orig_input_dnn_shape.IsMklTensor() && 360 orig_output_dnn_shape.IsMklTensor() && 361 input_grad_dnn_shape.IsMklChannelDim( 362 input_grad_dnn_shape.GetDimension() - 1) && 363 orig_input_dnn_shape.IsMklChannelDim( 364 orig_input_dnn_shape.GetDimension() - 1) && 365 orig_output_dnn_shape.IsMklChannelDim( 366 orig_output_dnn_shape.GetDimension() - 1); 367 368 if (!can_use_mkldnn) { 369 // Fallback to eigen 370 MklDefaultToEigen(context); 371 return; 372 } 373 // At this point, we have the all clear to use MklDnn constructs 374 // Naming: diff_dst is input_gradient_tensor; src is orig_input_tensor. 375 const Tensor& input_grad_tensor = MklGetInput(context, kIdxGradient); 376 const Tensor& orig_input_tensor = MklGetInput(context, kIdxOrigInput); 377 378 // Get input sizes in MKL-DNN required NCHW format. 379 // LRN does not have data_format attribute. But by default it has 380 // NHWC format. 381 memory::desc original_output_md = orig_output_dnn_shape.GetCurLayout(); 382 memory::desc target_diff_dst_md = ConfigureInputGradient( 383 input_grad_tensor, input_grad_dnn_shape, &input_grad_dnn_data); 384 385 memory::desc orig_input_md = orig_input_dnn_shape.GetCurLayout(); 386 memory::dims orig_input_dims = 387 orig_input_dnn_shape.GetSizesAsMklDnnDims(); 388 orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor); 389 orig_input_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc); 390 391 // output_dnn_data has the same shape as original input 392 output_dnn_data.SetUsrMem(orig_input_md); 393 output_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc); 394 395 // MKL-DNN has a notion of kernel_size and not depth_radius. 396 int kernel_size = 2 * depth_radius_ + 1; 397 float new_alpha = alpha_ * kernel_size; 398 399 // Create LRN backward primitive descriptor. It requires LRN forward 400 // primitive descriptor also. 401 auto lrn_fwd_desc = lrn_forward::desc( 402 prop_kind::forward, lrn_across_channels, orig_input_md, kernel_size, 403 new_alpha, beta_, bias_); 404 auto lrn_fwd_prim_desc = 405 lrn_forward::primitive_desc(lrn_fwd_desc, cpu_engine); 406 auto lrn_bwd_desc = lrn_backward::desc( 407 lrn_across_channels, original_output_md, target_diff_dst_md, 408 kernel_size, new_alpha, beta_, bias_); 409 auto lrn_bwd_prim_desc = lrn_backward::primitive_desc( 410 lrn_bwd_desc, cpu_engine, lrn_fwd_prim_desc); 411 412 Tensor* output_tensor = nullptr; 413 memory::format orig_input_format = orig_input_dnn_shape.GetTfDataFormat(); 414 AllocateOutputTensor(context, lrn_bwd_prim_desc, orig_input_dims, 415 orig_input_format, &output_tensor); 416 OP_REQUIRES_OK(context, context->status()); 417 CHECK_NOTNULL(output_tensor); 418 output_dnn_data.SetUsrMemDataHandle(output_tensor); 419 420 // Create LRN primitive and add it to the net 421 // At this point, workspace is enabled, so we don't need 422 // to check. Pass input workspace to LRN backward primitive. 423 const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace); 424 MklDnnData<uint8> workspace_dnn_data(&cpu_engine); 425 ConfigureWorkspace(workspace_tensor, 426 lrn_fwd_prim_desc.workspace_primitive_desc(), 427 &workspace_dnn_data); 428 429 PrepareAndExecuteNet( 430 lrn_bwd_prim_desc, lrn_fwd_prim_desc, &orig_input_dnn_data, 431 &input_grad_dnn_data, &output_dnn_data, 432 memory::primitive_desc(target_diff_dst_md, cpu_engine), 433 &workspace_dnn_data); 434 } catch (mkldnn::error& e) { 435 string error_msg = "Status: " + std::to_string(e.status) + 436 ", message: " + string(e.message) + ", in file " + 437 string(__FILE__) + ":" + std::to_string(__LINE__); 438 OP_REQUIRES_OK( 439 context, 440 errors::Aborted("Operation received an exception:", error_msg)); 441 } 442 } 443 444 void AllocateOutputTensor( 445 OpKernelContext* context, 446 const lrn_backward::primitive_desc& lrn_bkwd_prim_desc, 447 const memory::dims output_dims_mkl_order, 448 const memory::format& output_tf_format, Tensor** output_tensor) { 449 CHECK_NOTNULL(output_tensor); 450 memory::primitive_desc dst_pd = 451 lrn_bkwd_prim_desc.diff_src_primitive_desc(); 452 MklDnnShape output_mkl_shape; 453 454 // We assume that all outputs at this point are MKL Tensors 455 output_mkl_shape.SetMklTensor(true); 456 output_mkl_shape.SetMklLayout(&dst_pd); 457 output_mkl_shape.SetElemType(MklDnnType<T>()); 458 output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), 459 output_dims_mkl_order, output_tf_format); 460 461 TensorShape output_tf_shape; 462 size_t num_bytes = dst_pd.get_size(); 463 CHECK_EQ(num_bytes % sizeof(T), 0); 464 output_tf_shape.AddDim(num_bytes / sizeof(T)); 465 AllocateOutputSetMklShape(context, kIdxOutput, output_tensor, 466 output_tf_shape, output_mkl_shape); 467 } 468 469 memory::desc ConfigureInputGradient(const Tensor& input_grad_tensor, 470 const MklDnnShape& input_grad_dnn_shape, 471 MklDnnData<T>* input_grad_dnn_data) { 472 CHECK_NOTNULL(input_grad_dnn_data); 473 // This shouldn't be necessary at this point, but just in case 474 CHECK_EQ(input_grad_dnn_shape.IsMklTensor(), true); 475 476 memory::desc input_grad_md = input_grad_dnn_shape.GetCurLayout(); 477 memory::dims orig_input_dims = input_grad_dnn_shape.GetSizesAsMklDnnDims(); 478 input_grad_dnn_data->SetUsrMem(input_grad_md, &input_grad_tensor); 479 input_grad_dnn_data->SetOpMemDesc(orig_input_dims, memory::format::nhwc); 480 return input_grad_md; 481 } 482 483 void PrepareAndExecuteNet( 484 const lrn_backward::primitive_desc& lrn_bkwd_desc, 485 const lrn_forward::primitive_desc& lrn_fwd_desc, 486 MklDnnData<T>* src_dnn_data, MklDnnData<T>* input_gradient_diff_dst, 487 MklDnnData<T>* output_diff_src, 488 const memory::primitive_desc& target_diff_dst_pd, 489 const MklDnnData<uint8>* workspace_dnn_data = nullptr) { 490 // Check for input reordering on the diff dst input 491 input_gradient_diff_dst->CheckReorderToOpMem( 492 lrn_bkwd_desc.diff_dst_primitive_desc()); 493 494 // Check for input reordering on the original input 495 src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc()); 496 // Create pooling primitive and add it to net 497 std::vector<primitive> net; 498 if (nullptr == workspace_dnn_data) { 499 net.push_back(lrn_backward(lrn_bkwd_desc, src_dnn_data->GetOpMem(), 500 input_gradient_diff_dst->GetOpMem(), 501 output_diff_src->GetOpMem())); 502 } else { 503 net.push_back(lrn_backward(lrn_bkwd_desc, src_dnn_data->GetOpMem(), 504 input_gradient_diff_dst->GetOpMem(), 505 workspace_dnn_data->GetOpMem(), 506 output_diff_src->GetOpMem())); 507 } 508 stream(stream::kind::eager).submit(net).wait(); 509 } 510 511 void ConfigureWorkspace(const Tensor& workspace_tensor, 512 memory::primitive_desc workspace_pd, 513 MklDnnData<uint8>* workspace_dnn_data) { 514 CHECK_NOTNULL(workspace_dnn_data); 515 516 workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor); 517 } 518 519 // Fallback implementation - Taken from lrn_op.cc 520 // TODO(intelft) Check if we can use EigenLRNOp directly instead of making a 521 // copy. 522 void MklDefaultToEigen(OpKernelContext* context) { 523 Tensor input_gradient_tensor; 524 Tensor orig_input_tensor; 525 Tensor orig_output_tensor; 526 527 MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape, 528 orig_output_dnn_shape; 529 GetMklShape(context, kIdxGradient, &input_grad_dnn_shape); 530 GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape); 531 GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape); 532 533 if (input_grad_dnn_shape.IsMklTensor()) { 534 input_gradient_tensor = ConvertMklToTF<T>( 535 context, MklGetInput(context, kIdxGradient), input_grad_dnn_shape); 536 } else { 537 input_gradient_tensor = MklGetInput(context, kIdxGradient); 538 } 539 540 if (orig_input_dnn_shape.IsMklTensor()) { 541 orig_input_tensor = ConvertMklToTF<T>( 542 context, MklGetInput(context, kIdxOrigInput), orig_input_dnn_shape); 543 } else { 544 orig_input_tensor = MklGetInput(context, kIdxOrigInput); 545 } 546 547 if (orig_output_dnn_shape.IsMklTensor()) { 548 orig_output_tensor = ConvertMklToTF<T>( 549 context, MklGetInput(context, kIdxOrigOutput), orig_output_dnn_shape); 550 } else { 551 orig_output_tensor = MklGetInput(context, kIdxOrigOutput); 552 } 553 554 const int64 batch = static_cast<int64>(input_gradient_tensor.dim_size(0)); 555 const int64 rows = static_cast<int64>(input_gradient_tensor.dim_size(1)); 556 const int64 cols = static_cast<int64>(input_gradient_tensor.dim_size(2)); 557 const int64 depth = static_cast<int64>(input_gradient_tensor.dim_size(3)); 558 const auto nodes = cols * rows; 559 560 auto grads_shaped = 561 input_gradient_tensor.shaped<T, 2>({nodes * batch, depth}); 562 563 auto in_shaped = orig_input_tensor.shaped<T, 2>({nodes * batch, depth}); 564 auto activations = orig_output_tensor.shaped<T, 2>({nodes * batch, depth}); 565 566 Tensor* output_dnn_data; 567 MklDnnShape mkl_output_mkl_shape; 568 mkl_output_mkl_shape.SetMklTensor(false); 569 mkl_output_mkl_shape.SetDimensions(4); 570 AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data, 571 input_gradient_tensor.shape(), 572 mkl_output_mkl_shape); 573 574 auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth}); 575 out_shaped.setZero(); 576 auto shard = [this, activations, in_shaped, grads_shaped, out_shaped, 577 depth](int64 begin, int64 end) { 578 for (int64 i = begin; i < end; ++i) { 579 for (int64 j = 0; j < depth; ++j) { 580 int64 depth_begin = std::max<int64>(0, j - depth_radius_); 581 int64 depth_end = std::min<int64>(depth, j + depth_radius_ + 1); 582 583 T norm(0); 584 for (int64 k = depth_begin; k < depth_end; ++k) { 585 norm += in_shaped(i, k) * in_shaped(i, k); 586 } 587 norm = alpha_ * norm + bias_; 588 DCHECK_GT(norm, T(1e-6)); 589 for (int64 k = depth_begin; k < depth_end; ++k) { 590 T dyi = T(-2) * alpha_ * beta_ * in_shaped(i, k) * 591 activations(i, j) / norm; 592 if (k == j) { 593 dyi += Eigen::numext::pow(norm, -beta_); 594 } 595 dyi *= grads_shaped(i, j); 596 const_cast<typename TTypes<T, 2>::Tensor&>(out_shaped)(i, k) += dyi; 597 } 598 } 599 } 600 }; 601 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 602 Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch, 603 depth * depth, shard); 604 } 605 606 void SanityCheckInputs(OpKernelContext* context) { 607 const Tensor& input_gradient_tensor = MklGetInput(context, kIdxGradient); 608 const Tensor& orig_input_tensor = MklGetInput(context, kIdxOrigInput); 609 const Tensor& orig_output_tensor = MklGetInput(context, kIdxOrigOutput); 610 const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace); 611 MklDnnShape in_grads_dnn_shape, in_image_dnn_shape, out_image_dnn_shape, 612 workspace_dnn_shape; 613 GetMklShape(context, kIdxGradient, &in_grads_dnn_shape); 614 GetMklShape(context, kIdxOrigInput, &in_image_dnn_shape); 615 GetMklShape(context, kIdxOrigOutput, &out_image_dnn_shape); 616 GetMklShape(context, kIdxWorkspace, &workspace_dnn_shape); 617 if (in_grads_dnn_shape.IsMklTensor()) { 618 OP_REQUIRES(context, in_grads_dnn_shape.GetDimension() == 4, 619 errors::InvalidArgument("Input gradient must be " 620 "4-dimensional")); 621 } else { 622 OP_REQUIRES( 623 context, input_gradient_tensor.dims() == 4, 624 errors::InvalidArgument("input gradient must be 4-dimensional")); 625 } 626 627 if (in_image_dnn_shape.IsMklTensor()) { 628 OP_REQUIRES(context, in_image_dnn_shape.GetDimension() == 4, 629 errors::InvalidArgument("input images must be " 630 "4-dimensional")); 631 } else { 632 OP_REQUIRES(context, orig_input_tensor.dims() == 4, 633 errors::InvalidArgument("input images must be " 634 "4-dimensional")); 635 } 636 637 if (out_image_dnn_shape.IsMklTensor()) { 638 OP_REQUIRES(context, out_image_dnn_shape.GetDimension() == 4, 639 errors::InvalidArgument("Output image must be " 640 "4-dimensional")); 641 } else { 642 OP_REQUIRES( 643 context, orig_output_tensor.dims() == 4, 644 errors::InvalidArgument("Output image must be 4-dimensional")); 645 } 646 647 if (workspace_enabled_) { 648 if (workspace_dnn_shape.IsMklTensor()) { 649 OP_REQUIRES( 650 context, workspace_dnn_shape.IsMklTensor() == false, 651 errors::InvalidArgument("Workspace should not be MKL Tensor.")); 652 } else { 653 OP_REQUIRES(context, workspace_tensor.dims() == 1, 654 errors::InvalidArgument("Workspace must be 1-dimensional")); 655 } 656 } 657 } 658 659 // Input("input_grads: T") 660 // Input("input_image: T") 661 // Input("output_image: T") 662 // Input("workspace: uint8") 663 const int kIdxGradient = 0, kIdxOrigInput = 1, kIdxOrigOutput = 2, 664 kIdxWorkspace = 3, kIdxOutput = 0; 665 666 typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair; 667 bool workspace_enabled_; 668 int depth_radius_; 669 float bias_; 670 float alpha_; 671 float beta_; 672 }; 673 674 #define REGISTER_MKL_LRN_CPU(T) \ 675 REGISTER_KERNEL_BUILDER(Name("_MklLRN") \ 676 .Device(DEVICE_CPU) \ 677 .TypeConstraint<T>("T") \ 678 .Label(mkl_op_registry::kMklOpLabel), \ 679 MklLRNOp<T>); \ 680 REGISTER_KERNEL_BUILDER(Name("_MklLRNGrad") \ 681 .Device(DEVICE_CPU) \ 682 .TypeConstraint<T>("T") \ 683 .Label(mkl_op_registry::kMklOpLabel), \ 684 MklLRNGradOp<T>); 685 686 TF_CALL_float(REGISTER_MKL_LRN_CPU); 687 688 } // namespace tensorflow 689 690 #endif // INTEL_MKL 691