1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifdef INTEL_MKL 17 18 #include <algorithm> 19 #include <functional> 20 #include <memory> 21 #include <queue> 22 #include <set> 23 #include <string> 24 #include <unordered_set> 25 #include <utility> 26 #include <vector> 27 #include "tensorflow/core/common_runtime/function.h" 28 #include "tensorflow/core/common_runtime/optimization_registry.h" 29 #include "tensorflow/core/framework/node_def_util.h" 30 #include "tensorflow/core/graph/algorithm.h" 31 #include "tensorflow/core/graph/graph.h" 32 #include "tensorflow/core/graph/node_builder.h" 33 #include "tensorflow/core/lib/core/status.h" 34 #include "tensorflow/core/lib/gtl/array_slice.h" 35 #include "tensorflow/core/lib/gtl/map_util.h" 36 #include "tensorflow/core/lib/hash/hash.h" 37 #include "tensorflow/core/platform/logging.h" 38 #include "tensorflow/core/util/tensor_format.h" 39 40 #include "tensorflow/core/graph/mkl_graph_util.h" 41 #include "tensorflow/core/graph/mkl_layout_pass.h" 42 43 namespace tensorflow { 44 45 #ifdef INTEL_MKL_ML 46 47 // This pass implements rewriting of graph to support following scenarios: 48 // (A) Merging nodes in the graph 49 // (B) Rewriting a node in the graph to a new node 50 // Rewrite happens under following 2 scenarios: 51 // 1) Propagating Mkl layout as an additional output tensor 52 // (we will loosely call a tensor that carries Mkl layout as Mkl tensor 53 // henceforth.) from every Mkl supported NN layer. 54 // 2) Context-based rewrite: This is needed in order to optimize 55 // gradient ops of Conv2D+AddBias. Gradient op of both the Conv2D and 56 // MatMul is BiasAddGrad, and we need to rewrite BiasAddGrad into 57 // Conv2D-specific BiasAddGrad, and MatMul-specific BiasAddGrad. 58 // This is context-specific optimization, where the context is the 59 // forward operator that the BiasAddGrad corresponds to. 60 // 61 // Example of A : Merging nodes in the graph 62 // ----------------------------------------- 63 // Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as: 64 // 65 // O = Conv2D(A, B) 66 // P = BiasAdd(O, C) 67 // 68 // We merge them into Conv2DWithBias as: 69 // P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m) 70 // 71 // The meaning of A_m, B_m and C_m is explained in B.1. 72 // 73 // Merge rules: 74 // - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_ 75 // goes to BiasAdd. 76 // - Also, the intersection of attributes of both the nodes must have same 77 // values. 78 // - Both the nodes must have been assigned to same device (if any). 79 // 80 // Example of B.1 : Rewriting nodes to Mkl nodes 81 // --------------------------------------------- 82 // Consider a Relu node. Current definition of Relu node looks like: 83 // 84 // O = Relu(A) 85 // 86 // Relu has 1 input (A), and 1 output (O). 87 // 88 // This rewrite pass will generate a new graph node for Relu (new node is 89 // called MklRelu) as: 90 // 91 // O, O_m = MklRelu(A, A_m) 92 // 93 // MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is 94 // same as input A of Relu; output O is same as output O of Relu. O_m is the 95 // additional output tensor that will be set by MklRelu, and it represents 96 // Mkl tensor corresponding to O -- in other words, O_m is some kind of 97 // metadata for O. A_m is additional input of Relu, and it represents metadata 98 // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives 99 // this metadata from previous node in the graph. 100 // 101 // When a previous node in the graph is an Mkl node, A_m will represent a valid 102 // Mkl tensor. But when a previous node is not an Mkl node, A_m will represent 103 // a dummy Mkl tensor. 104 // 105 // Rewriting rules: 106 // - Selection of a node for rewriting happens by registering the op type of 107 // the node with the rewriting pass. If the op type is not registered, then 108 // all nodes of this op type will not be rewritten. 109 // - Number of inputs after rewriting: 110 // Since for every input Tensorflow tensor, the rewritten node gets Mkl 111 // tensor(s), rewritten node gets 2*N inputs, where N is the number of 112 // inputs for the original node. 113 // - Number of outputs after rewriting: 114 // Since for every output Tensorflow tensor, the rewritten node generates 115 // Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the 116 // number of outputs of the original node. 117 // - Ordering of Tensorflow tensors and Mkl tensors: 118 // Since every rewritten node generates twice the number of inputs and 119 // outputs, one could imagine various orderings among Tensorflow tensors 120 // and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as 121 // inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m 122 // in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m 123 // order. Among N inputs one can get N! permutations. 124 // 125 // So the question is: which order do we follow? We support 2 types of 126 // orderings: (1) interleaved, and (2) contiguous. Interleaved ordering 127 // follows an intuitive order where an Mkl tensor follows the 128 // corresponding Tensorflow tensor immediately. In the context of the 129 // above example, it will be: A, A_m, B, B_m. Note that the ordering rule 130 // applies to both the inputs and outputs. Contiguous ordering means 131 // all the Tensorflow tensors are contiguous followed by all the Mkl 132 // tensors. We use contiguous ordering as default. 133 // 134 // Graph rewrite algorithm: 135 // Algorithm: Graph Rewrite 136 // Input: Graph G, Names of the nodes to rewrite and their new names 137 // Output: Modified Graph G' if the nodes are modified, G otherwise. 138 // Start: 139 // N = Topological_Sort(G) // N is a set of nodes in toposort order. 140 // foreach node n in N 141 // do 142 // if (Is_MKL_Op(n)) // Can this node accept an Mkl layout as input. 143 // then 144 // E = set of <incoming edge and its src_output slot> of n 145 // E' = {} // a new set of edges for rewritten node 146 // foreach <e,s> in E 147 // do 148 // E' U {<e,s>} // First copy edge which generates Tensorflow 149 // // tensor as it is 150 // m = Source node of edge e 151 // if Is_Rewritten(m) // Did we rewrite this node in this pass? 152 // then 153 // E' U {<m,s+1>} // If yes, then m will generate an Mkl 154 // // tensor as an additional output. 155 // else 156 // d = Generate_Dummy_Mkl_Tensor() // If not, generate a dummy 157 // // Mkl tensor. 158 // E' U {<d,0>} // The dummy Mkl tensor has only 1 output slot. 159 // fi 160 // done 161 // n' = Build_New_Node(G,new_name,E') 162 // Mark_Rewritten(n') // Mark the new node as being rewritten. 163 // fi 164 // done 165 // 166 // Explanation: 167 // For graph rewrite, we visit nodes of the input graph in the 168 // topological sort order. With this ordering, we visit nodes in the 169 // top-to-bottom fashion. We need this order because while visiting a 170 // node we want that all of its input nodes are visited and rewritten if 171 // applicable. This is because if we need to rewrite a given node 172 // then all of its input nodes need to be fixed (in other words they 173 // cannot be deleted later.) 174 // 175 // While visiting a node, we first check if the op type of the node is 176 // an Mkl op. If it is, then we rewrite that node after constructing 177 // new inputs to the node. If the op type of the node is not Mkl op, 178 // then we do not rewrite that node. 179 // 180 // Handling workspace propagation for certain ops: 181 // 182 // Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require 183 // passing of a workspace from their respective forward ops. Workspace 184 // tensors provide memory for storing results of intermediate operations 185 // which are helpful in backward propagation. TensorFlow does not have 186 // a notion of a workspace and as a result does not allow producing 187 // additional outputs from these forward ops. For these ops, we need 188 // to add 2 extra edges between forward ops and their corresponding 189 // backward ops - the first extra edge carries a workspace tensor and 190 // the second one carries an Mkl tensor for the workspace tensor. 191 // 192 // Example: 193 // 194 // Typical graph for MaxPool and its gradient looks like: 195 // 196 // A = MaxPool(T) 197 // B = MaxPoolGrad(X, A, Y) 198 // 199 // We will transform this graph to propagate the workspace as: 200 // (with the contiguous ordering) 201 // 202 // A, W, A_m, W_m = MklMaxPool(T, T_m) 203 // B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m) 204 // 205 // Here W is the workspace tensor. Transformed tensor names with the 206 // suffix _m are Mkl tensors, and this transformation has been done 207 // using the algorithm discussed earlier. The transformation for 208 // workspace propagation only adds extra outputs (W, W_m) for a forward 209 // op and connects them to the corresponding backward ops. 210 // 211 // Terms: 212 // 213 // Forward op name = name of the op in the forward pass 214 // where a workspace tensor originates (MaxPool in this example) 215 // Backward op name = name of the op in the backward pass that receives 216 // a workspace tensor from the forward op (MaxPoolGrad in the example) 217 // Slot = Position of the output or input slot that will be 218 // used by the workspace tensor (1 for MklMaxPool as W is the 2nd 219 // output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad) 220 // 221 // Question: 222 // 223 // How do we associate a backward op to a forward op? There can be more 224 // than one op with the exact same name. 225 // 226 // In this example, we associate MaxPoolGrad with MaxPool. But there 227 // could be more than one MaxPool ops. To solve this problem, we look 228 // for _direct_ edge between a forward op and a backward op (tensor A is 229 // flowing along this edge in the example). 230 // 231 // How do we transform forward and backward ops when there is no direct 232 // edge between them? In such a case, we generate dummy tensors for 233 // workspace tensors. For the example, transformation of MaxPool will 234 // be exactly same as it would be when there is a direct edge between 235 // the forward and the backward op --- it is just that MaxPool won't 236 // generate any workspace tensor. For MaxPoolGrad, the transformation 237 // will also be same, but instead of connecting W and W_m with the 238 // outputs of MaxPool, we will produce dummy tensors for them, and we 239 // will set workspace_enabled attribute to false. 240 // 241 // Example of B.2 : Context-based node rewrite 242 // ------------------------------------------- 243 // Consider BiasAddGrad op as: 244 // 245 // O = _MklConv2D(A, B, C, A_m, B_m, C_m) 246 // P = BiasAddGrad(O) 247 // 248 // Then we rewrite it as: 249 // 250 // P = Conv2DWithBiasBackpropBias(O, O_m) 251 // 252 // Rewrite of BiasAddGrad into Conv2DWithBiasBackpropBias takes place depending 253 // on the matching 'context'. The term context is loosely related to which 254 // forward op is _associated_ to BiasAddGrad. If it is _MklConv2DWithBias then 255 // we consider it Conv2D context; if it is MatMul, then it is MatMul context. 256 257 class MklLayoutRewritePass : public GraphOptimizationPass { 258 public: 259 MklLayoutRewritePass() { 260 // NOTE: names are alphabetically sorted. 261 csinfo_.addn = "AddN"; 262 csinfo_.avg_pool = "AvgPool"; 263 csinfo_.avg_pool_grad = "AvgPoolGrad"; 264 csinfo_.bias_add = "BiasAdd"; 265 csinfo_.bias_add_grad = "BiasAddGrad"; 266 csinfo_.concat = "Concat"; 267 csinfo_.concatv2 = "ConcatV2"; 268 csinfo_.conv2d = "Conv2D"; 269 csinfo_.conv2d_grad_input = "Conv2DBackpropInput"; 270 csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter"; 271 csinfo_.fused_batch_norm = "FusedBatchNorm"; 272 csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad"; 273 csinfo_.identity = "Identity"; 274 csinfo_.lrn = "LRN"; 275 csinfo_.lrn_grad = "LRNGrad"; 276 csinfo_.matmul = "MatMul"; 277 csinfo_.max_pool = "MaxPool"; 278 csinfo_.max_pool_grad = "MaxPoolGrad"; 279 csinfo_.mkl_conv2d = "_MklConv2D"; 280 csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput"; 281 csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter"; 282 csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; 283 csinfo_.mkl_conv2d_with_bias_backprop_bias = 284 "_MklConv2DWithBiasBackpropBias"; 285 csinfo_.relu = "Relu"; 286 csinfo_.relu_grad = "ReluGrad"; 287 csinfo_.reshape = "Reshape"; 288 csinfo_.split = "Split"; 289 // Element-wise ops. Ensure you also add any new ops to IsOpElementWise 290 // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the 291 // MklInputConversion op is added before it. 292 csinfo_.add = "Add"; 293 csinfo_.maximum = "Maximum"; 294 csinfo_.mul = "Mul"; 295 csinfo_.squared_difference = "SquaredDifference"; 296 csinfo_.sub = "Sub"; 297 // End - element-wise ops. See note above. 298 299 // NOTE: names are alphabetically sorted. 300 rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), 301 CopyAttrsAddN, AddNRewrite, nullptr}); 302 rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), 303 CopyAttrsDataType, AlwaysRewrite, nullptr}); 304 rinfo_.push_back({csinfo_.avg_pool, 305 mkl_op_registry::GetMklOpName(csinfo_.avg_pool), 306 CopyAttrsPooling, AlwaysRewrite, nullptr}); 307 rinfo_.push_back({csinfo_.avg_pool_grad, 308 mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad), 309 CopyAttrsPooling, AlwaysRewrite, nullptr}); 310 // BiasAddGrad gets written into Conv2DWithBiasBackpropBias depending 311 // on if context contains Conv2D. 312 rinfo_.push_back({csinfo_.bias_add_grad, 313 csinfo_.mkl_conv2d_with_bias_backprop_bias, 314 CopyAttrsBiasAddGrad, ContextMatchRewrite, 315 &biasaddgrad_conv2dwithbias_context_}); 316 // BiasAddGrad gets written into BiasAddGrad depending on if context 317 // contains MatMul. 318 rinfo_.push_back({csinfo_.bias_add_grad, csinfo_.matmul, 319 CopyAttrsBiasAddGrad, ContextMatchRewrite, 320 &biasaddgrad_matmul_context_}); 321 rinfo_.push_back({csinfo_.concat, 322 mkl_op_registry::GetMklOpName(csinfo_.concat), 323 CopyAttrsConcat, AlwaysRewrite, nullptr}); 324 rinfo_.push_back({csinfo_.concatv2, 325 mkl_op_registry::GetMklOpName(csinfo_.concatv2), 326 CopyAttrsConcatV2, AlwaysRewrite, nullptr}); 327 rinfo_.push_back({csinfo_.conv2d, 328 mkl_op_registry::GetMklOpName(csinfo_.conv2d), 329 CopyAttrsConv2D, AlwaysRewrite, nullptr}); 330 rinfo_.push_back({csinfo_.conv2d_grad_filter, 331 mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter), 332 CopyAttrsConv2D, AlwaysRewrite, nullptr}); 333 rinfo_.push_back({csinfo_.conv2d_grad_input, 334 mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input), 335 CopyAttrsConv2D, AlwaysRewrite, nullptr}); 336 rinfo_.push_back({csinfo_.fused_batch_norm, 337 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), 338 CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); 339 rinfo_.push_back( 340 {csinfo_.fused_batch_norm_grad, 341 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), 342 CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); 343 rinfo_.push_back({csinfo_.identity, 344 mkl_op_registry::GetMklOpName(csinfo_.identity), 345 CopyAttrsIdentity, AlwaysRewrite, nullptr}); 346 rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn), 347 CopyAttrsLRN, AlwaysRewrite, nullptr}); 348 rinfo_.push_back({csinfo_.lrn_grad, 349 mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), 350 CopyAttrsLRN, AlwaysRewrite, nullptr}); 351 rinfo_.push_back({csinfo_.max_pool, 352 mkl_op_registry::GetMklOpName(csinfo_.max_pool), 353 CopyAttrsPooling, NonDepthBatchWisePoolRewrite, nullptr}); 354 rinfo_.push_back({csinfo_.max_pool_grad, 355 mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad), 356 CopyAttrsPooling, AlwaysRewrite, nullptr}); 357 rinfo_.push_back({csinfo_.maximum, 358 mkl_op_registry::GetMklOpName(csinfo_.maximum), 359 CopyAttrsDataType, AlwaysRewrite, nullptr}); 360 rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul), 361 CopyAttrsDataType, AlwaysRewrite, nullptr}); 362 rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), 363 CopyAttrsDataType, AlwaysRewrite, nullptr}); 364 rinfo_.push_back({csinfo_.relu_grad, 365 mkl_op_registry::GetMklOpName(csinfo_.relu_grad), 366 CopyAttrsDataType, AlwaysRewrite, nullptr}); 367 rinfo_.push_back({csinfo_.reshape, 368 mkl_op_registry::GetMklOpName(csinfo_.reshape), 369 CopyAttrsReshape, AlwaysRewrite, nullptr}); 370 rinfo_.push_back({csinfo_.squared_difference, 371 mkl_op_registry::GetMklOpName(csinfo_.squared_difference), 372 CopyAttrsDataType, AlwaysRewrite, nullptr}); 373 rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub), 374 CopyAttrsDataType, AlwaysRewrite, nullptr}); 375 376 // Add info about which ops to add workspace edge to and the slots. 377 wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3}); 378 wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3}); 379 380 // Add a rule for merging nodes 381 minfo_.push_back({csinfo_.mkl_conv2d, csinfo_.bias_add, 0, 382 csinfo_.mkl_conv2d_with_bias}); 383 384 biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul, 385 IsBiasAddGradInMatMulContext}; 386 387 biasaddgrad_conv2dwithbias_context_ = { 388 csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias, 389 IsBiasAddGradInConv2DWithBiasContext}; 390 391 cinfo_.push_back(&biasaddgrad_matmul_context_); 392 cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_); 393 } 394 395 // Standard interface to run pass 396 Status Run(const GraphOptimizationPassOptions& options); 397 398 // Helper function which does most of heavy lifting for rewriting 399 // Mkl nodes to propagate Mkl tensor as additional output 400 // 401 // Extracts common functionality between Run public interface and 402 // test interface. 403 // 404 // @return true, if and only if graph is mutated; false otherwise. 405 bool RunPass(std::unique_ptr<Graph>* g); 406 407 /// Structure to specify the context information used in a node rewrite rule 408 typedef struct { 409 string node; // Name of the node to be rewritten 410 string fwd; // Name of the node in the forward pass that this node 411 // corresponds to 412 std::function<bool(const Node*, const Node**, void* c)> context_match_fn; 413 } ContextInfo; 414 415 /// Structure to specify the name of an original node, its new name after 416 /// rewrite, the number of inputs to the original node, the function to 417 /// be used to copy attributes for the op, and the rule (if any) which 418 /// must hold for rewriting the node 419 typedef struct { 420 string name; // Original name of op of the node in the graph 421 string new_name; // New name of the op of the node in the graph 422 // A function handler to copy attributes from an old node to a new node. 423 std::function<void(const Node*, NodeBuilder*)> copy_attrs; 424 // A rule under which to rewrite this node 425 std::function<bool(const Node*, const ContextInfo* c)> rewrite_rule; 426 // ContextInfo, if any, to be used for rewrite 427 ContextInfo* context; 428 } RewriteInfo; 429 430 /// Structure to specify a forward op, a backward op, and the slot numbers 431 /// in the forward and backward ops where we will add a workspace edge. 432 typedef struct { 433 string fwd_op; // Name of a forward op in the graph 434 string bwd_op; // Name of a backward op in the graph 435 int fwd_slot; // Output slot in the forward op node where actual 436 // output tensor resides 437 int bwd_slot; // Input slot in the backward op node where actual 438 // input tensor resides 439 int ws_fwd_slot; // Output slot in the forward op node where workspace 440 // edge is added 441 int ws_bwd_slot; // Input slot in the backward op node where workspace 442 // edge is added 443 } WorkSpaceInfo; 444 445 /// Structure to specify information used in node merge 446 typedef struct { 447 string pred; // Predecessor node string 448 string succ; // Successor node string 449 int op; // The operand no the predecessor node corresponds 450 // to the successor node 451 string new_node; // Name of the node after merge 452 } MergeInfo; 453 454 /// Structure to store all constant strings 455 /// NOTE: names are alphabetically sorted. 456 typedef struct { 457 string addn; 458 string add; 459 string avg_pool; 460 string avg_pool_grad; 461 string bias_add; 462 string bias_add_grad; 463 string concat; 464 string concatv2; 465 string conv2d; 466 string conv2d_grad_input; 467 string conv2d_grad_filter; 468 string fused_batch_norm; 469 string fused_batch_norm_grad; 470 string identity; 471 string lrn; 472 string lrn_grad; 473 string matmul; 474 string max_pool; 475 string max_pool_grad; 476 string maximum; 477 string mkl_conv2d; 478 string mkl_conv2d_grad_input; 479 string mkl_conv2d_grad_filter; 480 string mkl_conv2d_with_bias; 481 string mkl_conv2d_with_bias_backprop_bias; 482 string mul; 483 string relu; 484 string relu_grad; 485 string reshape; 486 string split; 487 string squared_difference; 488 string sub; 489 } ConstStringsInfo; 490 491 private: 492 /// Maintain info about nodes to rewrite 493 std::vector<RewriteInfo> rinfo_; 494 495 /// Maintain info about nodes to add workspace edge 496 std::vector<WorkSpaceInfo> wsinfo_; 497 498 /// Maintain info about nodes to be merged 499 std::vector<MergeInfo> minfo_; 500 501 /// Maintain info about nodes to rewrite 502 static std::vector<ContextInfo*> cinfo_; 503 504 /// Maintain structure of constant strings 505 static ConstStringsInfo csinfo_; 506 507 /// Context variables used in referencing rules 508 static ContextInfo biasaddgrad_matmul_context_; 509 static ContextInfo biasaddgrad_conv2dwithbias_context_; 510 511 private: 512 // Is OpDef::ArgDef a list type? It could be N * T or list(type). 513 // Refer to opdef.proto for details of list type. 514 inline bool ArgIsList(const OpDef::ArgDef& arg) const { 515 return !arg.type_list_attr().empty() || !arg.number_attr().empty(); 516 } 517 518 // Get length of a list in 'n' if 'arg' is of list type. Refer to 519 // description of ArgIsList for definition of list type. 520 inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) { 521 CHECK_EQ(ArgIsList(arg), true); 522 int N = 0; 523 const string attr_name = !arg.type_list_attr().empty() 524 ? arg.type_list_attr() 525 : arg.number_attr(); 526 if (!arg.type_list_attr().empty()) { 527 std::vector<DataType> value; 528 TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value)); 529 N = value.size(); 530 } else { 531 TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N)); 532 } 533 return N; 534 } 535 536 // Can op represented by node 'n' run on DEVICE_CPU? 537 // Op can run on CPU with MKL if the runtime assigned device or the 538 // user requested device contains device CPU, or both are empty. 539 bool CanOpRunOnCPUDevice(const Node* n) { 540 bool result = true; 541 string reason; 542 543 // Substring that should be checked for in device name for CPU device. 544 const char* const kCPUDeviceSubStr = "CPU"; 545 546 // If Op has been specifically assigned to a non-CPU device, then No. 547 if (!n->assigned_device_name().empty() && 548 !StringPiece(n->assigned_device_name()).contains(kCPUDeviceSubStr)) { 549 result = false; 550 reason = "Op has been assigned a runtime device that is not CPU."; 551 } 552 553 // If user has specifically assigned this op to a non-CPU device, then No. 554 if (!n->def().device().empty() && 555 !StringPiece(n->def().device()).contains(kCPUDeviceSubStr)) { 556 result = false; 557 reason = "User has assigned a device that is not CPU."; 558 } 559 560 if (result == false) { 561 VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node " 562 << n->type_string() << ", reason: " << reason; 563 } 564 565 // Otherwise Yes. 566 return result; 567 } 568 569 // Return a node that can be merged with input node 'n' 570 // 571 // @return pointer to the node if we can find such a 572 // node. Otherwise, it returns nullptr. 573 Node* CheckForNodeMerge(const Node* n) const; 574 575 // Merge predecessor node with its successor. 576 // Currently, we merge Conv2D with BiasAdd only. 577 // 578 // Input nodes succ and pred may be deleted if the call to 579 // this function is successful. Attempt to use the pointers 580 // after the call to function may result in undefined behaviors. 581 // 582 // @input g - input graph, succ - successor node, pred - predecessor node 583 // @return Status::OK(), if merging is successful and supported. 584 // Returns appropriate Status error code otherwise. 585 // Graph is updated in case nodes are merged. Otherwise, it is 586 // not updated. 587 Status MergeNode(std::unique_ptr<Graph>* g, Node* succ, Node* pred); 588 589 // Check if the node 'n' has any applicable rewrite rule 590 // We check for 2 scenarios for rewrite. 591 // 592 // @return RewriteInfo* for the applicable rewrite rule 593 const RewriteInfo* CheckForNodeRewrite(const Node* n) const; 594 595 // Default rewrite rule to be used in scenario 1 for rewrite. 596 // @return - true (since we want to always rewrite) 597 static bool AlwaysRewrite(const Node* n, const ContextInfo* c = nullptr) { 598 return true; 599 } 600 601 // Check if we are performing pooling on depth or batch. If it is, then we 602 // do not rewrite MaxPool node to Mkl version. 603 // @return - true (if it is not a depth/batch wise pooling case); 604 // false otherwise. 605 static bool NonDepthBatchWisePoolRewrite(const Node* n, 606 const ContextInfo* c) { 607 CHECK_NOTNULL(n); 608 609 string data_format_str; 610 TensorFormat data_format; 611 std::vector<int32> ksize, strides; 612 CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); 613 CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); 614 CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true); 615 CHECK_EQ(FormatFromString(data_format_str, &data_format), true); 616 617 // Condition that specifies non-batch-wise and non-depth-wise pooling. 618 if (GetTensorDim(ksize, data_format, 'N') == 1 && 619 GetTensorDim(strides, data_format, 'N') == 1 && 620 GetTensorDim(ksize, data_format, 'C') == 1 && 621 GetTensorDim(strides, data_format, 'C') == 1) { 622 return true; 623 } 624 625 return false; 626 } 627 628 static bool AddNRewrite(const Node* n, const ContextInfo* c) { 629 CHECK_NOTNULL(n); 630 631 int num; 632 CHECK_EQ(GetNodeAttr(n->def(), "N", &num).ok(), true); 633 634 // Condition that specifies non-batch-wise and non-depth-wise pooling. 635 if (num == 2) { 636 return true; 637 } 638 639 return false; 640 } 641 // Is BiasAddGrad node in 'n' is associated with Conv2DWithBias node 642 // specified in contextinfo 'ci'. Function updates fwd_node to point 643 // to Conv2DWithBias node if 'n' is associated with Conv2DWithBias. 644 // 645 // Association checks for one of the following graphs: 646 // 647 // Graph A: 648 // 649 // _ = Conv2DWithBias(F, I, _) 650 // .. 651 // _ = Conv2DBackpropFilter(F, _, G) 652 // _ = Conv2DBackpropInput(_, I, G) 653 // _ = BiasAddGrad(G) 654 // 655 // OR 656 // 657 // Graph B: 658 // 659 // _ = Conv2DWithBias(F, _, _) 660 // .. 661 // _ = Conv2DBackpropFilter(F, _, G) 662 // _ = BiasAddGrad(G) 663 // 664 // Here F, G, and I are graph nodes; _ represents graph nodes that we 665 // don't care here. 666 // 667 // @return - true (if BiasAddGrad is associated with Conv2DWithBias); 668 // false otherwise. 669 static bool IsBiasAddGradInConv2DWithBiasContext(const Node* n, 670 const Node** fwd_node, 671 void* ci) { 672 CHECK_NOTNULL(n); 673 CHECK_NOTNULL(fwd_node); 674 CHECK_NOTNULL(ci); 675 *fwd_node = nullptr; 676 677 CHECK_EQ(n->type_string(), csinfo_.bias_add_grad); 678 679 // Get the only 1 input of BiasAddGrad. 680 CHECK_EQ(n->num_inputs(), 1); 681 const Node* bias_add_grad_inp = nullptr; 682 TF_CHECK_OK(n->input_node(0, &bias_add_grad_inp)); 683 CHECK_NOTNULL(bias_add_grad_inp); 684 685 // Check if this input also goes to BackpropFilter and BackpropInput 686 // as 3rd input. 687 bool found_backprop_input = false; 688 bool found_backprop_filter = false; 689 Node* backprop_filter_node = nullptr; 690 Node* backprop_input_node = nullptr; 691 692 for (const Edge* e : bias_add_grad_inp->out_edges()) { 693 Node* third_input = nullptr; 694 if (e->dst()->type_string() == csinfo_.conv2d_grad_input || 695 e->dst()->type_string() == csinfo_.mkl_conv2d_grad_input) { 696 // Third input (index 2) of BackpropInput 697 TF_CHECK_OK(e->dst()->input_node(2, &third_input)); 698 // Third input (index 2) of BackpropInput must be same as the input 699 // of BiasAddGrad. 700 if (third_input == bias_add_grad_inp) { 701 found_backprop_input = true; 702 backprop_input_node = e->dst(); 703 } 704 } 705 706 if (e->dst()->type_string() == csinfo_.conv2d_grad_filter || 707 e->dst()->type_string() == csinfo_.mkl_conv2d_grad_filter) { 708 // Third input (index 2) of BackpropFilter 709 TF_CHECK_OK(e->dst()->input_node(2, &third_input)); 710 // Third input (index 2) of BackpropFilter must be same as the input 711 // of BiasAddGrad. 712 if (third_input == bias_add_grad_inp) { 713 found_backprop_filter = true; 714 backprop_filter_node = e->dst(); 715 } 716 } 717 718 // If we found both the nodes, then we can stop the search. 719 if (found_backprop_input && found_backprop_filter) { 720 break; 721 } 722 } 723 724 // If BackpropFilter node is not found, then this is not 725 // Conv2DWithBias context. For 2nd graph in the example above, only 726 // BackpropFilter would be present. 727 if (!found_backprop_filter) { 728 return false; 729 } 730 731 // Otherwise, we found the nodes. 732 CHECK_NOTNULL(backprop_filter_node); 733 if (found_backprop_input) { 734 CHECK_NOTNULL(backprop_input_node); 735 } 736 737 // Now that we confirmed that this is Conv2DWithBias context, we need to 738 // get access to the forward node (Conv2DWithBias). 2nd input of 739 // Conv2DWithBias is same as the 2nd input of Conv2DBackpropInput; 1st 740 // input of Conv2DWithBias is same as the 1st input of Conv2DBackpropFilter 741 // (This comes from definition of gradient computation for Conv2D). 742 if (found_backprop_input) { 743 // Graph A in the example. 744 Node* second_inp_of_input = nullptr; 745 Node* first_inp_of_filter = nullptr; 746 TF_CHECK_OK(backprop_input_node->input_node(1, &second_inp_of_input)); 747 TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter)); 748 CHECK_NOTNULL(second_inp_of_input); 749 CHECK_NOTNULL(first_inp_of_filter); 750 751 // Now we need to find out Conv2DWithBias node from these input nodes. 752 // Conv2DWithBias node is the node that accepts both the nodes 753 // second_inp_of_input and first_inp_of_filter in 2nd and 1st input slots. 754 for (const Edge* fe : first_inp_of_filter->out_edges()) { 755 if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && 756 fe->dst_input() == 0) { 757 for (const Edge* ie : second_inp_of_input->out_edges()) { 758 if (ie->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && 759 ie->dst_input() == 1 && fe->dst() == ie->dst()) { 760 VLOG(1) << "MklLayoutRewritePass: found " 761 << fe->dst()->DebugString() 762 << " as the forward node for matching context, backward" 763 << " node is: " << n->DebugString(); 764 *fwd_node = fe->dst(); 765 return true; 766 } 767 } 768 } 769 } 770 } else { 771 // We did not find BackpropInput, so we work with BackpropFilter only. 772 // Graph B in the example. 773 Node* first_inp_of_filter = nullptr; 774 TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter)); 775 CHECK_NOTNULL(first_inp_of_filter); 776 777 // Now we need to find out Conv2DWithBias node from first input of 778 // BackpropFIlter. Conv2DWithBias node is the node that accepts 779 // first_inp_of_filter in 1st input slot. 780 for (const Edge* fe : first_inp_of_filter->out_edges()) { 781 if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && 782 fe->dst_input() == 0) { 783 VLOG(1) << "MklLayoutRewritePass: found " << fe->dst()->DebugString() 784 << " as the forward node for matching context, backward" 785 << " node is: " << n->DebugString(); 786 *fwd_node = fe->dst(); 787 return true; 788 } 789 } 790 } 791 792 return false; 793 } 794 795 // Is BiasAddGrad node in 'n' is associated with MatMul node 796 // specified in contextinfo 'ci'. Function does not update fwd_node. 797 // 798 // @return - true (if BiasAddGrad is associated with MatMul); 799 // false otherwise. 800 static bool IsBiasAddGradInMatMulContext(const Node* n, const Node** fwd_node, 801 void* ci) { 802 return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci)); 803 } 804 805 // Rewrite rule that uses context-information for matching, 806 // used in scenario 2. 807 // 808 // @input - Node 'n' for which to search for matching context 809 // @input - The context 'c' under which to rewrite 810 // @return - true if we can rewrite node under context 'c'; 811 // false otherwise. 812 static bool ContextMatchRewrite(const Node* n, const ContextInfo* c); 813 814 // Helper function that searches the matching contextinfo for the node. 815 // 816 // @input n - Node (gradient op) whose contextinfo is to be searched, 817 // fwd_node - pointer to node from the forward pass that this node 818 // belongs to. fwd_node cannot be NULL. 819 // @return Matching contextinfo in case a match is found; null otherwise. 820 // Also updates *fwd_node with pointer to forward node that this 821 // context matches. 822 static const ContextInfo* SearchMatchingContext(const Node* n, 823 const Node** fwd_node); 824 825 // Rewrites input node to a new node specified by its matching rewrite info. 826 // 827 // Method first searches matching rewrite info for input node and then 828 // uses that info to rewrite. 829 // 830 // Input node may be deleted in case of rewrite. Attempt to use the node 831 // after the call can result in undefined behaviors. 832 // 833 // @input g - input graph, n - Node to be rewritten, 834 // ri - matching rewriteinfo 835 // @return Status::OK(), if the input node is rewritten; 836 // Returns appropriate Status error code otherwise. 837 // Graph is updated in case the input node is rewritten. 838 // Otherwise, it is not updated. 839 Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri); 840 841 // Get nodes that will feed a list of TF tensors to the new 842 // node that we are constructing. 843 // 844 // @input g - input graph, 845 // @input inputs - inputs to old node that we are using for constructing 846 // new inputs, 847 // @input input_idx - the index in the 'inputs' vector pointing to the 848 // current input that we have processed so far 849 // @output input_idx - index will be incremented by the number of nodes 850 // from 'inputs' that are processed 851 // @input list_length - The expected length of list of TF tensors 852 // @output output_nodes - the list of new nodes creating TF tensors 853 // 854 // @return None 855 void GetNodesProducingTFTensorList( 856 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, 857 int* input_idx, int list_length, 858 std::vector<NodeBuilder::NodeOut>* output_nodes); 859 860 // Get nodes that will feed a list of Mkl tensors to the new 861 // node that we are constructing. 862 // 863 // @input g - input graph, 864 // @input orig_node - Original node that we are rewriting 865 // @input inputs - inputs to old node that we are using for constructing 866 // new inputs, 867 // @input input_idx - the index in the 'inputs' vector pointing to the 868 // current input that we have processed so far 869 // @output input_idx - index will be incremented by the number of nodes 870 // from 'inputs' that are processed 871 // @input list_length - The expected length of list of Mkl tensors 872 // @output output_nodes - the list of new nodes creating Mkl tensors 873 // 874 // @return None 875 void GetNodesProducingMklTensorList( 876 std::unique_ptr<Graph>* g, Node* orig_node, 877 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, 878 int* input_idx, int list_length, 879 std::vector<NodeBuilder::NodeOut>* output_nodes); 880 881 // Get a node that will feed an Mkl tensor to the new 882 // node that we are constructing. The output node could be (1) 'n' 883 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor 884 // if 'n' is not an Mkl layer. 885 // 886 // @input g - input graph, 887 // @input orig_node - Original node that we are rewriting, 888 // @input n - Node based on which we are creating Mkl node, 889 // @input n_output_slot - the output slot of node 'n' 890 // which is feeding to the node that we are constructing 891 // @output mkl_node - the new node that will feed Mkl tensor 892 // @output mkl_node_output_slot - the slot number of mkl_node that 893 // will feed the tensor 894 // @return None 895 void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node, 896 Node* n, int n_output_slot, Node** mkl_node, 897 int* mkl_node_output_slot); 898 899 // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' 900 // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are 901 // set up in contiguous fashion. 'workspace_tensors' carry graph nodes 902 // producing workspace edges if 'are_workspace_tensors_available' is true. 903 // Otherwise, 'workspace_tensors' is empty vector. 904 // 905 // For details, refer to 'Ordering of inputs after rewriting' section in the 906 // documentation above. 907 // 908 // Returns Status::OK() if setting up inputs is successful, otherwise 909 // returns appropriate status code. 910 int SetUpContiguousInputs( 911 std::unique_ptr<Graph>* g, 912 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, 913 NodeBuilder* nb, Node* old_node, 914 std::vector<NodeBuilder::NodeOut>* workspace_tensors, 915 bool are_workspace_tensors_available); 916 917 // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' 918 // in graph 'g'. Original node is input in 'orig_node'. 919 // 920 // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors' 921 // section in the documentation above. 922 // 923 // Returns Status::OK() if setting up inputs is successful, otherwise 924 // returns appropriate status code. 925 Status SetUpInputs(std::unique_ptr<Graph>* g, 926 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, 927 NodeBuilder* nb, Node* orig_node); 928 929 // Add workspace edge on the input or output side of Node 'orig_node' by using 930 // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate 931 // adding workspace edge then do not add it. Workspace Tensorflow and Mkl 932 // tensors, if they need to be added, will be set into these tensors. 933 // If we set workspace tensors, then are_ws_tensors_added should be true. 934 void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orig_node, 935 NodeBuilder* nb, 936 std::vector<NodeBuilder::NodeOut>* ws_tensors, 937 bool* are_ws_tensors_added); 938 939 // Functions specific to operators to copy attributes 940 // We need operator-specific function to copy attributes because the framework 941 // does not provide any generic function for it. 942 // NOTE: names are alphabetically sorted. 943 static void CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb); 944 static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb); 945 static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb); 946 static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb); 947 static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb); 948 static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb); 949 static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb); 950 static void CopyAttrsIdentity(const Node* orig_node, NodeBuilder* nb); 951 static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb); 952 static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb); 953 static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb); 954 static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb); 955 956 // Generate a graph node in graph 'g' representing a dummy Mkl tensor node, 957 // using node for original node 'orig_node' and return it in '*out'. 958 // TODO(nhasabni) We should move this to mkl_util.h 959 void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out, 960 Node* orig_node); 961 void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out, 962 Node* orig_node); 963 }; 964 965 MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_; 966 MklLayoutRewritePass::ContextInfo 967 MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_; 968 MklLayoutRewritePass::ContextInfo 969 MklLayoutRewritePass::biasaddgrad_matmul_context_; 970 std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_; 971 972 // We register Mkl rewrite pass for phase 1 in post partitioning group. 973 // We register it here so that we get a complete picture of all users of Mkl 974 // nodes. Do not change the ordering of the Mkl passes. 975 const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup = 976 OptimizationPassRegistry::POST_PARTITIONING; 977 REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass); 978 979 ////////////////////////////////////////////////////////////////////////// 980 // Helper functions for creating new node 981 ////////////////////////////////////////////////////////////////////////// 982 983 static void FillInputs(const Node* n, 984 gtl::InlinedVector<Node*, 4>* control_edges, 985 gtl::InlinedVector<std::pair<Node*, int>, 4>* in) { 986 control_edges->clear(); 987 for (const Edge* e : n->in_edges()) { 988 if (e->IsControlEdge()) { 989 control_edges->push_back(e->src()); 990 } else { 991 (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output()); 992 } 993 } 994 std::sort(control_edges->begin(), control_edges->end()); 995 if (n->op_def().is_commutative()) { 996 // For commutative inputs, we sort the input by the input Node* 997 // to get a canonical ordering (so that add(a,b) and add(b, a) will 998 // hash to the same value if is_commutative is true for 'add'). 999 std::sort(in->begin(), in->end()); 1000 } 1001 } 1002 1003 void MklLayoutRewritePass::GetNodesProducingTFTensorList( 1004 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, 1005 int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { 1006 CHECK_LT(*input_idx, inputs.size()); 1007 CHECK_GT(list_length, 0); 1008 CHECK_NOTNULL(output_nodes); 1009 output_nodes->reserve(list_length); 1010 1011 while (list_length != 0) { 1012 CHECK_GT(list_length, 0); 1013 CHECK_LT(*input_idx, inputs.size()); 1014 Node* n = inputs[*input_idx].first; 1015 int slot = inputs[*input_idx].second; 1016 // If input node 'n' is just producing a single tensor at 1017 // output slot 'slot' then we just add that single node. 1018 output_nodes->push_back(NodeBuilder::NodeOut(n, slot)); 1019 (*input_idx)++; 1020 list_length--; 1021 } 1022 } 1023 1024 // TODO(nhasabni) We should move this to mkl_util.h. 1025 void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, 1026 Node** out, Node* orig_node) { 1027 // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent 1028 // dummy Mkl tensor. 8 = 2*size_t. 1029 const DataType dt = DataTypeToEnum<uint8>::v(); 1030 TensorProto proto; 1031 proto.set_dtype(dt); 1032 uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; 1033 proto.set_tensor_content(const_cast<const void*>(static_cast<void*>(&zero)), 1034 8); 1035 TensorShape dummy_shape({8}); 1036 dummy_shape.AsProto(proto.mutable_tensor_shape()); 1037 TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") 1038 .Attr("value", proto) 1039 .Attr("dtype", dt) 1040 .Device(orig_node->def().device()) // We place this node on 1041 // the same device as the 1042 // device of the original 1043 // node. 1044 .Finalize(&**g, out)); 1045 1046 // If number of inputs to the original node is > 0, then we add 1047 // control dependency between 1st input (index 0) of the original node and 1048 // the dummy Mkl node. This is needed because control-flow ops such as Enter, 1049 // Merge, etc, require frame_name of the dummy Mkl node to be same as the 1050 // rewritten node. Adding control edge between 1st input of the original node 1051 // and the dummy Mkl node ensures that the dummy node is in the same frame 1052 // as the original node. Choosing 1st input is not necessary - any input of 1053 // the original node is fine because all the inputs of a node are always in 1054 // the same frame. 1055 if (orig_node->num_inputs() > 0) { 1056 Node* orig_input0 = nullptr; 1057 TF_CHECK_OK( 1058 orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); 1059 CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); 1060 } 1061 1062 (*out)->set_assigned_device_name(orig_node->assigned_device_name()); 1063 } 1064 1065 void MklLayoutRewritePass::GetNodesProducingMklTensorList( 1066 std::unique_ptr<Graph>* g, Node* orig_node, 1067 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, 1068 int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { 1069 CHECK_LT(*input_idx, inputs.size()); 1070 CHECK_GT(list_length, 0); 1071 CHECK_NOTNULL(output_nodes); 1072 output_nodes->reserve(list_length); 1073 1074 while (list_length != 0) { 1075 CHECK_GT(list_length, 0); 1076 CHECK_LT(*input_idx, inputs.size()); 1077 Node* n = inputs[*input_idx].first; 1078 int slot = inputs[*input_idx].second; 1079 // If 'n' is producing a single tensor, then create a single Mkl tensor 1080 // node. 1081 Node* mkl_node = nullptr; 1082 int mkl_node_output_slot = 0; 1083 GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, 1084 &mkl_node_output_slot); 1085 output_nodes->push_back( 1086 NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); 1087 (*input_idx)++; 1088 list_length--; 1089 } 1090 } 1091 1092 // Get an input node that will feed Mkl tensor to the new 1093 // node that we are constructing. An input node could be (1) 'n' 1094 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor 1095 // if 'n' is not an Mkl layer. 1096 void MklLayoutRewritePass::GetNodeProducingMklTensor( 1097 std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot, 1098 Node** mkl_node, int* mkl_node_output_slot) { 1099 CHECK_NOTNULL(n); 1100 CHECK_NOTNULL(mkl_node); 1101 CHECK_NOTNULL(mkl_node_output_slot); 1102 1103 // If this is an MKL op, then it will create extra output for MKL layout. 1104 DataType T; 1105 if (GetNodeAttr(n->def(), "T", &T).ok() && 1106 mkl_op_registry::IsMklOp(n->type_string(), T)) { 1107 // If this is an MKL op, then it will generate an edge that will receive 1108 // Mkl tensor from a node. 1109 // output slot number for Mkl tensor would be N+slot number of TensorFlow 1110 // tensor, where N is total number of TensorFlow tensors. 1111 *mkl_node = n; 1112 *mkl_node_output_slot = 1113 GetTensorMetaDataIndex(n_output_slot, n->num_outputs()); 1114 } else { 1115 // If we have not visited the node and rewritten it, then we need 1116 // to create a dummy node that will feed a dummy Mkl tensor to this node. 1117 // DummyMklTensor node has no input and generates only 1 output 1118 // (dummy Mkl tensor) as output slot number 0. 1119 GetDummyMklTensorNode(g, mkl_node, orig_node); 1120 CHECK_NOTNULL(*mkl_node); 1121 *mkl_node_output_slot = 0; 1122 } 1123 } 1124 1125 int MklLayoutRewritePass::SetUpContiguousInputs( 1126 std::unique_ptr<Graph>* g, 1127 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, 1128 NodeBuilder* nb, Node* old_node, 1129 std::vector<NodeBuilder::NodeOut>* workspace_tensors, 1130 bool are_workspace_tensors_available) { 1131 CHECK_NOTNULL(workspace_tensors); 1132 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); 1133 1134 // TODO(nhasabni): Temporary solution to connect filter input of 1135 // BackpropInput with the converted filter from Conv2D. 1136 bool do_connect_conv2d_backprop_input_filter = false; 1137 Node* conv2d_node = nullptr; 1138 // Filter node is 2nd input (slot index 1) of Conv2D. 1139 int kConv2DFilterInputSlotIdx = 1; 1140 int kConv2DBackpropInputFilterInputSlotIdx = 1; 1141 int kConv2DFilterOutputSlotIdx = 1; 1142 if (old_node->type_string() == csinfo_.conv2d_grad_input) { 1143 // We need to find Conv2D node from Conv2DBackpropInput. 1144 // For that let's first find filter node that is 2nd input (slot 1) 1145 // of BackpropInput. 1146 Node* filter_node = nullptr; 1147 old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx, &filter_node); 1148 CHECK_NOTNULL(filter_node); 1149 1150 // Now check which nodes receive from filter_node. Filter feeds as 1151 // 2nd input (slot 1) of _MklConv2D and _MklConv2DWithBias. 1152 for (const Edge* e : filter_node->out_edges()) { 1153 if (e->dst()->type_string() == csinfo_.mkl_conv2d && 1154 e->dst_input() == kConv2DFilterInputSlotIdx 1155 /* filter is 2nd input of Conv2D and _MklConv2D. */) { 1156 if (conv2d_node != nullptr) { 1157 VLOG(1) << "MklLayoutRewritePass: unusual case of same filter" 1158 << " feeding multiple Conv2D nodes: " 1159 << filter_node->DebugString(); 1160 // We will not connect filter input of Conv2DBackpropInput 1161 // to be safe here. 1162 do_connect_conv2d_backprop_input_filter = false; 1163 break; 1164 } else { 1165 conv2d_node = e->dst(); 1166 do_connect_conv2d_backprop_input_filter = true; 1167 } 1168 } 1169 } 1170 } 1171 1172 // Number of input slots to original op 1173 // Input slots are represented by .Input() calls in REGISTER_OP. 1174 int old_node_input_slots = old_node->op_def().input_arg_size(); 1175 // Actual number of inputs can be greater than or equal to number 1176 // of Input slots because inputs of type list could be unfolded. 1177 CHECK_GE(old_node_inputs.size(), old_node_input_slots); 1178 int nn_slot_idx = 0; // slot index for inputs of new node 1179 1180 // Let's copy all inputs (TF tensors) of original node to new node. 1181 int iidx = 0; 1182 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { 1183 // An input slot could be a single tensor or a list. We need 1184 // to handle this case accordingly. 1185 CHECK_LT(iidx, old_node_inputs.size()); 1186 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); 1187 if (ArgIsList(arg)) { 1188 std::vector<NodeBuilder::NodeOut> new_node_inputs; 1189 int N = GetTensorListLength(arg, old_node); 1190 GetNodesProducingTFTensorList(old_node_inputs, &iidx, N, 1191 &new_node_inputs); 1192 nb->Input(new_node_inputs); 1193 nn_slot_idx++; 1194 } else { 1195 // Special case for connecting filter input of Conv2DBackpropInput 1196 if (do_connect_conv2d_backprop_input_filter && 1197 iidx == kConv2DBackpropInputFilterInputSlotIdx) { 1198 nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx); 1199 } else { 1200 nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second); 1201 } 1202 iidx++; 1203 nn_slot_idx++; 1204 } 1205 } 1206 1207 // If workspace tensors are available for this op and we are using 1208 // contiguous ordering then we need to add Tensorflow tensor for 1209 // workspace here because Tensorflow tensor for workspace is the 1210 // last tensor in the list of Tensorflow tensors. 1211 if (are_workspace_tensors_available) { 1212 CHECK_EQ(workspace_tensors->size(), 2); 1213 // Tensorflow tensor 1214 nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index); 1215 nn_slot_idx++; 1216 } 1217 1218 // Let's now setup all Mkl inputs to new node. 1219 // Number of Mkl inputs must be same as number of TF inputs. 1220 iidx = 0; 1221 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { 1222 // An input slot could be a single tensor or a list. We need 1223 // to handle this case accordingly. 1224 CHECK_LT(iidx, old_node_inputs.size()); 1225 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); 1226 if (ArgIsList(arg)) { 1227 std::vector<NodeBuilder::NodeOut> new_node_inputs; 1228 int N = GetTensorListLength(arg, old_node); 1229 GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N, 1230 &new_node_inputs); 1231 nb->Input(new_node_inputs); 1232 nn_slot_idx++; 1233 } else { 1234 Node* mkl_node = nullptr; 1235 int mkl_node_output_slot = 0; 1236 // Special case for connecting filter input of Conv2DBackpropInput 1237 if (do_connect_conv2d_backprop_input_filter && 1238 iidx == kConv2DBackpropInputFilterInputSlotIdx) { 1239 GetNodeProducingMklTensor(g, old_node, conv2d_node, 1240 kConv2DFilterOutputSlotIdx, &mkl_node, 1241 &mkl_node_output_slot); 1242 } else { 1243 GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first, 1244 old_node_inputs[iidx].second, &mkl_node, 1245 &mkl_node_output_slot); 1246 } 1247 nb->Input(mkl_node, mkl_node_output_slot); 1248 iidx++; 1249 nn_slot_idx++; 1250 } 1251 } 1252 1253 // If workspace tensors are available for this op and we are using 1254 // contiguous ordering then we need to add Mkl tensor for 1255 // workspace here because Mkl tensor for workspace is the 1256 // last tensor in the list of Mkl tensors. 1257 if (are_workspace_tensors_available) { 1258 CHECK_EQ(workspace_tensors->size(), 2); 1259 // Mkl tensor 1260 nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index); 1261 nn_slot_idx++; 1262 } 1263 1264 return nn_slot_idx; 1265 } 1266 1267 Status MklLayoutRewritePass::SetUpInputs( 1268 std::unique_ptr<Graph>* g, 1269 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, 1270 NodeBuilder* nb, Node* old_node) { 1271 // Let's check if we need to add workspace tensors for this node. 1272 // We add workspace edge only for MaxPool, LRN and BatchNorm. 1273 std::vector<NodeBuilder::NodeOut> workspace_tensors; 1274 bool are_workspace_tensors_available = false; 1275 AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors, 1276 &are_workspace_tensors_available); 1277 1278 int new_node_input_slots = 0; 1279 if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { 1280 // TODO(nhasabni): implement this function just for same of completion. 1281 // We do not use interleaved ordering right now. 1282 return Status( 1283 error::Code::UNIMPLEMENTED, 1284 "Interleaved ordering of tensors is currently not supported."); 1285 } else { 1286 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); 1287 new_node_input_slots = SetUpContiguousInputs( 1288 g, old_node_inputs, nb, old_node, &workspace_tensors, 1289 are_workspace_tensors_available); 1290 } 1291 1292 // Sanity check 1293 int old_node_input_slots = old_node->op_def().input_arg_size(); 1294 if (!are_workspace_tensors_available) { 1295 // If we are not adding workspace tensors for this op, then the total 1296 // number of input slots to the new node _must_ be 2 times the number 1297 // of input slots to the original node: N original Tensorflow tensors and 1298 // N for Mkl tensors corresponding to each Tensorflow tensors. 1299 CHECK_EQ(new_node_input_slots, old_node_input_slots * 2); 1300 } else { 1301 // If we are adding workspace tensors for this op, then the total 1302 // The total number of input slots to new node _must_ be 2 times the number 1303 // of input slots to the original node: N original Tensorflow tensors and 1304 // N for Mkl tensors corresponding to each Tensorflow tensors plus 2 1305 // (for workspace Tensorflow tensor and workspace Mkl tensor). 1306 CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2); 1307 } 1308 1309 return Status::OK(); 1310 } 1311 1312 ////////////////////////////////////////////////////////////////////////// 1313 // Helper functions related to workspace pass 1314 ////////////////////////////////////////////////////////////////////////// 1315 1316 // TODO(nhasabni) We should move this to mkl_util.h. 1317 void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( 1318 std::unique_ptr<Graph>* g, Node** out, Node* orig_node) { 1319 // We use a tensor of shape {1} and value 0 to represent 1320 // dummy float tensor. We need this as a dummy workspace tensor. 1321 // Workspace tensor has type float. 1322 const DataType dt = DataTypeToEnum<float>::v(); 1323 TensorProto proto; 1324 proto.set_dtype(dt); 1325 float zero[1] = {0}; 1326 proto.set_tensor_content(const_cast<const void*>(static_cast<void*>(&zero)), 1327 4); 1328 TensorShape dummy_shape({1}); 1329 dummy_shape.AsProto(proto.mutable_tensor_shape()); 1330 TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") 1331 .Attr("value", proto) 1332 .Attr("dtype", dt) 1333 .Device(orig_node->def().device()) // We place this node on 1334 // same the device as the 1335 // device of the original 1336 // node. 1337 .Finalize(&**g, out)); 1338 1339 // If number of inputs to the original node is > 0, then we add 1340 // control dependency between 1st input (index 0) of the original node and 1341 // the dummy Mkl node. This is needed because control-flow ops such as Enter, 1342 // Merge, etc, require frame_name of the dummy Mkl node to be same as the 1343 // rewritten node. Adding control edge between 1st input of the original node 1344 // and the dummy Mkl node ensures that the dummy node is in the same frame 1345 // as the original node. Choosing 1st input is not necessary - any input of 1346 // the original node is fine because all the inputs of a node are always in 1347 // the same frame. 1348 if (orig_node->num_inputs() > 0) { 1349 Node* orig_input0 = nullptr; 1350 TF_CHECK_OK( 1351 orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); 1352 CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); 1353 } 1354 1355 (*out)->set_assigned_device_name(orig_node->assigned_device_name()); 1356 } 1357 1358 void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( 1359 std::unique_ptr<Graph>* g, Node* orig_node, NodeBuilder* nb, 1360 std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) { 1361 bool workspace_edge_added = false; // Default initializer 1362 CHECK_NOTNULL(are_ws_tensors_added); 1363 *are_ws_tensors_added = false; // Default initializer 1364 1365 DataType T; 1366 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 1367 for (auto ws : wsinfo_) { 1368 if (orig_node->type_string() == ws.fwd_op && 1369 mkl_op_registry::IsMklOp( 1370 mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { 1371 // If this op is a fwd op, then we need to check if there is an 1372 // edge from this node's fwd_slot to bwdop's bwd_slot. If there is 1373 // an edge, then we just add an attribute on this node for setting 1374 // workspace_passed to true. We don't add actual workspace edge 1375 // in this node. Actual workspace edge gets added in the backward 1376 // op for this node. 1377 for (const Edge* e : orig_node->out_edges()) { 1378 if (e->src_output() == ws.fwd_slot && 1379 e->dst()->type_string() == ws.bwd_op && 1380 e->dst_input() == ws.bwd_slot) { 1381 nb->Attr("workspace_enabled", true); 1382 VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " 1383 << orig_node->type_string(); 1384 workspace_edge_added = true; 1385 // We found the edge that we were looking for, so break. 1386 break; 1387 } 1388 } 1389 1390 if (!workspace_edge_added) { 1391 // If we are here, then we did not find backward operator for this 1392 // node. 1393 nb->Attr("workspace_enabled", false); 1394 } 1395 } else if (orig_node->type_string() == ws.bwd_op && 1396 mkl_op_registry::IsMklOp( 1397 mkl_op_registry::GetMklOpName(orig_node->type_string()), 1398 T)) { 1399 // If this op is a bwd op, then we need to add workspace edge and 1400 // it's Mkl tensor edge between its corresponding fwd op and this 1401 // op. Corresponding fwd op is specified in 'fwd_op' field of 1402 // workspace info. fwd_slot and bwd_slot in workspace info specify 1403 // an edge between which slots connect forward and backward op. 1404 // Once all these criteria match, we add a workspace edge between 1405 // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is 1406 // determined by interleaved/contiguous ordering. Function 1407 // DataIndexToMetaDataIndex tells us the location of Mkl tensor 1408 // from the location of the Tensorflow tensor. 1409 for (const Edge* e : orig_node->in_edges()) { 1410 if (e->src_output() == ws.fwd_slot && 1411 // We would have rewritten the forward op, so we need to use 1412 // GetMklOpName call to get its Mkl name. 1413 e->src()->type_string() == 1414 mkl_op_registry::GetMklOpName(ws.fwd_op) && 1415 e->dst_input() == ws.bwd_slot) { 1416 nb->Attr("workspace_enabled", true); 1417 CHECK_NOTNULL(ws_tensors); 1418 // Add workspace edge between fwd op and bwd op. 1419 ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot)); 1420 // Add Mkl tensor edge for workspace edge between fwd op and bwd op. 1421 ws_tensors->push_back(NodeBuilder::NodeOut( 1422 e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot, 1423 e->src()->num_outputs()))); 1424 *are_ws_tensors_added = true; 1425 // In terms of input ordering, we add these calls to add Input 1426 // here because workspace edge (and its Mkl tensor) is the last 1427 // edge in the fwdop and bwdop. So all inputs before workspace 1428 // tensor have been added by SetUpInputs function. 1429 VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " 1430 << orig_node->type_string(); 1431 workspace_edge_added = true; 1432 // We found the edge that we were looking for, so break. 1433 break; 1434 } 1435 } 1436 1437 // If we are here means we did not find fwd op that feeds to this 1438 // bwd op. So in this case, we need to generate dummy tensors for 1439 // workspace input and Mkl tensor for workspace, and set 1440 // workspace_enabled to false. 1441 if (!workspace_edge_added) { 1442 nb->Attr("workspace_enabled", false); 1443 Node* dmt_ws = nullptr; // Dummy tensor for workspace 1444 Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace 1445 GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node); 1446 GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node); 1447 CHECK_NOTNULL(dmt_ws); 1448 CHECK_NOTNULL(dmt_mkl_ws); 1449 CHECK_NOTNULL(ws_tensors); 1450 // We add dummy tensor as workspace tensor. 1451 ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0)); 1452 // We add dummy tensor as Mkl tensor for workspace tensor. 1453 ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0)); 1454 *are_ws_tensors_added = true; 1455 VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for " 1456 << orig_node->type_string(); 1457 } 1458 } else { 1459 // If this node does not match any workspace info, then we do not 1460 // do anything special for workspace propagation for it. 1461 } 1462 } 1463 } 1464 1465 ////////////////////////////////////////////////////////////////////////// 1466 // Op-specific functions to copy attributes from old node to new node 1467 ////////////////////////////////////////////////////////////////////////// 1468 1469 void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node, 1470 NodeBuilder* nb) { 1471 DataType T; 1472 string data_format; 1473 string padding; 1474 std::vector<int32> strides; 1475 bool use_cudnn_on_gpu; 1476 1477 // Get all attributes from old node. 1478 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 1479 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); 1480 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); 1481 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 1482 TF_CHECK_OK( 1483 GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu)); 1484 1485 // Add attributes to new node. 1486 nb->Attr("T", T); 1487 nb->Attr("strides", strides); 1488 nb->Attr("padding", padding); 1489 nb->Attr("data_format", data_format); 1490 nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu); 1491 } 1492 1493 void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node, 1494 NodeBuilder* nb) { 1495 DataType T; 1496 int N; 1497 1498 // Get all attributes from old node. 1499 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 1500 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); 1501 1502 // Add attributes to new node. 1503 nb->Attr("T", T); 1504 nb->Attr("N", N); 1505 } 1506 1507 void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node, 1508 NodeBuilder* nb) { 1509 DataType T; 1510 string data_format; 1511 std::vector<int32> strides; 1512 1513 // Get all attributes from old node. 1514 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 1515 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); 1516 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 1517 1518 // Add attributes to new node. 1519 nb->Attr("T", T); 1520 nb->Attr("strides", strides); 1521 nb->Attr("data_format", data_format); 1522 } 1523 1524 void MklLayoutRewritePass::CopyAttrsIdentity(const Node* orig_node, 1525 NodeBuilder* nb) { 1526 DataType T; 1527 1528 // Get all attributes from old node. 1529 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 1530 // Add attributes to new node. 1531 nb->Attr("T", T); 1532 } 1533 1534 void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node, 1535 NodeBuilder* nb) { 1536 DataType T; 1537 int depth_radius; 1538 float bias; 1539 float alpha; 1540 float beta; 1541 1542 // Get all attributes from old node. 1543 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 1544 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "depth_radius", &depth_radius)); 1545 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "bias", &bias)); 1546 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha)); 1547 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "beta", &beta)); 1548 1549 // Add attributes to new node. 1550 nb->Attr("T", T); 1551 nb->Attr("depth_radius", depth_radius); 1552 nb->Attr("bias", bias); 1553 nb->Attr("alpha", alpha); 1554 nb->Attr("beta", beta); 1555 } 1556 1557 void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node, 1558 NodeBuilder* nb) { 1559 DataType T; 1560 string data_format; 1561 string padding; 1562 std::vector<int32> ksize, strides; 1563 1564 // Get all attributes from old node. 1565 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 1566 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize)); 1567 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); 1568 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); 1569 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 1570 1571 // Add attributes to new node. 1572 nb->Attr("T", T); 1573 nb->Attr("ksize", ksize); 1574 nb->Attr("strides", strides); 1575 nb->Attr("padding", padding); 1576 nb->Attr("data_format", data_format); 1577 } 1578 1579 void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node, 1580 NodeBuilder* nb) { 1581 DataType T; 1582 1583 // Get all attributes from old node. 1584 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 1585 1586 // Add attributes to new node. 1587 nb->Attr("T", T); 1588 } 1589 1590 void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, 1591 NodeBuilder* nb) { 1592 DataType T; 1593 DataType Tshape; 1594 1595 // Get all attributes from old node. 1596 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 1597 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape)); 1598 // Add attributes to new node. 1599 nb->Attr("T", T); 1600 nb->Attr("Tshape", Tshape); 1601 } 1602 1603 void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node, 1604 NodeBuilder* nb) { 1605 DataType T; 1606 string data_format; 1607 int num_split; 1608 1609 // Get all attributes from old node. 1610 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 1611 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_split", &num_split)); 1612 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 1613 1614 // Add attributes to new node. 1615 nb->Attr("T", T); 1616 nb->Attr("num_split", num_split); 1617 nb->Attr("data_format", data_format); 1618 } 1619 1620 void MklLayoutRewritePass::CopyAttrsConcat(const Node* orig_node, 1621 NodeBuilder* nb) { 1622 DataType T; 1623 int N; 1624 1625 // Get all attributes from old node. 1626 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 1627 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); 1628 1629 // Add attributes to new node. 1630 nb->Attr("T", T); 1631 nb->Attr("N", N); 1632 } 1633 1634 void MklLayoutRewritePass::CopyAttrsConcatV2(const Node* orig_node, 1635 NodeBuilder* nb) { 1636 DataType T; 1637 int N; 1638 DataType tidx; 1639 1640 // Get all attributes from old node. 1641 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 1642 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); 1643 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tidx", &tidx)); 1644 1645 // Add attributes to new node. 1646 nb->Attr("T", T); 1647 nb->Attr("N", N); 1648 nb->Attr("Tidx", tidx); 1649 } 1650 1651 void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node, 1652 NodeBuilder* nb) { 1653 DataType T; 1654 float epsilon; 1655 string data_format; 1656 bool is_training; 1657 1658 // Get all attributes from old node. 1659 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 1660 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon)); 1661 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 1662 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "is_training", &is_training)); 1663 1664 // Add attributes to new node. 1665 nb->Attr("T", T); 1666 nb->Attr("epsilon", epsilon); 1667 nb->Attr("data_format", data_format); 1668 nb->Attr("is_training", is_training); 1669 } 1670 1671 ////////////////////////////////////////////////////////////////////////// 1672 // Helper functions related to node merge pass 1673 ////////////////////////////////////////////////////////////////////////// 1674 1675 Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const { 1676 // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite 1677 // once we support BiasAddGrad as Mkl layer. 1678 1679 // Search for all matching mergeinfo. 1680 // We allow more than one match for extensibility. 1681 std::vector<const MergeInfo*> matching_mi; 1682 for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) { 1683 if (a->type_string() == mi->succ) { 1684 matching_mi.push_back(&*mi); 1685 } 1686 } 1687 1688 for (const MergeInfo* mi : matching_mi) { 1689 const int N_in = a->num_inputs(); 1690 if (mi->op >= N_in) { 1691 continue; 1692 } 1693 1694 // Get the control edges and input of node 1695 gtl::InlinedVector<Node*, 4> a_control_edges; 1696 gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in); 1697 FillInputs(a, &a_control_edges, &a_in); 1698 1699 // Get operand op of the operator 1700 Node* b = nullptr; 1701 b = a_in[mi->op].first; 1702 if (b == nullptr || (b->type_string() != mi->pred)) { 1703 // NOTE: Should the first check be assert? 1704 continue; 1705 } 1706 1707 const int B_in = b->num_inputs(); 1708 gtl::InlinedVector<Node*, 4> b_control_edges; 1709 gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in); 1710 FillInputs(b, &b_control_edges, &b_in); 1711 1712 // Shouldn't merge if a and b have different control edges. 1713 if (a_control_edges != b_control_edges) { 1714 continue; 1715 } else { 1716 // We found a match. 1717 return b; 1718 } 1719 } 1720 1721 return nullptr; 1722 } 1723 1724 Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ, 1725 Node* pred) { 1726 CHECK_NOTNULL(succ); 1727 CHECK_NOTNULL(pred); 1728 1729 if (succ->type_string() == csinfo_.bias_add && 1730 pred->type_string() == csinfo_.mkl_conv2d) { 1731 // 1. Get all attributes from input nodes. 1732 DataType T_pred, T_succ; 1733 string padding; 1734 std::vector<int32> strides; 1735 string data_format_pred, data_format_succ; 1736 bool use_cudnn_on_gnu; 1737 TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred)); 1738 TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ)); 1739 TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding)); 1740 TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides)); 1741 TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred)); 1742 TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ)); 1743 TF_CHECK_OK( 1744 GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu)); 1745 // We check to ensure that data formats of both succ and pred are same. 1746 // We expect them to be same, so we can enforce this as assert. 1747 // But assert can be too strict, so we enforce this as a check. 1748 // If the check fails, then we do not merge two nodes. 1749 // We also do same check for devices. 1750 if (data_format_pred != data_format_succ || T_pred != T_succ || 1751 pred->assigned_device_name() != succ->assigned_device_name() || 1752 pred->def().device() != succ->def().device()) { 1753 return Status(error::Code::INVALID_ARGUMENT, 1754 "data_format or T attribute or devices of Conv2D and " 1755 "BiasAdd do not match. Will skip node merge optimization"); 1756 } 1757 1758 const int succ_num = succ->num_inputs(); 1759 gtl::InlinedVector<Node*, 4> succ_control_edges; 1760 gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num); 1761 FillInputs(succ, &succ_control_edges, &succ_in); 1762 1763 const int pred_num = pred->num_inputs(); 1764 gtl::InlinedVector<Node*, 4> pred_control_edges; 1765 gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num); 1766 FillInputs(pred, &pred_control_edges, &pred_in); 1767 1768 // We need to ensure that there is only 1 edge between Conv2D and AddBias. 1769 // Otherwise, merging is semantically incorrect. 1770 if (pred->out_edges().size() != 1) { 1771 return Status(error::Code::INVALID_ARGUMENT, 1772 "Conv2D has multiple outputs." 1773 "Will skip node merge optimization"); 1774 } 1775 1776 for (const Edge* e : pred->out_edges()) { 1777 if (e->dst() != succ) { 1778 return Status(error::Code::INVALID_ARGUMENT, 1779 "Conv2D does not feed to BiasAdd." 1780 "Will skip node merge optimization"); 1781 } 1782 } 1783 1784 // 2. Get inputs from both the nodes. 1785 // Find the 2 inputs from the conv and the bias from the add Bias. 1786 // Get operand 0, 1 of conv2D and their Mkl tensors. 1787 CHECK_EQ(pred->in_edges().size(), 4); // _MklConv2D must have 4 inputs. 1788 // Get operand 1 of add_bias 1789 // BiasAdd must have 2 inputs: Conv, bias 1790 CHECK_EQ(succ->in_edges().size(), 2); 1791 Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3 1792 int oper3_mkl_slot = 0; // For dummy MKL tensor node, output slot is 0. 1793 GetDummyMklTensorNode(g, &oper3_mkl, pred); // Get dummy Mkl tensor node 1794 // as BiasAdd does not have Mkl tensor as input. 1795 CHECK_NOTNULL(oper3_mkl); 1796 1797 // We will use the node name of BiasAdd as the name of new node 1798 // Build new node. We use same name as original node, but change the op 1799 // name. 1800 NodeBuilder nb(succ->name(), csinfo_.mkl_conv2d_with_bias); 1801 if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { 1802 nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D 1803 // pred_in[1] will be Mkl tensor for In1 if we follow interleaved 1804 // ordering, and it will be 2nd Tensorflow tensor for Conv2D if 1805 // we follow contiguous ordering. 1806 nb.Input(pred_in[1].first, pred_in[1].second); // Mkl for In1 1807 nb.Input(pred_in[2].first, pred_in[2].second); // In2 of Conv2D 1808 nb.Input(pred_in[3].first, pred_in[3].second); // Mkl for In2 1809 nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd 1810 nb.Input(oper3_mkl, oper3_mkl_slot); // Mkl for In2 of BiasAdd 1811 } else { 1812 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); 1813 nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D 1814 // pred_in[1] will be Mkl tensor for In1 if we follow interleaved 1815 // ordering, and it will be 2nd Tensorflow tensor for Conv2D if 1816 // we follow contiguous ordering. 1817 nb.Input(pred_in[1].first, pred_in[1].second); // In2 of Conv2D 1818 nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd 1819 nb.Input(pred_in[2].first, pred_in[2].second); // Mkl for In1 of Conv2D 1820 nb.Input(pred_in[3].first, pred_in[3].second); // Mkl for In2 of Conv2D 1821 nb.Input(oper3_mkl, oper3_mkl_slot); // Mkl for In2 of BiasAdd 1822 } 1823 1824 // Copy attributes from Conv2D to Conv2DWithBias. 1825 CopyAttrsConv2D(const_cast<const Node*>(pred), &nb); 1826 1827 // Copy the device assigned to old node to new node. 1828 nb.Device(succ->def().device()); 1829 1830 // Create node. 1831 Node* new_node; 1832 nb.Finalize(&**g, &new_node); 1833 CHECK_NOTNULL(new_node); 1834 1835 // Set the Mkl layer label for this op. 1836 new_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel); 1837 1838 // Incoming data edges from 'pred' node and 'succ' node to new 'new_node' 1839 // node are already copied in BuildNode. We handle control edges now. 1840 for (const Edge* e : pred->in_edges()) { 1841 if (e->IsControlEdge()) { 1842 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); 1843 } 1844 } 1845 for (const Edge* e : succ->in_edges()) { 1846 if (e->IsControlEdge()) { 1847 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); 1848 } 1849 } 1850 1851 // Incoming edges are fixed, we will fix the outgoing edges now. 1852 // First, we will fix outgoing control edges from 'pred' node. 1853 // We don't need to handle outgoing data edges from 'pred' node 1854 // because pred has only 1 output going to succ node (we enforced 1855 // this check for merge already). 1856 for (const Edge* e : pred->out_edges()) { 1857 if (e->IsControlEdge()) { 1858 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); 1859 } 1860 } 1861 1862 // Second, we will fix outgoing control and data edges from 'succ' node. 1863 for (const Edge* e : succ->out_edges()) { 1864 if (e->IsControlEdge()) { 1865 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); 1866 } else { 1867 CHECK_NOTNULL( 1868 (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input())); 1869 } 1870 } 1871 1872 // Copy device assigned to old node to new node. 1873 // It's ok to use pred or succ as we have enforced a check that 1874 // both have same device assigned. 1875 new_node->set_assigned_device_name(pred->assigned_device_name()); 1876 1877 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString() 1878 << ", and node: " << succ->DebugString() 1879 << ", into node:" << new_node->DebugString(); 1880 1881 (*g)->RemoveNode(succ); 1882 (*g)->RemoveNode(pred); 1883 1884 return Status::OK(); 1885 } 1886 1887 return Status(error::Code::UNIMPLEMENTED, 1888 "Unimplemented case for node merge optimization."); 1889 } 1890 1891 ////////////////////////////////////////////////////////////////////////// 1892 // Helper functions for node rewrite 1893 ////////////////////////////////////////////////////////////////////////// 1894 1895 Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, 1896 Node* orig_node, 1897 const RewriteInfo* ri) { 1898 CHECK_NOTNULL(ri); 1899 CHECK_NOTNULL(orig_node); 1900 1901 VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString(); 1902 1903 // Check if this is scenario 2 (context-based rewrite). 1904 // Get the matching ContextInfo if it is. 1905 const Node* fwd_node = nullptr; 1906 const ContextInfo* ci = nullptr; 1907 bool is_context_based_rewrite = false; 1908 if ((ci = SearchMatchingContext(orig_node, &fwd_node)) != nullptr) { 1909 is_context_based_rewrite = true; 1910 1911 // Sanity checks for context-based rewrite (if any) 1912 if (orig_node->type_string() == csinfo_.bias_add_grad && 1913 ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) { 1914 CHECK_NOTNULL(fwd_node); 1915 DataType orig_T, ctx_T; 1916 string orig_data_format, ctx_data_format; 1917 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &orig_T)); 1918 TF_CHECK_OK( 1919 GetNodeAttr(orig_node->def(), "data_format", &orig_data_format)); 1920 TF_CHECK_OK(GetNodeAttr(fwd_node->def(), "T", &ctx_T)); 1921 TF_CHECK_OK( 1922 GetNodeAttr(fwd_node->def(), "data_format", &ctx_data_format)); 1923 1924 if (orig_data_format != ctx_data_format || orig_T != ctx_T || 1925 orig_node->assigned_device_name() != 1926 fwd_node->assigned_device_name() || 1927 orig_node->def().device() != fwd_node->def().device()) { 1928 return Status( 1929 error::Code::INVALID_ARGUMENT, 1930 "data_format or T attribute or devices of BiasAddGrad and " 1931 "Conv2D do not match. Will skip node rewrite optimization"); 1932 } 1933 } else if (orig_node->type_string() == csinfo_.bias_add_grad && 1934 ri->new_name == csinfo_.matmul) { 1935 // When BiasAddGrad has MatMul in context, we do not do any rewrite 1936 // and leave BiasAddGrad as it is. But we check for this condition 1937 // when we check for node rewrite rule. So we should not even come 1938 // here for MatMul. So we will fail now. 1939 return Status( 1940 error::Code::INVALID_ARGUMENT, 1941 "No rewrite is required for BiasAddGrad for MatMul context."); 1942 } 1943 } 1944 1945 // Get all inputs. 1946 int num_inputs = orig_node->in_edges().size(); 1947 1948 // Drop count for control edges from inputs 1949 for (const Edge* e : orig_node->in_edges()) { 1950 if (e->IsControlEdge()) { 1951 num_inputs--; 1952 } 1953 } 1954 1955 gtl::InlinedVector<Node*, 4> control_edges; 1956 gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs); 1957 FillInputs(orig_node, &control_edges, &inputs); 1958 1959 // Build new node. We use same name as original node, but change the op name. 1960 NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str()); 1961 // Copy user-specified device assigned to original node to new node. 1962 nb.Device(orig_node->def().device()); 1963 // Set up new inputs to the rewritten node. 1964 Status s = SetUpInputs(g, inputs, &nb, orig_node); 1965 if (s != Status::OK()) { 1966 return s; 1967 } 1968 1969 // Copy attributes from original node to new node (for scenario 1). 1970 // For context-based rewrite, we use context to copy the attributes. 1971 if (is_context_based_rewrite) { 1972 if (orig_node->type_string() == csinfo_.bias_add_grad && 1973 ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) { 1974 CHECK_NOTNULL(fwd_node); 1975 ri->copy_attrs(fwd_node, &nb); 1976 } else { 1977 return Status(error::Code::UNIMPLEMENTED, 1978 "Unimplemented case for node rewrite optimization."); 1979 } 1980 } else { 1981 ri->copy_attrs(const_cast<const Node*>(orig_node), &nb); 1982 } 1983 // Set the Mkl layer label for this op. 1984 nb.Attr("_kernel", mkl_op_registry::kMklOpLabel); 1985 1986 // Finalize graph and get new node. 1987 Node* new_node = nullptr; 1988 TF_CHECK_OK(nb.Finalize(&**g, &new_node)); 1989 CHECK_NOTNULL(new_node); 1990 1991 // Incoming data edges from 'orig_node' node to new 'new_node' node are 1992 // already copied in BuildNode. We need to handle control edges now. 1993 for (const Edge* e : orig_node->in_edges()) { 1994 if (e->IsControlEdge()) { 1995 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); 1996 } 1997 } 1998 1999 // Copy outgoing edges from 'orig_node' node to new 2000 // 'new_node' node, since the output also follows same ordering among 2001 // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow 2002 // tensors appropriately. Specifically, nth output of the original node 2003 // will become 2*nth output of the Mkl node for the interleaved ordering 2004 // of the tensors. For the contiguous ordering of the tensors, it will be n. 2005 // GetTensorDataIndex provides this mapping function. 2006 for (const Edge* e : orig_node->out_edges()) { 2007 if (e->IsControlEdge()) { 2008 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); 2009 } else { 2010 CHECK_NOTNULL((*g)->AddEdge( 2011 new_node, 2012 GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), 2013 e->dst(), e->dst_input())); 2014 } 2015 } 2016 2017 // Copy the runtime device assigned from original code to new node. 2018 new_node->set_assigned_device_name(orig_node->assigned_device_name()); 2019 2020 // Delete original node and mark new node as rewritten. 2021 (*g)->RemoveNode(orig_node); 2022 2023 VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString(); 2024 return Status::OK(); 2025 } 2026 2027 const MklLayoutRewritePass::ContextInfo* 2028 MklLayoutRewritePass::SearchMatchingContext(const Node* n, 2029 const Node** fwd_node) { 2030 CHECK_NOTNULL(n); 2031 CHECK_NOTNULL(fwd_node); 2032 *fwd_node = nullptr; 2033 2034 // Search for matching contextinfo based on node name and call 2035 // callback function using matching contextinfo. 2036 // There could be more than one matching contextinfos but whichever 2037 // matches first is returned. 2038 for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) { 2039 if (n->type_string() == (*ci)->node && 2040 (*ci)->context_match_fn(n, fwd_node, *ci)) { 2041 VLOG(1) << "Found context as matching: " << (*ci)->fwd; 2042 return *ci; 2043 } 2044 } 2045 return nullptr; 2046 } 2047 2048 bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n, 2049 const ContextInfo* c) { 2050 const Node* fwd_node = nullptr; 2051 return SearchMatchingContext(n, &fwd_node) == c; 2052 } 2053 2054 const MklLayoutRewritePass::RewriteInfo* 2055 MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { 2056 CHECK_NOTNULL(n); 2057 2058 // First check if node along with its type is supported by MKL layer. 2059 // We do not want to rewrite an op into Mkl op if types are not supported. 2060 // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to 2061 // MklRelu if type is INT32. 2062 DataType T; 2063 if (!GetNodeAttr(n->def(), "T", &T).ok()) { 2064 return nullptr; 2065 } 2066 2067 // BiasAddGrad is not an Mkl layer, so we make an exception for it. 2068 if (n->type_string() != csinfo_.bias_add_grad) { 2069 if (!mkl_op_registry::IsMklOp( 2070 mkl_op_registry::GetMklOpName(n->type_string()), T)) { 2071 return nullptr; 2072 } 2073 } 2074 2075 // For elementwise node, we reuse the Eigen implementation and pass the MKL 2076 // metadata tensor through so we can avoid conversions. However, if all 2077 // incoming edges are in TF format, we don't need all this overhead, so 2078 // replace the elementwise node only if at least one of its parents is a MKL 2079 // node. 2080 // 2081 // TODO(vrane): Add implementation for element-wise ops that doesn't reuse 2082 // eigen code to reduce cross-library dependency. 2083 if (mkl_op_registry::IsMklElementWiseOp( 2084 mkl_op_registry::GetMklOpName(n->type_string()), T)) { 2085 bool incoming_mkl_edge = false; 2086 for (auto parent : n->in_edges()) { 2087 if (mkl_op_registry::IsMklOp( 2088 mkl_op_registry::GetMklOpName(parent->src()->type_string()), T)) { 2089 incoming_mkl_edge = true; 2090 break; 2091 } else { 2092 VLOG(1) << "Non-MKL parent is: " << parent->src()->type_string(); 2093 } 2094 } 2095 if (incoming_mkl_edge == false) { 2096 VLOG(1) << "Skipping replacement of elementwise node which has no MKL " 2097 "parents."; 2098 return nullptr; 2099 } 2100 } 2101 2102 // We support 2 types of node rewrites: 2103 // 1. Rewriting BiasAddGrad depending on its MklConv2DWithBias context. 2104 // 2. Rewriting an op to Mkl op always 2105 // We return true if any of these 2 conditions is met. 2106 2107 // Find matching RewriteInfo and then check that rewrite rule applies. 2108 for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { 2109 if (n->type_string().compare(ri->name) == 0 && 2110 ri->rewrite_rule(n, ri->context)) { 2111 // If we are rewriting BiasAddGrad into BiasAddGrad for MatMul context, 2112 // then we just return directly. 2113 if (n->type_string() == csinfo_.bias_add_grad && 2114 ri->context->fwd == csinfo_.matmul && 2115 ri->new_name == csinfo_.bias_add_grad) { 2116 return nullptr; 2117 } 2118 return &*ri; 2119 } 2120 } 2121 2122 // Else return not found. 2123 return nullptr; 2124 } 2125 2126 /////////////////////////////////////////////////////////////////////////////// 2127 // Run function for the pass 2128 /////////////////////////////////////////////////////////////////////////////// 2129 2130 bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) { 2131 bool result = false; 2132 CHECK_NOTNULL(g); 2133 2134 DumpGraph("Before running MklLayoutRewritePass", &**g); 2135 2136 std::vector<Node*> order; 2137 GetReversePostOrder(**g, &order); // This will give us topological sort. 2138 2139 for (Node* n : order) { 2140 // If node is not an op or it cannot run on CPU device, then skip. 2141 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { 2142 continue; 2143 } 2144 2145 const RewriteInfo* ri = nullptr; 2146 Node* predn = nullptr; 2147 // We will first search if node is to be rewritten 2148 if ((ri = CheckForNodeRewrite(n)) != nullptr) { 2149 string node_name = n->name(); 2150 string op_name = n->type_string(); 2151 2152 VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name 2153 << " with op " << op_name << " for rewrite using" 2154 << " layout optimization."; 2155 2156 if (RewriteNode(g, n, ri) == Status::OK()) { 2157 VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name 2158 << " with op " << op_name << " for Mkl layout optimization."; 2159 result = true; 2160 } 2161 } else if ((predn = CheckForNodeMerge(n)) != nullptr) { 2162 // Otherwise, we will check if the node is to be merged. 2163 string n1_name = n->name(); 2164 string n2_name = predn->name(); 2165 2166 VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and " 2167 << n2_name << " for merging"; 2168 2169 if (MergeNode(g, n, predn) == Status::OK()) { 2170 VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and " 2171 << n2_name; 2172 result = true; 2173 } 2174 } 2175 } 2176 2177 DumpGraph("After running MklLayoutRewritePass", &**g); 2178 2179 return result; 2180 } 2181 2182 bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) { 2183 return MklLayoutRewritePass().RunPass(g); 2184 } 2185 2186 Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { 2187 if (options.graph == nullptr && options.partition_graphs == nullptr) { 2188 return Status::OK(); 2189 } 2190 2191 auto process_graph = [&](std::unique_ptr<Graph>* g) { 2192 // Get the ownership of a graph 2193 std::unique_ptr<Graph>* ng = std::move(g); 2194 RunPass(ng); 2195 // Return the ownership of a graph back 2196 g->reset(ng->release()); 2197 }; 2198 2199 if (kMklLayoutRewritePassGroup != 2200 OptimizationPassRegistry::POST_PARTITIONING) { 2201 // For any pre-partitioning phase, a graph is stored in options.graph. 2202 process_graph(options.graph); 2203 } else { 2204 // For post partitioning phase, graphs are stored in 2205 // options.partition_graphs. 2206 for (auto& pg : *options.partition_graphs) { 2207 process_graph(&pg.second); 2208 } 2209 } 2210 2211 return Status::OK(); 2212 } 2213 2214 #else // INTEL_MKL_ML 2215 2216 // This pass implements rewriting of graph to support following scenarios: 2217 // (A) Merging nodes in the graph 2218 // (B) Rewriting a node in the graph to a new node 2219 // Rewrite happens under following scenario: 2220 // - Propagating Mkl layout as an additional output tensor 2221 // (we will loosely call a tensor that carries Mkl layout as Mkl tensor 2222 // henceforth.) from every Mkl supported NN layer. 2223 // 2224 // Example of A : Merging nodes in the graph 2225 // ----------------------------------------- 2226 // Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as: 2227 // 2228 // O = Conv2D(A, B) 2229 // P = BiasAdd(O, C) 2230 // 2231 // We merge them into Conv2DWithBias as: 2232 // P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m) 2233 // 2234 // The meaning of A_m, B_m and C_m is explained in B.1. 2235 // 2236 // Merge rules: 2237 // - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_ 2238 // goes to BiasAdd. 2239 // - Also, the intersection of attributes of both the nodes must have same 2240 // values. 2241 // - Both the nodes must have been assigned to same device (if any). 2242 // 2243 // Example of B.1 : Rewriting nodes to Mkl nodes 2244 // --------------------------------------------- 2245 // Consider a Relu node. Current definition of Relu node looks like: 2246 // 2247 // O = Relu(A) 2248 // 2249 // Relu has 1 input (A), and 1 output (O). 2250 // 2251 // This rewrite pass will generate a new graph node for Relu (new node is 2252 // called MklRelu) as: 2253 // 2254 // O, O_m = MklRelu(A, A_m) 2255 // 2256 // MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is 2257 // same as input A of Relu; output O is same as output O of Relu. O_m is the 2258 // additional output tensor that will be set by MklRelu, and it represents 2259 // Mkl tensor corresponding to O -- in other words, O_m is some kind of 2260 // metadata for O. A_m is additional input of Relu, and it represents metadata 2261 // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives 2262 // this metadata from previous node in the graph. 2263 // 2264 // When a previous node in the graph is an Mkl node, A_m will represent a valid 2265 // Mkl tensor. But when a previous node is not an Mkl node, A_m will represent 2266 // a dummy Mkl tensor. 2267 // 2268 // Rewriting rules: 2269 // - Selection of a node for rewriting happens by registering the op type of 2270 // the node with the rewriting pass. If the op type is not registered, then 2271 // all nodes of this op type will not be rewritten. 2272 // - Number of inputs after rewriting: 2273 // Since for every input Tensorflow tensor, the rewritten node gets Mkl 2274 // tensor(s), rewritten node gets 2*N inputs, where N is the number of 2275 // inputs for the original node. 2276 // - Number of outputs after rewriting: 2277 // Since for every output Tensorflow tensor, the rewritten node generates 2278 // Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the 2279 // number of outputs of the original node. 2280 // - Ordering of Tensorflow tensors and Mkl tensors: 2281 // Since every rewritten node generates twice the number of inputs and 2282 // outputs, one could imagine various orderings among Tensorflow tensors 2283 // and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as 2284 // inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m 2285 // in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m 2286 // order. Among N inputs one can get N! permutations. 2287 // 2288 // So the question is: which order do we follow? We support 2 types of 2289 // orderings: (1) interleaved, and (2) contiguous. Interleaved ordering 2290 // follows an intuitive order where an Mkl tensor follows the 2291 // corresponding Tensorflow tensor immediately. In the context of the 2292 // above example, it will be: A, A_m, B, B_m. Note that the ordering rule 2293 // applies to both the inputs and outputs. Contiguous ordering means 2294 // all the Tensorflow tensors are contiguous followed by all the Mkl 2295 // tensors. We use contiguous ordering as default. 2296 // 2297 // Graph rewrite algorithm: 2298 // Algorithm: Graph Rewrite 2299 // Input: Graph G, Names of the nodes to rewrite and their new names 2300 // Output: Modified Graph G' if the nodes are modified, G otherwise. 2301 // Start: 2302 // N = Topological_Sort(G) // N is a set of nodes in toposort order. 2303 // foreach node n in N 2304 // do 2305 // if (Is_MKL_Op(n)) // Can this node accept an Mkl layout as input. 2306 // then 2307 // E = set of <incoming edge and its src_output slot> of n 2308 // E' = {} // a new set of edges for rewritten node 2309 // foreach <e,s> in E 2310 // do 2311 // E' U {<e,s>} // First copy edge which generates Tensorflow 2312 // // tensor as it is 2313 // m = Source node of edge e 2314 // if Is_Rewritten(m) // Did we rewrite this node in this pass? 2315 // then 2316 // E' U {<m,s+1>} // If yes, then m will generate an Mkl 2317 // // tensor as an additional output. 2318 // else 2319 // d = Generate_Dummy_Mkl_Tensor() // If not, generate a dummy 2320 // // Mkl tensor. 2321 // E' U {<d,0>} // The dummy Mkl tensor has only 1 output slot. 2322 // fi 2323 // done 2324 // n' = Build_New_Node(G,new_name,E') 2325 // Mark_Rewritten(n') // Mark the new node as being rewritten. 2326 // fi 2327 // done 2328 // 2329 // Explanation: 2330 // For graph rewrite, we visit nodes of the input graph in the 2331 // topological sort order. With this ordering, we visit nodes in the 2332 // top-to-bottom fashion. We need this order because while visiting a 2333 // node we want that all of its input nodes are visited and rewritten if 2334 // applicable. This is because if we need to rewrite a given node 2335 // then all of its input nodes need to be fixed (in other words they 2336 // cannot be deleted later.) 2337 // 2338 // While visiting a node, we first check if the op type of the node is 2339 // an Mkl op. If it is, then we rewrite that node after constructing 2340 // new inputs to the node. If the op type of the node is not Mkl op, 2341 // then we do not rewrite that node. 2342 // 2343 // Handling workspace propagation for certain ops: 2344 // 2345 // Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require 2346 // passing of a workspace from their respective forward ops. Workspace 2347 // tensors provide memory for storing results of intermediate operations 2348 // which are helpful in backward propagation. TensorFlow does not have 2349 // a notion of a workspace and as a result does not allow producing 2350 // additional outputs from these forward ops. For these ops, we need 2351 // to add 2 extra edges between forward ops and their corresponding 2352 // backward ops - the first extra edge carries a workspace tensor and 2353 // the second one carries an Mkl tensor for the workspace tensor. 2354 // 2355 // Example: 2356 // 2357 // Typical graph for MaxPool and its gradient looks like: 2358 // 2359 // A = MaxPool(T) 2360 // B = MaxPoolGrad(X, A, Y) 2361 // 2362 // We will transform this graph to propagate the workspace as: 2363 // (with the contiguous ordering) 2364 // 2365 // A, W, A_m, W_m = MklMaxPool(T, T_m) 2366 // B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m) 2367 // 2368 // Here W is the workspace tensor. Transformed tensor names with the 2369 // suffix _m are Mkl tensors, and this transformation has been done 2370 // using the algorithm discussed earlier. The transformation for 2371 // workspace propagation only adds extra outputs (W, W_m) for a forward 2372 // op and connects them to the corresponding backward ops. 2373 // 2374 // Terms: 2375 // 2376 // Forward op name = name of the op in the forward pass 2377 // where a workspace tensor originates (MaxPool in this example) 2378 // Backward op name = name of the op in the backward pass that receives 2379 // a workspace tensor from the forward op (MaxPoolGrad in the example) 2380 // Slot = Position of the output or input slot that will be 2381 // used by the workspace tensor (1 for MklMaxPool as W is the 2nd 2382 // output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad) 2383 // 2384 // Question: 2385 // 2386 // How do we associate a backward op to a forward op? There can be more 2387 // than one op with the exact same name. 2388 // 2389 // In this example, we associate MaxPoolGrad with MaxPool. But there 2390 // could be more than one MaxPool ops. To solve this problem, we look 2391 // for _direct_ edge between a forward op and a backward op (tensor A is 2392 // flowing along this edge in the example). 2393 // 2394 // How do we transform forward and backward ops when there is no direct 2395 // edge between them? In such a case, we generate dummy tensors for 2396 // workspace tensors. For the example, transformation of MaxPool will 2397 // be exactly same as it would be when there is a direct edge between 2398 // the forward and the backward op --- it is just that MaxPool won't 2399 // generate any workspace tensor. For MaxPoolGrad, the transformation 2400 // will also be same, but instead of connecting W and W_m with the 2401 // outputs of MaxPool, we will produce dummy tensors for them, and we 2402 // will set workspace_enabled attribute to false. 2403 // 2404 class MklLayoutRewritePass : public GraphOptimizationPass { 2405 public: 2406 MklLayoutRewritePass() { 2407 // NOTE: names are alphabetically sorted. 2408 csinfo_.addn = "AddN"; 2409 csinfo_.avg_pool = "AvgPool"; 2410 csinfo_.avg_pool_grad = "AvgPoolGrad"; 2411 csinfo_.bias_add = "BiasAdd"; 2412 csinfo_.bias_add_grad = "BiasAddGrad"; 2413 csinfo_.concat = "Concat"; 2414 csinfo_.concatv2 = "ConcatV2"; 2415 csinfo_.conv2d = "Conv2D"; 2416 csinfo_.conv2d_with_bias = "__MklDummyConv2DWithBias"; 2417 csinfo_.conv2d_grad_input = "Conv2DBackpropInput"; 2418 csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter"; 2419 csinfo_.conv2d_grad_filter_with_bias = 2420 "__MklDummyConv2DBackpropFilterWithBias"; 2421 csinfo_.fused_batch_norm = "FusedBatchNorm"; 2422 csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad"; 2423 csinfo_.identity = "Identity"; 2424 csinfo_.lrn = "LRN"; 2425 csinfo_.lrn_grad = "LRNGrad"; 2426 csinfo_.matmul = "MatMul"; 2427 csinfo_.max_pool = "MaxPool"; 2428 csinfo_.max_pool_grad = "MaxPoolGrad"; 2429 csinfo_.mkl_conv2d = "_MklConv2D"; 2430 csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput"; 2431 csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter"; 2432 csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; 2433 csinfo_.mkl_conv2d_grad_filter_with_bias = 2434 "_MklConv2DBackpropFilterWithBias"; 2435 csinfo_.relu = "Relu"; 2436 csinfo_.relu_grad = "ReluGrad"; 2437 csinfo_.tanh = "Tanh"; 2438 csinfo_.tanh_grad = "TanhGrad"; 2439 csinfo_.reshape = "Reshape"; 2440 csinfo_.softmax = "Softmax"; 2441 csinfo_.split = "Split"; 2442 // Element-wise ops. Ensure you also add any new ops to IsOpElementWise 2443 // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the 2444 // MklInputConversion op is added before it. 2445 csinfo_.add = "Add"; 2446 csinfo_.maximum = "Maximum"; 2447 csinfo_.mul = "Mul"; 2448 csinfo_.squared_difference = "SquaredDifference"; 2449 csinfo_.sub = "Sub"; 2450 // End - element-wise ops. See note above. 2451 2452 // NOTE: names are alphabetically sorted. 2453 rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), 2454 CopyAttrsAddN, AddNRewrite}); 2455 rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), 2456 CopyAttrsDataType, AlwaysRewrite}); 2457 rinfo_.push_back({csinfo_.avg_pool, 2458 mkl_op_registry::GetMklOpName(csinfo_.avg_pool), 2459 CopyAttrsPooling, AlwaysRewrite}); 2460 rinfo_.push_back({csinfo_.avg_pool_grad, 2461 mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad), 2462 CopyAttrsPooling, AlwaysRewrite}); 2463 rinfo_.push_back({csinfo_.concat, 2464 mkl_op_registry::GetMklOpName(csinfo_.concat), 2465 CopyAttrsConcat, AlwaysRewrite}); 2466 rinfo_.push_back({csinfo_.concatv2, 2467 mkl_op_registry::GetMklOpName(csinfo_.concatv2), 2468 CopyAttrsConcatV2, AlwaysRewrite}); 2469 rinfo_.push_back({csinfo_.conv2d, 2470 mkl_op_registry::GetMklOpName(csinfo_.conv2d), 2471 CopyAttrsConv2D, AlwaysRewrite}); 2472 rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias, 2473 CopyAttrsConv2D, AlwaysRewrite}); 2474 rinfo_.push_back({csinfo_.conv2d_grad_filter, 2475 mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter), 2476 CopyAttrsConv2D, AlwaysRewrite}); 2477 rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias, 2478 csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv2D, 2479 AlwaysRewrite}); 2480 rinfo_.push_back({csinfo_.conv2d_grad_input, 2481 mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input), 2482 CopyAttrsConv2D, AlwaysRewrite}); 2483 rinfo_.push_back({csinfo_.fused_batch_norm, 2484 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), 2485 CopyAttrsFusedBatchNorm, AlwaysRewrite}); 2486 rinfo_.push_back( 2487 {csinfo_.fused_batch_norm_grad, 2488 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), 2489 CopyAttrsFusedBatchNorm, AlwaysRewrite}); 2490 rinfo_.push_back({csinfo_.identity, 2491 mkl_op_registry::GetMklOpName(csinfo_.identity), 2492 CopyAttrsDataType, AlwaysRewrite}); 2493 rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn), 2494 CopyAttrsLRN, AlwaysRewrite}); 2495 rinfo_.push_back({csinfo_.lrn_grad, 2496 mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), 2497 CopyAttrsLRN, AlwaysRewrite}); 2498 rinfo_.push_back({csinfo_.max_pool, 2499 mkl_op_registry::GetMklOpName(csinfo_.max_pool), 2500 CopyAttrsPooling, NonDepthBatchWisePoolRewrite}); 2501 rinfo_.push_back({csinfo_.max_pool_grad, 2502 mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad), 2503 CopyAttrsPooling, AlwaysRewrite}); 2504 2505 rinfo_.push_back({csinfo_.maximum, 2506 mkl_op_registry::GetMklOpName(csinfo_.maximum), 2507 CopyAttrsDataType, AlwaysRewrite}); 2508 rinfo_.push_back({csinfo_.mul, 2509 mkl_op_registry::GetMklOpName(csinfo_.mul), 2510 CopyAttrsDataType, AlwaysRewrite}); 2511 rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), 2512 CopyAttrsDataType, AlwaysRewrite}); 2513 rinfo_.push_back({csinfo_.relu_grad, 2514 mkl_op_registry::GetMklOpName(csinfo_.relu_grad), 2515 CopyAttrsDataType, AlwaysRewrite}); 2516 /* 2517 rinfo_.push_back({csinfo_.tanh, 2518 mkl_op_registry::GetMklOpName(csinfo_.tanh), 2519 CopyAttrsDataType, AlwaysRewrite}); 2520 rinfo_.push_back({csinfo_.tanh_grad, 2521 mkl_op_registry::GetMklOpName(csinfo_.tanh_grad), 2522 CopyAttrsDataType, AlwaysRewrite}); 2523 */ 2524 rinfo_.push_back({csinfo_.reshape, 2525 mkl_op_registry::GetMklOpName(csinfo_.reshape), 2526 CopyAttrsReshape, AlwaysRewrite}); 2527 rinfo_.push_back({csinfo_.softmax, 2528 mkl_op_registry::GetMklOpName(csinfo_.softmax), 2529 CopyAttrsDataType, AlwaysRewrite}); 2530 2531 rinfo_.push_back({csinfo_.squared_difference, 2532 mkl_op_registry::GetMklOpName(csinfo_.squared_difference), 2533 CopyAttrsDataType, AlwaysRewrite}); 2534 rinfo_.push_back({csinfo_.sub, 2535 mkl_op_registry::GetMklOpName(csinfo_.sub), 2536 CopyAttrsDataType, AlwaysRewrite}); 2537 2538 // Add info about which ops to add workspace edge to and the slots. 2539 wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3}); 2540 wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3}); 2541 2542 // Add a rule for merging nodes 2543 minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add, 2544 csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd}); 2545 2546 minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad, 2547 csinfo_.conv2d_grad_filter_with_bias, 2548 GetConv2DBackpropFilterOrBiasAddGrad}); 2549 } 2550 2551 // Standard interface to run pass 2552 Status Run(const GraphOptimizationPassOptions& options); 2553 2554 // Helper function which does most of heavy lifting for rewriting 2555 // Mkl nodes to propagate Mkl tensor as additional output 2556 // 2557 // Extracts common functionality between Run public interface and 2558 // test interface. 2559 // 2560 // @return true, if and only if graph is mutated; false otherwise. 2561 bool RunPass(std::unique_ptr<Graph>* g); 2562 2563 /// Structure to specify the name of an original node, its new name after 2564 /// rewrite, the number of inputs to the original node, the function to 2565 /// be used to copy attributes for the op, and the rule (if any) which 2566 /// must hold for rewriting the node 2567 typedef struct { 2568 string name; // Original name of op of the node in the graph 2569 string new_name; // New name of the op of the node in the graph 2570 // A function handler to copy attributes from an old node to a new node. 2571 std::function<void(const Node*, NodeBuilder*)> copy_attrs; 2572 // A rule under which to rewrite this node 2573 std::function<bool(const Node*)> rewrite_rule; 2574 } RewriteInfo; 2575 2576 /// Structure to specify a forward op, a backward op, and the slot numbers 2577 /// in the forward and backward ops where we will add a workspace edge. 2578 typedef struct { 2579 string fwd_op; // Name of a forward op in the graph 2580 string bwd_op; // Name of a backward op in the graph 2581 int fwd_slot; // Output slot in the forward op node where actual 2582 // output tensor resides 2583 int bwd_slot; // Input slot in the backward op node where actual 2584 // input tensor resides 2585 int ws_fwd_slot; // Output slot in the forward op node where workspace 2586 // edge is added 2587 int ws_bwd_slot; // Input slot in the backward op node where workspace 2588 // edge is added 2589 } WorkSpaceInfo; 2590 2591 /// Structure to specify information used in node merge of 2 operators 2592 typedef struct { 2593 string op1; // Node string for one operator. 2594 string op2; // Node string for second operator. 2595 string new_node; // Name of the node after merge 2596 // Function that enables user of the node merger to specify how to find 2597 // second operator given the first operator. 2598 std::function<Node*(const Node*)> get_node_to_be_merged; 2599 } MergeInfo; 2600 2601 /// Structure to store all constant strings 2602 /// NOTE: names are alphabetically sorted. 2603 typedef struct { 2604 string addn; 2605 string add; 2606 string avg_pool; 2607 string avg_pool_grad; 2608 string bias_add; 2609 string bias_add_grad; 2610 string concat; 2611 string concatv2; 2612 string conv2d; 2613 string conv2d_with_bias; 2614 string conv2d_grad_input; 2615 string conv2d_grad_filter; 2616 string conv2d_grad_filter_with_bias; 2617 string fused_batch_norm; 2618 string fused_batch_norm_grad; 2619 string identity; 2620 string lrn; 2621 string lrn_grad; 2622 string matmul; 2623 string max_pool; 2624 string max_pool_grad; 2625 string maximum; 2626 string mkl_conv2d; 2627 string mkl_conv2d_grad_input; 2628 string mkl_conv2d_grad_filter; 2629 string mkl_conv2d_grad_filter_with_bias; 2630 string mkl_conv2d_with_bias; 2631 string mul; 2632 string relu; 2633 string relu_grad; 2634 string tanh; 2635 string tanh_grad; 2636 string reshape; 2637 string softmax; 2638 string split; 2639 string squared_difference; 2640 string sub; 2641 } ConstStringsInfo; 2642 2643 private: 2644 /// Maintain info about nodes to rewrite 2645 std::vector<RewriteInfo> rinfo_; 2646 2647 /// Maintain info about nodes to add workspace edge 2648 std::vector<WorkSpaceInfo> wsinfo_; 2649 2650 /// Maintain info about nodes to be merged 2651 std::vector<MergeInfo> minfo_; 2652 2653 /// Maintain structure of constant strings 2654 static ConstStringsInfo csinfo_; 2655 2656 private: 2657 // Is OpDef::ArgDef a list type? It could be N * T or list(type). 2658 // Refer to opdef.proto for details of list type. 2659 inline bool ArgIsList(const OpDef::ArgDef& arg) const { 2660 return !arg.type_list_attr().empty() || !arg.number_attr().empty(); 2661 } 2662 2663 // Get length of a list in 'n' if 'arg' is of list type. Refer to 2664 // description of ArgIsList for definition of list type. 2665 inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) { 2666 CHECK_EQ(ArgIsList(arg), true); 2667 int N = 0; 2668 const string attr_name = !arg.type_list_attr().empty() 2669 ? arg.type_list_attr() 2670 : arg.number_attr(); 2671 if (!arg.type_list_attr().empty()) { 2672 std::vector<DataType> value; 2673 TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value)); 2674 N = value.size(); 2675 } else { 2676 TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N)); 2677 } 2678 return N; 2679 } 2680 2681 // Can op represented by node 'n' run on DEVICE_CPU? 2682 // Op can run on CPU with MKL if the runtime assigned device or the 2683 // user requested device contains device CPU, or both are empty. 2684 bool CanOpRunOnCPUDevice(const Node* n) { 2685 bool result = true; 2686 string reason; 2687 2688 // Substring that should be checked for in device name for CPU device. 2689 const char* const kCPUDeviceSubStr = "CPU"; 2690 2691 // If Op has been specifically assigned to a non-CPU device, then No. 2692 if (!n->assigned_device_name().empty() && 2693 !StringPiece(n->assigned_device_name()).contains(kCPUDeviceSubStr)) { 2694 result = false; 2695 reason = "Op has been assigned a runtime device that is not CPU."; 2696 } 2697 2698 // If user has specifically assigned this op to a non-CPU device, then No. 2699 if (!n->def().device().empty() && 2700 !StringPiece(n->def().device()).contains(kCPUDeviceSubStr)) { 2701 result = false; 2702 reason = "User has assigned a device that is not CPU."; 2703 } 2704 2705 if (result == false) { 2706 VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node " 2707 << n->type_string() << ", reason: " << reason; 2708 } 2709 2710 // Otherwise Yes. 2711 return result; 2712 } 2713 2714 // Return a node that can be merged with input node 'n' 2715 // 2716 // @return pointer to the node if we can find such a 2717 // node. Otherwise, it returns nullptr. 2718 Node* CheckForNodeMerge(const Node* n) const; 2719 2720 // Merge node 'm' with node 'n'. 2721 // Currently, we merge (1) Conv2D with BiasAdd, and (2) BiasAddGrad with 2722 // Conv2DBackpropFilter. 2723 // 2724 // Input nodes m and n may be deleted if the call to 2725 // this function is successful. Attempt to use the pointers 2726 // after the call to function may result in undefined behaviors. 2727 // 2728 // @input g - input graph, m - graph node, n - graph node to be merged with m 2729 // @return Status::OK(), if merging is successful and supported. 2730 // Returns appropriate Status error code otherwise. 2731 // Graph is updated in case nodes are merged. Otherwise, it is 2732 // not updated. 2733 Status MergeNode(std::unique_ptr<Graph>* g, Node* m, Node* n); 2734 2735 // Helper function to merge different nodes 2736 Status MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, Node* m, Node* n); 2737 Status MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph>* g, 2738 Node* m, Node* n); 2739 2740 // Find BiasAdd or Conv2D node that can be merged with input node 'm'. 2741 // If input 'm' is BiasAdd, then check if there exists Conv2D node that can be 2742 // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd 2743 // node that can be merged with 'm'. 2744 static Node* GetConv2DOrBiasAdd(const Node* m) { 2745 CHECK_NOTNULL(m); 2746 Node* n = nullptr; 2747 2748 if (m->type_string() == csinfo_.bias_add) { 2749 // If a is BiasAdd, then Conv2D is 0th input of BiasAdd. 2750 TF_CHECK_OK(m->input_node(0, &n)); 2751 } else { 2752 CHECK_EQ(m->type_string(), csinfo_.conv2d); 2753 // Go over all output edges and search for BiasAdd Node. 2754 // 0th input of BiasAdd is Conv2D. 2755 for (const Edge* e : m->out_edges()) { 2756 if (!e->IsControlEdge() && 2757 e->dst()->type_string() == csinfo_.bias_add && 2758 e->dst_input() == 0) { 2759 n = e->dst(); 2760 break; 2761 } 2762 } 2763 } 2764 2765 if (n == nullptr) { 2766 VLOG(1) << "MklLayoutRewritePass: Could not find matching " 2767 << "Conv2D and BiasAdd node for merging. Input node: " 2768 << m->DebugString(); 2769 } 2770 2771 return n; 2772 } 2773 2774 // Find Conv2DBackpropFilter or BiasAddGrad node that can be merged with input 2775 // node 'm'. If input 'm' is Conv2DBackpropFilter, then check if there exists 2776 // BiasAddGrad node that can be merged with 'm'. If input 'm' is BiasAddGrad, 2777 // then check if there exists Conv2DBackpropFilter node that can be merged 2778 // with 'm'. 2779 // 2780 // Graph that will allow us to connect Conv2DBackpropFilter with BiasAddGrad 2781 // would look like: 2782 // 2783 // _ = Conv2DBackpropFilter(F, _, G) 2784 // _ = BiasAddGrad(G) 2785 // 2786 // So 1st input of BiasAddGrad connects with 3rd input of 2787 // Conv2DBackpropFilter and vice versa. 2788 static Node* GetConv2DBackpropFilterOrBiasAddGrad(const Node* m) { 2789 CHECK_NOTNULL(m); 2790 Node* n = nullptr; 2791 2792 if (m->type_string() == csinfo_.bias_add_grad) { 2793 // Get 1st input 'g' of BiasAddGrad. 2794 Node* g = nullptr; 2795 TF_CHECK_OK(m->input_node(0, &g)); 2796 // Now traverse all outgoing edges from g that have destination node as 2797 // Conv2DBackpropFilter. 2798 for (const Edge* e : g->out_edges()) { 2799 if (!e->IsControlEdge() && 2800 e->dst()->type_string() == csinfo_.conv2d_grad_filter && 2801 e->dst_input() == 2 /* 3rd input of BackpropFilter */) { 2802 n = e->dst(); 2803 break; 2804 } 2805 } 2806 } else { 2807 CHECK_EQ(m->type_string(), csinfo_.conv2d_grad_filter); 2808 // Get 3rd input 'g' of Conv2DBackpropFilter. 2809 Node* g = nullptr; 2810 TF_CHECK_OK(m->input_node(2, &g)); 2811 // Now traverse all outgoing edges from g that have destination node as 2812 // BiasAddGrad. 2813 for (const Edge* e : g->out_edges()) { 2814 if (!e->IsControlEdge() && 2815 e->dst()->type_string() == csinfo_.bias_add_grad && 2816 e->dst_input() == 0 /* 1st input of BiasAddGrad */) { 2817 n = e->dst(); 2818 break; 2819 } 2820 } 2821 } 2822 2823 if (n == nullptr) { 2824 VLOG(1) << "MklLayoutRewritePass: Could not find matching " 2825 << "Conv2DBackpropFilter and BiasAddGrad node for merging. " 2826 << "Input node: " << m->DebugString(); 2827 } 2828 return n; 2829 } 2830 2831 // Check if the node 'n' has any applicable rewrite rule 2832 // We check for 2 scenarios for rewrite. 2833 // 2834 // @return RewriteInfo* for the applicable rewrite rule 2835 const RewriteInfo* CheckForNodeRewrite(const Node* n) const; 2836 2837 // Default rewrite rule to be used in scenario 1 for rewrite. 2838 // @return - true (since we want to always rewrite) 2839 static bool AlwaysRewrite(const Node* n) { return true; } 2840 2841 // Check if we are performing pooling on depth or batch. If it is, then we 2842 // do not rewrite MaxPool node to Mkl version. 2843 // @return - true (if it is not a depth/batch wise pooling case); 2844 // false otherwise. 2845 static bool NonDepthBatchWisePoolRewrite(const Node* n) { 2846 CHECK_NOTNULL(n); 2847 2848 string data_format_str; 2849 TensorFormat data_format; 2850 std::vector<int32> ksize, strides; 2851 CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); 2852 CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); 2853 CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true); 2854 CHECK_EQ(FormatFromString(data_format_str, &data_format), true); 2855 2856 // Condition that specifies non-batch-wise and non-depth-wise pooling. 2857 if (GetTensorDim(ksize, data_format, 'N') == 1 && 2858 GetTensorDim(strides, data_format, 'N') == 1 && 2859 GetTensorDim(ksize, data_format, 'C') == 1 && 2860 GetTensorDim(strides, data_format, 'C') == 1) { 2861 return true; 2862 } 2863 2864 return false; 2865 } 2866 2867 static bool AddNRewrite(const Node* n) { 2868 CHECK_NOTNULL(n); 2869 2870 int num; 2871 CHECK_EQ(GetNodeAttr(n->def(), "N", &num).ok(), true); 2872 2873 // Condition that specifies non-batch-wise and non-depth-wise pooling. 2874 if (num == 2) { 2875 return true; 2876 } 2877 2878 return false; 2879 } 2880 2881 // Rewrites input node to a new node specified by its matching rewrite info. 2882 // 2883 // Method first searches matching rewrite info for input node and then 2884 // uses that info to rewrite. 2885 // 2886 // Input node may be deleted in case of rewrite. Attempt to use the node 2887 // after the call can result in undefined behaviors. 2888 // 2889 // @input g - input graph, n - Node to be rewritten, 2890 // ri - matching rewriteinfo 2891 // @return Status::OK(), if the input node is rewritten; 2892 // Returns appropriate Status error code otherwise. 2893 // Graph is updated in case the input node is rewritten. 2894 // Otherwise, it is not updated. 2895 Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri); 2896 2897 // Get nodes that will feed a list of TF tensors to the new 2898 // node that we are constructing. 2899 // 2900 // @input g - input graph, 2901 // @input inputs - inputs to old node that we are using for constructing 2902 // new inputs, 2903 // @input input_idx - the index in the 'inputs' vector pointing to the 2904 // current input that we have processed so far 2905 // @output input_idx - index will be incremented by the number of nodes 2906 // from 'inputs' that are processed 2907 // @input list_length - The expected length of list of TF tensors 2908 // @output output_nodes - the list of new nodes creating TF tensors 2909 // 2910 // @return None 2911 void GetNodesProducingTFTensorList( 2912 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, 2913 int* input_idx, int list_length, 2914 std::vector<NodeBuilder::NodeOut>* output_nodes); 2915 2916 // Get nodes that will feed a list of Mkl tensors to the new 2917 // node that we are constructing. 2918 // 2919 // @input g - input graph, 2920 // @input orig_node - Original node that we are rewriting 2921 // @input inputs - inputs to old node that we are using for constructing 2922 // new inputs, 2923 // @input input_idx - the index in the 'inputs' vector pointing to the 2924 // current input that we have processed so far 2925 // @output input_idx - index will be incremented by the number of nodes 2926 // from 'inputs' that are processed 2927 // @input list_length - The expected length of list of Mkl tensors 2928 // @output output_nodes - the list of new nodes creating Mkl tensors 2929 // 2930 // @return None 2931 void GetNodesProducingMklTensorList( 2932 std::unique_ptr<Graph>* g, Node* orig_node, 2933 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, 2934 int* input_idx, int list_length, 2935 std::vector<NodeBuilder::NodeOut>* output_nodes); 2936 2937 // Get a node that will feed an Mkl tensor to the new 2938 // node that we are constructing. The output node could be (1) 'n' 2939 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor 2940 // if 'n' is not an Mkl layer. 2941 // 2942 // @input g - input graph, 2943 // @input orig_node - Original node that we are rewriting, 2944 // @input n - Node based on which we are creating Mkl node, 2945 // @input n_output_slot - the output slot of node 'n' 2946 // which is feeding to the node that we are constructing 2947 // @output mkl_node - the new node that will feed Mkl tensor 2948 // @output mkl_node_output_slot - the slot number of mkl_node that 2949 // will feed the tensor 2950 // @return None 2951 void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node, 2952 Node* n, int n_output_slot, Node** mkl_node, 2953 int* mkl_node_output_slot); 2954 2955 // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' 2956 // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are 2957 // set up in contiguous fashion. 'workspace_tensors' carry graph nodes 2958 // producing workspace edges if 'are_workspace_tensors_available' is true. 2959 // Otherwise, 'workspace_tensors' is empty vector. 2960 // 2961 // For details, refer to 'Ordering of inputs after rewriting' section in the 2962 // documentation above. 2963 // 2964 // Returns Status::OK() if setting up inputs is successful, otherwise 2965 // returns appropriate status code. 2966 int SetUpContiguousInputs( 2967 std::unique_ptr<Graph>* g, 2968 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, 2969 NodeBuilder* nb, Node* old_node, 2970 std::vector<NodeBuilder::NodeOut>* workspace_tensors, 2971 bool are_workspace_tensors_available); 2972 2973 // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' 2974 // in graph 'g'. Original node is input in 'orig_node'. 2975 // 2976 // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors' 2977 // section in the documentation above. 2978 // 2979 // Returns Status::OK() if setting up inputs is successful, otherwise 2980 // returns appropriate status code. 2981 Status SetUpInputs(std::unique_ptr<Graph>* g, 2982 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, 2983 NodeBuilder* nb, Node* orig_node); 2984 2985 // Add workspace edge on the input or output side of Node 'orig_node' by using 2986 // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate 2987 // adding workspace edge then do not add it. Workspace Tensorflow and Mkl 2988 // tensors, if they need to be added, will be set into these tensors. 2989 // If we set workspace tensors, then are_ws_tensors_added should be true. 2990 void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orig_node, 2991 NodeBuilder* nb, 2992 std::vector<NodeBuilder::NodeOut>* ws_tensors, 2993 bool* are_ws_tensors_added); 2994 2995 // Functions specific to operators to copy attributes 2996 // We need operator-specific function to copy attributes because the framework 2997 // does not provide any generic function for it. 2998 // NOTE: names are alphabetically sorted. 2999 static void CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb); 3000 static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb); 3001 static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb); 3002 static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb); 3003 static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb); 3004 static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb); 3005 static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb); 3006 static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb); 3007 static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb); 3008 static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb); 3009 static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb); 3010 3011 // Generate a graph node in graph 'g' representing a dummy Mkl tensor node, 3012 // using node for original node 'orig_node' and return it in '*out'. 3013 // TODO(nhasabni) We should move this to mkl_util.h 3014 void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out, 3015 Node* orig_node); 3016 void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out, 3017 Node* orig_node); 3018 }; 3019 3020 MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_; 3021 3022 // We register Mkl rewrite pass for phase 1 in post partitioning group. 3023 // We register it here so that we get a complete picture of all users of Mkl 3024 // nodes. Do not change the ordering of the Mkl passes. 3025 const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup = 3026 OptimizationPassRegistry::POST_PARTITIONING; 3027 REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass); 3028 3029 ////////////////////////////////////////////////////////////////////////// 3030 // Helper functions for creating new node 3031 ////////////////////////////////////////////////////////////////////////// 3032 3033 static void FillInputs(const Node* n, 3034 gtl::InlinedVector<Node*, 4>* control_edges, 3035 gtl::InlinedVector<std::pair<Node*, int>, 4>* in) { 3036 control_edges->clear(); 3037 for (const Edge* e : n->in_edges()) { 3038 if (e->IsControlEdge()) { 3039 control_edges->push_back(e->src()); 3040 } else { 3041 (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output()); 3042 } 3043 } 3044 std::sort(control_edges->begin(), control_edges->end()); 3045 if (n->op_def().is_commutative()) { 3046 // For commutative inputs, we sort the input by the input Node* 3047 // to get a canonical ordering (so that add(a,b) and add(b, a) will 3048 // hash to the same value if is_commutative is true for 'add'). 3049 std::sort(in->begin(), in->end()); 3050 } 3051 } 3052 3053 void MklLayoutRewritePass::GetNodesProducingTFTensorList( 3054 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, 3055 int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { 3056 CHECK_LT(*input_idx, inputs.size()); 3057 CHECK_GT(list_length, 0); 3058 CHECK_NOTNULL(output_nodes); 3059 output_nodes->reserve(list_length); 3060 3061 while (list_length != 0) { 3062 CHECK_GT(list_length, 0); 3063 CHECK_LT(*input_idx, inputs.size()); 3064 Node* n = inputs[*input_idx].first; 3065 int slot = inputs[*input_idx].second; 3066 // If input node 'n' is just producing a single tensor at 3067 // output slot 'slot' then we just add that single node. 3068 output_nodes->push_back(NodeBuilder::NodeOut(n, slot)); 3069 (*input_idx)++; 3070 list_length--; 3071 } 3072 } 3073 3074 // TODO(nhasabni) We should move this to mkl_util.h. 3075 void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, 3076 Node** out, Node* orig_node) { 3077 // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent 3078 // dummy Mkl tensor. 8 = 2*size_t. 3079 const DataType dt = DataTypeToEnum<uint8>::v(); 3080 TensorProto proto; 3081 proto.set_dtype(dt); 3082 uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; 3083 proto.set_tensor_content(const_cast<const void*>(static_cast<void*>(&zero)), 3084 8); 3085 TensorShape dummy_shape({8}); 3086 dummy_shape.AsProto(proto.mutable_tensor_shape()); 3087 TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") 3088 .Attr("value", proto) 3089 .Attr("dtype", dt) 3090 .Device(orig_node->def().device()) // We place this node on 3091 // the same device as the 3092 // device of the original 3093 // node. 3094 .Finalize(&**g, out)); 3095 3096 // If number of inputs to the original node is > 0, then we add 3097 // control dependency between 1st input (index 0) of the original node and 3098 // the dummy Mkl node. This is needed because control-flow ops such as Enter, 3099 // Merge, etc, require frame_name of the dummy Mkl node to be same as the 3100 // rewritten node. Adding control edge between 1st input of the original node 3101 // and the dummy Mkl node ensures that the dummy node is in the same frame 3102 // as the original node. Choosing 1st input is not necessary - any input of 3103 // the original node is fine because all the inputs of a node are always in 3104 // the same frame. 3105 if (orig_node->num_inputs() > 0) { 3106 Node* orig_input0 = nullptr; 3107 TF_CHECK_OK( 3108 orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); 3109 // Allow duplicate while adding control edge as it would fail (return 3110 // NULL) if we try to add duplicate edge. 3111 CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true)); 3112 } 3113 3114 (*out)->set_assigned_device_name(orig_node->assigned_device_name()); 3115 } 3116 3117 void MklLayoutRewritePass::GetNodesProducingMklTensorList( 3118 std::unique_ptr<Graph>* g, Node* orig_node, 3119 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, 3120 int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { 3121 CHECK_LT(*input_idx, inputs.size()); 3122 CHECK_GT(list_length, 0); 3123 CHECK_NOTNULL(output_nodes); 3124 output_nodes->reserve(list_length); 3125 3126 while (list_length != 0) { 3127 CHECK_GT(list_length, 0); 3128 CHECK_LT(*input_idx, inputs.size()); 3129 Node* n = inputs[*input_idx].first; 3130 int slot = inputs[*input_idx].second; 3131 // If 'n' is producing a single tensor, then create a single Mkl tensor 3132 // node. 3133 Node* mkl_node = nullptr; 3134 int mkl_node_output_slot = 0; 3135 GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, 3136 &mkl_node_output_slot); 3137 output_nodes->push_back( 3138 NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); 3139 (*input_idx)++; 3140 list_length--; 3141 } 3142 } 3143 3144 // Get an input node that will feed Mkl tensor to the new 3145 // node that we are constructing. An input node could be (1) 'n' 3146 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor 3147 // if 'n' is not an Mkl layer. 3148 void MklLayoutRewritePass::GetNodeProducingMklTensor( 3149 std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot, 3150 Node** mkl_node, int* mkl_node_output_slot) { 3151 CHECK_NOTNULL(n); 3152 CHECK_NOTNULL(mkl_node); 3153 CHECK_NOTNULL(mkl_node_output_slot); 3154 3155 // If this is an MKL op, then it will create extra output for MKL layout. 3156 DataType T; 3157 if (GetNodeAttr(n->def(), "T", &T).ok() && 3158 mkl_op_registry::IsMklOp(n->type_string(), T)) { 3159 // If this is an MKL op, then it will generate an edge that will receive 3160 // Mkl tensor from a node. 3161 // output slot number for Mkl tensor would be N+slot number of TensorFlow 3162 // tensor, where N is total number of TensorFlow tensors. 3163 *mkl_node = n; 3164 *mkl_node_output_slot = 3165 GetTensorMetaDataIndex(n_output_slot, n->num_outputs()); 3166 } else { 3167 // If we have not visited the node and rewritten it, then we need 3168 // to create a dummy node that will feed a dummy Mkl tensor to this node. 3169 // DummyMklTensor node has no input and generates only 1 output 3170 // (dummy Mkl tensor) as output slot number 0. 3171 GetDummyMklTensorNode(g, mkl_node, orig_node); 3172 CHECK_NOTNULL(*mkl_node); 3173 *mkl_node_output_slot = 0; 3174 } 3175 } 3176 3177 int MklLayoutRewritePass::SetUpContiguousInputs( 3178 std::unique_ptr<Graph>* g, 3179 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, 3180 NodeBuilder* nb, Node* old_node, 3181 std::vector<NodeBuilder::NodeOut>* workspace_tensors, 3182 bool are_workspace_tensors_available) { 3183 CHECK_NOTNULL(workspace_tensors); 3184 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); 3185 3186 // TODO(nhasabni): Temporary solution to connect filter input of 3187 // BackpropInput with the converted filter from Conv2D. 3188 bool do_connect_conv2d_backprop_input_filter = false; 3189 Node* conv2d_node = nullptr; 3190 // Filter node is 2nd input (slot index 1) of Conv2D. 3191 int kConv2DFilterInputSlotIdx = 1; 3192 int kConv2DBackpropInputFilterInputSlotIdx = 1; 3193 int kConv2DFilterOutputSlotIdx = 1; 3194 if (old_node->type_string() == csinfo_.conv2d_grad_input) { 3195 // We need to find Conv2D node from Conv2DBackpropInput. 3196 // For that let's first find filter node that is 2nd input (slot 1) 3197 // of BackpropInput. 3198 Node* filter_node = nullptr; 3199 old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx, &filter_node); 3200 CHECK_NOTNULL(filter_node); 3201 3202 // Now check which nodes receive from filter_node. Filter feeds as 3203 // 2nd input (slot 1) of _MklConv2D and _MklConv2DWithBias. 3204 for (const Edge* e : filter_node->out_edges()) { 3205 if ((e->dst()->type_string() == csinfo_.mkl_conv2d || 3206 e->dst()->type_string() == csinfo_.mkl_conv2d_with_bias) && 3207 e->dst_input() == kConv2DFilterInputSlotIdx 3208 /* filter is 2nd input of Conv2D and _MklConv2D. */) { 3209 if (conv2d_node != nullptr) { 3210 VLOG(1) << "MklLayoutRewritePass: unusual case of same filter" 3211 << " feeding multiple Conv2D nodes: " 3212 << filter_node->DebugString(); 3213 // We will not connect filter input of Conv2DBackpropInput 3214 // to be safe here. 3215 do_connect_conv2d_backprop_input_filter = false; 3216 break; 3217 } else { 3218 conv2d_node = e->dst(); 3219 do_connect_conv2d_backprop_input_filter = true; 3220 } 3221 } 3222 } 3223 } 3224 3225 // Number of input slots to original op 3226 // Input slots are represented by .Input() calls in REGISTER_OP. 3227 int old_node_input_slots = old_node->op_def().input_arg_size(); 3228 // Actual number of inputs can be greater than or equal to number 3229 // of Input slots because inputs of type list could be unfolded. 3230 CHECK_GE(old_node_inputs.size(), old_node_input_slots); 3231 int nn_slot_idx = 0; // slot index for inputs of new node 3232 3233 // Let's copy all inputs (TF tensors) of original node to new node. 3234 int iidx = 0; 3235 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { 3236 // An input slot could be a single tensor or a list. We need 3237 // to handle this case accordingly. 3238 CHECK_LT(iidx, old_node_inputs.size()); 3239 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); 3240 if (ArgIsList(arg)) { 3241 std::vector<NodeBuilder::NodeOut> new_node_inputs; 3242 int N = GetTensorListLength(arg, old_node); 3243 GetNodesProducingTFTensorList(old_node_inputs, &iidx, N, 3244 &new_node_inputs); 3245 nb->Input(new_node_inputs); 3246 nn_slot_idx++; 3247 } else { 3248 // Special case for connecting filter input of Conv2DBackpropInput 3249 if (do_connect_conv2d_backprop_input_filter && 3250 iidx == kConv2DBackpropInputFilterInputSlotIdx) { 3251 nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx); 3252 } else { 3253 nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second); 3254 } 3255 iidx++; 3256 nn_slot_idx++; 3257 } 3258 } 3259 3260 // If workspace tensors are available for this op and we are using 3261 // contiguous ordering then we need to add Tensorflow tensor for 3262 // workspace here because Tensorflow tensor for workspace is the 3263 // last tensor in the list of Tensorflow tensors. 3264 if (are_workspace_tensors_available) { 3265 CHECK_EQ(workspace_tensors->size(), 2); 3266 // Tensorflow tensor 3267 nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index); 3268 nn_slot_idx++; 3269 } 3270 3271 // Let's now setup all Mkl inputs to a new node. 3272 // Number of Mkl inputs must be same as number of TF inputs. 3273 iidx = 0; 3274 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { 3275 // An input slot could be a single tensor or a list. We need 3276 // to handle this case accordingly. 3277 CHECK_LT(iidx, old_node_inputs.size()); 3278 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); 3279 if (ArgIsList(arg)) { 3280 std::vector<NodeBuilder::NodeOut> new_node_inputs; 3281 int N = GetTensorListLength(arg, old_node); 3282 GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N, 3283 &new_node_inputs); 3284 nb->Input(new_node_inputs); 3285 nn_slot_idx++; 3286 } else { 3287 Node* mkl_node = nullptr; 3288 int mkl_node_output_slot = 0; 3289 // Special case for connecting filter input of Conv2DBackpropInput 3290 if (do_connect_conv2d_backprop_input_filter && 3291 iidx == kConv2DBackpropInputFilterInputSlotIdx) { 3292 GetNodeProducingMklTensor(g, old_node, conv2d_node, 3293 kConv2DFilterOutputSlotIdx, &mkl_node, 3294 &mkl_node_output_slot); 3295 } else { 3296 GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first, 3297 old_node_inputs[iidx].second, &mkl_node, 3298 &mkl_node_output_slot); 3299 } 3300 nb->Input(mkl_node, mkl_node_output_slot); 3301 iidx++; 3302 nn_slot_idx++; 3303 } 3304 } 3305 3306 // If workspace tensors are available for this op and we are using 3307 // contiguous ordering then we need to add Mkl tensor for 3308 // workspace here because Mkl tensor for workspace is the 3309 // last tensor in the list of Mkl tensors. 3310 if (are_workspace_tensors_available) { 3311 CHECK_EQ(workspace_tensors->size(), 2); 3312 // Mkl tensor 3313 nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index); 3314 nn_slot_idx++; 3315 } 3316 3317 return nn_slot_idx; 3318 } 3319 3320 Status MklLayoutRewritePass::SetUpInputs( 3321 std::unique_ptr<Graph>* g, 3322 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, 3323 NodeBuilder* nb, Node* old_node) { 3324 // Let's check if we need to add workspace tensors for this node. 3325 // We add workspace edge only for MaxPool, LRN and BatchNorm. 3326 std::vector<NodeBuilder::NodeOut> workspace_tensors; 3327 bool are_workspace_tensors_available = false; 3328 AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors, 3329 &are_workspace_tensors_available); 3330 3331 int new_node_input_slots = 0; 3332 if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { 3333 // TODO(nhasabni): implement this function just for same of completion. 3334 // We do not use interleaved ordering right now. 3335 return Status( 3336 error::Code::UNIMPLEMENTED, 3337 "Interleaved ordering of tensors is currently not supported."); 3338 } else { 3339 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); 3340 new_node_input_slots = SetUpContiguousInputs( 3341 g, old_node_inputs, nb, old_node, &workspace_tensors, 3342 are_workspace_tensors_available); 3343 } 3344 3345 // Sanity check 3346 int old_node_input_slots = old_node->op_def().input_arg_size(); 3347 if (!are_workspace_tensors_available) { 3348 // If we are not adding workspace tensors for this op, then the total 3349 // number of input slots to the new node _must_ be 2 times the number 3350 // of input slots to the original node: N original Tensorflow tensors and 3351 // N for Mkl tensors corresponding to each Tensorflow tensors. 3352 CHECK_EQ(new_node_input_slots, old_node_input_slots * 2); 3353 } else { 3354 // If we are adding workspace tensors for this op, then the total 3355 // The total number of input slots to new node _must_ be 2 times the number 3356 // of input slots to the original node: N original Tensorflow tensors and 3357 // N for Mkl tensors corresponding to each Tensorflow tensors plus 2 3358 // (for workspace Tensorflow tensor and workspace Mkl tensor). 3359 CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2); 3360 } 3361 3362 return Status::OK(); 3363 } 3364 3365 ////////////////////////////////////////////////////////////////////////// 3366 // Helper functions related to workspace pass 3367 ////////////////////////////////////////////////////////////////////////// 3368 3369 // TODO(nhasabni) We should move this to mkl_util.h. 3370 void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( 3371 std::unique_ptr<Graph>* g, Node** out, Node* orig_node) { 3372 // We use a tensor of shape {1} and value 0 to represent 3373 // dummy float tensor. We need this as a dummy workspace tensor. 3374 // Workspace tensor has type uint8. 3375 const DataType dt = DataTypeToEnum<uint8>::v(); 3376 TensorProto proto; 3377 proto.set_dtype(dt); 3378 float zero[1] = {0}; 3379 proto.set_tensor_content(const_cast<const void*>(static_cast<void*>(&zero)), 3380 4); 3381 TensorShape dummy_shape({1}); 3382 dummy_shape.AsProto(proto.mutable_tensor_shape()); 3383 TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") 3384 .Attr("value", proto) 3385 .Attr("dtype", dt) 3386 .Device(orig_node->def().device()) // We place this node on 3387 // same the device as the 3388 // device of the original 3389 // node. 3390 .Finalize(&**g, out)); 3391 3392 // If number of inputs to the original node is > 0, then we add 3393 // control dependency between 1st input (index 0) of the original node and 3394 // the dummy Mkl node. This is needed because control-flow ops such as Enter, 3395 // Merge, etc, require frame_name of the dummy Mkl node to be same as the 3396 // rewritten node. Adding control edge between 1st input of the original node 3397 // and the dummy Mkl node ensures that the dummy node is in the same frame 3398 // as the original node. Choosing 1st input is not necessary - any input of 3399 // the original node is fine because all the inputs of a node are always in 3400 // the same frame. 3401 if (orig_node->num_inputs() > 0) { 3402 Node* orig_input0 = nullptr; 3403 TF_CHECK_OK( 3404 orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); 3405 // Allow duplicate while adding control edge as it would fail (return 3406 // NULL) if we try to add duplicate edge. 3407 CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true)); 3408 } 3409 3410 (*out)->set_assigned_device_name(orig_node->assigned_device_name()); 3411 } 3412 3413 void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( 3414 std::unique_ptr<Graph>* g, Node* orig_node, NodeBuilder* nb, 3415 std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) { 3416 bool workspace_edge_added = false; // Default initializer 3417 CHECK_NOTNULL(are_ws_tensors_added); 3418 *are_ws_tensors_added = false; // Default initializer 3419 3420 DataType T; 3421 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 3422 for (auto ws : wsinfo_) { 3423 if (orig_node->type_string() == ws.fwd_op && 3424 mkl_op_registry::IsMklOp( 3425 mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { 3426 // If this op is a fwd op, then we need to check if there is an 3427 // edge from this node's fwd_slot to bwdop's bwd_slot. If there is 3428 // an edge, then we just add an attribute on this node for setting 3429 // workspace_passed to true. We don't add actual workspace edge 3430 // in this node. Actual workspace edge gets added in the backward 3431 // op for this node. 3432 for (const Edge* e : orig_node->out_edges()) { 3433 if (e->src_output() == ws.fwd_slot && 3434 e->dst()->type_string() == ws.bwd_op && 3435 e->dst_input() == ws.bwd_slot) { 3436 nb->Attr("workspace_enabled", true); 3437 VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " 3438 << orig_node->type_string(); 3439 workspace_edge_added = true; 3440 // We found the edge that we were looking for, so break. 3441 break; 3442 } 3443 } 3444 3445 if (!workspace_edge_added) { 3446 // If we are here, then we did not find backward operator for this 3447 // node. 3448 nb->Attr("workspace_enabled", false); 3449 } 3450 } else if (orig_node->type_string() == ws.bwd_op && 3451 mkl_op_registry::IsMklOp( 3452 mkl_op_registry::GetMklOpName(orig_node->type_string()), 3453 T)) { 3454 // If this op is a bwd op, then we need to add workspace edge and 3455 // it's Mkl tensor edge between its corresponding fwd op and this 3456 // op. Corresponding fwd op is specified in 'fwd_op' field of 3457 // workspace info. fwd_slot and bwd_slot in workspace info specify 3458 // an edge between which slots connect forward and backward op. 3459 // Once all these criteria match, we add a workspace edge between 3460 // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is 3461 // determined by interleaved/contiguous ordering. Function 3462 // DataIndexToMetaDataIndex tells us the location of Mkl tensor 3463 // from the location of the Tensorflow tensor. 3464 for (const Edge* e : orig_node->in_edges()) { 3465 if (e->src_output() == ws.fwd_slot && 3466 // We would have rewritten the forward op, so we need to use 3467 // GetMklOpName call to get its Mkl name. 3468 e->src()->type_string() == 3469 mkl_op_registry::GetMklOpName(ws.fwd_op) && 3470 e->dst_input() == ws.bwd_slot) { 3471 nb->Attr("workspace_enabled", true); 3472 CHECK_NOTNULL(ws_tensors); 3473 // Add workspace edge between fwd op and bwd op. 3474 ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot)); 3475 // Add Mkl tensor edge for workspace edge between fwd op and bwd op. 3476 ws_tensors->push_back(NodeBuilder::NodeOut( 3477 e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot, 3478 e->src()->num_outputs()))); 3479 *are_ws_tensors_added = true; 3480 // In terms of input ordering, we add these calls to add Input 3481 // here because workspace edge (and its Mkl tensor) is the last 3482 // edge in the fwdop and bwdop. So all inputs before workspace 3483 // tensor have been added by SetUpInputs function. 3484 VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " 3485 << orig_node->type_string(); 3486 workspace_edge_added = true; 3487 // We found the edge that we were looking for, so break. 3488 break; 3489 } 3490 } 3491 3492 // If we are here means we did not find fwd op that feeds to this 3493 // bwd op. So in this case, we need to generate dummy tensors for 3494 // workspace input and Mkl tensor for workspace, and set 3495 // workspace_enabled to false. 3496 if (!workspace_edge_added) { 3497 nb->Attr("workspace_enabled", false); 3498 Node* dmt_ws = nullptr; // Dummy tensor for workspace 3499 Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace 3500 GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node); 3501 GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node); 3502 CHECK_NOTNULL(dmt_ws); 3503 CHECK_NOTNULL(dmt_mkl_ws); 3504 CHECK_NOTNULL(ws_tensors); 3505 // We add dummy tensor as workspace tensor. 3506 ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0)); 3507 // We add dummy tensor as Mkl tensor for workspace tensor. 3508 ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0)); 3509 *are_ws_tensors_added = true; 3510 VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for " 3511 << orig_node->type_string(); 3512 } 3513 } else { 3514 // If this node does not match any workspace info, then we do not 3515 // do anything special for workspace propagation for it. 3516 } 3517 } 3518 } 3519 3520 ////////////////////////////////////////////////////////////////////////// 3521 // Op-specific functions to copy attributes from old node to new node 3522 ////////////////////////////////////////////////////////////////////////// 3523 3524 void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node, 3525 NodeBuilder* nb) { 3526 DataType T; 3527 string data_format; 3528 string padding; 3529 std::vector<int32> strides; 3530 bool use_cudnn_on_gpu; 3531 3532 // Get all attributes from old node. 3533 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 3534 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); 3535 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); 3536 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 3537 TF_CHECK_OK( 3538 GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu)); 3539 3540 // Add attributes to new node. 3541 nb->Attr("T", T); 3542 nb->Attr("strides", strides); 3543 nb->Attr("padding", padding); 3544 nb->Attr("data_format", data_format); 3545 nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu); 3546 } 3547 3548 void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node, 3549 NodeBuilder* nb) { 3550 DataType T; 3551 int N; 3552 3553 // Get all attributes from old node. 3554 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 3555 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); 3556 3557 // Add attributes to new node. 3558 nb->Attr("T", T); 3559 nb->Attr("N", N); 3560 } 3561 3562 void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node, 3563 NodeBuilder* nb) { 3564 DataType T; 3565 string data_format; 3566 std::vector<int32> strides; 3567 3568 // Get all attributes from old node. 3569 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 3570 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); 3571 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 3572 3573 // Add attributes to new node. 3574 nb->Attr("T", T); 3575 nb->Attr("strides", strides); 3576 nb->Attr("data_format", data_format); 3577 } 3578 3579 void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node, 3580 NodeBuilder* nb) { 3581 DataType T; 3582 int depth_radius; 3583 float bias; 3584 float alpha; 3585 float beta; 3586 3587 // Get all attributes from old node. 3588 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 3589 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "depth_radius", &depth_radius)); 3590 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "bias", &bias)); 3591 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha)); 3592 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "beta", &beta)); 3593 3594 // Add attributes to new node. 3595 nb->Attr("T", T); 3596 nb->Attr("depth_radius", depth_radius); 3597 nb->Attr("bias", bias); 3598 nb->Attr("alpha", alpha); 3599 nb->Attr("beta", beta); 3600 } 3601 3602 void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node, 3603 NodeBuilder* nb) { 3604 DataType T; 3605 string data_format; 3606 string padding; 3607 std::vector<int32> ksize, strides; 3608 3609 // Get all attributes from old node. 3610 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 3611 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize)); 3612 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); 3613 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); 3614 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 3615 3616 // Add attributes to new node. 3617 nb->Attr("T", T); 3618 nb->Attr("ksize", ksize); 3619 nb->Attr("strides", strides); 3620 nb->Attr("padding", padding); 3621 nb->Attr("data_format", data_format); 3622 } 3623 3624 void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node, 3625 NodeBuilder* nb) { 3626 DataType T; 3627 3628 // Get all attributes from old node. 3629 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 3630 3631 // Add attributes to new node. 3632 nb->Attr("T", T); 3633 } 3634 3635 void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, 3636 NodeBuilder* nb) { 3637 DataType T; 3638 DataType Tshape; 3639 3640 // Get all attributes from old node. 3641 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 3642 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape)); 3643 // Add attributes to new node. 3644 nb->Attr("T", T); 3645 nb->Attr("Tshape", Tshape); 3646 } 3647 3648 void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node, 3649 NodeBuilder* nb) { 3650 DataType T; 3651 string data_format; 3652 int num_split; 3653 3654 // Get all attributes from old node. 3655 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 3656 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_split", &num_split)); 3657 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 3658 3659 // Add attributes to new node. 3660 nb->Attr("T", T); 3661 nb->Attr("num_split", num_split); 3662 nb->Attr("data_format", data_format); 3663 } 3664 3665 void MklLayoutRewritePass::CopyAttrsConcat(const Node* orig_node, 3666 NodeBuilder* nb) { 3667 DataType T; 3668 int N; 3669 3670 // Get all attributes from old node. 3671 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 3672 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); 3673 3674 // Add attributes to new node. 3675 nb->Attr("T", T); 3676 nb->Attr("N", N); 3677 } 3678 3679 void MklLayoutRewritePass::CopyAttrsConcatV2(const Node* orig_node, 3680 NodeBuilder* nb) { 3681 DataType T; 3682 int N; 3683 DataType tidx; 3684 3685 // Get all attributes from old node. 3686 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 3687 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); 3688 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tidx", &tidx)); 3689 3690 // Add attributes to new node. 3691 nb->Attr("T", T); 3692 nb->Attr("N", N); 3693 nb->Attr("Tidx", tidx); 3694 } 3695 3696 void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node, 3697 NodeBuilder* nb) { 3698 DataType T; 3699 float epsilon; 3700 string data_format; 3701 bool is_training; 3702 3703 // Get all attributes from old node. 3704 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); 3705 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon)); 3706 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); 3707 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "is_training", &is_training)); 3708 3709 // Add attributes to new node. 3710 nb->Attr("T", T); 3711 nb->Attr("epsilon", epsilon); 3712 nb->Attr("data_format", data_format); 3713 nb->Attr("is_training", is_training); 3714 } 3715 3716 ////////////////////////////////////////////////////////////////////////// 3717 // Helper functions related to node merge pass 3718 ////////////////////////////////////////////////////////////////////////// 3719 3720 Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const { 3721 // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite 3722 // once we support BiasAddGrad as Mkl layer. 3723 3724 // Search for all matching mergeinfo. 3725 // We allow more than one match for extensibility. 3726 std::vector<const MergeInfo*> matching_mi; 3727 for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) { 3728 if (a->type_string() == mi->op1 || a->type_string() == mi->op2) { 3729 matching_mi.push_back(&*mi); 3730 } 3731 } 3732 3733 for (const MergeInfo* mi : matching_mi) { 3734 // Get the operand with which 'a' can be merged. 3735 Node* b = nullptr; 3736 if ((b = mi->get_node_to_be_merged(a)) == nullptr) { 3737 continue; 3738 } 3739 3740 // Get the control edges and input of node 3741 const int N_in = a->num_inputs(); 3742 gtl::InlinedVector<Node*, 4> a_control_edges; 3743 gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in); 3744 FillInputs(a, &a_control_edges, &a_in); 3745 3746 const int B_in = b->num_inputs(); 3747 gtl::InlinedVector<Node*, 4> b_control_edges; 3748 gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in); 3749 FillInputs(b, &b_control_edges, &b_in); 3750 3751 // Shouldn't merge if a and b have different control edges. 3752 if (a_control_edges != b_control_edges) { 3753 continue; 3754 } else { 3755 // We found a match. 3756 return b; 3757 } 3758 } 3759 3760 return nullptr; 3761 } 3762 3763 Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, 3764 Node* m, Node* n) { 3765 CHECK_EQ(((m->type_string() == csinfo_.bias_add && 3766 n->type_string() == csinfo_.conv2d)) || 3767 ((n->type_string() == csinfo_.bias_add && 3768 m->type_string() == csinfo_.conv2d)), 3769 true); 3770 3771 // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd, 3772 // BiasAdd is successor node, and Conv2D predecessor node. 3773 Node* pred = m->type_string() == csinfo_.bias_add ? n : m; 3774 Node* succ = m->type_string() == csinfo_.bias_add ? m : n; 3775 3776 // 1. Get all attributes from input nodes. 3777 DataType T_pred, T_succ; 3778 string padding; 3779 std::vector<int32> strides; 3780 string data_format_pred, data_format_succ; 3781 bool use_cudnn_on_gnu; 3782 TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred)); 3783 TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ)); 3784 TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding)); 3785 TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides)); 3786 TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred)); 3787 TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ)); 3788 TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu)); 3789 // We check to ensure that data formats of both succ and pred are same. 3790 // We expect them to be same, so we can enforce this as assert. 3791 // But assert can be too strict, so we enforce this as a check. 3792 // If the check fails, then we do not merge two nodes. 3793 // We also do same check for devices. 3794 if (data_format_pred != data_format_succ || T_pred != T_succ || 3795 pred->assigned_device_name() != succ->assigned_device_name() || 3796 pred->def().device() != succ->def().device()) { 3797 return Status(error::Code::INVALID_ARGUMENT, 3798 "data_format or T attribute or devices of Conv2D and " 3799 "BiasAdd do not match. Will skip node merge optimization"); 3800 } 3801 3802 const int succ_num = succ->num_inputs(); 3803 gtl::InlinedVector<Node*, 4> succ_control_edges; 3804 gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num); 3805 FillInputs(succ, &succ_control_edges, &succ_in); 3806 3807 const int pred_num = pred->num_inputs(); 3808 gtl::InlinedVector<Node*, 4> pred_control_edges; 3809 gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num); 3810 FillInputs(pred, &pred_control_edges, &pred_in); 3811 3812 // We need to ensure that Conv2D only feeds to BiasAdd (some other operator is 3813 // not expecting output of Conv2D). If this is not the case, then we cannot 3814 // merge Conv2D with BiasAdd. 3815 const int kFirstOutputSlot = 0; 3816 for (const Edge* e : pred->out_edges()) { 3817 if (e->src_output() == kFirstOutputSlot && e->dst() != succ) { 3818 return Status(error::Code::INVALID_ARGUMENT, 3819 "Conv2D does not feed to BiasAdd, or " 3820 "it feeds BiasAdd but has multiple outputs. " 3821 "Will skip node merge optimization"); 3822 } 3823 } 3824 3825 // 2. Get inputs from both the nodes. 3826 // Find the 2 inputs from the conv and the bias from the add Bias. 3827 // Get operand 0, 1 of conv2D. 3828 CHECK_EQ(pred->in_edges().size(), 2); // Conv2D must have 2 inputs. 3829 // Get operand 1 of add_bias 3830 // BiasAdd must have 2 inputs: Conv, bias 3831 CHECK_EQ(succ->in_edges().size(), 2); 3832 3833 // We will use the node name of BiasAdd as the name of new node 3834 // Build new node. We use same name as original node, but change the op 3835 // name. 3836 NodeBuilder nb(succ->name(), csinfo_.conv2d_with_bias); 3837 nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D 3838 // pred_in[1] will be 2nd Tensorflow tensor for Conv2D. 3839 nb.Input(pred_in[1].first, pred_in[1].second); // In2 of Conv2D 3840 // In1 of BiasAdd is same as output of Conv2D. 3841 nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd 3842 3843 // Copy attributes from Conv2D to Conv2DWithBias. 3844 CopyAttrsConv2D(const_cast<const Node*>(pred), &nb); 3845 3846 // Copy the device assigned to old node to new node. 3847 nb.Device(succ->def().device()); 3848 3849 // Create node. 3850 Node* new_node; 3851 nb.Finalize(&**g, &new_node); 3852 CHECK_NOTNULL(new_node); 3853 3854 // Incoming data edges from 'pred' node and 'succ' node to new 'new_node' 3855 // node are already copied in BuildNode. We handle control edges now. 3856 for (const Edge* e : pred->in_edges()) { 3857 if (e->IsControlEdge()) { 3858 // Allow duplicate while adding control edge as it would fail (return 3859 // NULL) if we try to add duplicate edge. 3860 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); 3861 } 3862 } 3863 for (const Edge* e : succ->in_edges()) { 3864 if (e->IsControlEdge()) { 3865 // Allow duplicate while adding control edge as it would fail (return 3866 // NULL) if we try to add duplicate edge. 3867 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); 3868 } 3869 } 3870 3871 // Incoming edges are fixed, we will fix the outgoing edges now. 3872 // First, we will fix outgoing control edges from 'pred' node. 3873 for (const Edge* e : pred->out_edges()) { 3874 if (e->IsControlEdge()) { 3875 // Allow duplicate while adding control edge as it would fail (return 3876 // NULL) if we try to add duplicate edge. 3877 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); 3878 } 3879 } 3880 3881 // Second, we will fix outgoing control and data edges from 'succ' node. 3882 for (const Edge* e : succ->out_edges()) { 3883 if (e->IsControlEdge()) { 3884 // Allow duplicate while adding control edge as it would fail (return 3885 // NULL) if we try to add duplicate edge. 3886 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); 3887 } else { 3888 // BiasAdd has only 1 output (at slot 0) and merged node also has only 1 3889 // output (at slot 0). 3890 const int kConv2DWithBiasOutputSlot = 0; 3891 CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, e->dst(), 3892 e->dst_input())); 3893 } 3894 } 3895 3896 // Copy device assigned to old node to new node. 3897 // It's ok to use pred or succ as we have enforced a check that 3898 // both have same device assigned. 3899 new_node->set_assigned_device_name(pred->assigned_device_name()); 3900 3901 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString() 3902 << ", and node: " << succ->DebugString() 3903 << ", into node:" << new_node->DebugString(); 3904 3905 (*g)->RemoveNode(succ); 3906 (*g)->RemoveNode(pred); 3907 3908 return Status::OK(); 3909 } 3910 3911 Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad( 3912 std::unique_ptr<Graph>* g, Node* m, Node* n) { 3913 CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad && 3914 n->type_string() == csinfo_.conv2d_grad_filter)) || 3915 ((n->type_string() == csinfo_.bias_add_grad && 3916 m->type_string() == csinfo_.conv2d_grad_filter)), 3917 true); 3918 3919 // If 'm' is BiasAddGrad, then 'n' is BackpropFilter. 3920 Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n; 3921 Node* fltr = m->type_string() == csinfo_.bias_add_grad ? n : m; 3922 3923 // Sanity check for attributes from input nodes. 3924 DataType T_b, T_f; 3925 string data_format_b, data_format_f; 3926 TF_CHECK_OK(GetNodeAttr(badd->def(), "T", &T_b)); 3927 TF_CHECK_OK(GetNodeAttr(fltr->def(), "T", &T_f)); 3928 TF_CHECK_OK(GetNodeAttr(badd->def(), "data_format", &data_format_b)); 3929 TF_CHECK_OK(GetNodeAttr(fltr->def(), "data_format", &data_format_f)); 3930 if (data_format_b != data_format_f || T_b != T_f || 3931 badd->assigned_device_name() != fltr->assigned_device_name() || 3932 badd->def().device() != fltr->def().device()) { 3933 return Status(error::Code::INVALID_ARGUMENT, 3934 "data_format or T attribute or devices of " 3935 "Conv2DBackpropFilter and BiasAddGrad do not match. " 3936 "Will skip node merge optimization"); 3937 } 3938 3939 // We will use the node name of Conv2DBackpropFilter as the name of new node. 3940 // This is because BackpropFilterWithBias is going to emit bias output also. 3941 NodeBuilder nb(fltr->name(), csinfo_.conv2d_grad_filter_with_bias); 3942 // Since Conv2DBackpropFilterWithBias has same number of inputs as 3943 // Conv2DBackpropFilter, we can just copy input edges directly. We dont need 3944 // to copy any data input of BiasAddGrad because that input also goes to 3945 // Conv2DBackpropFilter. 3946 const int fltr_ins = fltr->num_inputs(); 3947 gtl::InlinedVector<Node*, 4> fltr_control_edges; 3948 gtl::InlinedVector<std::pair<Node*, int>, 4> fltr_in_edges(fltr_ins); 3949 FillInputs(fltr, &fltr_control_edges, &fltr_in_edges); 3950 for (int idx = 0; idx < fltr_ins; idx++) { 3951 nb.Input(fltr_in_edges[idx].first, fltr_in_edges[idx].second); 3952 } 3953 3954 // Copy attributes from Conv2DBackpropFilter. 3955 CopyAttrsConv2D(const_cast<const Node*>(fltr), &nb); 3956 3957 // Copy the device assigned to old node to new node. 3958 nb.Device(fltr->def().device()); 3959 3960 // Create node. 3961 Node* new_node; 3962 nb.Finalize(&**g, &new_node); 3963 CHECK_NOTNULL(new_node); 3964 3965 // Incoming data edges from BiasAddGrad node and Conv2DBackpropFilter node to 3966 // new 'new_node' node are already copied in BuildNode. We handle control 3967 // edges now. 3968 for (const Edge* e : badd->in_edges()) { 3969 if (e->IsControlEdge()) { 3970 // Allow duplicate while adding control edge as it would fail (return 3971 // NULL) if we try to add duplicate edge. 3972 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); 3973 } 3974 } 3975 for (const Edge* e : fltr->in_edges()) { 3976 if (e->IsControlEdge()) { 3977 // Allow duplicate while adding control edge as it would fail (return 3978 // NULL) if we try to add duplicate edge. 3979 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); 3980 } 3981 } 3982 3983 // Incoming edges are fixed, we will fix the outgoing edges now. 3984 // First, we will fix outgoing control edges from 'badd' node. 3985 // Conv2DBackpropFilter has 1 output -- filter_grad. 3986 // Conv2DBackpropFilterWithBias has 2 outputs -- filter_grad and 3987 // bias_grad. But filter_grad is at same slot number (0) in both the 3988 // nodes. bias_grad is at slot number 1 in Conv2DBackpropFilterWithBias, while 3989 // it is at slot number 0 in BiasAddGrad. 3990 const int kMergedNodeFilterGradOutputIdx = 0; 3991 const int kMergedNodeBiasGradOutputIdx = 1; 3992 3993 for (const Edge* e : badd->out_edges()) { 3994 if (e->IsControlEdge()) { 3995 // Allow duplicate while adding control edge as it would fail (return 3996 // NULL) if we try to add duplicate edge. 3997 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); 3998 } else { 3999 CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx, 4000 e->dst(), e->dst_input())); 4001 } 4002 } 4003 4004 // Second, we will fix outgoing control and data edges from 'fltr' node. 4005 for (const Edge* e : fltr->out_edges()) { 4006 if (e->IsControlEdge()) { 4007 // We allow duplicate edge for this case since we already add control 4008 // edge from new_node in line 3990. Line below could be adding same 4009 // edge to same destination again. In such case, if we do not allow 4010 // duplicate edge, then this call will fail. 4011 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); 4012 } else { 4013 CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx, 4014 e->dst(), e->dst_input())); 4015 } 4016 } 4017 4018 // Copy device assigned to old node to new node. 4019 // It's ok to use badd or fltr as we have enforced a check that 4020 // both have same device assigned. 4021 new_node->set_assigned_device_name(badd->assigned_device_name()); 4022 4023 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << badd->DebugString() 4024 << ", and node: " << fltr->DebugString() 4025 << ", into node:" << new_node->DebugString(); 4026 4027 (*g)->RemoveNode(badd); 4028 (*g)->RemoveNode(fltr); 4029 4030 return Status::OK(); 4031 } 4032 4033 Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* m, 4034 Node* n) { 4035 CHECK_NOTNULL(m); 4036 CHECK_NOTNULL(n); 4037 4038 if (((m->type_string() == csinfo_.bias_add && 4039 n->type_string() == csinfo_.conv2d)) || 4040 ((n->type_string() == csinfo_.bias_add && 4041 m->type_string() == csinfo_.conv2d))) { 4042 return this->MergeConv2DWithBiasAdd(g, m, n); 4043 } 4044 4045 if (((m->type_string() == csinfo_.bias_add_grad && 4046 n->type_string() == csinfo_.conv2d_grad_filter)) || 4047 ((n->type_string() == csinfo_.bias_add_grad && 4048 m->type_string() == csinfo_.conv2d_grad_filter))) { 4049 return this->MergeConv2DBackpropFilterWithBiasAddGrad(g, m, n); 4050 } 4051 4052 return Status(error::Code::UNIMPLEMENTED, 4053 "Unimplemented case for node merge optimization."); 4054 } 4055 4056 ////////////////////////////////////////////////////////////////////////// 4057 // Helper functions for node rewrite 4058 ////////////////////////////////////////////////////////////////////////// 4059 4060 Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, 4061 Node* orig_node, 4062 const RewriteInfo* ri) { 4063 CHECK_NOTNULL(ri); 4064 CHECK_NOTNULL(orig_node); 4065 4066 VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString(); 4067 4068 // Get all inputs. 4069 int num_inputs = orig_node->in_edges().size(); 4070 4071 // Drop count for control edges from inputs 4072 for (const Edge* e : orig_node->in_edges()) { 4073 if (e->IsControlEdge()) { 4074 num_inputs--; 4075 } 4076 } 4077 4078 gtl::InlinedVector<Node*, 4> control_edges; 4079 gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs); 4080 FillInputs(orig_node, &control_edges, &inputs); 4081 4082 // Build new node. We use same name as original node, but change the op name. 4083 NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str()); 4084 // Copy user-specified device assigned to original node to new node. 4085 nb.Device(orig_node->def().device()); 4086 // Set up new inputs to the rewritten node. 4087 Status s = SetUpInputs(g, inputs, &nb, orig_node); 4088 if (s != Status::OK()) { 4089 return s; 4090 } 4091 4092 ri->copy_attrs(const_cast<const Node*>(orig_node), &nb); 4093 // Set the Mkl layer label for this op. 4094 nb.Attr("_kernel", mkl_op_registry::kMklOpLabel); 4095 4096 // Finalize graph and get new node. 4097 Node* new_node = nullptr; 4098 TF_CHECK_OK(nb.Finalize(&**g, &new_node)); 4099 CHECK_NOTNULL(new_node); 4100 4101 // Incoming data edges from 'orig_node' node to new 'new_node' node are 4102 // already copied in BuildNode. We need to handle control edges now. 4103 for (const Edge* e : orig_node->in_edges()) { 4104 if (e->IsControlEdge()) { 4105 // Allow duplicate while adding control edge as it would fail (return 4106 // NULL) if we try to add duplicate edge. 4107 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); 4108 } 4109 } 4110 4111 // Copy outgoing edges from 'orig_node' node to new 4112 // 'new_node' node, since the output also follows same ordering among 4113 // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow 4114 // tensors appropriately. Specifically, nth output of the original node 4115 // will become 2*nth output of the Mkl node for the interleaved ordering 4116 // of the tensors. For the contiguous ordering of the tensors, it will be n. 4117 // GetTensorDataIndex provides this mapping function. 4118 for (const Edge* e : orig_node->out_edges()) { 4119 if (e->IsControlEdge()) { 4120 // Allow duplicate while adding control edge as it would fail (return 4121 // NULL) if we try to add duplicate edge. 4122 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); 4123 } else { 4124 CHECK_NOTNULL((*g)->AddEdge( 4125 new_node, 4126 GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), 4127 e->dst(), e->dst_input())); 4128 } 4129 } 4130 4131 // Copy the runtime device assigned from original code to new node. 4132 new_node->set_assigned_device_name(orig_node->assigned_device_name()); 4133 4134 // Delete original node and mark new node as rewritten. 4135 (*g)->RemoveNode(orig_node); 4136 4137 VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString(); 4138 return Status::OK(); 4139 } 4140 4141 const MklLayoutRewritePass::RewriteInfo* 4142 MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { 4143 CHECK_NOTNULL(n); 4144 4145 // First check if node along with its type is supported by MKL layer. 4146 // We do not want to rewrite an op into Mkl op if types are not supported. 4147 // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to 4148 // MklRelu if type is INT32. 4149 DataType T; 4150 if (!GetNodeAttr(n->def(), "T", &T).ok()) { 4151 return nullptr; 4152 } 4153 4154 // We make an exception for __MklDummyConv2DWithBias and 4155 // __MklConv2DBackpropFilterWithBias since their names do not match Mkl node 4156 // names. 4157 if (n->type_string() != csinfo_.conv2d_with_bias && 4158 n->type_string() != csinfo_.conv2d_grad_filter_with_bias && 4159 !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), 4160 T)) { 4161 return nullptr; 4162 } 4163 4164 // For elementwise node, we reuse the Eigen implementation and pass the MKL 4165 // metadata tensor through so we can avoid conversions. However, if all 4166 // incoming edges are in TF format, we don't need all this overhead, so 4167 // replace the elementwise node only if at least one of its parents is a MKL 4168 // node. 4169 // 4170 // Identity nodes can also skip replacement if they are not being served by 4171 // any MKL nodes. 4172 // 4173 // TODO(vrane): Add implementation for element-wise ops that doesn't reuse 4174 // eigen code to reduce cross-library dependency. 4175 VLOG(1) << "ELEMENTWISE: checking op: " << n->type_string(); 4176 if (mkl_op_registry::IsMklElementWiseOp( 4177 mkl_op_registry::GetMklOpName(n->type_string()), T) || 4178 n->type_string().find("Identity") != string::npos) { 4179 VLOG(1) << "ELEMENTWISE: op is elementwise: " << n->type_string(); 4180 bool incoming_mkl_edge = false; 4181 int num_parent = 0; 4182 for (auto parent : n->in_edges()) { 4183 if (mkl_op_registry::IsMklOp(parent->src()->type_string(), T)) { 4184 VLOG(1) << "ELEMENTWISE: parent " << num_parent++ 4185 << " is MKL op: " << parent->src()->type_string(); 4186 incoming_mkl_edge = true; 4187 break; 4188 } else { 4189 VLOG(1) << "ELEMENTWISE: parent " << num_parent++ 4190 << " is NON-MKL op: " << parent->src()->type_string(); 4191 } 4192 } 4193 if (incoming_mkl_edge == false) { 4194 VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which " 4195 "has no MKL " 4196 "parents."; 4197 return nullptr; 4198 } else { 4199 VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string() 4200 << " which has MKL parents"; 4201 } 4202 } 4203 4204 // We now check if rewrite rule applies for this op. If rewrite rule passes 4205 // for this op, then we rewrite it to Mkl op. 4206 // Find matching RewriteInfo and then check that rewrite rule applies. 4207 for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { 4208 if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) { 4209 return &*ri; 4210 } 4211 } 4212 4213 // Else return not found. 4214 return nullptr; 4215 } 4216 4217 /////////////////////////////////////////////////////////////////////////////// 4218 // Run function for the pass 4219 /////////////////////////////////////////////////////////////////////////////// 4220 4221 bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) { 4222 bool result = false; 4223 CHECK_NOTNULL(g); 4224 4225 DumpGraph("Before running MklLayoutRewritePass", &**g); 4226 4227 std::vector<Node*> order; 4228 GetReversePostOrder(**g, &order); // This will give us topological sort. 4229 for (Node* n : order) { 4230 // If node is not an op or it cannot run on CPU device, then skip. 4231 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { 4232 continue; 4233 } 4234 4235 Node* m = nullptr; 4236 if ((m = CheckForNodeMerge(n)) != nullptr && CanOpRunOnCPUDevice(m)) { 4237 // Check if the node 'n' can be merged with any other node. If it can 4238 // be 'm' contains the node with which it can be merged. 4239 string n1_name = n->name(); 4240 string n2_name = m->name(); 4241 4242 VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and " 4243 << n2_name << " for merging"; 4244 4245 if (MergeNode(g, n, m) == Status::OK()) { 4246 VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and " 4247 << n2_name; 4248 result = true; 4249 } 4250 } 4251 } 4252 4253 DumpGraph("After running MklLayoutRewritePass(NodeMerge)", &**g); 4254 4255 order.clear(); 4256 GetReversePostOrder(**g, &order); // This will give us topological sort. 4257 for (Node* n : order) { 4258 // If node is not an op or it cannot run on CPU device, then skip. 4259 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { 4260 continue; 4261 } 4262 4263 const RewriteInfo* ri = nullptr; 4264 // We will first search if node is to be rewritten. 4265 if ((ri = CheckForNodeRewrite(n)) != nullptr) { 4266 string node_name = n->name(); 4267 string op_name = n->type_string(); 4268 4269 VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name 4270 << " with op " << op_name << " for rewrite using" 4271 << " layout optimization."; 4272 4273 if (RewriteNode(g, n, ri) == Status::OK()) { 4274 VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name 4275 << " with op " << op_name << " for Mkl layout optimization."; 4276 result = true; 4277 } 4278 } 4279 } 4280 4281 DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g); 4282 4283 return result; 4284 } 4285 4286 bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) { 4287 return MklLayoutRewritePass().RunPass(g); 4288 } 4289 4290 Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { 4291 if (options.graph == nullptr && options.partition_graphs == nullptr) { 4292 return Status::OK(); 4293 } 4294 4295 auto process_graph = [&](std::unique_ptr<Graph>* g) { 4296 // Get the ownership of a graph 4297 std::unique_ptr<Graph>* ng = std::move(g); 4298 RunPass(ng); 4299 // Return the ownership of a graph back 4300 g->reset(ng->release()); 4301 }; 4302 4303 if (kMklLayoutRewritePassGroup != 4304 OptimizationPassRegistry::POST_PARTITIONING) { 4305 // For any pre-partitioning phase, a graph is stored in options.graph. 4306 process_graph(options.graph); 4307 } else { 4308 // For post partitioning phase, graphs are stored in 4309 // options.partition_graphs. 4310 for (auto& pg : *options.partition_graphs) { 4311 process_graph(&pg.second); 4312 } 4313 } 4314 4315 return Status::OK(); 4316 } 4317 #endif // INTEL_MKL_ML 4318 } // namespace tensorflow 4319 4320 #endif 4321