1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_ 17 #define TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_ 18 19 #include <limits> 20 #include <string> 21 #include <vector> 22 23 #include "tensorflow/core/framework/numeric_op.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/framework/tensor_slice.h" 29 #include "tensorflow/core/kernels/bounds_check.h" 30 #include "tensorflow/core/kernels/conv_grad_ops.h" 31 #include "tensorflow/core/kernels/ops_util.h" 32 #include "tensorflow/core/lib/core/errors.h" 33 #include "tensorflow/core/lib/gtl/array_slice.h" 34 #include "tensorflow/core/lib/strings/numbers.h" 35 #include "tensorflow/core/lib/strings/str_util.h" 36 #include "tensorflow/core/platform/logging.h" 37 #include "tensorflow/core/platform/macros.h" 38 #include "tensorflow/core/util/padding.h" 39 #include "tensorflow/core/util/tensor_format.h" 40 41 #include "tensorflow/core/util/mkl_util.h" 42 43 #ifndef INTEL_MKL_ML 44 #include "mkldnn.hpp" 45 46 using mkldnn::prop_kind; 47 using mkldnn::stream; 48 49 using mkldnn::convolution_direct; 50 using mkldnn::convolution_forward; 51 #endif 52 53 namespace tensorflow { 54 55 #ifndef INTEL_MKL_ML 56 57 class MklDnnConvUtil { 58 protected: 59 OpKernelContext* context_; // We don't own this. 60 std::vector<int32> strides_; 61 Padding padding_; 62 TensorFormat data_format_; 63 64 public: 65 MklDnnConvUtil(OpKernelContext* context, const std::vector<int32>& strides, 66 Padding pad, TensorFormat fm) 67 : context_(context), strides_(strides), padding_(pad), data_format_(fm) {} 68 69 virtual ~MklDnnConvUtil() { context_ = nullptr; } 70 71 // Calculate Convolution strides 72 virtual inline void GetStridesInMklOrder(memory::dims* strides) { 73 // For now we take the stride from the second and third dimensions only 74 // (we do not support striding on the batch or depth dimension). 75 CHECK_NOTNULL(strides); 76 int stride_rows = GetTensorDim(strides_, data_format_, 'H'); 77 int stride_cols = GetTensorDim(strides_, data_format_, 'W'); 78 *strides = {stride_rows, stride_cols}; 79 } 80 81 // Calculate Convolution input size in MKL-DNN order. MKL-DNN 82 // requires input in NCHW format. Function does not return anything. 83 // But errors arising from sanity checks are returned in context's 84 // status. 85 virtual inline void GetInputSizeInMklOrder(const TensorShape& input_shape, 86 memory::dims* input_dims) { 87 #define CHECK_BOUNDS(val, err_msg) \ 88 do { \ 89 OP_REQUIRES(context_, \ 90 FastBoundsCheck(val, std::numeric_limits<int>::max()), \ 91 errors::InvalidArgument(err_msg)); \ 92 } while (0) 93 94 CHECK_NOTNULL(input_dims); 95 96 // Input channel 97 int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C'); 98 int input_depth = static_cast<int>(input_depth_raw); 99 100 // Input rows/height 101 int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H'); 102 CHECK_BOUNDS(input_rows_raw, "Input rows too large"); 103 int input_rows = static_cast<int>(input_rows_raw); 104 105 // Input columns/width 106 int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W'); 107 CHECK_BOUNDS(input_cols_raw, "Input cols too large"); 108 int input_cols = static_cast<int>(input_cols_raw); 109 110 // Input batch 111 int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N'); 112 CHECK_BOUNDS(input_batch_raw, "Input batch too large"); 113 int input_batch = static_cast<int>(input_batch_raw); 114 115 #undef CHECK_BOUNDS 116 117 // MKL-DNN always requires input in NCHW format. 118 std::vector<int> mkldnn_sizes(4, -1); 119 mkldnn_sizes[MklDnnDims::Dim_N] = input_batch; 120 mkldnn_sizes[MklDnnDims::Dim_C] = input_depth; 121 mkldnn_sizes[MklDnnDims::Dim_H] = input_rows; 122 mkldnn_sizes[MklDnnDims::Dim_W] = input_cols; 123 124 *input_dims = mkldnn_sizes; 125 } 126 127 // Calculate Convolution filter size in MKL-DNN order. MKL-DNN 128 // requires filter in OIHW format. Function does not return anything. 129 // But errors arising from sanity checks are returned in context's 130 // status. 131 // 132 // Calculate Convolution filter size in MKL-DNN order. MKL-DNN 133 // requires filter in OIHW format. Function does not return anything. 134 // But errors arising from sanity checks are returned in context's 135 // status. This function differs from GetConvFilterSizeInMklOrder in 136 // parameter for input - it accepts src_shape since Convolution Backward 137 // Input gets shape of input tensor rather than actual tensor (Convolution 138 // forward gets actual tensor as input). 139 // 140 // TODO(nhasabni): Add similar function for input and filter in MklShape. 141 virtual inline void GetFilterSizeInMklOrder(const TensorShape& input_shape, 142 const TensorShape& filter_shape, 143 memory::dims* filter_dims) { 144 CHECK_NOTNULL(filter_dims); 145 146 OP_REQUIRES(context_, filter_shape.dims() == 4, 147 errors::InvalidArgument("filter must be 4-dimensional: ", 148 filter_shape.DebugString())); 149 150 for (int i = 0; i < 3; i++) { 151 OP_REQUIRES(context_, 152 FastBoundsCheck(filter_shape.dim_size(i), 153 std::numeric_limits<int>::max()), 154 errors::InvalidArgument("filter too large")); 155 } 156 157 int input_depth = GetTensorDim(input_shape, data_format_, 'C'); 158 159 OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2), 160 errors::InvalidArgument( 161 "input and filter must have the same depth: ", input_depth, 162 " vs ", filter_shape.dim_size(2))); 163 164 // TF filter is always in (rows, cols, in_depth, out_depth) order. 165 int filter_rows = static_cast<int>(filter_shape.dim_size(0)); 166 int filter_cols = static_cast<int>(filter_shape.dim_size(1)); 167 int in_depth = static_cast<int>(filter_shape.dim_size(2)); 168 int out_depth = static_cast<int>(filter_shape.dim_size(3)); 169 170 // MKL-DNN always needs filter in OIHW format. 171 // OIHW = (out_depth, in_depth, rows, cols) 172 std::vector<int> mkldnn_sizes(4, -1); 173 mkldnn_sizes[MklDnnDims::Dim_O] = out_depth; 174 mkldnn_sizes[MklDnnDims::Dim_I] = in_depth; 175 mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows; 176 mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols; 177 178 *filter_dims = mkldnn_sizes; 179 } 180 181 // Calculate Convolution filter size in MKL-DNN order. MKL-DNN 182 // requires filter in OIHW format. Function does not return anything. 183 // But errors arising from sanity checks are returned in context's 184 // status. 185 virtual inline void GetFilterSizeInMklOrder(size_t src_index, 186 size_t filter_index, 187 memory::dims* filter_dims) { 188 CHECK_NOTNULL(filter_dims); 189 GetFilterSizeInMklOrder(GetTfShape(context_, src_index), 190 GetTfShape(context_, filter_index), filter_dims); 191 } 192 193 // Calculate Bias size for 2D Convolution. Function does not return 194 // anything, but sets error in context status. 195 virtual inline void GetBiasSizeInMklOrder(size_t bias_index, 196 memory::dims* bias_dims) { 197 const Tensor& bias = MklGetInput(context_, bias_index); 198 OP_REQUIRES(context_, bias.dims() == 1, 199 errors::InvalidArgument("bias must be 1-dimensional: ", 200 bias.shape().DebugString())); 201 202 *bias_dims = {static_cast<int>(bias.dim_size(0))}; 203 } 204 205 // Function to calculate output and padding size for 2D convolution. 206 // 207 // Calculate output shape of Convolution in MKL-DNN and TensorFlow order. 208 // MKL-DNN uses NCHW for output order. But TensorFlow output will be in 209 // NHWC or NCHW format depending on data format. Function also calculates 210 // left, right, top and bottom pads. Function does not return any status - 211 // status is returned via context status. 212 // 213 // TODO(nhasabni): Add similar function for input and filter in MklShape. 214 virtual inline void GetOutputAndPadSizeInMklOrder( 215 const TensorShape& input_shape, const TensorShape& filter_shape, 216 const memory::dims& strides, memory::dims* output_dims_tf_order, 217 memory::dims* output_dims_mkl_order, memory::dims* pad_l, 218 memory::dims* pad_r) { 219 CHECK_NOTNULL(output_dims_tf_order); 220 CHECK_NOTNULL(output_dims_mkl_order); 221 CHECK_NOTNULL(pad_l); 222 CHECK_NOTNULL(pad_r); 223 224 int input_rows = GetTensorDim(input_shape, data_format_, 'H'); 225 int input_cols = GetTensorDim(input_shape, data_format_, 'W'); 226 227 // The first dimension for filter is rows/height. 228 int filter_rows = filter_shape.dim_size(0); 229 // The second dimension for filter is cols/width. 230 int filter_cols = filter_shape.dim_size(1); 231 232 // Stride is vector of 2 elements: {s_r, s_c} 233 int stride_rows = strides[0]; 234 int stride_cols = strides[1]; 235 236 // Output batch is same as input batch. 237 int out_batch = GetTensorDim(input_shape, data_format_, 'N'); 238 // Output depth is same as last dimension for filter. 239 int out_depth = filter_shape.dim_size(3); 240 241 int64 out_rows = 0, out_cols = 0; 242 int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right; 243 244 OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( 245 input_rows, filter_rows, stride_rows, padding_, 246 &out_rows, &pad_top, &pad_bottom)); 247 OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( 248 input_cols, filter_cols, stride_cols, padding_, 249 &out_cols, &pad_left, &pad_right)); 250 251 // Tensorflow output is in data_format order. (NHWC or NCHW) 252 TensorShape out_shape = 253 ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, out_depth); 254 *output_dims_tf_order = TFShapeToMklDnnDims(out_shape); 255 256 // MKL-DNN always needs output in NCHW format. 257 std::vector<int> mkldnn_sizes(4, -1); 258 mkldnn_sizes[MklDnnDims::Dim_N] = out_batch; 259 mkldnn_sizes[MklDnnDims::Dim_C] = out_depth; 260 mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows); 261 mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols); 262 *output_dims_mkl_order = mkldnn_sizes; 263 264 // Now handle padding. MKL-DNN uses asymetric padding. 265 *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)}; 266 *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)}; 267 } 268 269 // Calculate output and pad size of forward Convolution operator. 270 // See comment on GetConvOutputAndPadSizeInMklOrder for parameters. 271 // 272 // Function does not return anything, but sets error in context status. 273 inline void GetOutputAndPadSizeInMklOrder( 274 size_t src_index, size_t filter_index, const memory::dims& strides, 275 memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order, 276 memory::dims* pad_l, memory::dims* pad_r) { 277 CHECK_NOTNULL(output_dims_tf_order); 278 CHECK_NOTNULL(output_dims_mkl_order); 279 CHECK_NOTNULL(pad_l); 280 CHECK_NOTNULL(pad_r); 281 282 auto input_tf_shape = GetTfShape(context_, src_index); 283 auto filter_tf_shape = GetTfShape(context_, filter_index); 284 285 OP_REQUIRES(context_, input_tf_shape.dims() == 4, 286 errors::InvalidArgument("input must be 4-dimensional", 287 input_tf_shape.DebugString())); 288 289 GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape, strides, 290 output_dims_tf_order, output_dims_mkl_order, 291 pad_l, pad_r); 292 } 293 294 // Wrapper function to calculate input, filter, and output sizes of 295 // 2D Convolution in MKL order (NCHW for input and output; OIHW for filter.) 296 // Function also calculates output shape in Tensorflow order. Additionally, it 297 // also calculates strides and paddings for 2D Convolution. 298 // 299 // Function does not return anything, but sets error in context status. 300 inline void GetConvFwdSizesInMklOrder( 301 const TensorShape& input_shape, const TensorShape& filter_shape, 302 memory::dims* input_dims, memory::dims* filter_dims, 303 memory::dims* strides, memory::dims* output_dims_tf_order, 304 memory::dims* output_dims_mkl_order, memory::dims* pad_l, 305 memory::dims* pad_r) { 306 CHECK_NOTNULL(input_dims); 307 CHECK_NOTNULL(filter_dims); 308 CHECK_NOTNULL(strides); 309 CHECK_NOTNULL(output_dims_tf_order); 310 CHECK_NOTNULL(output_dims_mkl_order); 311 CHECK_NOTNULL(pad_l); 312 CHECK_NOTNULL(pad_r); 313 314 GetInputSizeInMklOrder(input_shape, input_dims); 315 if (!context_->status().ok()) return; 316 GetFilterSizeInMklOrder(input_shape, filter_shape, filter_dims); 317 if (!context_->status().ok()) return; 318 GetStridesInMklOrder(strides); 319 GetOutputAndPadSizeInMklOrder(input_shape, filter_shape, *strides, 320 output_dims_tf_order, output_dims_mkl_order, 321 pad_l, pad_r); 322 if (!context_->status().ok()) return; 323 } 324 }; 325 326 ///////////////////////////////////////////////////////////////////// 327 /// Common class that implements Conv2DBackpropFilter and Input 328 ///////////////////////////////////////////////////////////////////// 329 330 template <typename Device, class T> 331 class MklConv2DBackpropCommonOp : public OpKernel { 332 public: 333 ~MklConv2DBackpropCommonOp() {} 334 explicit MklConv2DBackpropCommonOp(OpKernelConstruction* context) 335 : OpKernel(context) { 336 string data_format_str; 337 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); 338 OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_), 339 errors::InvalidArgument("Invalid data format")); 340 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 341 int stride_n = GetTensorDim(strides_, data_format_, 'N'); 342 int stride_c = GetTensorDim(strides_, data_format_, 'C'); 343 OP_REQUIRES( 344 context, (stride_n == 1 && stride_c == 1), 345 errors::InvalidArgument("Current implementation does not yet support " 346 "strides in the batch and depth dimensions.")); 347 348 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 349 } 350 351 void Compute(OpKernelContext* context) override { 352 try { 353 auto cpu_engine = engine(engine::cpu, 0); 354 355 // Prepare common tensors for Conv2DBackpropInput and 356 // Conv2DBackpropFilter. 357 MklDnnData<T> input(&cpu_engine); 358 MklDnnData<T> filter(&cpu_engine); 359 MklDnnData<T> outbackprop(&cpu_engine); 360 MklDnnData<T> output(&cpu_engine); 361 362 // Input tensors 363 const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2; 364 const Tensor& input_tensor = MklGetInput(context, kInputIdx); 365 const Tensor& filter_tensor = MklGetInput(context, kFilterIdx); 366 const Tensor& outbprop_tensor = MklGetInput(context, kOutbpropIdx); 367 368 MklDnnShape input_mkl_shape, filter_mkl_shape, outbprop_mkl_shape; 369 GetMklShape(context, kInputIdx, &input_mkl_shape); 370 GetMklShape(context, kFilterIdx, &filter_mkl_shape); 371 GetMklShape(context, kOutbpropIdx, &outbprop_mkl_shape); 372 // Allow operator-specific sanity checking of shapes. 373 ValidateMklShapes(input_mkl_shape, filter_mkl_shape, outbprop_mkl_shape); 374 375 // Allow operator-specific generation of shapes. 376 // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a 377 // tensor containing shape of filter. So filter.shape() is not 378 // a correct way to get filter shape. These operator-specific calls 379 // allow this class to handle this case. 380 TensorShape input_tf_shape = MakeInputTfShape(context, input_tensor); 381 TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor); 382 TensorShape outbprop_tf_shape = GetTfShape(context, kOutbpropIdx); 383 384 // Corner cases: output with 0 elements and 0 batch size. 385 Tensor* output_tensor = nullptr; 386 if (input_tf_shape.num_elements() == 0 || 387 filter_tf_shape.num_elements() == 0 || 388 outbprop_tf_shape.num_elements() == 0) { 389 MklDnnShape output_mkl_shape; 390 output_mkl_shape.SetMklTensor(false); 391 TensorShape output_tf_shape = GetOutputTfShape( 392 input_tf_shape, filter_tf_shape, outbprop_tf_shape); 393 const int kOutputIdx = 0; 394 AllocateOutputSetMklShape(context, kOutputIdx, &output_tensor, 395 output_tf_shape, output_mkl_shape); 396 CHECK_NOTNULL(output_tensor); 397 398 // if output tensor has more than 0 elements, we need to 0 them out. 399 for (size_t i = 0; i < output_tf_shape.num_elements(); ++i) { 400 output_tensor->flat<T>().data()[i] = 0; 401 } 402 403 return; 404 } 405 406 // By default, all dims are in MKL order. Only dims in TF order 407 // are those with prefix tf_order. 408 memory::dims outbprop_dims, fwd_input_dims, fwd_filter_dims; 409 memory::dims padding_l, padding_r, strides, fwd_output_dims; 410 memory::dims fwd_output_dims_tf_order; 411 412 // Get forward convolution parameters. 413 MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_); 414 conv_utl.GetConvFwdSizesInMklOrder( 415 input_tf_shape, filter_tf_shape, &fwd_input_dims, &fwd_filter_dims, 416 &strides, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_l, 417 &padding_r); 418 if (!context->status().ok()) return; 419 420 // Create Convolution forward descriptor since Convolution backward 421 // API needs it. For that, we first need to create input, filter 422 // and output memory descriptors. 423 auto tf_fmt = TFDataFormatToMklDnnDataFormat(data_format_); 424 // If input is in MKL layout, then simply grab input layout; otherwise, 425 // construct input TF layout. For TF layout, although input shape 426 // required is in MKL-DNN order, the layout is Tensorflow's layout 427 // (NHWC or NCHW depending on data format). 428 auto fwd_input_md = 429 input_mkl_shape.IsMklTensor() 430 ? input_mkl_shape.GetMklLayout() 431 : memory::desc(fwd_input_dims, MklDnnType<T>(), tf_fmt); 432 // If filter is in MKL layout, then simply grab filter layout; otherwise 433 // construct filter in TF layout. For TF layout, filter is in HWIO format. 434 auto fwd_filter_md = filter_mkl_shape.IsMklTensor() 435 ? filter_mkl_shape.GetMklLayout() 436 : memory::desc(fwd_filter_dims, MklDnnType<T>(), 437 memory::format::hwio); 438 // Tensorflow Output of Conv2D is in data_format order. 439 auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(), tf_fmt); 440 auto fwd_desc = convolution_forward::desc( 441 prop_kind::forward, convolution_direct, fwd_input_md, fwd_filter_md, 442 fwd_out_md, strides, padding_l, padding_r, 443 TFPaddingToMklDnnPadding(padding_)); 444 auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine); 445 446 // Create memory for user data. Describe how the inputs and outputs of 447 // Convolution look like. Also specify buffers containing actual input 448 // and output data. 449 450 // Since this is a common class for both Conv2DBackpropFilter and 451 // Conv2DBackpropInput, we skip SetUsrMem call for input tensor (for 452 // Conv2DBackpropInput) and for filter tensor (for 453 // conv2DBackpropFilter) depending on which tensor is int32 type. 454 size_t input_with_sizes = GetInputTensorIndexWithSizes(); 455 if (input_with_sizes != kInputIdx) { 456 // Shape of Conv2DBackpropFilter's input is same as Conv2D input. 457 input.SetUsrMem(fwd_input_md, &input_tensor); 458 } else if (input_with_sizes != kFilterIdx) { 459 // Shape of Conv2DBackpropInput's filter is same as Conv2D filter. 460 filter.SetUsrMem(fwd_filter_md, &filter_tensor); 461 } 462 463 conv_utl.GetInputSizeInMklOrder(outbprop_tf_shape, &outbprop_dims); 464 if (!context->status().ok()) return; 465 if (outbprop_mkl_shape.IsMklTensor()) { 466 // If outbackprop is in Mkl layout, then simply grab it. 467 auto outbprop_md = outbprop_mkl_shape.GetMklLayout(); 468 outbackprop.SetUsrMem(outbprop_md, &outbprop_tensor); 469 } else { 470 // If outbackprop is in TensorFlow layout, then we need to create memory 471 // descriptor for it. Outbackprop shape is data format order. 472 outbackprop.SetUsrMem(outbprop_dims, tf_fmt, &outbprop_tensor); 473 } 474 475 // Operator specific call to get output shape and data_format. 476 auto bwd_output_dims = GetOutputDims(fwd_input_dims, fwd_filter_dims); 477 auto bwd_output_format = GetOutputFormat(tf_fmt); 478 output.SetUsrMem(bwd_output_dims, bwd_output_format); 479 480 // Create memory descriptors for convolution data w/ no specified format. 481 input.SetOpMemDesc(fwd_input_dims, memory::format::any); 482 filter.SetOpMemDesc(fwd_filter_dims, memory::format::any); 483 outbackprop.SetOpMemDesc(outbprop_dims, memory::format::any); 484 output.SetOpMemDesc(bwd_output_dims, memory::format::any); 485 486 // Operator-specific call to create and execute primitive. 487 CreatePrimitive(context, cpu_engine, fwd_pd, &input, &filter, 488 &outbackprop, &output, &output_tensor, strides, padding_l, 489 padding_r, TFPaddingToMklDnnPadding(padding_), 490 bwd_output_dims, bwd_output_format); 491 } catch (mkldnn::error& e) { 492 string error_msg = "Status: " + std::to_string(e.status) + 493 ", message: " + string(e.message) + ", in file " + 494 string(__FILE__) + ":" + std::to_string(__LINE__); 495 OP_REQUIRES_OK( 496 context, 497 errors::Aborted("Operation received an exception:", error_msg)); 498 } 499 } 500 501 /// Pure virtual function to allow operator to check for validity of input 502 /// shapes. Function asserts that input shapes are valid. 503 virtual void ValidateMklShapes(const MklDnnShape& input_mkl_shape, 504 const MklDnnShape& filter_mkl_shape, 505 const MklDnnShape& outbprop_mkl_shape) = 0; 506 507 /// Operator-specific function that returns index of input that is 508 /// representing input sizes. For Conv2DBackpropFilter it returns 1 since 509 /// filter for this operator is filter shape. For Conv2DBackpropInput it 510 /// returns 0 (for input). 511 virtual size_t GetInputTensorIndexWithSizes() = 0; 512 513 /// Get TensorFlow shape of input tensor. 514 virtual TensorShape MakeInputTfShape(OpKernelContext* context, 515 const Tensor& input_tensor) = 0; 516 517 /// Get TensorFlow shape of filter tensor. 518 virtual TensorShape MakeFilterTfShape(OpKernelContext* context, 519 const Tensor& filter_tensor) = 0; 520 521 /// Get the TensorFlow shape of output tensor. 522 virtual TensorShape GetOutputTfShape(const TensorShape& input_shape, 523 const TensorShape& filter_shape, 524 const TensorShape& outbprop_shape) = 0; 525 526 /// Get shape of output in MKL-DNN order. Computes shape of output from 527 /// input shape (fwd_input_dims) and filter shape (fwd_filter_dims). 528 virtual const memory::dims& GetOutputDims( 529 const memory::dims& fwd_input_dims, 530 const memory::dims& fwd_filter_dims) = 0; 531 532 /// Get data_format of output in MKL-DNN order. If output data format is 533 /// same as input data format, then it simply returns value of data_format 534 /// parameter as it is. 535 virtual memory::format GetOutputFormat(const memory::format data_format) = 0; 536 537 /// Create and execute the primitive storing output in the output_tensor. 538 virtual void CreatePrimitive( 539 OpKernelContext* context, const engine& cpu_engine, 540 const convolution_forward::primitive_desc& conv_fwd_pd, 541 MklDnnData<T>* input, MklDnnData<T>* filter, MklDnnData<T>* outbackprop, 542 MklDnnData<T>* output, Tensor** output_tensor, 543 const memory::dims& strides, const memory::dims& padding_l, 544 const memory::dims& padding_r, padding_kind padding, 545 const memory::dims& bwd_output_dims, 546 memory::format bwd_output_format) = 0; 547 548 // Get the data_format {NCHW, NHWC} 549 TensorFormat GetTFDataFormat() { return data_format_; } 550 551 private: 552 std::vector<int32> strides_; 553 Padding padding_; 554 TensorFormat data_format_; 555 }; 556 #endif // INTEL_MKL_ML 557 558 ///////////////////////////////////////////////////////////////////// 559 /// Dummy Mkl op that is just used for operators that are intermediate 560 /// output of node fusion in the graph 561 ///////////////////////////////////////////////////////////////////// 562 563 template <typename Device, typename T> 564 class MklDummyOp : public OpKernel { 565 public: 566 ~MklDummyOp() {} 567 568 explicit MklDummyOp(OpKernelConstruction* context) : OpKernel(context) {} 569 570 void Compute(OpKernelContext* context) override { 571 TF_CHECK_OK( 572 errors::Unimplemented("This is a dummy op." 573 "It should not have been invoked.")); 574 } 575 }; 576 577 } // namespace tensorflow 578 579 #endif // TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_ 580