1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // TODO(intel): Improve error handling in this file; instead of CHECK failing 17 // all over the place, we should log an error and execute the original graph. 18 #ifdef INTEL_MKL 19 20 #include <algorithm> 21 #include <functional> 22 #include <memory> 23 #include <queue> 24 #include <set> 25 #include <stack> 26 #include <tuple> 27 #include <unordered_set> 28 #include <utility> 29 #include <vector> 30 31 #include "tensorflow/core/common_runtime/function.h" 32 #include "tensorflow/core/common_runtime/optimization_registry.h" 33 #include "tensorflow/core/framework/node_def_util.h" 34 #include "tensorflow/core/framework/tensor.pb.h" 35 #include "tensorflow/core/graph/algorithm.h" 36 #include "tensorflow/core/graph/graph.h" 37 #include "tensorflow/core/graph/node_builder.h" 38 #include "tensorflow/core/lib/core/status.h" 39 #include "tensorflow/core/lib/gtl/array_slice.h" 40 #include "tensorflow/core/lib/gtl/map_util.h" 41 #include "tensorflow/core/lib/hash/hash.h" 42 #include "tensorflow/core/platform/logging.h" 43 #include "tensorflow/core/util/tensor_format.h" 44 #include "tensorflow/core/util/util.h" 45 46 #include "tensorflow/core/graph/mkl_graph_util.h" 47 #include "tensorflow/core/graph/mkl_layout_pass.h" 48 49 namespace tensorflow { 50 51 // This pass implements rewriting of graph to support following scenarios: 52 // (A) Merging nodes in the graph 53 // (B) Rewriting a node in the graph to a new node 54 // Rewrite happens under following scenario: 55 // - Propagating Mkl layout as an additional output tensor 56 // (we will loosely call a tensor that carries Mkl layout as Mkl tensor 57 // henceforth.) from every Mkl supported NN layer. 58 // 59 // Example of A : Merging nodes in the graph 60 // ----------------------------------------- 61 // Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as: 62 // 63 // O = Conv2D(A, B) 64 // P = BiasAdd(O, C) 65 // 66 // We merge them into Conv2DWithBias as: 67 // P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m) 68 // 69 // The meaning of A_m, B_m and C_m is explained in B.1. 70 // 71 // Merge rules: 72 // - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_ 73 // goes to BiasAdd. 74 // - Also, the intersection of attributes of both the nodes must have same 75 // values. 76 // - Both the nodes must have been assigned to same device (if any). 77 // 78 // Example of B.1 : Rewriting nodes to Mkl nodes 79 // --------------------------------------------- 80 // Consider a Relu node. Current definition of Relu node looks like: 81 // 82 // O = Relu(A) 83 // 84 // Relu has 1 input (A), and 1 output (O). 85 // 86 // This rewrite pass will generate a new graph node for Relu (new node is 87 // called MklRelu) as: 88 // 89 // O, O_m = MklRelu(A, A_m) 90 // 91 // MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is 92 // same as input A of Relu; output O is same as output O of Relu. O_m is the 93 // additional output tensor that will be set by MklRelu, and it represents 94 // Mkl tensor corresponding to O -- in other words, O_m is some kind of 95 // metadata for O. A_m is additional input of Relu, and it represents metadata 96 // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives 97 // this metadata from previous node in the graph. 98 // 99 // When a previous node in the graph is an Mkl node, A_m will represent a valid 100 // Mkl tensor. But when a previous node is not an Mkl node, A_m will represent 101 // a dummy Mkl tensor. 102 // 103 // Rewriting rules: 104 // - Selection of a node for rewriting happens by registering the op type of 105 // the node with the rewriting pass. If the op type is not registered, then 106 // all nodes of this op type will not be rewritten. 107 // - Number of inputs after rewriting: 108 // Since for every input Tensorflow tensor, the rewritten node gets Mkl 109 // tensor(s), rewritten node gets 2*N inputs, where N is the number of 110 // inputs for the original node. 111 // - Number of outputs after rewriting: 112 // Since for every output Tensorflow tensor, the rewritten node generates 113 // Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the 114 // number of outputs of the original node. 115 // - Ordering of Tensorflow tensors and Mkl tensors: 116 // Since every rewritten node generates twice the number of inputs and 117 // outputs, one could imagine various orderings among Tensorflow tensors 118 // and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as 119 // inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m 120 // in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m 121 // order. Among N inputs one can get N! permutations. 122 // 123 // So the question is: which order do we follow? We support 2 types of 124 // orderings: (1) interleaved, and (2) contiguous. Interleaved ordering 125 // follows an intuitive order where an Mkl tensor follows the 126 // corresponding Tensorflow tensor immediately. In the context of the 127 // above example, it will be: A, A_m, B, B_m. Note that the ordering rule 128 // applies to both the inputs and outputs. Contiguous ordering means 129 // all the Tensorflow tensors are contiguous followed by all the Mkl 130 // tensors. We use contiguous ordering as default. 131 // 132 // Graph rewrite algorithm: 133 // Algorithm: Graph Rewrite 134 // Input: Graph G, Names of the nodes to rewrite and their new names 135 // Output: Modified Graph G' if the nodes are modified, G otherwise. 136 // Start: 137 // N = Topological_Sort(G) // N is a set of nodes in toposort order. 138 // foreach node n in N 139 // do 140 // if (Is_MKL_Op(n)) // Can this node accept an Mkl layout as input. 141 // then 142 // E = set of <incoming edge and its src_output slot> of n 143 // E' = {} // a new set of edges for rewritten node 144 // foreach <e,s> in E 145 // do 146 // E' U {<e,s>} // First copy edge which generates Tensorflow 147 // // tensor as it is 148 // m = Source node of edge e 149 // if Is_Rewritten(m) // Did we rewrite this node in this pass? 150 // then 151 // E' U {<m,s+1>} // If yes, then m will generate an Mkl 152 // // tensor as an additional output. 153 // else 154 // d = Generate_Dummy_Mkl_Tensor() // If not, generate a dummy 155 // // Mkl tensor. 156 // E' U {<d,0>} // The dummy Mkl tensor has only 1 output slot. 157 // fi 158 // done 159 // n' = Build_New_Node(G,new_name,E') 160 // Mark_Rewritten(n') // Mark the new node as being rewritten. 161 // fi 162 // done 163 // 164 // Explanation: 165 // For graph rewrite, we visit nodes of the input graph in the 166 // topological sort order. With this ordering, we visit nodes in the 167 // top-to-bottom fashion. We need this order because while visiting a 168 // node we want that all of its input nodes are visited and rewritten if 169 // applicable. This is because if we need to rewrite a given node 170 // then all of its input nodes need to be fixed (in other words they 171 // cannot be deleted later.) 172 // 173 // While visiting a node, we first check if the op type of the node is 174 // an Mkl op. If it is, then we rewrite that node after constructing 175 // new inputs to the node. If the op type of the node is not Mkl op, 176 // then we do not rewrite that node. 177 // 178 // Handling workspace propagation for certain ops: 179 // 180 // Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require 181 // passing of a workspace from their respective forward ops. Workspace 182 // tensors provide memory for storing results of intermediate operations 183 // which are helpful in backward propagation. TensorFlow does not have 184 // a notion of a workspace and as a result does not allow producing 185 // additional outputs from these forward ops. For these ops, we need 186 // to add 2 extra edges between forward ops and their corresponding 187 // backward ops - the first extra edge carries a workspace tensor and 188 // the second one carries an Mkl tensor for the workspace tensor. 189 // 190 // Example: 191 // 192 // Typical graph for MaxPool and its gradient looks like: 193 // 194 // A = MaxPool(T) 195 // B = MaxPoolGrad(X, A, Y) 196 // 197 // We will transform this graph to propagate the workspace as: 198 // (with the contiguous ordering) 199 // 200 // A, W, A_m, W_m = MklMaxPool(T, T_m) 201 // B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m) 202 // 203 // Here W is the workspace tensor. Transformed tensor names with the 204 // suffix _m are Mkl tensors, and this transformation has been done 205 // using the algorithm discussed earlier. The transformation for 206 // workspace propagation only adds extra outputs (W, W_m) for a forward 207 // op and connects them to the corresponding backward ops. 208 // 209 // Terms: 210 // 211 // Forward op name = name of the op in the forward pass 212 // where a workspace tensor originates (MaxPool in this example) 213 // Backward op name = name of the op in the backward pass that receives 214 // a workspace tensor from the forward op (MaxPoolGrad in the example) 215 // Slot = Position of the output or input slot that will be 216 // used by the workspace tensor (1 for MklMaxPool as W is the 2nd 217 // output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad) 218 // 219 // Question: 220 // 221 // How do we associate a backward op to a forward op? There can be more 222 // than one op with the exact same name. 223 // 224 // In this example, we associate MaxPoolGrad with MaxPool. But there 225 // could be more than one MaxPool ops. To solve this problem, we look 226 // for _direct_ edge between a forward op and a backward op (tensor A is 227 // flowing along this edge in the example). 228 // 229 // How do we transform forward and backward ops when there is no direct 230 // edge between them? In such a case, we generate dummy tensors for 231 // workspace tensors. For the example, transformation of MaxPool will 232 // be exactly same as it would be when there is a direct edge between 233 // the forward and the backward op --- it is just that MaxPool won't 234 // generate any workspace tensor. For MaxPoolGrad, the transformation 235 // will also be same, but instead of connecting W and W_m with the 236 // outputs of MaxPool, we will produce dummy tensors for them, and we 237 // will set workspace_enabled attribute to false. 238 // 239 class MklLayoutRewritePass : public GraphOptimizationPass { 240 public: 241 MklLayoutRewritePass() { 242 // NOTE: names are alphabetically sorted. 243 csinfo_.addn = "AddN"; 244 csinfo_.avg_pool = "AvgPool"; 245 csinfo_.avg_pool_grad = "AvgPoolGrad"; 246 csinfo_.avg_pool3d = "AvgPool3D"; 247 csinfo_.avg_pool3d_grad = "AvgPool3DGrad"; 248 csinfo_.bias_add = "BiasAdd"; 249 csinfo_.bias_add_grad = "BiasAddGrad"; 250 csinfo_.concat = "Concat"; 251 csinfo_.concatv2 = "ConcatV2"; 252 csinfo_.conv2d = "Conv2D"; 253 csinfo_.conv2d_with_bias = "__MklDummyConv2DWithBias"; 254 csinfo_.conv2d_grad_input = "Conv2DBackpropInput"; 255 csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter"; 256 csinfo_.conv2d_grad_filter_with_bias = 257 "__MklDummyConv2DBackpropFilterWithBias"; 258 csinfo_.conv3d = "Conv3D"; 259 csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2"; 260 csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2"; 261 csinfo_.depthwise_conv2d = "DepthwiseConv2dNative"; 262 csinfo_.depthwise_conv2d_grad_input = "DepthwiseConv2dNativeBackpropInput"; 263 csinfo_.depthwise_conv2d_grad_filter = 264 "DepthwiseConv2dNativeBackpropFilter"; 265 csinfo_.fused_batch_norm = "FusedBatchNorm"; 266 csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad"; 267 csinfo_.fused_conv2d = "_FusedConv2D"; 268 csinfo_.identity = "Identity"; 269 csinfo_.leakyrelu = "LeakyRelu"; 270 csinfo_.leakyrelu_grad = "LeakyReluGrad"; 271 csinfo_.lrn = "LRN"; 272 csinfo_.lrn_grad = "LRNGrad"; 273 csinfo_.matmul = "MatMul"; 274 csinfo_.max_pool = "MaxPool"; 275 csinfo_.max_pool_grad = "MaxPoolGrad"; 276 csinfo_.max_pool3d = "MaxPool3D"; 277 csinfo_.max_pool3d_grad = "MaxPool3DGrad"; 278 csinfo_.mkl_conv2d = "_MklConv2D"; 279 csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput"; 280 csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter"; 281 csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; 282 csinfo_.mkl_conv2d_grad_filter_with_bias = 283 "_MklConv2DBackpropFilterWithBias"; 284 csinfo_.mkl_depthwise_conv2d_grad_input = 285 "_MklDepthwiseConv2dNativeBackpropInput"; 286 csinfo_.mkl_depthwise_conv2d_grad_filter = 287 "_MklDepthwiseConv2dNativeBackpropFilter"; 288 csinfo_.mkl_fused_conv2d = "_MklFusedConv2D"; 289 csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D"; 290 csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D"; 291 csinfo_.pad = "Pad"; 292 csinfo_.pad_with_conv2d = "__MklDummyPadWithConv2D"; 293 csinfo_.pad_with_fused_conv2d = "__MklDummyPadWithFusedConv2D"; 294 csinfo_.quantized_avg_pool = "QuantizedAvgPool"; 295 csinfo_.quantized_concatv2 = "QuantizedConcatV2"; 296 csinfo_.quantized_conv2d = "QuantizedConv2D"; 297 csinfo_.quantized_conv2d_with_requantize = "QuantizedConv2DAndRequantize"; 298 csinfo_.quantized_conv2d_with_bias = "QuantizedConv2DWithBias"; 299 csinfo_.quantized_conv2d_with_bias_and_requantize = 300 "QuantizedConv2DWithBiasAndRequantize"; 301 csinfo_.quantized_conv2d_and_relu = "QuantizedConv2DAndRelu"; 302 csinfo_.quantized_conv2d_and_relu_and_requantize = 303 "QuantizedConv2DAndReluAndRequantize"; 304 csinfo_.quantized_conv2d_with_bias_and_relu = 305 "QuantizedConv2DWithBiasAndRelu"; 306 csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize = 307 "QuantizedConv2DWithBiasAndReluAndRequantize"; 308 csinfo_.quantized_max_pool = "QuantizedMaxPool"; 309 csinfo_.quantized_conv2d_with_bias_sum_and_relu = 310 "QuantizedConv2DWithBiasSumAndRelu"; 311 csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize = 312 "QuantizedConv2DWithBiasSumAndReluAndRequantize"; 313 csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize = 314 "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize"; 315 csinfo_.relu = "Relu"; 316 csinfo_.relu_grad = "ReluGrad"; 317 csinfo_.relu6 = "Relu6"; 318 csinfo_.relu6_grad = "Relu6Grad"; 319 csinfo_.requantize = "Requantize"; 320 csinfo_.tanh = "Tanh"; 321 csinfo_.tanh_grad = "TanhGrad"; 322 csinfo_.reshape = "Reshape"; 323 csinfo_.slice = "Slice"; 324 csinfo_.softmax = "Softmax"; 325 csinfo_.split = "Split"; 326 csinfo_.transpose = "Transpose"; 327 // Element-wise ops. Ensure you also add any new ops to IsOpElementWise 328 // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the 329 // MklInputConversion op is added before it. 330 csinfo_.add = "Add"; 331 csinfo_.maximum = "Maximum"; 332 csinfo_.mul = "Mul"; 333 csinfo_.squared_difference = "SquaredDifference"; 334 csinfo_.sub = "Sub"; 335 // End - element-wise ops. See note above. 336 337 // NOTE: names are alphabetically sorted. 338 rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), 339 CopyAttrsAddN, AddNRewrite}); 340 rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), 341 CopyAttrsDataType, AlwaysRewrite}); 342 rinfo_.push_back({csinfo_.avg_pool, 343 mkl_op_registry::GetMklOpName(csinfo_.avg_pool), 344 CopyAttrsPooling, AlwaysRewrite}); 345 rinfo_.push_back({csinfo_.avg_pool_grad, 346 mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad), 347 CopyAttrsPooling, AlwaysRewrite}); 348 rinfo_.push_back({csinfo_.avg_pool3d, 349 mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d), 350 CopyAttrsPooling, AlwaysRewrite}); 351 rinfo_.push_back({csinfo_.avg_pool3d_grad, 352 mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad), 353 CopyAttrsPooling, AlwaysRewrite}); 354 rinfo_.push_back({csinfo_.concat, 355 mkl_op_registry::GetMklOpName(csinfo_.concat), 356 CopyAttrsConcat, AlwaysRewrite}); 357 rinfo_.push_back({csinfo_.concatv2, 358 mkl_op_registry::GetMklOpName(csinfo_.concatv2), 359 CopyAttrsConcatV2, AlwaysRewrite}); 360 rinfo_.push_back({csinfo_.conv2d, 361 mkl_op_registry::GetMklOpName(csinfo_.conv2d), 362 CopyAttrsConvCheckConstFilter, AlwaysRewrite}); 363 rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias, 364 CopyAttrsConvCheckConstFilter, AlwaysRewrite}); 365 rinfo_.push_back({csinfo_.conv2d_grad_filter, 366 mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter), 367 CopyAttrsConv, AlwaysRewrite}); 368 rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias, 369 csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv, 370 AlwaysRewrite}); 371 rinfo_.push_back({csinfo_.conv2d_grad_input, 372 mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input), 373 CopyAttrsConv, AlwaysRewrite}); 374 rinfo_.push_back({csinfo_.conv3d, 375 mkl_op_registry::GetMklOpName(csinfo_.conv3d), 376 CopyAttrsConvCheckConstFilter, AlwaysRewrite}); 377 rinfo_.push_back({csinfo_.conv3d_grad_filter, 378 mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter), 379 CopyAttrsConv, AlwaysRewrite}); 380 rinfo_.push_back({csinfo_.conv3d_grad_input, 381 mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input), 382 CopyAttrsConv, AlwaysRewrite}); 383 rinfo_.push_back({csinfo_.depthwise_conv2d, 384 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d), 385 CopyAttrsConv2DDepthwiseCheckConstFilter, AlwaysRewrite}); 386 rinfo_.push_back( 387 {csinfo_.depthwise_conv2d_grad_input, 388 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_input), 389 CopyAttrsConv2DDepthwise, AlwaysRewrite}); 390 rinfo_.push_back( 391 {csinfo_.depthwise_conv2d_grad_filter, 392 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_filter), 393 CopyAttrsConv2DDepthwise, AlwaysRewrite}); 394 rinfo_.push_back({csinfo_.fused_batch_norm, 395 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), 396 CopyAttrsFusedBatchNorm, AlwaysRewrite}); 397 rinfo_.push_back( 398 {csinfo_.fused_batch_norm_grad, 399 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), 400 CopyAttrsFusedBatchNorm, AlwaysRewrite}); 401 rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d, 402 CopyAttrsFusedConv2D, FusedConv2DRewrite}); 403 rinfo_.push_back({csinfo_.identity, 404 mkl_op_registry::GetMklOpName(csinfo_.identity), 405 CopyAttrsDataType, AlwaysRewrite}); 406 rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn), 407 CopyAttrsLRN, LrnRewrite}); 408 rinfo_.push_back({csinfo_.lrn_grad, 409 mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), 410 CopyAttrsLRN, LrnGradRewrite}); 411 rinfo_.push_back({csinfo_.leakyrelu, 412 mkl_op_registry::GetMklOpName(csinfo_.leakyrelu), 413 CopyAttrsLeakyRelu, LeakyReluRewrite}); 414 rinfo_.push_back({csinfo_.leakyrelu_grad, 415 mkl_op_registry::GetMklOpName(csinfo_.leakyrelu_grad), 416 CopyAttrsLeakyRelu, LeakyReluRewrite}); 417 rinfo_.push_back({csinfo_.max_pool, 418 mkl_op_registry::GetMklOpName(csinfo_.max_pool), 419 CopyAttrsPooling, NonDepthBatchWisePoolRewrite}); 420 rinfo_.push_back({csinfo_.max_pool_grad, 421 mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad), 422 CopyAttrsPooling, MaxpoolGradRewrite}); 423 rinfo_.push_back({csinfo_.max_pool3d, 424 mkl_op_registry::GetMklOpName(csinfo_.max_pool3d), 425 CopyAttrsPooling, NonDepthBatchWisePoolRewrite}); 426 rinfo_.push_back({csinfo_.max_pool3d_grad, 427 mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad), 428 CopyAttrsPooling, AlwaysRewrite}); 429 rinfo_.push_back({csinfo_.maximum, 430 mkl_op_registry::GetMklOpName(csinfo_.maximum), 431 CopyAttrsDataType, AlwaysRewrite}); 432 rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul), 433 CopyAttrsDataType, AlwaysRewrite}); 434 rinfo_.push_back({csinfo_.pad_with_conv2d, csinfo_.mkl_pad_with_conv2d, 435 CopyAttrsPadWithConv2D, AlwaysRewrite}); 436 rinfo_.push_back({csinfo_.pad_with_fused_conv2d, 437 csinfo_.mkl_pad_with_fused_conv2d, 438 CopyAttrsPadWithFusedConv2D, AlwaysRewrite}); 439 rinfo_.push_back({csinfo_.quantized_avg_pool, 440 mkl_op_registry::GetMklOpName(csinfo_.quantized_avg_pool), 441 CopyAttrsQuantizedPooling, AlwaysRewrite}); 442 rinfo_.push_back({csinfo_.quantized_concatv2, 443 mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2), 444 CopyAttrsConcatV2, AlwaysRewrite}); 445 rinfo_.push_back({csinfo_.quantized_conv2d, 446 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d), 447 CopyAttrsQuantizedConv2D, AlwaysRewrite}); 448 rinfo_.push_back({csinfo_.quantized_conv2d_with_requantize, 449 mkl_op_registry::GetMklOpName( 450 csinfo_.quantized_conv2d_with_requantize), 451 CopyAttrsQuantizedConv2D, AlwaysRewrite}); 452 rinfo_.push_back( 453 {csinfo_.quantized_conv2d_with_bias, 454 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_with_bias), 455 CopyAttrsQuantizedConv2D, AlwaysRewrite}); 456 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_requantize, 457 mkl_op_registry::GetMklOpName( 458 csinfo_.quantized_conv2d_with_bias_and_requantize), 459 CopyAttrsQuantizedConv2D, AlwaysRewrite}); 460 rinfo_.push_back( 461 {csinfo_.quantized_conv2d_and_relu, 462 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_and_relu), 463 CopyAttrsQuantizedConv2D, AlwaysRewrite}); 464 rinfo_.push_back({csinfo_.quantized_conv2d_and_relu_and_requantize, 465 mkl_op_registry::GetMklOpName( 466 csinfo_.quantized_conv2d_and_relu_and_requantize), 467 CopyAttrsQuantizedConv2D, AlwaysRewrite}); 468 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_relu, 469 mkl_op_registry::GetMklOpName( 470 csinfo_.quantized_conv2d_with_bias_and_relu), 471 CopyAttrsQuantizedConv2D, AlwaysRewrite}); 472 rinfo_.push_back( 473 {csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize, 474 mkl_op_registry::GetMklOpName( 475 csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize), 476 CopyAttrsQuantizedConv2D, AlwaysRewrite}); 477 rinfo_.push_back({csinfo_.quantized_max_pool, 478 mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool), 479 CopyAttrsQuantizedPooling, AlwaysRewrite}); 480 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu, 481 mkl_op_registry::GetMklOpName( 482 csinfo_.quantized_conv2d_with_bias_sum_and_relu), 483 CopyAttrsQuantizedConv2D, AlwaysRewrite}); 484 rinfo_.push_back( 485 {csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize, 486 mkl_op_registry::GetMklOpName( 487 csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize), 488 CopyAttrsQuantizedConv2D, AlwaysRewrite}); 489 rinfo_.push_back( 490 {csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize, 491 mkl_op_registry::GetMklOpName( 492 csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize), 493 CopyAttrsQuantizedConv2D, AlwaysRewrite}); 494 rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), 495 CopyAttrsDataType, AlwaysRewrite}); 496 rinfo_.push_back({csinfo_.relu_grad, 497 mkl_op_registry::GetMklOpName(csinfo_.relu_grad), 498 CopyAttrsDataType, AlwaysRewrite}); 499 rinfo_.push_back({csinfo_.relu6, 500 mkl_op_registry::GetMklOpName(csinfo_.relu6), 501 CopyAttrsDataType, AlwaysRewrite}); 502 rinfo_.push_back({csinfo_.relu6_grad, 503 mkl_op_registry::GetMklOpName(csinfo_.relu6_grad), 504 CopyAttrsDataType, AlwaysRewrite}); 505 rinfo_.push_back({csinfo_.requantize, 506 mkl_op_registry::GetMklOpName(csinfo_.requantize), 507 CopyAttrsRequantize, AlwaysRewrite}); 508 /* 509 rinfo_.push_back({csinfo_.tanh, 510 mkl_op_registry::GetMklOpName(csinfo_.tanh), 511 CopyAttrsDataType, AlwaysRewrite}); 512 rinfo_.push_back({csinfo_.tanh_grad, 513 mkl_op_registry::GetMklOpName(csinfo_.tanh_grad), 514 CopyAttrsDataType, AlwaysRewrite}); 515 */ 516 rinfo_.push_back({csinfo_.reshape, 517 mkl_op_registry::GetMklOpName(csinfo_.reshape), 518 CopyAttrsReshape, AlwaysRewrite}); 519 rinfo_.push_back({csinfo_.slice, 520 mkl_op_registry::GetMklOpName(csinfo_.slice), 521 CopyAttrsSlice, AlwaysRewrite}); 522 rinfo_.push_back({csinfo_.softmax, 523 mkl_op_registry::GetMklOpName(csinfo_.softmax), 524 CopyAttrsDataType, AlwaysRewrite}); 525 526 rinfo_.push_back({csinfo_.squared_difference, 527 mkl_op_registry::GetMklOpName(csinfo_.squared_difference), 528 CopyAttrsDataType, AlwaysRewrite}); 529 rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub), 530 CopyAttrsDataType, AlwaysRewrite}); 531 532 // Add info about which ops to add workspace edge to and the slots. 533 wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3}); 534 wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3}); 535 wsinfo_.push_back( 536 {csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3}); 537 538 // Add a rule for merging nodes 539 minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add, 540 csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd}); 541 542 minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad, 543 csinfo_.conv2d_grad_filter_with_bias, 544 GetConv2DBackpropFilterOrBiasAddGrad}); 545 // Merge Pad and Conv2d, only if the pad op is "Pad" 546 // Doesn't merge if pad op is "PadV2" or "MirrorPad" 547 minfo_.push_back( 548 {csinfo_.pad, csinfo_.conv2d, csinfo_.pad_with_conv2d, GetPadOrConv2D}); 549 550 minfo_.push_back({csinfo_.pad, csinfo_.fused_conv2d, 551 csinfo_.pad_with_fused_conv2d, GetPadOrFusedConv2D}); 552 553 // The fusion patterns in "finfo_" that show up first will get applied 554 // first, for example, graph "A->B->C-D" and finfo_ is {A->B->C to ABC, 555 // A->B->C->D to ABCD}, since the first gets applied first, the final 556 // graph will be ABC->D. 557 558 // 559 // Add rules to fuse sequences such as "Transpose (NCHW -> NHWC) + Conv2D 560 // (NHWC) + Transpose (NHWC-> 561 // NCHW)" into "Conv2D (NCHW)". Such patterns occur frequently in Keras. 562 // Note: we use the term "merge" to combine (exactly) 2 nodes into one, 563 // while "fusion" is for 3+ nodes situation. 564 // 565 566 // Transpose + Conv2d + Transpose: 567 std::vector<int> transpose_to_nhwc = {NCHW::dim::N, NCHW::dim::H, 568 NCHW::dim::W, NCHW::dim::C}; 569 std::vector<int> transpose_to_nchw = {NHWC::dim::N, NHWC::dim::C, 570 NHWC::dim::H, NHWC::dim::W}; 571 auto CheckForTransposeToNHWC = 572 std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_nhwc); 573 auto CheckForConv2dOp = 574 std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv2d); 575 auto CheckForTransposeToNCHW = 576 std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_nchw); 577 auto FuseConv2D = 578 std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1, 579 std::placeholders::_2, std::placeholders::_3, "NCHW"); 580 finfo_.push_back( 581 {"transpose-elimination for Conv2D", 582 {CheckForTransposeToNHWC, CheckForConv2dOp, CheckForTransposeToNCHW}, 583 // CheckForMklOp 584 FuseConv2D, 585 CopyAttrsConv}); 586 } 587 588 // Standard interface to run pass 589 Status Run(const GraphOptimizationPassOptions& options); 590 591 // Helper function which does most of heavy lifting for rewriting 592 // Mkl nodes to propagate Mkl tensor as additional output 593 // 594 // Extracts common functionality between Run public interface and 595 // test interface. 596 // 597 // @return true, if and only if graph is mutated; false otherwise. 598 bool RunPass(std::unique_ptr<Graph>* g); 599 600 /// Structure to specify the name of an original node, its new name after 601 /// rewrite, the number of inputs to the original node, the function to 602 /// be used to copy attributes for the op, and the rule (if any) which 603 /// must hold for rewriting the node 604 typedef struct { 605 string name; // Original name of op of the node in the graph 606 string new_name; // New name of the op of the node in the graph 607 // A function handler to copy attributes from an old node to a new node. 608 std::function<void(const Node*, NodeBuilder*, bool)> copy_attrs; 609 // A rule under which to rewrite this node 610 std::function<bool(const Node*)> rewrite_rule; 611 } RewriteInfo; 612 613 /// Structure to specify a forward op, a backward op, and the slot numbers 614 /// in the forward and backward ops where we will add a workspace edge. 615 typedef struct { 616 string fwd_op; // Name of a forward op in the graph 617 string bwd_op; // Name of a backward op in the graph 618 int fwd_slot; // Output slot in the forward op node where actual 619 // output tensor resides 620 int bwd_slot; // Input slot in the backward op node where actual 621 // input tensor resides 622 int ws_fwd_slot; // Output slot in the forward op node where workspace 623 // edge is added 624 int ws_bwd_slot; // Input slot in the backward op node where workspace 625 // edge is added 626 } WorkSpaceInfo; 627 628 /// Structure to specify information used in node merge of 2 operators 629 typedef struct { 630 string op1; // Node string for one operator. 631 string op2; // Node string for second operator. 632 string new_node; // Name of the node after merge 633 // Function that enables user of the node merger to specify how to find 634 // second operator given the first operator. 635 std::function<Node*(const Node*)> get_node_to_be_merged; 636 } MergeInfo; 637 638 // Structure to specify information used in node fusion of 3+ operators 639 typedef struct { 640 std::string pattern_name; // Name to describe this pattern, such as 641 // "Transpose_Mklop_Transpose". 642 std::vector<std::function<bool(const Node*)> > 643 node_checkers; // Extra restriction checker for these ops 644 std::function<Status( 645 std::unique_ptr<Graph>*, std::vector<Node*>&, 646 std::function<void(const Node*, NodeBuilder* nb, bool)>)> 647 fuse_func; 648 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs; 649 } FusionInfo; 650 651 // 652 // Dimension indices for 2D tensor. 653 // 654 struct NCHW { 655 enum dim { N = 0, C = 1, H = 2, W = 3 }; 656 }; 657 658 struct NHWC { 659 enum dim { N = 0, H = 1, W = 2, C = 3 }; 660 }; 661 662 // 663 // dimension indices for 3D tensor. 664 // 665 struct NCDHW { 666 enum dim { N = 0, C = 1, D = 2, H = 3, W = 4 }; 667 }; 668 669 struct NDHWC { 670 enum dim { N = 0, D = 1, H = 2, W = 3, C = 4 }; 671 }; 672 673 /// Structure to store all constant strings 674 /// NOTE: names are alphabetically sorted. 675 typedef struct { 676 string addn; 677 string add; 678 string avg_pool; 679 string avg_pool_grad; 680 string avg_pool3d; 681 string avg_pool3d_grad; 682 string bias_add; 683 string bias_add_grad; 684 string concat; 685 string concatv2; 686 string conv2d; 687 string conv2d_with_bias; 688 string conv2d_grad_input; 689 string conv2d_grad_filter; 690 string conv2d_grad_filter_with_bias; 691 string conv3d; 692 string conv3d_grad_input; 693 string conv3d_grad_filter; 694 string depthwise_conv2d; 695 string depthwise_conv2d_grad_input; 696 string depthwise_conv2d_grad_filter; 697 string fused_batch_norm; 698 string fused_batch_norm_grad; 699 string fused_conv2d; 700 string identity; 701 string leakyrelu; 702 string leakyrelu_grad; 703 string lrn; 704 string lrn_grad; 705 string matmul; 706 string max_pool; 707 string max_pool_grad; 708 string max_pool3d; 709 string max_pool3d_grad; 710 string maximum; 711 string mkl_conv2d; 712 string mkl_conv2d_grad_input; 713 string mkl_conv2d_grad_filter; 714 string mkl_conv2d_grad_filter_with_bias; 715 string mkl_conv2d_with_bias; 716 string mkl_depthwise_conv2d_grad_input; 717 string mkl_depthwise_conv2d_grad_filter; 718 string mkl_fused_conv2d; 719 string mkl_pad_with_conv2d; 720 string mkl_pad_with_fused_conv2d; 721 string mul; 722 string pad; 723 string pad_with_conv2d; 724 string pad_with_fused_conv2d; 725 string quantized_avg_pool; 726 string quantized_conv2d; 727 string quantized_conv2d_with_requantize; 728 string quantized_conv2d_with_bias; 729 string quantized_conv2d_with_bias_and_requantize; 730 string quantized_conv2d_and_relu; 731 string quantized_conv2d_and_relu_and_requantize; 732 string quantized_conv2d_with_bias_and_relu; 733 string quantized_conv2d_with_bias_and_relu_and_requantize; 734 string quantized_concatv2; 735 string quantized_max_pool; 736 string quantized_conv2d_with_bias_sum_and_relu; 737 string quantized_conv2d_with_bias_sum_and_relu_and_requantize; 738 string quant_conv2d_with_bias_signed_sum_and_relu_and_requantize; 739 string relu; 740 string relu_grad; 741 string relu6; 742 string relu6_grad; 743 string requantize; 744 string tanh; 745 string tanh_grad; 746 string transpose; 747 string reshape; 748 string slice; 749 string softmax; 750 string split; 751 string squared_difference; 752 string sub; 753 } ConstStringsInfo; 754 755 private: 756 /// Maintain info about nodes to rewrite 757 std::vector<RewriteInfo> rinfo_; 758 759 /// Maintain info about nodes to add workspace edge 760 std::vector<WorkSpaceInfo> wsinfo_; 761 762 /// Maintain info about nodes to be merged 763 std::vector<MergeInfo> minfo_; 764 765 /// Maintain info about nodes to be fused 766 std::vector<FusionInfo> finfo_; 767 768 /// Maintain structure of constant strings 769 static ConstStringsInfo csinfo_; 770 771 private: 772 // Is OpDef::ArgDef a list type? It could be N * T or list(type). 773 // Refer to opdef.proto for details of list type. 774 inline bool ArgIsList(const OpDef::ArgDef& arg) const { 775 return !arg.type_list_attr().empty() || !arg.number_attr().empty(); 776 } 777 778 // Get length of a list in 'n' if 'arg' is of list type. Refer to 779 // description of ArgIsList for definition of list type. 780 inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) { 781 CHECK_EQ(ArgIsList(arg), true); 782 int N = 0; 783 const string attr_name = !arg.type_list_attr().empty() 784 ? arg.type_list_attr() 785 : arg.number_attr(); 786 if (!arg.type_list_attr().empty()) { 787 std::vector<DataType> value; 788 TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value)); 789 N = value.size(); 790 } else { 791 TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N)); 792 } 793 return N; 794 } 795 796 // Can op represented by node 'n' run on DEVICE_CPU? 797 // Op can run on CPU with MKL if the runtime assigned device or the 798 // user requested device contains device CPU, or both are empty. 799 bool CanOpRunOnCPUDevice(const Node* n) { 800 bool result = true; 801 string reason; 802 803 // Substring that should be checked for in device name for CPU device. 804 const char* const kCPUDeviceSubStr = "CPU"; 805 806 // If Op has been specifically assigned to a non-CPU device, then No. 807 if (!n->assigned_device_name().empty() && 808 !str_util::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) { 809 result = false; 810 reason = "Op has been assigned a runtime device that is not CPU."; 811 } 812 813 // If user has specifically assigned this op to a non-CPU device, then No. 814 if (!n->def().device().empty() && 815 !str_util::StrContains(n->def().device(), kCPUDeviceSubStr)) { 816 result = false; 817 reason = "User has assigned a device that is not CPU."; 818 } 819 820 if (result == false) { 821 VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node " 822 << n->type_string() << ", reason: " << reason; 823 } 824 825 // Otherwise Yes. 826 return result; 827 } 828 829 // Return a node that can be merged with input node 'n' 830 // 831 // @return pointer to the node if we can find such a 832 // node. Otherwise, it returns nullptr. 833 Node* CheckForNodeMerge(const Node* n) const; 834 835 // Merge node 'm' with node 'n'. 836 // Currently, we merge (1) Conv2D with BiasAdd, and (2) BiasAddGrad with 837 // Conv2DBackpropFilter. 838 // 839 // Input nodes m and n may be deleted if the call to 840 // this function is successful. Attempt to use the pointers 841 // after the call to function may result in undefined behaviors. 842 // 843 // @input g - input graph, m - graph node, n - graph node to be merged with m 844 // @return Status::OK(), if merging is successful and supported. 845 // Returns appropriate Status error code otherwise. 846 // Graph is updated in case nodes are merged. Otherwise, it is 847 // not updated. 848 Status MergeNode(std::unique_ptr<Graph>* g, Node* m, Node* n); 849 850 // Helper function to merge different nodes 851 Status MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, Node* m, Node* n); 852 Status MergePadWithConv2D(std::unique_ptr<Graph>* g, Node* m, Node* n); 853 Status MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph>* g, 854 Node* m, Node* n); 855 856 // Find BiasAdd or Conv2D node that can be merged with input node 'm'. 857 // If input 'm' is BiasAdd, then check if there exists Conv2D node that can be 858 // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd 859 // node that can be merged with 'm'. 860 static Node* GetConv2DOrBiasAdd(const Node* m) { 861 CHECK_NOTNULL(m); 862 Node* n = nullptr; 863 864 DataType T_m; 865 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m)); 866 867 // Don't try to merge if datatype is not DT_FLOAT 868 if (T_m != DT_FLOAT) return n; 869 870 if (m->type_string() == csinfo_.bias_add) { 871 // If a is BiasAdd, then Conv2D is 0th input of BiasAdd. 872 TF_CHECK_OK(m->input_node(0, &n)); 873 } else { 874 CHECK_EQ(m->type_string(), csinfo_.conv2d); 875 // Go over all output edges and search for BiasAdd Node. 876 // 0th input of BiasAdd is Conv2D. 877 for (const Edge* e : m->out_edges()) { 878 if (!e->IsControlEdge() && 879 e->dst()->type_string() == csinfo_.bias_add && 880 e->dst_input() == 0) { 881 n = e->dst(); 882 break; 883 } 884 } 885 } 886 887 if (n == nullptr) { 888 VLOG(1) << "MklLayoutRewritePass: Could not find matching " 889 << "Conv2D and BiasAdd node for merging. Input node: " 890 << m->DebugString(); 891 } 892 893 return n; 894 } 895 896 // Find Pad or Conv2D node that can be merged with input node 'm'. 897 // If input 'm' is Pad, then check if there exists Conv2D node that can be 898 // merged with 'm'. If input 'm' is Conv2D, then check if there exists Pad 899 // node that can be merged with 'm'. 900 static Node* GetPadOrConv2D(const Node* m) { 901 DCHECK(m); 902 Node* n = nullptr; 903 904 DataType T_m; 905 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m)); 906 907 // Don't try to merge if datatype is not DT_FLOAT 908 if (T_m != DT_FLOAT) return n; 909 910 const Node* conv_node; 911 if (m->type_string() == csinfo_.pad) { 912 // If m is Pad, then Conv2D is the output of Pad. 913 for (const Edge* e : m->out_edges()) { 914 if (!e->IsControlEdge() && e->dst()->type_string() == csinfo_.conv2d) { 915 n = e->dst(); 916 conv_node = n; 917 break; 918 } 919 } 920 } else { 921 DCHECK_EQ(m->type_string(), csinfo_.conv2d); 922 // If m is conv2D, Go over all input edges 923 // and search for Pad Node. 924 for (const Edge* e : m->in_edges()) { 925 if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) { 926 n = e->src(); 927 conv_node = m; 928 break; 929 } 930 } 931 } 932 // Check if only VALID type of padding is used 933 // or not. 934 if (n != nullptr) { 935 string padding; 936 TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding)); 937 if (padding != "VALID") 938 // Then do not merge. 939 // Only VALID type of padding in conv op can be 940 // merged with Pad op. 941 n = nullptr; 942 } else { 943 VLOG(1) << "MklLayoutRewritePass: Could not find matching " 944 << "Pad and Conv2D node for merging. Input node: " 945 << m->DebugString(); 946 } 947 948 return n; 949 } 950 951 // Find Pad or _FusedConv2D node that can be merged with input node 'm'. 952 // If input 'm' is Pad, then check if there exists _FusedConv2D node that can 953 // be merged with 'm'. If input 'm' is _FusedConv2D, then check if there 954 // exists Pad node that can be merged with 'm'. 955 static Node* GetPadOrFusedConv2D(const Node* m) { 956 DCHECK(m); 957 Node* n = nullptr; 958 959 const Node* conv_node; 960 if (m->type_string() == csinfo_.pad) { 961 // If m is Pad, then _FusedConv2D is the output of Pad. 962 for (const Edge* e : m->out_edges()) { 963 if (!e->IsControlEdge() && 964 e->dst()->type_string() == csinfo_.fused_conv2d) { 965 n = e->dst(); 966 conv_node = n; 967 break; 968 } 969 } 970 } else { 971 DCHECK_EQ(m->type_string(), csinfo_.fused_conv2d); 972 // If m is _FusedConv2D, Go over all input edges 973 // and search for Pad node. 974 for (const Edge* e : m->in_edges()) { 975 if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) { 976 n = e->src(); 977 conv_node = m; 978 break; 979 } 980 } 981 } 982 // Check if only VALID type of padding is used or not. 983 if (n != nullptr) { 984 string padding; 985 TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding)); 986 if (padding != "VALID") { 987 // Then do not merge. 988 n = nullptr; 989 VLOG(1) << "MklLayoutRewritePass: Could match Pad and _FusedConv2D " 990 << "nodes but cannot merge them. Only conv ops with padding " 991 << "type VALID can be merged with Pad op Input node: " 992 << m->DebugString(); 993 } 994 } else { 995 VLOG(1) << "MklLayoutRewritePass: Could not find matching " 996 << "Pad and _FusedConv2D node for merging. Input node: " 997 << m->DebugString(); 998 } 999 1000 return n; 1001 } 1002 1003 // Find Conv2DBackpropFilter or BiasAddGrad node that can be merged with input 1004 // node 'm'. If input 'm' is Conv2DBackpropFilter, then check if there exists 1005 // BiasAddGrad node that can be merged with 'm'. If input 'm' is BiasAddGrad, 1006 // then check if there exists Conv2DBackpropFilter node that can be merged 1007 // with 'm'. 1008 // 1009 // Graph that will allow us to connect Conv2DBackpropFilter with BiasAddGrad 1010 // would look like: 1011 // 1012 // _ = Conv2DBackpropFilter(F, _, G) 1013 // _ = BiasAddGrad(G) 1014 // 1015 // So 1st input of BiasAddGrad connects with 3rd input of 1016 // Conv2DBackpropFilter and vice versa. 1017 static Node* GetConv2DBackpropFilterOrBiasAddGrad(const Node* m) { 1018 CHECK_NOTNULL(m); 1019 Node* n = nullptr; 1020 1021 DataType T_m; 1022 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m)); 1023 1024 // Don't try to merge if datatype is not DT_FLOAT 1025 if (T_m != DT_FLOAT) return n; 1026 1027 if (m->type_string() == csinfo_.bias_add_grad) { 1028 // Get 1st input 'g' of BiasAddGrad. 1029 Node* g = nullptr; 1030 TF_CHECK_OK(m->input_node(0, &g)); 1031 // Now traverse all outgoing edges from g that have destination node as 1032 // Conv2DBackpropFilter. 1033 for (const Edge* e : g->out_edges()) { 1034 if (!e->IsControlEdge() && 1035 e->dst()->type_string() == csinfo_.conv2d_grad_filter && 1036 e->dst_input() == 2 /* 3rd input of BackpropFilter */) { 1037 n = e->dst(); 1038 break; 1039 } 1040 } 1041 } else { 1042 CHECK_EQ(m->type_string(), csinfo_.conv2d_grad_filter); 1043 // Get 3rd input 'g' of Conv2DBackpropFilter. 1044 Node* g = nullptr; 1045 TF_CHECK_OK(m->input_node(2, &g)); 1046 // Now traverse all outgoing edges from g that have destination node as 1047 // BiasAddGrad. 1048 for (const Edge* e : g->out_edges()) { 1049 if (!e->IsControlEdge() && 1050 e->dst()->type_string() == csinfo_.bias_add_grad && 1051 e->dst_input() == 0 /* 1st input of BiasAddGrad */) { 1052 n = e->dst(); 1053 break; 1054 } 1055 } 1056 } 1057 1058 if (n == nullptr) { 1059 VLOG(1) << "MklLayoutRewritePass: Could not find matching " 1060 << "Conv2DBackpropFilter and BiasAddGrad node for merging. " 1061 << "Input node: " << m->DebugString(); 1062 } 1063 return n; 1064 } 1065 1066 // Return a node that can be fused with input node 'n' 1067 // 1068 // @return tuple. If we can find such nodes, the first 1069 // element of the tuple is a true. Otherwise, it's false. 1070 std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo> 1071 CheckForNodeFusion(Node* n) const; 1072 1073 // Fuse nodes in the vector "nodes" 1074 Status FuseNode(std::unique_ptr<Graph>* g, std::vector<Node*>& nodes, 1075 const MklLayoutRewritePass::FusionInfo fi); 1076 1077 // Fuse tranpose(to "NHWC") + mklop("NHWC") + transpose(to "NCHW") into 1078 // mklop("NCHW"). 1079 // Here "mklop" can be any MKL-DNN supported op, such as Conv2D. 1080 static Status FuseTransposeMklOpTranspose( 1081 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes, 1082 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs, 1083 string data_format); 1084 1085 static bool CheckForTranspose(const Node* node, std::vector<int> perm) { 1086 // Check if node's type is "Transpose" 1087 if (node->type_string() != "Transpose") return false; 1088 1089 // If "Transpose" has multiple output data edges, also don't fuse it. 1090 if (node->num_outputs() > 1 || node->out_edges().size() > 1) return false; 1091 1092 // Check if has out control edge. If true, this is a training graph. 1093 // Currently we focus on inference and do no fusion in training. 1094 // Note: this constraint will eventually be removed, if we enabled this 1095 // fusion for training 1096 // in the future. 1097 for (const Edge* e : node->out_edges()) { 1098 if (e->IsControlEdge()) { 1099 return false; 1100 } 1101 } 1102 1103 // If "Transpose" has input control edges, don't fuse on it. 1104 for (const Edge* e : node->in_edges()) { 1105 if (e->IsControlEdge()) { 1106 return false; 1107 } 1108 } 1109 1110 // We compared the tensor containing the permutation order ("perm_node") 1111 // with our desired order ("perm"). If they're exactly match, this check 1112 // succeed and returns true. 1113 for (const Edge* e : node->in_edges()) { 1114 if (!e->IsControlEdge()) { 1115 const Node* perm_node = e->src(); 1116 1117 const int kPermTensorIndex = 1; 1118 if (perm_node->type_string() == "Const" && 1119 e->dst_input() == kPermTensorIndex) { 1120 // we find the "perm" node, now try to retrieve its value. 1121 const TensorProto* proto = nullptr; 1122 TF_CHECK_OK(GetNodeAttr(perm_node->def(), "value", &proto)); 1123 1124 DataType type; 1125 GetNodeAttr(perm_node->def(), "dtype", &type); 1126 1127 // Here we directly access to the "tensor_content", rather than 1128 // "int_val". This is because we find "int_val" is 1129 // not set properly under some circumstances. 1130 if (type == DT_INT32) { 1131 const int type_size = 4; 1132 const int* tensor_content = 1133 reinterpret_cast<const int*>(proto->tensor_content().c_str()); 1134 const int tensor_content_size = 1135 proto->tensor_content().size() / type_size; 1136 1137 std::vector<int> perm_value(tensor_content, 1138 tensor_content + tensor_content_size); 1139 1140 return perm_value == perm; 1141 } else if (type == DT_INT64) { 1142 const int type_size = 8; 1143 const long* tensor_content = 1144 reinterpret_cast<const long*>(proto->tensor_content().c_str()); 1145 const int tensor_content_size = 1146 proto->tensor_content().size() / type_size; 1147 1148 std::vector<long> perm_value(tensor_content, 1149 tensor_content + tensor_content_size); 1150 std::vector<long> long_perm(perm.cbegin(), perm.cend()); 1151 1152 return perm_value == long_perm; 1153 } 1154 return false; 1155 } 1156 } 1157 } 1158 return false; 1159 } 1160 1161 static bool CheckForMklOp(const Node* node, string name = "") { 1162 if (node == nullptr) return false; 1163 1164 if (!name.empty() && node->type_string() != name) { 1165 return false; 1166 } 1167 1168 // if mklop has multiple outputs, don't fuse it. 1169 if (node->num_outputs() > 1) return false; 1170 1171 if (node->out_edges().size() > 1) return false; 1172 1173 DataType T; 1174 TF_CHECK_OK(GetNodeAttr(node->def(), "T", &T)); 1175 return mkl_op_registry::IsMklOp( 1176 mkl_op_registry::GetMklOpName(node->type_string()), T); 1177 } 1178 1179 // Check if the node 'n' has any applicable rewrite rule 1180 // We check for 2 scenarios for rewrite. 1181 // 1182 // @return RewriteInfo* for the applicable rewrite rule 1183 const RewriteInfo* CheckForNodeRewrite(const Node* n) const; 1184 const RewriteInfo* CheckForQuantizedNodeRewrite(const Node* n) const; 1185 1186 // Default rewrite rule to be used in scenario 1 for rewrite. 1187 // @return - true (since we want to always rewrite) 1188 static bool AlwaysRewrite(const Node* n) { return true; } 1189 1190 // Check if we are performing pooling on depth or batch. If it is, then we 1191 // do not rewrite MaxPool node to Mkl version. 1192 // @return - true (if it is not a depth/batch wise pooling case); 1193 // false otherwise. 1194 static bool NonDepthBatchWisePoolRewrite(const Node* n) { 1195 CHECK_NOTNULL(n); 1196 1197 string data_format_str; 1198 TensorFormat data_format; 1199 std::vector<int32> ksize, strides; 1200 CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); 1201 CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); 1202 CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true); 1203 CHECK_EQ(FormatFromString(data_format_str, &data_format), true); 1204 1205 // Condition that specifies non-batch-wise and non-depth-wise pooling. 1206 if (GetTensorDim(ksize, data_format, 'N') == 1 && 1207 GetTensorDim(strides, data_format, 'N') == 1 && 1208 GetTensorDim(ksize, data_format, 'C') == 1 && 1209 GetTensorDim(strides, data_format, 'C') == 1) { 1210 return true; 1211 } 1212 1213 return false; 1214 } 1215 1216 // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized 1217 // path. The unoptimized path is slow. Thus we dont rewrite the node 1218 // and use default Eigen. But for depth_radius=2, MKL DNN optimized 1219 // path is taken, i.e., eigen node is rewritten by MKl DNN node. 1220 static bool LrnRewrite(const Node* n) { 1221 CHECK_NOTNULL(n); 1222 1223 int depth_radius; 1224 CHECK_EQ(GetNodeAttr(n->def(), "depth_radius", &depth_radius).ok(), true); 1225 1226 // if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN 1227 // and use eigen node instead 1228 if (depth_radius == 2) { 1229 return true; 1230 } 1231 VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which" 1232 << "case is not optimized by Intel MKL, thus using Eigen op" 1233 << "for LRN "; 1234 1235 return false; 1236 } 1237 1238 static bool LrnGradRewrite(const Node* n) { 1239 CHECK_NOTNULL(n); 1240 bool do_rewrite = false; 1241 1242 for (const Edge* e : n->in_edges()) { 1243 // Rewrite only if there is corresponding LRN, i.e workspace is available 1244 if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 && 1245 e->src()->type_string() == 1246 mkl_op_registry::GetMklOpName(csinfo_.lrn) && 1247 e->src_output() == 0) { 1248 do_rewrite = true; 1249 break; 1250 } 1251 } 1252 return do_rewrite; 1253 } 1254 1255 // MKL-DNN's LeakyRelu(feature) = feature (if feature > 0), or 1256 // feature * alpha (otherwise), 1257 // while TensorFlow's LeakyRelu(feature) = max(feature, feature * alpha). 1258 // These two algorithms are not consistent when alpha > 1, 1259 // so we only rewrite LeakyRelu to MKL OP when alpha <= 1. 1260 static bool LeakyReluRewrite(const Node* n) { 1261 DCHECK(n); 1262 1263 float alpha; 1264 bool has_attr = GetNodeAttr(n->def(), "alpha", &alpha).ok(); 1265 DCHECK(has_attr); 1266 1267 // If the alpha of LeakyRelu is less than 1, rewrite the node. 1268 // Otherwise eigen node is used instead. 1269 if (alpha <= 1) { 1270 return true; 1271 } 1272 VLOG(1) << "LeakyReluRewrite: The model sets alpha is greater than 1 " 1273 << "which case is not optimized by Intel MKL, thus using Eigen op" 1274 << "for LeakyRelu "; 1275 1276 return false; 1277 } 1278 1279 static bool MaxpoolGradRewrite(const Node* n) { 1280 CHECK_NOTNULL(n); 1281 bool do_rewrite = false; 1282 for (const Edge* e : n->in_edges()) { 1283 // Rewrite only if there is corresponding Maxpool, i.e workspace is 1284 // available 1285 if (e->dst()->type_string() == csinfo_.max_pool_grad && 1286 e->dst_input() == 1 && 1287 e->src()->type_string() == 1288 mkl_op_registry::GetMklOpName(csinfo_.max_pool) && 1289 e->src_output() == 0) { 1290 do_rewrite = true; 1291 break; 1292 } 1293 } 1294 return do_rewrite; 1295 } 1296 1297 static bool AddNRewrite(const Node* n) { 1298 CHECK_NOTNULL(n); 1299 1300 int num; 1301 CHECK_EQ(GetNodeAttr(n->def(), "N", &num).ok(), true); 1302 1303 // Condition that specifies non-batch-wise and non-depth-wise pooling. 1304 if (num == 2) { 1305 return true; 1306 } 1307 1308 return false; 1309 } 1310 1311 static bool FusedConv2DRewrite(const Node* n) { 1312 // MKL DNN currently doesn't support all fusions that grappler fuses 1313 // together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if 1314 // it includes those we support. 1315 DataType T; 1316 if (!GetNodeAttr(n->def(), "T", &T).ok() || 1317 !mkl_op_registry::IsMklOp(csinfo_.mkl_fused_conv2d, T)) { 1318 return false; 1319 } 1320 1321 std::vector<string> fused_ops; 1322 TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops)); 1323 return (fused_ops == std::vector<string>{"BiasAdd"} || 1324 fused_ops == std::vector<string>{"Relu"} || 1325 fused_ops == std::vector<string>{"BiasAdd", "Relu"}); 1326 } 1327 1328 // Rewrites input node to a new node specified by its matching rewrite info. 1329 // 1330 // Method first searches matching rewrite info for input node and then 1331 // uses that info to rewrite. 1332 // 1333 // Input node may be deleted in case of rewrite. Attempt to use the node 1334 // after the call can result in undefined behaviors. 1335 // 1336 // @input g - input graph, n - Node to be rewritten, 1337 // ri - matching rewriteinfo 1338 // @return Status::OK(), if the input node is rewritten; 1339 // Returns appropriate Status error code otherwise. 1340 // Graph is updated in case the input node is rewritten. 1341 // Otherwise, it is not updated. 1342 Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri); 1343 1344 // Get nodes that will feed a list of TF tensors to the new 1345 // node that we are constructing. 1346 // 1347 // @input g - input graph, 1348 // @input inputs - inputs to old node that we are using for constructing 1349 // new inputs, 1350 // @input input_idx - the index in the 'inputs' vector pointing to the 1351 // current input that we have processed so far 1352 // @output input_idx - index will be incremented by the number of nodes 1353 // from 'inputs' that are processed 1354 // @input list_length - The expected length of list of TF tensors 1355 // @output output_nodes - the list of new nodes creating TF tensors 1356 // 1357 // @return None 1358 void GetNodesProducingTFTensorList( 1359 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, 1360 int* input_idx, int list_length, 1361 std::vector<NodeBuilder::NodeOut>* output_nodes); 1362 1363 // Get nodes that will feed a list of Mkl tensors to the new 1364 // node that we are constructing. 1365 // 1366 // @input g - input graph, 1367 // @input orig_node - Original node that we are rewriting 1368 // @input inputs - inputs to old node that we are using for constructing 1369 // new inputs, 1370 // @input input_idx - the index in the 'inputs' vector pointing to the 1371 // current input that we have processed so far 1372 // @output input_idx - index will be incremented by the number of nodes 1373 // from 'inputs' that are processed 1374 // @input list_length - The expected length of list of Mkl tensors 1375 // @output output_nodes - the list of new nodes creating Mkl tensors 1376 // 1377 // @return None 1378 void GetNodesProducingMklTensorList( 1379 std::unique_ptr<Graph>* g, Node* orig_node, 1380 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, 1381 int* input_idx, int list_length, 1382 std::vector<NodeBuilder::NodeOut>* output_nodes); 1383 1384 // Get a node that will feed an Mkl tensor to the new 1385 // node that we are constructing. The output node could be (1) 'n' 1386 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor 1387 // if 'n' is not an Mkl layer. 1388 // 1389 // @input g - input graph, 1390 // @input orig_node - Original node that we are rewriting, 1391 // @input n - Node based on which we are creating Mkl node, 1392 // @input n_output_slot - the output slot of node 'n' 1393 // which is feeding to the node that we are constructing 1394 // @output mkl_node - the new node that will feed Mkl tensor 1395 // @output mkl_node_output_slot - the slot number of mkl_node that 1396 // will feed the tensor 1397 // @return None 1398 void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node, 1399 Node* n, int n_output_slot, Node** mkl_node, 1400 int* mkl_node_output_slot); 1401 1402 // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' 1403 // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are 1404 // set up in contiguous fashion. 'workspace_tensors' carry graph nodes 1405 // producing workspace edges if 'are_workspace_tensors_available' is true. 1406 // Otherwise, 'workspace_tensors' is empty vector. 1407 // 1408 // For details, refer to 'Ordering of inputs after rewriting' section in the 1409 // documentation above. 1410 // 1411 // Returns Status::OK() if setting up inputs is successful, otherwise 1412 // returns appropriate status code. 1413 int SetUpContiguousInputs( 1414 std::unique_ptr<Graph>* g, 1415 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, 1416 NodeBuilder* nb, Node* old_node, 1417 std::vector<NodeBuilder::NodeOut>* workspace_tensors, 1418 bool are_workspace_tensors_available); 1419 1420 // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' 1421 // in graph 'g'. Original node is input in 'orig_node'. 1422 // 1423 // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors' 1424 // section in the documentation above. 1425 // 1426 // Returns Status::OK() if setting up inputs is successful, otherwise 1427 // returns appropriate status code. 1428 Status SetUpInputs(std::unique_ptr<Graph>* g, 1429 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, 1430 NodeBuilder* nb, Node* orig_node); 1431 1432 // Add workspace edge on the input or output side of Node 'orig_node' by using 1433 // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate 1434 // adding workspace edge then do not add it. Workspace Tensorflow and Mkl 1435 // tensors, if they need to be added, will be set into these tensors. 1436 // If we set workspace tensors, then are_ws_tensors_added should be true. 1437 void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orig_node, 1438 NodeBuilder* nb, 1439 std::vector<NodeBuilder::NodeOut>* ws_tensors, 1440 bool* are_ws_tensors_added); 1441 1442 // Helper function used by FixMklMetaDataEdges. Fixes the metadata edge 1443 // pointed by 'e_metadata' corresponding to the data edge 'e_data' in graph 1444 // 'g'. Returns true is fixup was done; otherwise, it returns false. 1445 bool FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, const Edge* e_data, 1446 const Edge* e_metadata); 1447 1448 // Are the input Mkl metadata edges for node 'n' in graph 'g' correctly 1449 // connected? If not, then fix them. This is needed because a graph may have 1450 // some input Mkl metadata edges incorrectly setup after node merge and 1451 // rewrite passes. This could happen because GetReversePostOrder function may 1452 // not provide topologically sorted order if a graph contains cycles. The 1453 // function returns true if at least one Mkl metadata edge for node 'n' was 1454 // fixed. Otherwise, it returns false. 1455 // 1456 // Example: 1457 // 1458 // X = MklConv2D(_, _, _) 1459 // Y = MklConv2DWithBias(_, _, _, _, _, _) 1460 // Z = MklAdd(X, Y, DummyMklTensor, Y:1) 1461 // 1462 // For a graph such as shown above, note that 3rd argument of MklAdd contains 1463 // DummyMklTensor. Actually, it should be getting the Mkl metadata from 1464 // MklConv2D op (specifically, X:2). This incorrect plumbing could be possible 1465 // (although rare) if the Mkl NodeMerge + NodeRewrite passes visit Z before X 1466 // (possible if X, Y, Z are part of a loop.) This function fixes the Mkl 1467 // metadata edges only - it does not rewrite nodes nor does it modify the Mkl 1468 // data edges (1st and 2nd arguments of MklAdd). 1469 bool FixMklMetaDataEdges(std::unique_ptr<Graph>* g, Node* n); 1470 1471 // Functions specific to operators to copy attributes 1472 // We need operator-specific function to copy attributes because the framework 1473 // does not provide any generic function for it. 1474 // NOTE: names are alphabetically sorted. 1475 static void CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb, 1476 bool change_format = false); 1477 static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb, 1478 bool change_format = false); 1479 static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb, 1480 bool change_format = false); 1481 static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb, 1482 bool change_format = false); 1483 static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb, 1484 bool change_format = false); 1485 static void CopyAttrsConv2DDepthwise(const Node* orig_node, NodeBuilder* nb, 1486 bool change_format = false); 1487 static void CopyAttrsConv2DDepthwiseCheckConstFilter( 1488 const Node* orig_node, NodeBuilder* nb, bool change_format = false); 1489 static void CopyAttrsConvCheckConstFilter(const Node* orig_node, 1490 NodeBuilder* nb, 1491 bool change_format = false); 1492 static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb, 1493 bool change_format = false); 1494 static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb, 1495 bool change_format = false); 1496 static void CopyAttrsLeakyRelu(const Node* orig_node, NodeBuilder* nb, 1497 bool change_format = false); 1498 static void CopyAttrsFusedConv2D(const Node* orig_node, NodeBuilder* nb, 1499 bool change_format = false); 1500 static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb, 1501 bool change_format = false); 1502 static void CopyAttrsPadWithConv2D(const Node* orig_node, NodeBuilder* nb, 1503 bool change_format = false); 1504 static void CopyAttrsPadWithFusedConv2D(const Node* orig_node, 1505 NodeBuilder* nb, 1506 bool change_format = false); 1507 static void CopyAttrsFromPadAndConv2D(const Node* orig_node1, 1508 const Node* orig_node2, NodeBuilder* nb, 1509 bool change_format = false); 1510 static void CopyAttrsFromPadAndFusedConv2D(const Node* orig_node1, 1511 const Node* orig_node2, 1512 NodeBuilder* nb, 1513 bool change_format = false); 1514 static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb, 1515 bool change_format = false); 1516 static void CopyAttrsQuantizedPooling(const Node* orig_node, NodeBuilder* nb, 1517 bool change_format = false); 1518 static void CopyAttrsQuantizedConv2D(const Node* orig_node, NodeBuilder* nb, 1519 bool change_format = false); 1520 static void CopyAttrsQuantizedConcat(const Node* orig_node, NodeBuilder* nb, 1521 bool change_format = false); 1522 static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb, 1523 bool change_format = false); 1524 static void CopyAttrsRequantize(const Node* orig_node, NodeBuilder* nb, 1525 bool change_format = false); 1526 static void CopyAttrsSlice(const Node* orig_node, NodeBuilder* nb, 1527 bool change_format = false); 1528 static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb, 1529 bool change_format = false); 1530 static void CopyFormatAttrsConv(const Node* orig_node, NodeBuilder* nb, 1531 const std::vector<int32>& strides, 1532 const std::vector<int32>& dilations, 1533 bool change_format = false); 1534 1535 // Generate a graph node in graph 'g' representing a dummy Mkl tensor node, 1536 // using node for original node 'orig_node' and return it in '*out'. 1537 // TODO(nhasabni) We should move this to mkl_util.h 1538 void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out, 1539 Node* orig_node); 1540 void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out, 1541 Node* orig_node); 1542 }; 1543 1544 MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_; 1545 1546 // We register Mkl rewrite pass for phase 1 in post partitioning group. 1547 // We register it here so that we get a complete picture of all users of Mkl 1548 // nodes. Do not change the ordering of the Mkl passes. 1549 const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup = 1550 OptimizationPassRegistry::POST_PARTITIONING; 1551 #ifdef ENABLE_MKL 1552 REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass); 1553 #endif // ENABLE_MKL 1554 1555 ////////////////////////////////////////////////////////////////////////// 1556 // Helper functions for creating new node 1557 ////////////////////////////////////////////////////////////////////////// 1558 1559 static void FillInputs(const Node* n, 1560 gtl::InlinedVector<Node*, 4>* control_edges, 1561 gtl::InlinedVector<std::pair<Node*, int>, 4>* in) { 1562 control_edges->clear(); 1563 for (const Edge* e : n->in_edges()) { 1564 if (e->IsControlEdge()) { 1565 control_edges->push_back(e->src()); 1566 } else { 1567 (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output()); 1568 } 1569 } 1570 std::sort(control_edges->begin(), control_edges->end()); 1571 if (n->op_def().is_commutative()) { 1572 // For commutative inputs, we sort the input by the input Node* 1573 // to get a canonical ordering (so that add(a,b) and add(b, a) will 1574 // hash to the same value if is_commutative is true for 'add'). 1575 std::sort(in->begin(), in->end()); 1576 } 1577 } 1578 1579 void MklLayoutRewritePass::GetNodesProducingTFTensorList( 1580 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, 1581 int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { 1582 CHECK_LT(*input_idx, inputs.size()); 1583 CHECK_GT(list_length, 0); 1584 CHECK_NOTNULL(output_nodes); 1585 output_nodes->reserve(list_length); 1586 1587 while (list_length != 0) { 1588 CHECK_GT(list_length, 0); 1589 CHECK_LT(*input_idx, inputs.size()); 1590 Node* n = inputs[*input_idx].first; 1591 int slot = inputs[*input_idx].second; 1592 // If input node 'n' is just producing a single tensor at 1593 // output slot 'slot' then we just add that single node. 1594 output_nodes->push_back(NodeBuilder::NodeOut(n, slot)); 1595 (*input_idx)++; 1596 list_length--; 1597 } 1598 } 1599 1600 // TODO(nhasabni) We should move this to mkl_util.h. 1601 void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, 1602 Node** out, Node* orig_node) { 1603 // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent 1604 // dummy Mkl tensor. 8 = 2*size_t. 1605 const DataType dt = DataTypeToEnum<uint8>::v(); 1606 TensorProto proto; 1607 proto.set_dtype(dt); 1608 uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; 1609 proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 8)); 1610 TensorShape dummy_shape({8}); 1611 dummy_shape.AsProto(proto.mutable_tensor_shape()); 1612 TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") 1613 .Attr("value", proto) 1614 .Attr("dtype", dt) 1615 .Device(orig_node->def().device()) // We place this node on 1616 // the same device as the 1617 // device of the original 1618 // node. 1619 .Finalize(&**g, out)); 1620 CHECK_NOTNULL(*out); // Make sure we got a valid object before using it 1621 1622 // If number of inputs to the original node is > 0, then we add 1623 // control dependency between 1st input (index 0) of the original node and 1624 // the dummy Mkl node. This is needed because control-flow ops such as Enter, 1625 // Merge, etc, require frame_name of the dummy Mkl node to be same as the 1626 // rewritten node. Adding control edge between 1st input of the original node 1627 // and the dummy Mkl node ensures that the dummy node is in the same frame 1628 // as the original node. Choosing 1st input is not necessary - any input of 1629 // the original node is fine because all the inputs of a node are always in 1630 // the same frame. 1631 if (orig_node->num_inputs() > 0) { 1632 Node* orig_input0 = nullptr; 1633 TF_CHECK_OK( 1634 orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); 1635 // Allow duplicate while adding control edge as it would fail (return 1636 // NULL) if we try to add duplicate edge. 1637 CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true)); 1638 } 1639 1640 (*out)->set_assigned_device_name(orig_node->assigned_device_name()); 1641 } 1642 1643 void MklLayoutRewritePass::GetNodesProducingMklTensorList( 1644 std::unique_ptr<Graph>* g, Node* orig_node, 1645 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, 1646 int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { 1647 CHECK_LT(*input_idx, inputs.size()); 1648 CHECK_GT(list_length, 0); 1649 CHECK_NOTNULL(output_nodes); 1650 output_nodes->reserve(list_length); 1651 1652 while (list_length != 0) { 1653 CHECK_GT(list_length, 0); 1654 CHECK_LT(*input_idx, inputs.size()); 1655 Node* n = inputs[*input_idx].first; 1656 int slot = inputs[*input_idx].second; 1657 // If 'n' is producing a single tensor, then create a single Mkl tensor 1658 // node. 1659 Node* mkl_node = nullptr; 1660 int mkl_node_output_slot = 0; 1661 GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, 1662 &mkl_node_output_slot); 1663 output_nodes->push_back( 1664 NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); 1665 (*input_idx)++; 1666 list_length--; 1667 } 1668 } 1669 1670 // Get an input node that will feed Mkl tensor to the new 1671 // node that we are constructing. An input node could be (1) 'n' 1672 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor 1673 // if 'n' is not an Mkl layer. 1674 void MklLayoutRewritePass::GetNodeProducingMklTensor( 1675 std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot, 1676 Node** mkl_node, int* mkl_node_output_slot) { 1677 CHECK_NOTNULL(n); 1678 CHECK_NOTNULL(mkl_node); 1679 CHECK_NOTNULL(mkl_node_output_slot); 1680 1681 // If this is an MKL op, then it will create extra output for MKL layout. 1682 DataType T; 1683 if (GetNodeAttr(n->def(), "T", &T).ok() && 1684 mkl_op_registry::IsMklOp(n->type_string(), T)) { 1685 // If this is an MKL op, then it will generate an edge that will receive 1686 // Mkl tensor from a node. 1687 // output slot number for Mkl tensor would be N+slot number of TensorFlow 1688 // tensor, where N is total number of TensorFlow tensors. 1689 *mkl_node = n; 1690 *mkl_node_output_slot = 1691 GetTensorMetaDataIndex(n_output_slot, n->num_outputs()); 1692 } else { 1693 // If we have not visited the node and rewritten it, then we need 1694 // to create a dummy node that will feed a dummy Mkl tensor to this node. 1695 // DummyMklTensor node has no input and generates only 1 output 1696 // (dummy Mkl tensor) as output slot number 0. 1697 GetDummyMklTensorNode(g, mkl_node, orig_node); 1698 CHECK_NOTNULL(*mkl_node); 1699 *mkl_node_output_slot = 0; 1700 } 1701 } 1702 1703 int MklLayoutRewritePass::SetUpContiguousInputs( 1704 std::unique_ptr<Graph>* g, 1705 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, 1706 NodeBuilder* nb, Node* old_node, 1707 std::vector<NodeBuilder::NodeOut>* workspace_tensors, 1708 bool are_workspace_tensors_available) { 1709 CHECK_NOTNULL(workspace_tensors); 1710 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); 1711 1712 // TODO(nhasabni): Temporary solution to connect filter input of 1713 // BackpropInput with the converted filter from Conv2D. 1714 bool do_connect_conv2d_backprop_input_filter = false; 1715 Node* conv2d_node = nullptr; 1716 // Filter node is 2nd input (slot index 1) of Conv2D. 1717 int kConv2DFilterInputSlotIdx = 1; 1718 int kConv2DBackpropInputFilterInputSlotIdx = 1; 1719 int kConv2DFilterOutputSlotIdx = 1; 1720 if (old_node->type_string() == csinfo_.conv2d_grad_input) { 1721 // We need to find Conv2D node from Conv2DBackpropInput. 1722 // For that let's first find filter node that is 2nd input (slot 1) 1723 // of BackpropInput. 1724 Node* filter_node = nullptr; 1725 TF_CHECK_OK(old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx, 1726 &filter_node)); 1727 CHECK_NOTNULL(filter_node); 1728 1729 // Now check which nodes receive from filter_node. Filter feeds as 1730 // 2nd input (slot 1) of _MklConv2D, _MklConv2DWithBias, and 1731 // _MklFusedConv2D. 1732 for (const Edge* e : filter_node->out_edges()) { 1733 if ((e->dst()->type_string() == csinfo_.mkl_conv2d || 1734 e->dst()->type_string() == csinfo_.mkl_pad_with_conv2d || 1735 e->dst()->type_string() == csinfo_.mkl_pad_with_fused_conv2d || 1736 e->dst()->type_string() == csinfo_.mkl_conv2d_with_bias || 1737 e->dst()->type_string() == csinfo_.mkl_fused_conv2d) && 1738 e->dst_input() == kConv2DFilterInputSlotIdx 1739 /* filter is 2nd input of Conv2D and _MklConv2D. */) { 1740 if (conv2d_node != nullptr) { 1741 VLOG(1) << "MklLayoutRewritePass: unusual case of same filter" 1742 << " feeding multiple Conv2D nodes: " 1743 << filter_node->DebugString(); 1744 // We will not connect filter input of Conv2DBackpropInput 1745 // to be safe here. 1746 do_connect_conv2d_backprop_input_filter = false; 1747 break; 1748 } else { 1749 conv2d_node = e->dst(); 1750 do_connect_conv2d_backprop_input_filter = true; 1751 } 1752 } 1753 } 1754 } 1755 1756 // Number of input slots to original op 1757 // Input slots are represented by .Input() calls in REGISTER_OP. 1758 int old_node_input_slots = old_node->op_def().input_arg_size(); 1759 // Actual number of inputs can be greater than or equal to number 1760 // of Input slots because inputs of type list could be unfolded. 1761 CHECK_GE(old_node_inputs.size(), old_node_input_slots); 1762 int nn_slot_idx = 0; // slot index for inputs of new node 1763 1764 // Let's copy all inputs (TF tensors) of original node to new node. 1765 int iidx = 0; 1766 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { 1767 // An input slot could be a single tensor or a list. We need 1768 // to handle this case accordingly. 1769 CHECK_LT(iidx, old_node_inputs.size()); 1770 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); 1771 if (ArgIsList(arg)) { 1772 std::vector<NodeBuilder::NodeOut> new_node_inputs; 1773 int N = GetTensorListLength(arg, old_node); 1774 GetNodesProducingTFTensorList(old_node_inputs, &iidx, N, 1775 &new_node_inputs); 1776 nb->Input(new_node_inputs); 1777 nn_slot_idx++; 1778 } else { 1779 // Special case for connecting filter input of Conv2DBackpropInput 1780 if (do_connect_conv2d_backprop_input_filter && 1781 iidx == kConv2DBackpropInputFilterInputSlotIdx) { 1782 nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx); 1783 } else { 1784 nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second); 1785 } 1786 iidx++; 1787 nn_slot_idx++; 1788 } 1789 } 1790 1791 // If workspace tensors are available for this op and we are using 1792 // contiguous ordering then we need to add Tensorflow tensor for 1793 // workspace here because Tensorflow tensor for workspace is the 1794 // last tensor in the list of Tensorflow tensors. 1795 if (are_workspace_tensors_available) { 1796 CHECK_EQ(workspace_tensors->size(), 2); 1797 // Tensorflow tensor 1798 nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index); 1799 nn_slot_idx++; 1800 } 1801 1802 // Let's now setup all Mkl inputs to a new node. 1803 // Number of Mkl inputs must be same as number of TF inputs. 1804 iidx = 0; 1805 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { 1806 // An input slot could be a single tensor or a list. We need 1807 // to handle this case accordingly. 1808 CHECK_LT(iidx, old_node_inputs.size()); 1809 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); 1810 if (ArgIsList(arg)) { 1811 std::vector<NodeBuilder::NodeOut> new_node_inputs; 1812 int N = GetTensorListLength(arg, old_node); 1813 GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N, 1814 &new_node_inputs); 1815 nb->Input(new_node_inputs); 1816 nn_slot_idx++; 1817 } else { 1818 Node* mkl_node = nullptr; 1819 int mkl_node_output_slot = 0; 1820 // Special case for connecting filter input of Conv2DBackpropInput 1821 if (do_connect_conv2d_backprop_input_filter && 1822 iidx == kConv2DBackpropInputFilterInputSlotIdx) { 1823 GetNodeProducingMklTensor(g, old_node, conv2d_node, 1824 kConv2DFilterOutputSlotIdx, &mkl_node, 1825 &mkl_node_output_slot); 1826 } else { 1827 GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first, 1828 old_node_inputs[iidx].second, &mkl_node, 1829 &mkl_node_output_slot); 1830 } 1831 nb->Input(mkl_node, mkl_node_output_slot); 1832 iidx++; 1833 nn_slot_idx++; 1834 } 1835 } 1836 1837 // If workspace tensors are available for this op and we are using 1838 // contiguous ordering then we need to add Mkl tensor for 1839 // workspace here because Mkl tensor for workspace is the 1840 // last tensor in the list of Mkl tensors. 1841 if (are_workspace_tensors_available) { 1842 CHECK_EQ(workspace_tensors->size(), 2); 1843 // Mkl tensor 1844 nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index); 1845 nn_slot_idx++; 1846 } 1847 1848 return nn_slot_idx; 1849 } 1850 1851 Status MklLayoutRewritePass::SetUpInputs( 1852 std::unique_ptr<Graph>* g, 1853 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, 1854 NodeBuilder* nb, Node* old_node) { 1855 // Let's check if we need to add workspace tensors for this node. 1856 // We add workspace edge only for MaxPool, LRN and BatchNorm. 1857 std::vector<NodeBuilder::NodeOut> workspace_tensors; 1858 bool are_workspace_tensors_available = false; 1859 1860 // Avoid workspace check for QuantizedConv2D and the fused 1861 // Ops as they don't have attribute: "T". 1862 std::vector<string> quant_ops{ 1863 "QuantizedConv2D", 1864 "QuantizedConv2DWithBias", 1865 "QuantizedConv2DAndRelu", 1866 "QuantizedConv2DWithBiasAndRelu", 1867 "QuantizedConv2DWithBiasSumAndRelu", 1868 "QuantizedConv2DAndRequantize", 1869 "QuantizedConv2DWithBiasAndRequantize", 1870 "QuantizedConv2DAndReluAndRequantize", 1871 "QuantizedConv2DWithBiasAndReluAndRequantize", 1872 "QuantizedConv2DWithBiasSumAndReluAndRequantize", 1873 "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize"}; 1874 bool should_check_workspace = 1875 std::find(std::begin(quant_ops), std::end(quant_ops), 1876 old_node->type_string()) == std::end(quant_ops); 1877 if (should_check_workspace) 1878 AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors, 1879 &are_workspace_tensors_available); 1880 1881 int new_node_input_slots = 0; 1882 if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { 1883 // TODO(nhasabni): implement this function just for same of completion. 1884 // We do not use interleaved ordering right now. 1885 return Status( 1886 error::Code::UNIMPLEMENTED, 1887 "Interleaved ordering of tensors is currently not supported."); 1888 } else { 1889 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); 1890 new_node_input_slots = SetUpContiguousInputs( 1891 g, old_node_inputs, nb, old_node, &workspace_tensors, 1892 are_workspace_tensors_available); 1893 } 1894 1895 // Sanity check 1896 int old_node_input_slots = old_node->op_def().input_arg_size(); 1897 if (!are_workspace_tensors_available) { 1898 // If we are not adding workspace tensors for this op, then the total 1899 // number of input slots to the new node _must_ be 2 times the number 1900 // of input slots to the original node: N original Tensorflow tensors and 1901 // N for Mkl tensors corresponding to each Tensorflow tensors. 1902 CHECK_EQ(new_node_input_slots, old_node_input_slots * 2); 1903 } else { 1904 // If we are adding workspace tensors for this op, then the total 1905 // The total number of input slots to new node _must_ be 2 times the number 1906 // of input slots to the original node: N original Tensorflow tensors and 1907 // N for Mkl tensors corresponding to each Tensorflow tensors plus 2 1908 // (for workspace Tensorflow tensor and workspace Mkl tensor). 1909 CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2); 1910 } 1911 1912 return Status::OK(); 1913 } 1914 1915 ////////////////////////////////////////////////////////////////////////// 1916 // Helper functions related to workspace pass 1917 ////////////////////////////////////////////////////////////////////////// 1918 1919 // TODO(nhasabni) We should move this to mkl_util.h. 1920 void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( 1921 std::unique_ptr<Graph>* g, Node** out, Node* orig_node) { 1922 // We use uint8 tensor of shape 8 with content {0,0,0,0,0,0,0,0} to represent 1923 // workspace tensor. 1924 GetDummyMklTensorNode(g, out, orig_node); 1925 } 1926 1927 void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( 1928 std::unique_ptr<Graph>* g, Node* orig_node, NodeBuilder* nb, 1929 std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) { 1930 bool workspace_edge_added = false; // Default initializer 1931 CHECK_NOTNULL(are_ws_tensors_added); 1932 *are_ws_tensors_added = false; // Default initializer 1933 1934 DataType T; 1935 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 1936 for (auto ws : wsinfo_) { 1937 if (orig_node->type_string() == ws.fwd_op && 1938 mkl_op_registry::IsMklOp( 1939 mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { 1940 // If this op is a fwd op, then we need to check if there is an 1941 // edge from this node's fwd_slot to bwdop's bwd_slot. If there is 1942 // an edge, then we just add an attribute on this node for setting 1943 // workspace_passed to true. We don't add actual workspace edge 1944 // in this node. Actual workspace edge gets added in the backward 1945 // op for this node. 1946 for (const Edge* e : orig_node->out_edges()) { 1947 if (e->src_output() == ws.fwd_slot && 1948 e->dst()->type_string() == ws.bwd_op && 1949 e->dst_input() == ws.bwd_slot) { 1950 nb->Attr("workspace_enabled", true); 1951 VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " 1952 << orig_node->type_string(); 1953 workspace_edge_added = true; 1954 // We found the edge that we were looking for, so break. 1955 break; 1956 } 1957 } 1958 1959 if (!workspace_edge_added) { 1960 // If we are here, then we did not find backward operator for this 1961 // node. 1962 nb->Attr("workspace_enabled", false); 1963 } 1964 } else if (orig_node->type_string() == ws.bwd_op && 1965 mkl_op_registry::IsMklOp( 1966 mkl_op_registry::GetMklOpName(orig_node->type_string()), 1967 T)) { 1968 // If this op is a bwd op, then we need to add workspace edge and 1969 // it's Mkl tensor edge between its corresponding fwd op and this 1970 // op. Corresponding fwd op is specified in 'fwd_op' field of 1971 // workspace info. fwd_slot and bwd_slot in workspace info specify 1972 // an edge between which slots connect forward and backward op. 1973 // Once all these criteria match, we add a workspace edge between 1974 // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is 1975 // determined by interleaved/contiguous ordering. Function 1976 // DataIndexToMetaDataIndex tells us the location of Mkl tensor 1977 // from the location of the Tensorflow tensor. 1978 for (const Edge* e : orig_node->in_edges()) { 1979 if (e->src_output() == ws.fwd_slot && 1980 // We would have rewritten the forward op, so we need to use 1981 // GetMklOpName call to get its Mkl name. 1982 e->src()->type_string() == 1983 mkl_op_registry::GetMklOpName(ws.fwd_op) && 1984 e->dst_input() == ws.bwd_slot) { 1985 nb->Attr("workspace_enabled", true); 1986 CHECK_NOTNULL(ws_tensors); 1987 // Add workspace edge between fwd op and bwd op. 1988 ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot)); 1989 // Add Mkl tensor edge for workspace edge between fwd op and bwd op. 1990 ws_tensors->push_back(NodeBuilder::NodeOut( 1991 e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot, 1992 e->src()->num_outputs()))); 1993 *are_ws_tensors_added = true; 1994 // In terms of input ordering, we add these calls to add Input 1995 // here because workspace edge (and its Mkl tensor) is the last 1996 // edge in the fwdop and bwdop. So all inputs before workspace 1997 // tensor have been added by SetUpInputs function. 1998 VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " 1999 << orig_node->type_string(); 2000 workspace_edge_added = true; 2001 // We found the edge that we were looking for, so break. 2002 break; 2003 } 2004 } 2005 2006 // If we are here means we did not find fwd op that feeds to this 2007 // bwd op. So in this case, we need to generate dummy tensors for 2008 // workspace input and Mkl tensor for workspace, and set 2009 // workspace_enabled to false. 2010 if (!workspace_edge_added) { 2011 nb->Attr("workspace_enabled", false); 2012 Node* dmt_ws = nullptr; // Dummy tensor for workspace 2013 Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace 2014 GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node); 2015 GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node); 2016 CHECK_NOTNULL(dmt_ws); 2017 CHECK_NOTNULL(dmt_mkl_ws); 2018 CHECK_NOTNULL(ws_tensors); 2019 // We add dummy tensor as workspace tensor. 2020 ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0)); 2021 // We add dummy tensor as Mkl tensor for workspace tensor. 2022 ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0)); 2023 *are_ws_tensors_added = true; 2024 VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for " 2025 << orig_node->type_string(); 2026 } 2027 } else { 2028 // If this node does not match any workspace info, then we do not 2029 // do anything special for workspace propagation for it. 2030 } 2031 } 2032 } 2033 2034 ////////////////////////////////////////////////////////////////////////// 2035 // Op-specific functions to copy attributes from old node to new node 2036 ////////////////////////////////////////////////////////////////////////// 2037 2038 void MklLayoutRewritePass::CopyAttrsConvCheckConstFilter(const Node* orig_node, 2039 NodeBuilder* nb, 2040 bool change_format) { 2041 DataType T; 2042 string padding; 2043 std::vector<int32> strides; 2044 std::vector<int32> dilations; 2045 2046 // Get all attributes from old node. 2047 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2048 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); 2049 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations)); 2050 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); 2051 2052 Node* filter_node = nullptr; 2053 orig_node->input_node(1, &filter_node); 2054 2055 // Add attributes to new node. 2056 nb->Attr("T", T); 2057 nb->Attr("padding", padding); 2058 nb->Attr("is_filter_const", filter_node->IsConstant()); 2059 2060 // Add attributes related to `data_format`. 2061 CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format); 2062 } 2063 2064 void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node, NodeBuilder* nb, 2065 bool change_format) { 2066 DataType T; 2067 string padding; 2068 std::vector<int32> strides; 2069 std::vector<int32> dilations; 2070 2071 // Get all attributes from old node. 2072 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2073 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); 2074 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations)); 2075 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); 2076 2077 // Add attributes to new node. 2078 nb->Attr("T", T); 2079 nb->Attr("padding", padding); 2080 2081 // Add attributes related to `data_format`. 2082 CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format); 2083 } 2084 2085 // Used in rinfo when replacing __MklDummyPadWithConv2D by _MklPadWithConv2D 2086 void MklLayoutRewritePass::CopyAttrsPadWithConv2D(const Node* orig_node, 2087 NodeBuilder* nb, 2088 bool change_format) { 2089 DataType Tpaddings; 2090 DataType T; 2091 string data_format; 2092 string padding; 2093 std::vector<int32> strides; 2094 std::vector<int32> dilations; 2095 bool use_cudnn_on_gpu; 2096 2097 // Get all attributes from old node. 2098 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2099 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); 2100 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations)); 2101 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); 2102 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 2103 TF_CHECK_OK( 2104 GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu)); 2105 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tpaddings", &Tpaddings)); 2106 2107 Node* filter_node = nullptr; 2108 orig_node->input_node(1, &filter_node); 2109 2110 // Add attributes to new node. 2111 nb->Attr("T", T); 2112 nb->Attr("strides", strides); 2113 nb->Attr("dilations", dilations); 2114 nb->Attr("padding", padding); 2115 nb->Attr("is_filter_const", filter_node->IsConstant()); 2116 nb->Attr("data_format", data_format); 2117 nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu); 2118 nb->Attr("Tpaddings", Tpaddings); 2119 } 2120 2121 void MklLayoutRewritePass::CopyAttrsPadWithFusedConv2D(const Node* orig_node, 2122 NodeBuilder* nb, 2123 bool change_format) { 2124 DataType Tpaddings; 2125 2126 CopyAttrsFusedConv2D(orig_node, nb, change_format); 2127 2128 // Get attributes from old node. 2129 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tpaddings", &Tpaddings)); 2130 // Check if filter is a constant. 2131 Node* filter_node = nullptr; 2132 orig_node->input_node(1, &filter_node); 2133 2134 // Add attributes to new node. 2135 nb->Attr("Tpaddings", Tpaddings); 2136 nb->Attr("is_filter_const", filter_node->IsConstant()); 2137 } 2138 2139 // Used with MergePadWithConv2D 2140 void MklLayoutRewritePass::CopyAttrsFromPadAndConv2D(const Node* orig_node1, 2141 const Node* orig_node2, 2142 NodeBuilder* nb, 2143 bool change_format) { 2144 DataType Tpaddings; 2145 DataType T; 2146 string data_format; 2147 string padding; 2148 std::vector<int32> strides; 2149 std::vector<int32> dilations; 2150 bool use_cudnn_on_gpu; 2151 2152 // Get all attributes from old node 1. 2153 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "T", &T)); 2154 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "strides", &strides)); 2155 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "dilations", &dilations)); 2156 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "padding", &padding)); 2157 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "data_format", &data_format)); 2158 TF_CHECK_OK( 2159 GetNodeAttr(orig_node1->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu)); 2160 // Get all attributes from old node 2. 2161 TF_CHECK_OK(GetNodeAttr(orig_node2->def(), "Tpaddings", &Tpaddings)); 2162 2163 // Add attributes to new node. 2164 nb->Attr("T", T); 2165 nb->Attr("strides", strides); 2166 nb->Attr("dilations", dilations); 2167 nb->Attr("padding", padding); 2168 nb->Attr("data_format", data_format); 2169 nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu); 2170 nb->Attr("Tpaddings", Tpaddings); 2171 } 2172 2173 void MklLayoutRewritePass::CopyAttrsFromPadAndFusedConv2D( 2174 const Node* fused_conv2d, const Node* pad, NodeBuilder* nb, 2175 bool change_format) { 2176 DataType T; 2177 int num_args; 2178 string data_format; 2179 string padding; 2180 std::vector<int32> strides; 2181 std::vector<int32> dilations; 2182 float epsilon; 2183 std::vector<string> fused_ops; 2184 DataType Tpaddings; 2185 2186 // Get all attributes from old node. 2187 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "T", &T)); 2188 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "num_args", &num_args)); 2189 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "strides", &strides)); 2190 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "padding", &padding)); 2191 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "data_format", &data_format)); 2192 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "dilations", &dilations)); 2193 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "fused_ops", &fused_ops)); 2194 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "epsilon", &epsilon)); 2195 TF_CHECK_OK(GetNodeAttr(pad->def(), "Tpaddings", &Tpaddings)); 2196 2197 // Add attributes to new node. 2198 nb->Attr("T", T); 2199 nb->Attr("num_args", num_args); 2200 nb->Attr("strides", strides); 2201 nb->Attr("padding", padding); 2202 nb->Attr("data_format", data_format); 2203 nb->Attr("dilations", dilations); 2204 nb->Attr("epsilon", epsilon); 2205 nb->Attr("Tpaddings", Tpaddings); 2206 nb->Attr("fused_ops", fused_ops); 2207 } 2208 2209 void MklLayoutRewritePass::CopyAttrsConv2DDepthwise(const Node* orig_node, 2210 NodeBuilder* nb, 2211 bool change_format) { 2212 DataType T; 2213 string data_format; 2214 string padding; 2215 std::vector<int32> strides; 2216 std::vector<int32> dilations; 2217 2218 // Get all attributes from old node. 2219 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2220 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); 2221 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations)); 2222 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); 2223 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 2224 2225 // Add attributes to new node. 2226 nb->Attr("T", T); 2227 nb->Attr("strides", strides); 2228 nb->Attr("dilations", dilations); 2229 nb->Attr("padding", padding); 2230 nb->Attr("data_format", data_format); 2231 } 2232 2233 void MklLayoutRewritePass::CopyAttrsConv2DDepthwiseCheckConstFilter( 2234 const Node* orig_node, NodeBuilder* nb, bool change_format) { 2235 DataType T; 2236 string data_format; 2237 string padding; 2238 std::vector<int32> strides; 2239 std::vector<int32> dilations; 2240 2241 // Get all attributes from old node. 2242 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2243 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); 2244 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations)); 2245 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); 2246 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 2247 2248 Node* filter_node = nullptr; 2249 orig_node->input_node(1, &filter_node); 2250 2251 // Add attributes to new node. 2252 nb->Attr("T", T); 2253 nb->Attr("strides", strides); 2254 nb->Attr("dilations", dilations); 2255 nb->Attr("padding", padding); 2256 nb->Attr("is_filter_const", filter_node->IsConstant()); 2257 nb->Attr("data_format", data_format); 2258 } 2259 2260 void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb, 2261 bool change_format) { 2262 DataType T; 2263 int N; 2264 2265 // Get all attributes from old node. 2266 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2267 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); 2268 2269 // Add attributes to new node. 2270 nb->Attr("T", T); 2271 nb->Attr("N", N); 2272 } 2273 2274 void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node, 2275 NodeBuilder* nb, 2276 bool change_format) { 2277 DataType T; 2278 string data_format; 2279 std::vector<int32> strides; 2280 2281 // Get all attributes from old node. 2282 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2283 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); 2284 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 2285 2286 // Add attributes to new node. 2287 nb->Attr("T", T); 2288 nb->Attr("strides", strides); 2289 nb->Attr("data_format", data_format); 2290 } 2291 2292 void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb, 2293 bool change_format) { 2294 DataType T; 2295 int depth_radius; 2296 float bias; 2297 float alpha; 2298 float beta; 2299 2300 // Get all attributes from old node. 2301 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2302 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "depth_radius", &depth_radius)); 2303 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "bias", &bias)); 2304 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha)); 2305 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "beta", &beta)); 2306 2307 // Add attributes to new node. 2308 nb->Attr("T", T); 2309 nb->Attr("depth_radius", depth_radius); 2310 nb->Attr("bias", bias); 2311 nb->Attr("alpha", alpha); 2312 nb->Attr("beta", beta); 2313 } 2314 2315 void MklLayoutRewritePass::CopyAttrsLeakyRelu(const Node* orig_node, 2316 NodeBuilder* nb, 2317 bool change_format) { 2318 DataType T; 2319 float alpha; 2320 2321 // Get all attributes from old node. 2322 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2323 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha)); 2324 2325 // Add attributes to new node. 2326 nb->Attr("T", T); 2327 nb->Attr("alpha", alpha); 2328 } 2329 2330 void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node, 2331 NodeBuilder* nb, 2332 bool change_format) { 2333 DataType T; 2334 string data_format; 2335 string padding; 2336 std::vector<int32> ksize, strides; 2337 2338 // Get all attributes from old node. 2339 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2340 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize)); 2341 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); 2342 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); 2343 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 2344 2345 // Add attributes to new node. 2346 nb->Attr("T", T); 2347 nb->Attr("ksize", ksize); 2348 nb->Attr("strides", strides); 2349 nb->Attr("padding", padding); 2350 nb->Attr("data_format", data_format); 2351 } 2352 2353 void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node, 2354 NodeBuilder* nb, 2355 bool change_format) { 2356 DataType T; 2357 2358 // Get all attributes from old node. 2359 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2360 2361 // Add attributes to new node. 2362 nb->Attr("T", T); 2363 } 2364 2365 void MklLayoutRewritePass::CopyAttrsQuantizedPooling(const Node* orig_node, 2366 NodeBuilder* nb, 2367 bool change_format) { 2368 DataType T; 2369 string padding; 2370 std::vector<int32> ksize, strides; 2371 2372 // Get all attributes from old node. 2373 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2374 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize)); 2375 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); 2376 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); 2377 2378 // Add attributes to new node. 2379 nb->Attr("T", T); 2380 nb->Attr("ksize", ksize); 2381 nb->Attr("strides", strides); 2382 nb->Attr("padding", padding); 2383 } 2384 2385 void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node, 2386 NodeBuilder* nb, 2387 bool change_format) { 2388 DataType Tinput, Tfilter, out_type; 2389 string padding; 2390 string data_format("NHWC"); 2391 std::vector<int32> strides, dilations, padding_list; 2392 bool has_padding_list = HasNodeAttr(orig_node->def(), "padding_list"); 2393 2394 // Get all attributes from old node. 2395 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput)); 2396 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tfilter", &Tfilter)); 2397 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "out_type", &out_type)); 2398 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); 2399 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); 2400 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations)); 2401 if (has_padding_list) { 2402 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding_list", &padding_list)); 2403 } 2404 2405 Node* filter_node = nullptr; 2406 orig_node->input_node(1, &filter_node); 2407 2408 // Add attributes to new node. 2409 nb->Attr("Tinput", Tinput); 2410 nb->Attr("Tfilter", Tfilter); 2411 nb->Attr("out_type", out_type); 2412 nb->Attr("padding", padding); 2413 nb->Attr("is_filter_const", filter_node->IsConstant()); 2414 nb->Attr("strides", strides); 2415 nb->Attr("dilations", dilations); 2416 nb->Attr("T", out_type); // added "T" for facilitating MklToTf conversion. 2417 nb->Attr("data_format", data_format); 2418 if (has_padding_list) { 2419 nb->Attr("padding_list", padding_list); 2420 } 2421 2422 // Requantization attr Tbias. 2423 DataType Tbias; 2424 Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias); 2425 if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias); 2426 } 2427 2428 void MklLayoutRewritePass::CopyAttrsRequantize(const Node* orig_node, 2429 NodeBuilder* nb, 2430 bool change_format) { 2431 DataType Tinput, out_type; 2432 2433 // Get all attributes from old node. 2434 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput)); 2435 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "out_type", &out_type)); 2436 2437 // Add attributes to new node. 2438 nb->Attr("Tinput", Tinput); 2439 nb->Attr("out_type", out_type); 2440 } 2441 2442 void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, 2443 NodeBuilder* nb, 2444 bool change_format) { 2445 DataType T; 2446 DataType Tshape; 2447 2448 // Get all attributes from old node. 2449 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2450 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape)); 2451 2452 // Add attributes to new node. 2453 nb->Attr("T", T); 2454 nb->Attr("Tshape", Tshape); 2455 } 2456 2457 void MklLayoutRewritePass::CopyAttrsSlice(const Node* orig_node, 2458 NodeBuilder* nb, bool change_format) { 2459 DataType T; 2460 DataType Index; 2461 2462 // Get all attributes from old node. 2463 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2464 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Index", &Index)); 2465 2466 // Add attributes to new node. 2467 nb->Attr("T", T); 2468 nb->Attr("Index", Index); 2469 } 2470 2471 void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node, 2472 NodeBuilder* nb, bool change_format) { 2473 DataType T; 2474 string data_format; 2475 int num_split; 2476 2477 // Get all attributes from old node. 2478 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2479 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_split", &num_split)); 2480 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 2481 2482 // Add attributes to new node. 2483 nb->Attr("T", T); 2484 nb->Attr("num_split", num_split); 2485 nb->Attr("data_format", data_format); 2486 } 2487 2488 void MklLayoutRewritePass::CopyFormatAttrsConv( 2489 const Node* orig_node, NodeBuilder* nb, const std::vector<int32>& strides, 2490 const std::vector<int32>& dilations, bool change_format) { 2491 string data_format; 2492 2493 if (!change_format) { 2494 nb->Attr("strides", strides); 2495 nb->Attr("dilations", dilations); 2496 2497 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 2498 nb->Attr("data_format", data_format); 2499 } else { 2500 std::vector<int32> new_strides; 2501 std::vector<int32> new_dilations; 2502 if (strides.size() == 5) { 2503 // `strides` and `dilations` also need to be changed according to 2504 // `data_format`. In this case, from `NDHWC` to `NCDHW`. 2505 new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C], 2506 strides[NDHWC::dim::D], strides[NDHWC::dim::H], 2507 strides[NDHWC::dim::W]}; 2508 2509 new_dilations = {dilations[NDHWC::dim::N], dilations[NDHWC::dim::C], 2510 dilations[NDHWC::dim::D], dilations[NDHWC::dim::H], 2511 dilations[NDHWC::dim::W]}; 2512 } else { 2513 // `strides` and `dilations` also need to be changed according to 2514 // `data_format`. In this case, from `NHWC` to `NCHW`. 2515 2516 new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C], 2517 strides[NHWC::dim::H], strides[NHWC::dim::W]}; 2518 2519 new_dilations = {dilations[NHWC::dim::N], dilations[NHWC::dim::C], 2520 dilations[NHWC::dim::H], dilations[NHWC::dim::W]}; 2521 } 2522 nb->Attr("strides", new_strides); 2523 nb->Attr("dilations", new_dilations); 2524 } 2525 } 2526 2527 void MklLayoutRewritePass::CopyAttrsConcat(const Node* orig_node, 2528 NodeBuilder* nb, 2529 bool change_format) { 2530 DataType T; 2531 int N; 2532 2533 // Get all attributes from old node. 2534 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2535 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); 2536 2537 // Add attributes to new node. 2538 nb->Attr("T", T); 2539 nb->Attr("N", N); 2540 } 2541 2542 void MklLayoutRewritePass::CopyAttrsConcatV2(const Node* orig_node, 2543 NodeBuilder* nb, 2544 bool change_format) { 2545 DataType T; 2546 int N; 2547 DataType tidx; 2548 2549 // Get all attributes from old node. 2550 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2551 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); 2552 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tidx", &tidx)); 2553 2554 // Add attributes to new node. 2555 nb->Attr("T", T); 2556 nb->Attr("N", N); 2557 nb->Attr("Tidx", tidx); 2558 } 2559 2560 void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node, 2561 NodeBuilder* nb, 2562 bool change_format) { 2563 DataType T; 2564 float epsilon; 2565 string data_format; 2566 bool is_training; 2567 2568 // Get all attributes from old node. 2569 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2570 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon)); 2571 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 2572 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "is_training", &is_training)); 2573 2574 // Add attributes to new node. 2575 nb->Attr("T", T); 2576 nb->Attr("epsilon", epsilon); 2577 nb->Attr("data_format", data_format); 2578 nb->Attr("is_training", is_training); 2579 } 2580 2581 void MklLayoutRewritePass::CopyAttrsFusedConv2D(const Node* orig_node, 2582 NodeBuilder* nb, 2583 bool change_format) { 2584 DataType T; 2585 int num_args; 2586 float epsilon; 2587 string data_format; 2588 string padding; 2589 std::vector<int32> strides; 2590 std::vector<int32> dilations; 2591 std::vector<string> fused_ops; 2592 2593 // Get all attributes from old node. 2594 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 2595 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_args", &num_args)); 2596 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); 2597 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); 2598 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 2599 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations)); 2600 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "fused_ops", &fused_ops)); 2601 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon)); 2602 2603 Node* filter_node = nullptr; 2604 orig_node->input_node(1, &filter_node); 2605 2606 // Add attributes to new node. 2607 nb->Attr("T", T); 2608 nb->Attr("num_args", num_args); 2609 nb->Attr("strides", strides); 2610 nb->Attr("padding", padding); 2611 nb->Attr("is_filter_const", filter_node->IsConstant()); 2612 nb->Attr("data_format", data_format); 2613 nb->Attr("dilations", dilations); 2614 nb->Attr("fused_ops", fused_ops); 2615 nb->Attr("epsilon", epsilon); 2616 } 2617 2618 ////////////////////////////////////////////////////////////////////////// 2619 // Helper functions related to node merge pass 2620 ////////////////////////////////////////////////////////////////////////// 2621 2622 Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const { 2623 // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite 2624 // once we support BiasAddGrad as Mkl layer. 2625 2626 // Search for all matching mergeinfo. 2627 // We allow more than one match for extensibility. 2628 std::vector<const MergeInfo*> matching_mi; 2629 for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) { 2630 if (a->type_string() == mi->op1 || a->type_string() == mi->op2) { 2631 matching_mi.push_back(&*mi); 2632 } 2633 } 2634 2635 for (const MergeInfo* mi : matching_mi) { 2636 // Get the operand with which 'a' can be merged. 2637 Node* b = nullptr; 2638 if ((b = mi->get_node_to_be_merged(a)) == nullptr) { 2639 continue; 2640 } 2641 2642 // Get the control edges and input of node 2643 const int N_in = a->num_inputs(); 2644 gtl::InlinedVector<Node*, 4> a_control_edges; 2645 gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in); 2646 FillInputs(a, &a_control_edges, &a_in); 2647 2648 const int B_in = b->num_inputs(); 2649 gtl::InlinedVector<Node*, 4> b_control_edges; 2650 gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in); 2651 FillInputs(b, &b_control_edges, &b_in); 2652 2653 // Shouldn't merge if a and b have different control edges. 2654 if (a_control_edges != b_control_edges) { 2655 continue; 2656 } else { 2657 // We found a match. 2658 return b; 2659 } 2660 } 2661 2662 return nullptr; 2663 } 2664 2665 Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, 2666 Node* m, Node* n) { 2667 CHECK_EQ(((m->type_string() == csinfo_.bias_add && 2668 n->type_string() == csinfo_.conv2d)) || 2669 ((n->type_string() == csinfo_.bias_add && 2670 m->type_string() == csinfo_.conv2d)), 2671 true); 2672 2673 // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd, 2674 // BiasAdd is successor node, and Conv2D predecessor node. 2675 Node* pred = m->type_string() == csinfo_.bias_add ? n : m; 2676 Node* succ = m->type_string() == csinfo_.bias_add ? m : n; 2677 2678 // 1. Get all attributes from input nodes. 2679 DataType T_pred, T_succ; 2680 string padding; 2681 std::vector<int32> strides; 2682 std::vector<int32> dilations; 2683 string data_format_pred, data_format_succ; 2684 bool use_cudnn_on_gpu; 2685 TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred)); 2686 TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ)); 2687 TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding)); 2688 TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides)); 2689 TF_CHECK_OK(GetNodeAttr(pred->def(), "dilations", &dilations)); 2690 TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred)); 2691 TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ)); 2692 TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu)); 2693 // We check to ensure that data formats of both succ and pred are same. 2694 // We expect them to be same, so we can enforce this as assert. 2695 // But assert can be too strict, so we enforce this as a check. 2696 // If the check fails, then we do not merge two nodes. 2697 // We also do same check for devices. 2698 if (data_format_pred != data_format_succ || T_pred != T_succ || 2699 pred->assigned_device_name() != succ->assigned_device_name() || 2700 pred->def().device() != succ->def().device()) { 2701 return Status(error::Code::INVALID_ARGUMENT, 2702 "data_format or T attribute or devices of Conv2D and " 2703 "BiasAdd do not match. Will skip node merge optimization"); 2704 } 2705 2706 const int succ_num = succ->num_inputs(); 2707 gtl::InlinedVector<Node*, 4> succ_control_edges; 2708 gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num); 2709 FillInputs(succ, &succ_control_edges, &succ_in); 2710 2711 const int pred_num = pred->num_inputs(); 2712 gtl::InlinedVector<Node*, 4> pred_control_edges; 2713 gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num); 2714 FillInputs(pred, &pred_control_edges, &pred_in); 2715 2716 // We need to ensure that Conv2D only feeds to BiasAdd (some other operator is 2717 // not expecting output of Conv2D). If this is not the case, then we cannot 2718 // merge Conv2D with BiasAdd. 2719 const int kFirstOutputSlot = 0; 2720 for (const Edge* e : pred->out_edges()) { 2721 if (e->src_output() == kFirstOutputSlot && e->dst() != succ) { 2722 return Status(error::Code::INVALID_ARGUMENT, 2723 "Conv2D does not feed to BiasAdd, or " 2724 "it feeds BiasAdd but has multiple outputs. " 2725 "Will skip node merge optimization"); 2726 } 2727 } 2728 2729 // 2. Get inputs from both the nodes. 2730 // Find the 2 inputs from the conv and the bias from the add Bias. 2731 // Get operand 0, 1 of conv2D. 2732 CHECK_EQ(pred->in_edges().size(), 2); // Conv2D must have 2 inputs. 2733 // Get operand 1 of add_bias 2734 // BiasAdd must have 2 inputs: Conv, bias 2735 CHECK_EQ(succ->in_edges().size(), 2); 2736 2737 // We will use the node name of BiasAdd as the name of new node 2738 // Build new node. We use same name as original node, but change the op 2739 // name. 2740 NodeBuilder nb(succ->name(), csinfo_.conv2d_with_bias); 2741 nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D 2742 // pred_in[1] will be 2nd Tensorflow tensor for Conv2D. 2743 nb.Input(pred_in[1].first, pred_in[1].second); // In2 of Conv2D 2744 // In1 of BiasAdd is same as output of Conv2D. 2745 nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd 2746 2747 // Copy attributes from Conv2D to Conv2DWithBias. 2748 CopyAttrsConvCheckConstFilter(const_cast<const Node*>(pred), &nb); 2749 2750 // Copy the device assigned to old node to new node. 2751 nb.Device(succ->def().device()); 2752 2753 // Create node. 2754 Node* new_node; 2755 TF_CHECK_OK(nb.Finalize(&**g, &new_node)); 2756 CHECK_NOTNULL(new_node); 2757 2758 // Incoming data edges from 'pred' node and 'succ' node to new 'new_node' 2759 // node are already copied in BuildNode. We handle control edges now. 2760 for (const Edge* e : pred->in_edges()) { 2761 if (e->IsControlEdge()) { 2762 // Allow duplicate while adding control edge as it would fail (return 2763 // NULL) if we try to add duplicate edge. 2764 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); 2765 } 2766 } 2767 for (const Edge* e : succ->in_edges()) { 2768 if (e->IsControlEdge()) { 2769 // Allow duplicate while adding control edge as it would fail (return 2770 // NULL) if we try to add duplicate edge. 2771 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); 2772 } 2773 } 2774 2775 // Incoming edges are fixed, we will fix the outgoing edges now. 2776 // First, we will fix outgoing control edges from 'pred' node. 2777 for (const Edge* e : pred->out_edges()) { 2778 if (e->IsControlEdge()) { 2779 // Allow duplicate while adding control edge as it would fail (return 2780 // NULL) if we try to add duplicate edge. 2781 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); 2782 } 2783 } 2784 2785 // Second, we will fix outgoing control and data edges from 'succ' node. 2786 for (const Edge* e : succ->out_edges()) { 2787 if (e->IsControlEdge()) { 2788 // Allow duplicate while adding control edge as it would fail (return 2789 // NULL) if we try to add duplicate edge. 2790 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); 2791 } else { 2792 // BiasAdd has only 1 output (at slot 0) and merged node also has only 1 2793 // output (at slot 0). 2794 const int kConv2DWithBiasOutputSlot = 0; 2795 CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, e->dst(), 2796 e->dst_input())); 2797 } 2798 } 2799 2800 // Copy device assigned to old node to new node. 2801 // It's ok to use pred or succ as we have enforced a check that 2802 // both have same device assigned. 2803 new_node->set_assigned_device_name(pred->assigned_device_name()); 2804 2805 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString() 2806 << ", and node: " << succ->DebugString() 2807 << ", into node:" << new_node->DebugString(); 2808 2809 (*g)->RemoveNode(succ); 2810 (*g)->RemoveNode(pred); 2811 2812 return Status::OK(); 2813 } 2814 2815 Status MklLayoutRewritePass::MergePadWithConv2D(std::unique_ptr<Graph>* g, 2816 Node* m, Node* n) { 2817 DCHECK((m->type_string() == csinfo_.pad && 2818 (n->type_string() == csinfo_.conv2d || 2819 n->type_string() == csinfo_.fused_conv2d)) || 2820 (n->type_string() == csinfo_.pad && 2821 (m->type_string() == csinfo_.conv2d || 2822 m->type_string() == csinfo_.fused_conv2d))); 2823 2824 bool is_fused_conv2d = n->type_string() == csinfo_.fused_conv2d || 2825 m->type_string() == csinfo_.fused_conv2d; 2826 // Conv2D is successor node, and Pad predecessor node. 2827 Node* pred = m->type_string() == csinfo_.pad ? m : n; 2828 Node* succ = m->type_string() == csinfo_.pad ? n : m; 2829 2830 // 1. Get all attributes from input nodes. 2831 DataType T_pred, T_succ; 2832 string padding; 2833 std::vector<int32> strides; 2834 std::vector<int32> dilations; 2835 string data_format_pred, data_format_succ; 2836 2837 TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred)); 2838 TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ)); 2839 TF_CHECK_OK(GetNodeAttr(succ->def(), "padding", &padding)); 2840 TF_CHECK_OK(GetNodeAttr(succ->def(), "strides", &strides)); 2841 TF_CHECK_OK(GetNodeAttr(succ->def(), "dilations", &dilations)); 2842 // Check if the devices of both succ and pred are the same. 2843 // Assert is not used because it can be too strict. 2844 // Don't need to check for data formats because it is not available in Pad. 2845 if (T_pred != T_succ || 2846 pred->assigned_device_name() != succ->assigned_device_name() || 2847 pred->def().device() != succ->def().device()) { 2848 return Status(error::Code::INVALID_ARGUMENT, 2849 "T attribute or devices of Conv2D and " 2850 "Pad do not match. Will skip node merge optimization"); 2851 } 2852 2853 const int succ_num = succ->num_inputs(); 2854 gtl::InlinedVector<Node*, 4> succ_control_edges; 2855 gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num); 2856 FillInputs(succ, &succ_control_edges, &succ_in); 2857 2858 const int pred_num = pred->num_inputs(); 2859 gtl::InlinedVector<Node*, 4> pred_control_edges; 2860 gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num); 2861 FillInputs(pred, &pred_control_edges, &pred_in); 2862 2863 // We need to ensure that Pad only feeds to Conv2D (some other operator is 2864 // not expecting output of Pad). If this is not the case, then we cannot 2865 // merge Conv2D with Pad. 2866 const int kFirstOutputSlot = 0; 2867 for (const Edge* e : pred->out_edges()) { 2868 if (e->src_output() == kFirstOutputSlot && e->dst() != succ) { 2869 return Status(error::Code::INVALID_ARGUMENT, 2870 "Pad does not feed to Conv2D, or " 2871 "it feeds Conv2D but has multiple outputs. " 2872 "Will skip node merge optimization"); 2873 } 2874 } 2875 2876 // 2. Get inputs from both the nodes. 2877 2878 // Pad must have 2 data inputs: "input" and paddings. 2879 int PadDataInputEdges = 0; 2880 for (const Edge* e : pred->in_edges()) { 2881 if (!e->IsControlEdge()) { 2882 PadDataInputEdges++; 2883 } 2884 } 2885 DCHECK_EQ(PadDataInputEdges, 2); 2886 2887 // Conv2D must have 2 data inputs: Pad output and Filter 2888 // FusedConv2D have 3 data inputs: Pad output, Filter and Args; 2889 int ConvDataInputEdges = 0; 2890 for (const Edge* e : succ->in_edges()) { 2891 if (!e->IsControlEdge()) { 2892 ConvDataInputEdges++; 2893 } 2894 } 2895 2896 DCHECK_EQ(ConvDataInputEdges, is_fused_conv2d ? 3 : 2); 2897 2898 // We will use the node name of Conv2D as the name of new node 2899 // Build new node. We use same name as original node, but change the op 2900 // name. 2901 2902 NodeBuilder nb(succ->name(), is_fused_conv2d ? csinfo_.pad_with_fused_conv2d 2903 : csinfo_.pad_with_conv2d); 2904 nb.Input(pred_in[0].first, pred_in[0].second); // In1 (input data) of Pad 2905 // pred_in[1] will be 2nd Tensorflow tensor for Conv2D. 2906 nb.Input(succ_in[1].first, succ_in[1].second); // In2 (filter) of conv2d 2907 // In1 of Conv2D is same as output of Pad. 2908 // Thus, only need to add In2 of Conv2D 2909 2910 if (is_fused_conv2d) { 2911 // FusedConv2D has one additional input, args 2912 std::vector<NodeBuilder::NodeOut> args; 2913 args.emplace_back(succ_in[2].first, succ_in[2].second); 2914 nb.Input(gtl::ArraySlice<NodeBuilder::NodeOut>{ 2915 args}); // In3 (args) of FusedConv2D 2916 nb.Input(pred_in[1].first, pred_in[1].second); // In2 (paddings) of Pad 2917 // Copy attributes from Pad and FusedConv2D to PadWithFusedConv2D. 2918 CopyAttrsFromPadAndFusedConv2D(const_cast<const Node*>(succ), 2919 const_cast<const Node*>(pred), &nb); 2920 } else { 2921 nb.Input(pred_in[1].first, pred_in[1].second); // In2 (paddings) of Pad 2922 // Copy attributes from Pad and conv2D to PadWithConv2D. 2923 CopyAttrsFromPadAndConv2D(const_cast<const Node*>(succ), 2924 const_cast<const Node*>(pred), &nb); 2925 } 2926 2927 // Copy the device assigned to old node to new node. 2928 nb.Device(succ->def().device()); 2929 2930 // Create node. 2931 Node* new_node; 2932 TF_CHECK_OK(nb.Finalize(&**g, &new_node)); 2933 DCHECK(new_node); 2934 2935 // Incoming data edges from 'pred' node and 'succ' node to new 'new_node' 2936 // node are already copied in BuildNode. 2937 // We handle control edges now. 2938 for (const Edge* e : pred->in_edges()) { 2939 if (e->IsControlEdge()) { 2940 // Don't allow duplicate edge 2941 (*g)->AddControlEdge(e->src(), new_node, false); 2942 } 2943 } 2944 for (const Edge* e : succ->in_edges()) { 2945 if (e->IsControlEdge()) { 2946 // Don't allow duplicate edge 2947 (*g)->AddControlEdge(e->src(), new_node, false); 2948 } 2949 } 2950 2951 // Incoming edges are fixed, we will fix the outgoing edges now. 2952 // First, we will fix outgoing control edges from 'pred' node. 2953 for (const Edge* e : pred->out_edges()) { 2954 if (e->IsControlEdge()) { 2955 // Don't allow duplicate edge 2956 (*g)->AddControlEdge(new_node, e->dst(), false); 2957 } 2958 } 2959 2960 // Second, we will fix outgoing control and data edges from 'succ' node. 2961 for (const Edge* e : succ->out_edges()) { 2962 if (e->IsControlEdge()) { 2963 // Allow duplicate while adding control edge as it would fail (return 2964 // NULL) if we try to add duplicate edge. 2965 (*g)->AddControlEdge(new_node, e->dst(), false); 2966 } else { 2967 // Conv2D has only 1 output (at slot 0) and merged node also has only 1 2968 // output (at slot 0). 2969 const int kPadWithConv2DOutputSlot = 0; 2970 (*g)->AddEdge(new_node, kPadWithConv2DOutputSlot, e->dst(), 2971 e->dst_input()); 2972 } 2973 } 2974 2975 // Copy device assigned to old node to new node. 2976 // It's ok to use pred or succ as we have enforced a check that 2977 // both have same device assigned. 2978 new_node->set_assigned_device_name(pred->assigned_device_name()); 2979 2980 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString() 2981 << ", and node: " << succ->DebugString() 2982 << ", into node:" << new_node->DebugString(); 2983 2984 (*g)->RemoveNode(succ); 2985 (*g)->RemoveNode(pred); 2986 2987 return Status::OK(); 2988 } 2989 2990 Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad( 2991 std::unique_ptr<Graph>* g, Node* m, Node* n) { 2992 CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad && 2993 n->type_string() == csinfo_.conv2d_grad_filter)) || 2994 ((n->type_string() == csinfo_.bias_add_grad && 2995 m->type_string() == csinfo_.conv2d_grad_filter)), 2996 true); 2997 2998 // If 'm' is BiasAddGrad, then 'n' is BackpropFilter. 2999 Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n; 3000 Node* fltr = m->type_string() == csinfo_.bias_add_grad ? n : m; 3001 3002 // Sanity check for attributes from input nodes. 3003 DataType T_b, T_f; 3004 string data_format_b, data_format_f; 3005 TF_CHECK_OK(GetNodeAttr(badd->def(), "T", &T_b)); 3006 TF_CHECK_OK(GetNodeAttr(fltr->def(), "T", &T_f)); 3007 TF_CHECK_OK(GetNodeAttr(badd->def(), "data_format", &data_format_b)); 3008 TF_CHECK_OK(GetNodeAttr(fltr->def(), "data_format", &data_format_f)); 3009 if (data_format_b != data_format_f || T_b != T_f || 3010 badd->assigned_device_name() != fltr->assigned_device_name() || 3011 badd->def().device() != fltr->def().device()) { 3012 return Status(error::Code::INVALID_ARGUMENT, 3013 "data_format or T attribute or devices of " 3014 "Conv2DBackpropFilter and BiasAddGrad do not match. " 3015 "Will skip node merge optimization"); 3016 } 3017 3018 // We will use the node name of Conv2DBackpropFilter as the name of new node. 3019 // This is because BackpropFilterWithBias is going to emit bias output also. 3020 NodeBuilder nb(fltr->name(), csinfo_.conv2d_grad_filter_with_bias); 3021 // Since Conv2DBackpropFilterWithBias has same number of inputs as 3022 // Conv2DBackpropFilter, we can just copy input edges directly. We dont need 3023 // to copy any data input of BiasAddGrad because that input also goes to 3024 // Conv2DBackpropFilter. 3025 const int fltr_ins = fltr->num_inputs(); 3026 gtl::InlinedVector<Node*, 4> fltr_control_edges; 3027 gtl::InlinedVector<std::pair<Node*, int>, 4> fltr_in_edges(fltr_ins); 3028 FillInputs(fltr, &fltr_control_edges, &fltr_in_edges); 3029 for (int idx = 0; idx < fltr_ins; idx++) { 3030 nb.Input(fltr_in_edges[idx].first, fltr_in_edges[idx].second); 3031 } 3032 3033 // Copy attributes from Conv2DBackpropFilter. 3034 CopyAttrsConv(const_cast<const Node*>(fltr), &nb); 3035 3036 // Copy the device assigned to old node to new node. 3037 nb.Device(fltr->def().device()); 3038 3039 // Create node. 3040 Node* new_node; 3041 TF_CHECK_OK(nb.Finalize(&**g, &new_node)); 3042 CHECK_NOTNULL(new_node); 3043 3044 // Incoming data edges from BiasAddGrad node and Conv2DBackpropFilter node to 3045 // new 'new_node' node are already copied in BuildNode. We handle control 3046 // edges now. 3047 for (const Edge* e : badd->in_edges()) { 3048 if (e->IsControlEdge()) { 3049 // Allow duplicate while adding control edge as it would fail (return 3050 // NULL) if we try to add duplicate edge. 3051 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); 3052 } 3053 } 3054 for (const Edge* e : fltr->in_edges()) { 3055 if (e->IsControlEdge()) { 3056 // Allow duplicate while adding control edge as it would fail (return 3057 // NULL) if we try to add duplicate edge. 3058 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); 3059 } 3060 } 3061 3062 // Incoming edges are fixed, we will fix the outgoing edges now. 3063 // First, we will fix outgoing control edges from 'badd' node. 3064 // Conv2DBackpropFilter has 1 output -- filter_grad. 3065 // Conv2DBackpropFilterWithBias has 2 outputs -- filter_grad and 3066 // bias_grad. But filter_grad is at same slot number (0) in both the 3067 // nodes. bias_grad is at slot number 1 in Conv2DBackpropFilterWithBias, while 3068 // it is at slot number 0 in BiasAddGrad. 3069 const int kMergedNodeFilterGradOutputIdx = 0; 3070 const int kMergedNodeBiasGradOutputIdx = 1; 3071 3072 for (const Edge* e : badd->out_edges()) { 3073 if (e->IsControlEdge()) { 3074 // Allow duplicate while adding control edge as it would fail (return 3075 // NULL) if we try to add duplicate edge. 3076 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); 3077 } else { 3078 CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx, 3079 e->dst(), e->dst_input())); 3080 } 3081 } 3082 3083 // Second, we will fix outgoing control and data edges from 'fltr' node. 3084 for (const Edge* e : fltr->out_edges()) { 3085 if (e->IsControlEdge()) { 3086 // We allow duplicate edge for this case since we already add control 3087 // edge from new_node in line 3990. Line below could be adding same 3088 // edge to same destination again. In such case, if we do not allow 3089 // duplicate edge, then this call will fail. 3090 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); 3091 } else { 3092 CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx, 3093 e->dst(), e->dst_input())); 3094 } 3095 } 3096 3097 // Copy device assigned to old node to new node. 3098 // It's ok to use badd or fltr as we have enforced a check that 3099 // both have same device assigned. 3100 new_node->set_assigned_device_name(badd->assigned_device_name()); 3101 3102 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << badd->DebugString() 3103 << ", and node: " << fltr->DebugString() 3104 << ", into node:" << new_node->DebugString(); 3105 3106 (*g)->RemoveNode(badd); 3107 (*g)->RemoveNode(fltr); 3108 3109 return Status::OK(); 3110 } 3111 3112 Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* m, 3113 Node* n) { 3114 CHECK_NOTNULL(m); 3115 CHECK_NOTNULL(n); 3116 3117 if (((m->type_string() == csinfo_.bias_add && 3118 n->type_string() == csinfo_.conv2d)) || 3119 ((n->type_string() == csinfo_.bias_add && 3120 m->type_string() == csinfo_.conv2d))) { 3121 return this->MergeConv2DWithBiasAdd(g, m, n); 3122 } 3123 if ((m->type_string() == csinfo_.pad && 3124 (n->type_string() == csinfo_.conv2d || 3125 (n->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(n)))) || 3126 (n->type_string() == csinfo_.pad && 3127 (m->type_string() == csinfo_.conv2d || 3128 (m->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(m))))) { 3129 return this->MergePadWithConv2D(g, m, n); 3130 } 3131 3132 if (((m->type_string() == csinfo_.bias_add_grad && 3133 n->type_string() == csinfo_.conv2d_grad_filter)) || 3134 ((n->type_string() == csinfo_.bias_add_grad && 3135 m->type_string() == csinfo_.conv2d_grad_filter))) { 3136 return this->MergeConv2DBackpropFilterWithBiasAddGrad(g, m, n); 3137 } 3138 3139 return Status(error::Code::UNIMPLEMENTED, 3140 "Unimplemented case for node merge optimization."); 3141 } 3142 3143 ////////////////////////////////////////////////////////////////////////// 3144 // Helper functions for node rewrite 3145 ////////////////////////////////////////////////////////////////////////// 3146 3147 Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, 3148 Node* orig_node, 3149 const RewriteInfo* ri) { 3150 CHECK_NOTNULL(ri); 3151 CHECK_NOTNULL(orig_node); 3152 3153 VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString(); 3154 3155 // Get all inputs. 3156 int num_inputs = orig_node->in_edges().size(); 3157 3158 // Drop count for control edges from inputs 3159 for (const Edge* e : orig_node->in_edges()) { 3160 if (e->IsControlEdge()) { 3161 num_inputs--; 3162 } 3163 } 3164 3165 gtl::InlinedVector<Node*, 4> control_edges; 3166 gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs); 3167 FillInputs(orig_node, &control_edges, &inputs); 3168 3169 // Build new node. We use same name as original node, but change the op name. 3170 NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str()); 3171 // Copy user-specified device assigned to original node to new node. 3172 nb.Device(orig_node->def().device()); 3173 // Set up new inputs to the rewritten node. 3174 Status s = SetUpInputs(g, inputs, &nb, orig_node); 3175 if (s != Status::OK()) { 3176 return s; 3177 } 3178 3179 const bool kPartialCopyAttrs = false; 3180 ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, kPartialCopyAttrs); 3181 3182 // Set the Mkl layer label for this op. 3183 if (DataTypeIsQuantized(orig_node->input_type(0)) || 3184 DataTypeIsQuantized(orig_node->output_type(0))) { 3185 nb.Attr("_kernel", mkl_op_registry::kMklQuantizedOpLabel); 3186 } else { 3187 nb.Attr("_kernel", mkl_op_registry::kMklOpLabel); 3188 } 3189 // Finalize graph and get new node. 3190 Node* new_node = nullptr; 3191 TF_CHECK_OK(nb.Finalize(&**g, &new_node)); 3192 CHECK_NOTNULL(new_node); 3193 3194 // Incoming data edges from 'orig_node' node to new 'new_node' node are 3195 // already copied in BuildNode. We need to handle control edges now. 3196 for (const Edge* e : orig_node->in_edges()) { 3197 if (e->IsControlEdge()) { 3198 // Allow duplicate while adding control edge as it would fail (return 3199 // NULL) if we try to add duplicate edge. 3200 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); 3201 } 3202 } 3203 3204 // Copy outgoing edges from 'orig_node' node to new 3205 // 'new_node' node, since the output also follows same ordering among 3206 // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow 3207 // tensors appropriately. Specifically, nth output of the original node 3208 // will become 2*nth output of the Mkl node for the interleaved ordering 3209 // of the tensors. For the contiguous ordering of the tensors, it will be n. 3210 // GetTensorDataIndex provides this mapping function. 3211 for (const Edge* e : orig_node->out_edges()) { 3212 if (e->IsControlEdge()) { 3213 // Allow duplicate while adding control edge as it would fail (return 3214 // NULL) if we try to add duplicate edge. 3215 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); 3216 } else { 3217 CHECK_NOTNULL((*g)->AddEdge( 3218 new_node, 3219 GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), 3220 e->dst(), e->dst_input())); 3221 } 3222 } 3223 3224 // Copy the runtime device assigned from original code to new node. 3225 new_node->set_assigned_device_name(orig_node->assigned_device_name()); 3226 3227 // Delete original node and mark new node as rewritten. 3228 (*g)->RemoveNode(orig_node); 3229 3230 VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString(); 3231 return Status::OK(); 3232 } 3233 3234 // TODO(mdfaijul): Is there any other elegent way to check for quantized ops 3235 // having attributes other than "T"? 3236 // Current implementation reflects only QuantizedConv2D and its fused Ops. 3237 const MklLayoutRewritePass::RewriteInfo* 3238 MklLayoutRewritePass::CheckForQuantizedNodeRewrite(const Node* n) const { 3239 DataType Tinput, Tfilter; 3240 if (!(GetNodeAttr(n->def(), "Tinput", &Tinput).ok() && 3241 GetNodeAttr(n->def(), "Tfilter", &Tfilter).ok())) { 3242 return nullptr; 3243 } 3244 if (mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), 3245 Tinput, Tfilter)) { 3246 for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { 3247 if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) { 3248 return &*ri; 3249 } 3250 } 3251 } 3252 return nullptr; 3253 } 3254 3255 const MklLayoutRewritePass::RewriteInfo* 3256 MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { 3257 CHECK_NOTNULL(n); 3258 3259 // QuntizedOps may have attributes other than "T", so decoupled the check 3260 // with a function, CheckForQuantizedNodeRewrite(const Node*). 3261 const RewriteInfo* ri = CheckForQuantizedNodeRewrite(n); 3262 if (ri != nullptr) return ri; 3263 3264 // First check if node along with its type is supported by MKL layer. 3265 // We do not want to rewrite an op into Mkl op if types are not supported. 3266 // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to 3267 // MklRelu if type is INT32. 3268 DataType T; 3269 if (!GetNodeAttr(n->def(), "T", &T).ok()) { 3270 return nullptr; 3271 } 3272 3273 // We make an exception for __MklDummyConv2DWithBias, 3274 // __MklConv2DBackpropFilterWithBias, and __MklDummyPadWithConv2D since their 3275 // names do not match Mkl node names. 3276 if (n->type_string() != csinfo_.conv2d_with_bias && 3277 n->type_string() != csinfo_.pad_with_conv2d && 3278 n->type_string() != csinfo_.pad_with_fused_conv2d && 3279 n->type_string() != csinfo_.conv2d_grad_filter_with_bias && 3280 n->type_string() != csinfo_.fused_conv2d && 3281 !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), 3282 T)) { 3283 return nullptr; 3284 } 3285 3286 // For elementwise node, we reuse the Eigen implementation and pass the MKL 3287 // metadata tensor through so we can avoid conversions. However, if all 3288 // incoming edges are in TF format, we don't need all this overhead, so 3289 // replace the elementwise node only if at least one of its parents is a MKL 3290 // node. 3291 // 3292 // Identity nodes can also skip replacement if they are not being served by 3293 // any MKL nodes. 3294 // 3295 // TODO(vrane): Add implementation for element-wise ops that doesn't reuse 3296 // eigen code to reduce cross-library dependency. 3297 VLOG(1) << "ELEMENTWISE: checking op: " << n->type_string(); 3298 if (mkl_op_registry::IsMklElementWiseOp( 3299 mkl_op_registry::GetMklOpName(n->type_string()), T) || 3300 n->type_string().find("Identity") != string::npos) { 3301 VLOG(1) << "ELEMENTWISE: op is elementwise: " << n->type_string(); 3302 bool incoming_mkl_edge = false; 3303 int num_parent = 0; 3304 for (auto parent : n->in_edges()) { 3305 if (mkl_op_registry::IsMklOp(parent->src()->type_string(), T)) { 3306 VLOG(1) << "ELEMENTWISE: parent " << num_parent++ 3307 << " is MKL op: " << parent->src()->type_string(); 3308 incoming_mkl_edge = true; 3309 break; 3310 } else { 3311 VLOG(1) << "ELEMENTWISE: parent " << num_parent++ 3312 << " is NON-MKL op: " << parent->src()->type_string(); 3313 } 3314 } 3315 if (incoming_mkl_edge == false) { 3316 VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which " 3317 "has no MKL " 3318 "parents."; 3319 return nullptr; 3320 } else { 3321 VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string() 3322 << " which has MKL parents"; 3323 } 3324 } 3325 3326 // We now check if rewrite rule applies for this op. If rewrite rule passes 3327 // for this op, then we rewrite it to Mkl op. 3328 // Find matching RewriteInfo and then check that rewrite rule applies. 3329 for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { 3330 if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) { 3331 return &*ri; 3332 } 3333 } 3334 3335 // Else return not found. 3336 return nullptr; 3337 } 3338 3339 ////////////////////////////////////////////////////////////////////////// 3340 // Helper functions for node fusion 3341 ////////////////////////////////////////////////////////////////////////// 3342 Status MklLayoutRewritePass::FuseTransposeMklOpTranspose( 3343 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes, 3344 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs, 3345 string data_format) { 3346 Node* transpose_to_nhwc = nodes[0]; 3347 Node* mklop = nodes[1]; 3348 Node* transpose_to_nchw = nodes[2]; 3349 3350 const int transpose_nhwc_num_inputs = transpose_to_nhwc->num_inputs(); 3351 gtl::InlinedVector<Node*, 4> transpose_nhwc_control_edges; 3352 gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nhwc_in( 3353 transpose_nhwc_num_inputs); 3354 FillInputs(transpose_to_nhwc, &transpose_nhwc_control_edges, 3355 &transpose_nhwc_in); 3356 3357 const int mklop_num_inputs = mklop->num_inputs(); 3358 gtl::InlinedVector<Node*, 4> mklop_control_edges; 3359 gtl::InlinedVector<std::pair<Node*, int>, 4> mklop_in(mklop_num_inputs); 3360 FillInputs(mklop, &mklop_control_edges, &mklop_in); 3361 3362 const int transpose_nchw_num_inputs = transpose_to_nchw->num_inputs(); 3363 gtl::InlinedVector<Node*, 4> transpose_nchw_control_edges; 3364 gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nchw_in( 3365 transpose_nchw_num_inputs); 3366 FillInputs(transpose_to_nchw, &transpose_nchw_control_edges, 3367 &transpose_nchw_in); 3368 3369 // We use same name as original node, but change the op 3370 // type. 3371 NodeBuilder nb(mklop->name(), mklop->type_string()); 3372 3373 // Storing the output slots of the input nodes. 3374 for (int i = 0; i < mklop_num_inputs; i++) { 3375 if (mklop_in[i].first == transpose_to_nhwc) { 3376 // Fill "x": 3377 nb.Input(transpose_nhwc_in[0].first, transpose_nhwc_in[0].second); 3378 } else { 3379 // Fill inputs other than "x": 3380 nb.Input(mklop_in[i].first, mklop_in[i].second); 3381 } 3382 } 3383 3384 copy_attrs(const_cast<const Node*>(mklop), &nb, true); 3385 nb.Attr("data_format", data_format); 3386 3387 // Copy the device assigned to old node to new node. 3388 nb.Device(mklop->def().device()); 3389 3390 // Create node. 3391 Node* new_node; 3392 TF_CHECK_OK(nb.Finalize(&**g, &new_node)); 3393 DCHECK(new_node); 3394 3395 // Fill outputs. 3396 for (const Edge* e : transpose_to_nchw->out_edges()) { 3397 if (!e->IsControlEdge()) { 3398 const int kTransposeWithMklOpOutputSlot = 0; 3399 auto new_edge = (*g)->AddEdge(new_node, kTransposeWithMklOpOutputSlot, 3400 e->dst(), e->dst_input()); 3401 DCHECK(new_edge); 3402 } 3403 } 3404 3405 // Copy device assigned to old node to new node. 3406 new_node->set_assigned_device_name(mklop->assigned_device_name()); 3407 3408 // Copy requested_device and assigned_device_name_index 3409 new_node->set_requested_device(mklop->requested_device()); 3410 new_node->set_assigned_device_name_index(mklop->assigned_device_name_index()); 3411 3412 (*g)->RemoveNode(transpose_to_nhwc); 3413 (*g)->RemoveNode(mklop); 3414 (*g)->RemoveNode(transpose_to_nchw); 3415 3416 return Status::OK(); 3417 } 3418 3419 Status MklLayoutRewritePass::FuseNode( 3420 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes, 3421 const MklLayoutRewritePass::FusionInfo fi) { 3422 return fi.fuse_func(g, nodes, fi.copy_attrs); 3423 } 3424 3425 std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo> 3426 MklLayoutRewritePass::CheckForNodeFusion(Node* a) const { 3427 // Stores matched nodes, in the same order as node_checkers. 3428 std::vector<Node*> nodes; 3429 3430 for (auto fi = finfo_.begin(); fi != finfo_.end(); ++fi) { 3431 // 3432 // Make sure node "a" and its succeding nodes (b, c ...), match the pattern 3433 // defined in fusion info (ops[0], ops[1], ...), 3434 // a.k.a. "a->b->c" matches "op1->op2->op3" 3435 // 3436 3437 // Stores the first unvisted outgoing edge of each matched node in "nodes". 3438 std::stack<EdgeSet::const_iterator> current_neighbor_stack; 3439 nodes.clear(); 3440 3441 auto node_checker = fi->node_checkers.begin(); 3442 if (a != nullptr && (*node_checker)(a)) { 3443 nodes.push_back(a); 3444 current_neighbor_stack.push(a->out_edges().begin()); 3445 ++node_checker; 3446 } 3447 3448 while (!nodes.empty()) { 3449 auto& current_neighbor_iter = current_neighbor_stack.top(); 3450 3451 if (current_neighbor_iter != nodes.back()->out_edges().end()) { 3452 // Found an unvisited edge. Goes through the edge to get the neighbor. 3453 Node* neighbor_node = (*current_neighbor_iter)->dst(); 3454 ++current_neighbor_stack.top(); // Retrieves the next unvisited edge. 3455 3456 if ((*node_checker)(neighbor_node)) { 3457 // Found a match. Stores the node and moves to the next checker. 3458 nodes.push_back(neighbor_node); 3459 current_neighbor_stack.push(neighbor_node->out_edges().begin()); 3460 if (++node_checker == fi->node_checkers.end()) { 3461 return make_tuple(true, nodes, *fi); 3462 } 3463 } 3464 } else { 3465 // Removes the current node since none of its neighbor leads to a 3466 // further match. 3467 nodes.pop_back(); 3468 current_neighbor_stack.pop(); 3469 --node_checker; 3470 } 3471 } 3472 } 3473 3474 return make_tuple(false, std::vector<Node*>(), FusionInfo()); 3475 } 3476 3477 /////////////////////////////////////////////////////////////////////////////// 3478 // Post-rewrite Mkl metadata fixup pass 3479 /////////////////////////////////////////////////////////////////////////////// 3480 bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, 3481 const Edge* e_data, 3482 const Edge* e_metadata) { 3483 if (g == nullptr || e_data == nullptr || e_metadata == nullptr) { 3484 return false; 3485 } 3486 3487 Node* n_data = e_data->src(); 3488 int n_data_op_slot = e_data->src_output(); 3489 int n_metadata_op_slot = 3490 GetTensorMetaDataIndex(n_data_op_slot, n_data->num_outputs()); 3491 3492 // If the source of meta edge is a constant node (producing dummy Mkl metadata 3493 // tensor), then we will need to fix. 3494 if (IsConstant(e_metadata->src())) { 3495 Node* e_metadata_dst = e_metadata->dst(); 3496 int e_metadata_in_slot = e_metadata->dst_input(); 3497 CHECK_NOTNULL((*g)->AddEdge(n_data, n_metadata_op_slot, e_metadata_dst, 3498 e_metadata_in_slot)); 3499 3500 (*g)->RemoveEdge(e_metadata); 3501 return true; 3502 } 3503 3504 return false; 3505 } 3506 3507 bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g, 3508 Node* n) { 3509 bool result = false; 3510 3511 // If graph node is not Mkl node, then return. 3512 DataType T = DT_INVALID; 3513 if (!GetNodeAttr(n->def(), "T", &T).ok() || 3514 !mkl_op_registry::IsMklOp(n->type_string(), T)) { 3515 return result; 3516 } 3517 3518 // If it is Mkl node, then check if the input edges to this node that carry 3519 // Mkl metadata are linked up correctly with the source node. 3520 3521 // For Mkl nodes, we generate twice the number of input tensors (n for Mkl 3522 // data tensors + n for Mkl metadata tensors). We need to check for correct 3523 // connection of n metadata tensors only. 3524 int num_data_inputs = n->num_inputs() / 2; 3525 for (int idx = 0; idx < num_data_inputs; idx++) { 3526 // Get the edge connecting input slot with index (idx). 3527 const Edge* e = nullptr; 3528 TF_CHECK_OK(n->input_edge(idx, &e)); 3529 3530 // If e is control edge, then skip. 3531 if (e->IsControlEdge()) { 3532 continue; 3533 } 3534 3535 // Check that the source node for edge 'e' is Mkl node. If it is not an Mkl 3536 // node, then we don't need to do anything. 3537 Node* e_src = e->src(); 3538 if (GetNodeAttr(e_src->def(), "T", &T).ok() && 3539 mkl_op_registry::IsMklOp(e_src->type_string(), T)) { 3540 // Source node for edge 'e' is Mkl node. 3541 // Destination node and destination input slot of e is node 'n' and 'idx' 3542 // resp. 3543 CHECK_EQ(e->dst(), n); 3544 CHECK_EQ(e->dst_input(), idx); 3545 3546 // Let's get edge that carries Mkl metadata corresponding to Mkl data edge 3547 // 'e'. For that, let's first get the input slot of 'n' where the meta 3548 // edge will feed the value. 3549 int e_meta_in_slot = 3550 GetTensorMetaDataIndex(e->dst_input(), n->num_inputs()); 3551 const Edge* e_meta = nullptr; 3552 TF_CHECK_OK(n->input_edge(e_meta_in_slot, &e_meta)); 3553 3554 // Let's check if we need to fix this meta edge. 3555 if (FixMklMetaDataEdgeIfNeeded(g, e, e_meta)) { 3556 result = true; 3557 } 3558 } 3559 } 3560 3561 return result; 3562 } 3563 3564 /////////////////////////////////////////////////////////////////////////////// 3565 // Run function for the pass 3566 /////////////////////////////////////////////////////////////////////////////// 3567 3568 bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) { 3569 bool result = false; 3570 CHECK_NOTNULL(g); 3571 3572 DumpGraph("Before running MklLayoutRewritePass", &**g); 3573 3574 std::vector<Node*> order; 3575 GetReversePostOrder(**g, &order); // This will give us topological sort. 3576 for (Node* n : order) { 3577 // If node is not an op or it cannot run on CPU device, then skip. 3578 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { 3579 continue; 3580 } 3581 3582 Node* m = nullptr; 3583 if ((m = CheckForNodeMerge(n)) != nullptr && CanOpRunOnCPUDevice(m)) { 3584 // Check if the node 'n' can be merged with any other node. If it can 3585 // be 'm' contains the node with which it can be merged. 3586 string n1_name = n->name(); 3587 string n2_name = m->name(); 3588 3589 VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and " 3590 << n2_name << " for merging"; 3591 3592 if (MergeNode(g, n, m) == Status::OK()) { 3593 VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and " 3594 << n2_name; 3595 result = true; 3596 } 3597 } 3598 } 3599 3600 DumpGraph("After running MklLayoutRewritePass(NodeMerge)", &**g); 3601 3602 order.clear(); 3603 GetReversePostOrder(**g, &order); // This will give us topological sort. 3604 for (Node* n : order) { 3605 // If node is not an op or it cannot run on CPU device, then skip. 3606 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { 3607 continue; 3608 } 3609 3610 auto check_result = CheckForNodeFusion(n); 3611 bool found_pattern = std::get<0>(check_result); 3612 std::vector<Node*> nodes = std::get<1>(check_result); 3613 const FusionInfo fi = std::get<2>(check_result); 3614 3615 // if "found_pattern" is true, we can do the fusion. 3616 if (found_pattern) { 3617 if (FuseNode(g, nodes, fi) == Status::OK()) { 3618 result = true; 3619 } 3620 } 3621 } 3622 DumpGraph("After running MklLayoutRewritePass(NodeFusion)", &**g); 3623 3624 order.clear(); 3625 GetReversePostOrder(**g, &order); // This will give us topological sort. 3626 for (Node* n : order) { 3627 // If node is not an op or it cannot run on CPU device, then skip. 3628 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { 3629 continue; 3630 } 3631 3632 const RewriteInfo* ri = nullptr; 3633 // We will first search if node is to be rewritten. 3634 if ((ri = CheckForNodeRewrite(n)) != nullptr) { 3635 string node_name = n->name(); 3636 string op_name = n->type_string(); 3637 3638 VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name 3639 << " with op " << op_name << " for rewrite using" 3640 << " layout optimization."; 3641 3642 if (RewriteNode(g, n, ri) == Status::OK()) { 3643 VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name 3644 << " with op " << op_name << " for Mkl layout optimization."; 3645 result = true; 3646 } 3647 } 3648 } 3649 3650 DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g); 3651 3652 order.clear(); 3653 GetReversePostOrder(**g, &order); // This will give us topological sort. 3654 for (Node* n : order) { 3655 // If node is not an op or it cannot run on CPU device, then skip. 3656 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { 3657 continue; 3658 } 3659 if (FixMklMetaDataEdges(g, n)) { 3660 string node_name = n->name(); 3661 string op_name = n->type_string(); 3662 3663 VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node " 3664 << node_name << " with op " << op_name; 3665 result = true; 3666 } 3667 } 3668 DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)", 3669 &**g); 3670 3671 return result; 3672 } 3673 3674 bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) { 3675 return MklLayoutRewritePass().RunPass(g); 3676 } 3677 3678 Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { 3679 if (options.graph == nullptr && options.partition_graphs == nullptr) { 3680 return Status::OK(); 3681 } 3682 if (DisableMKL()) { 3683 VLOG(2) << "TF-MKL: Disabling MKL"; 3684 return Status::OK(); 3685 } 3686 3687 auto process_graph = [&](std::unique_ptr<Graph>* g) { 3688 // Get the ownership of a graph 3689 std::unique_ptr<Graph>* ng = std::move(g); 3690 RunPass(ng); 3691 // Return the ownership of a graph back 3692 g->reset(ng->release()); 3693 }; 3694 3695 if (kMklLayoutRewritePassGroup != 3696 OptimizationPassRegistry::POST_PARTITIONING) { 3697 // For any pre-partitioning phase, a graph is stored in options.graph. 3698 process_graph(options.graph); 3699 } else { 3700 // For post partitioning phase, graphs are stored in 3701 // options.partition_graphs. 3702 for (auto& pg : *options.partition_graphs) { 3703 process_graph(&pg.second); 3704 } 3705 } 3706 3707 return Status::OK(); 3708 } 3709 3710 } // namespace tensorflow 3711 3712 #endif 3713