Home | History | Annotate | Download | only in graph
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifdef INTEL_MKL
     17 
     18 #include <algorithm>
     19 #include <functional>
     20 #include <memory>
     21 #include <queue>
     22 #include <set>
     23 #include <string>
     24 #include <unordered_set>
     25 #include <utility>
     26 #include <vector>
     27 #include "tensorflow/core/common_runtime/function.h"
     28 #include "tensorflow/core/common_runtime/optimization_registry.h"
     29 #include "tensorflow/core/framework/node_def_util.h"
     30 #include "tensorflow/core/graph/algorithm.h"
     31 #include "tensorflow/core/graph/graph.h"
     32 #include "tensorflow/core/graph/node_builder.h"
     33 #include "tensorflow/core/lib/core/status.h"
     34 #include "tensorflow/core/lib/gtl/array_slice.h"
     35 #include "tensorflow/core/lib/gtl/map_util.h"
     36 #include "tensorflow/core/lib/hash/hash.h"
     37 #include "tensorflow/core/platform/logging.h"
     38 #include "tensorflow/core/util/tensor_format.h"
     39 
     40 #include "tensorflow/core/graph/mkl_graph_util.h"
     41 #include "tensorflow/core/graph/mkl_layout_pass.h"
     42 
     43 namespace tensorflow {
     44 
     45 #ifdef INTEL_MKL_ML
     46 
     47 // This pass implements rewriting of graph to support following scenarios:
     48 // (A) Merging nodes in the graph
     49 // (B) Rewriting a node in the graph to a new node
     50 //     Rewrite happens under following 2 scenarios:
     51 //     1) Propagating Mkl layout as an additional output tensor
     52 //        (we will loosely call a tensor that carries Mkl layout as Mkl tensor
     53 //         henceforth.) from every Mkl supported NN layer.
     54 //     2) Context-based rewrite: This is needed in order to optimize
     55 //        gradient ops of Conv2D+AddBias. Gradient op of both the Conv2D and
     56 //        MatMul is BiasAddGrad, and we need to rewrite BiasAddGrad into
     57 //        Conv2D-specific BiasAddGrad, and MatMul-specific BiasAddGrad.
     58 //        This is context-specific optimization, where the context is the
     59 //        forward operator that the BiasAddGrad corresponds to.
     60 //
     61 // Example of A : Merging nodes in the graph
     62 // -----------------------------------------
     63 // Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as:
     64 //
     65 //           O = Conv2D(A, B)
     66 //           P = BiasAdd(O, C)
     67 //
     68 // We merge them into Conv2DWithBias as:
     69 //           P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
     70 //
     71 // The meaning of A_m, B_m and C_m is explained in B.1.
     72 //
     73 // Merge rules:
     74 //  - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_
     75 //    goes to BiasAdd.
     76 //  - Also, the intersection of attributes of both the nodes must have same
     77 //    values.
     78 //  - Both the nodes must have been assigned to same device (if any).
     79 //
     80 // Example of B.1 : Rewriting nodes to Mkl nodes
     81 // ---------------------------------------------
     82 // Consider a Relu node. Current definition of Relu node looks like:
     83 //
     84 //           O = Relu(A)
     85 //
     86 // Relu has 1 input (A), and 1 output (O).
     87 //
     88 // This rewrite pass will generate a new graph node for Relu (new node is
     89 // called MklRelu) as:
     90 //
     91 //          O, O_m = MklRelu(A, A_m)
     92 //
     93 // MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is
     94 // same as input A of Relu; output O is same as output O of Relu. O_m is the
     95 // additional output tensor that will be set by MklRelu, and it represents
     96 // Mkl tensor corresponding to O -- in other words, O_m is some kind of
     97 // metadata for O. A_m is additional input of Relu, and it represents metadata
     98 // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives
     99 // this metadata from previous node in the graph.
    100 //
    101 // When a previous node in the graph is an Mkl node, A_m will represent a valid
    102 // Mkl tensor. But when a previous node is not an Mkl node, A_m will represent
    103 // a dummy Mkl tensor.
    104 //
    105 // Rewriting rules:
    106 //  - Selection of a node for rewriting happens by registering the op type of
    107 //    the node with the rewriting pass. If the op type is not registered, then
    108 //    all nodes of this op type will not be rewritten.
    109 //  - Number of inputs after rewriting:
    110 //      Since for every input Tensorflow tensor, the rewritten node gets Mkl
    111 //      tensor(s), rewritten node gets 2*N inputs, where N is the number of
    112 //      inputs for the original node.
    113 //  - Number of outputs after rewriting:
    114 //      Since for every output Tensorflow tensor, the rewritten node generates
    115 //      Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the
    116 //      number of outputs of the original node.
    117 //  - Ordering of Tensorflow tensors and Mkl tensors:
    118 //      Since every rewritten node generates twice the number of inputs and
    119 //      outputs, one could imagine various orderings among Tensorflow tensors
    120 //      and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as
    121 //      inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m
    122 //      in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m
    123 //      order. Among N inputs one can get N! permutations.
    124 //
    125 //      So the question is: which order do we follow? We support 2 types of
    126 //      orderings: (1) interleaved, and (2) contiguous. Interleaved ordering
    127 //      follows an intuitive order where an Mkl tensor follows the
    128 //      corresponding Tensorflow tensor immediately. In the context of the
    129 //      above example, it will be: A, A_m, B, B_m. Note that the ordering rule
    130 //      applies to both the inputs and outputs. Contiguous ordering means
    131 //      all the Tensorflow tensors are contiguous followed by all the Mkl
    132 //      tensors. We use contiguous ordering as default.
    133 //
    134 // Graph rewrite algorithm:
    135 //      Algorithm: Graph Rewrite
    136 //      Input: Graph G, Names of the nodes to rewrite and their new names
    137 //      Output: Modified Graph G' if the nodes are modified, G otherwise.
    138 //      Start:
    139 //        N = Topological_Sort(G) // N is a set of nodes in toposort order.
    140 //        foreach node n in N
    141 //        do
    142 //          if (Is_MKL_Op(n))  // Can this node accept an Mkl layout as input.
    143 //          then
    144 //            E = set of <incoming edge and its src_output slot> of n
    145 //            E' = {}   // a new set of edges for rewritten node
    146 //            foreach <e,s> in E
    147 //            do
    148 //              E' U {<e,s>}  // First copy edge which generates Tensorflow
    149 //                            // tensor as it is
    150 //              m = Source node of edge e
    151 //              if Is_Rewritten(m)  // Did we rewrite this node in this pass?
    152 //              then
    153 //                E' U {<m,s+1>}    // If yes, then m will generate an Mkl
    154 //                                  // tensor as an additional output.
    155 //              else
    156 //                d = Generate_Dummy_Mkl_Tensor()  // If not, generate a dummy
    157 //                                                 // Mkl tensor.
    158 //                E' U {<d,0>}  // The dummy Mkl tensor has only 1 output slot.
    159 //              fi
    160 //            done
    161 //            n' = Build_New_Node(G,new_name,E')
    162 //            Mark_Rewritten(n')  // Mark the new node as being rewritten.
    163 //          fi
    164 //        done
    165 //
    166 //      Explanation:
    167 //        For graph rewrite, we visit nodes of the input graph in the
    168 //        topological sort order. With this ordering, we visit nodes in the
    169 //        top-to-bottom fashion. We need this order because while visiting a
    170 //        node we want that all of its input nodes are visited and rewritten if
    171 //        applicable. This is because if we need to rewrite a given node
    172 //        then all of its input nodes need to be fixed (in other words they
    173 //        cannot be deleted later.)
    174 //
    175 //        While visiting a node, we first check if the op type of the node is
    176 //        an Mkl op. If it is, then we rewrite that node after constructing
    177 //        new inputs to the node. If the op type of the node is not Mkl op,
    178 //        then we do not rewrite that node.
    179 //
    180 // Handling workspace propagation for certain ops:
    181 //
    182 //        Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require
    183 //        passing of a workspace from their respective forward ops. Workspace
    184 //        tensors provide memory for storing results of intermediate operations
    185 //        which are helpful in backward propagation. TensorFlow does not have
    186 //        a notion of a workspace and as a result does not allow producing
    187 //        additional outputs from these forward ops. For these ops, we need
    188 //        to add 2 extra edges between forward ops and their corresponding
    189 //        backward ops - the first extra edge carries a workspace tensor and
    190 //        the second one carries an Mkl tensor for the workspace tensor.
    191 //
    192 //        Example:
    193 //
    194 //        Typical graph for MaxPool and its gradient looks like:
    195 //
    196 //        A = MaxPool(T)
    197 //        B = MaxPoolGrad(X, A, Y)
    198 //
    199 //        We will transform this graph to propagate the workspace as:
    200 //        (with the contiguous ordering)
    201 //
    202 //        A, W, A_m, W_m = MklMaxPool(T, T_m)
    203 //        B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m)
    204 //
    205 //        Here W is the workspace tensor. Transformed tensor names with the
    206 //        suffix _m are Mkl tensors, and this transformation has been done
    207 //        using the algorithm discussed earlier. The transformation for
    208 //        workspace propagation only adds extra outputs (W, W_m) for a forward
    209 //        op and connects them to the corresponding backward ops.
    210 //
    211 //        Terms:
    212 //
    213 //        Forward op name = name of the op in the forward pass
    214 //          where a workspace tensor originates (MaxPool in this example)
    215 //        Backward op name = name of the op in the backward pass that receives
    216 //          a workspace tensor from the forward op (MaxPoolGrad in the example)
    217 //        Slot = Position of the output or input slot that will be
    218 //               used by the workspace tensor (1 for MklMaxPool as W is the 2nd
    219 //               output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad)
    220 //
    221 //        Question:
    222 //
    223 //        How do we associate a backward op to a forward op? There can be more
    224 //        than one op with the exact same name.
    225 //
    226 //        In this example, we associate MaxPoolGrad with MaxPool. But there
    227 //        could be more than one MaxPool ops. To solve this problem, we look
    228 //        for _direct_ edge between a forward op and a backward op (tensor A is
    229 //        flowing along this edge in the example).
    230 //
    231 //        How do we transform forward and backward ops when there is no direct
    232 //        edge between them? In such a case, we generate dummy tensors for
    233 //        workspace tensors. For the example, transformation of MaxPool will
    234 //        be exactly same as it would be when there is a direct edge between
    235 //        the forward and the backward op --- it is just that MaxPool won't
    236 //        generate any workspace tensor. For MaxPoolGrad, the transformation
    237 //        will also be same, but instead of connecting W and W_m with the
    238 //        outputs of MaxPool, we will produce dummy tensors for them, and we
    239 //        will set workspace_enabled attribute to false.
    240 //
    241 // Example of B.2 : Context-based node rewrite
    242 // -------------------------------------------
    243 // Consider BiasAddGrad op as:
    244 //
    245 //           O = _MklConv2D(A, B, C, A_m, B_m, C_m)
    246 //           P = BiasAddGrad(O)
    247 //
    248 // Then we rewrite it as:
    249 //
    250 //           P = Conv2DWithBiasBackpropBias(O, O_m)
    251 //
    252 // Rewrite of BiasAddGrad into Conv2DWithBiasBackpropBias takes place depending
    253 // on the matching 'context'. The term context is loosely related to which
    254 // forward op is _associated_ to BiasAddGrad. If it is _MklConv2DWithBias then
    255 // we consider it Conv2D context; if it is MatMul, then it is MatMul context.
    256 
    257 class MklLayoutRewritePass : public GraphOptimizationPass {
    258  public:
    259   MklLayoutRewritePass() {
    260     // NOTE: names are alphabetically sorted.
    261     csinfo_.addn = "AddN";
    262     csinfo_.avg_pool = "AvgPool";
    263     csinfo_.avg_pool_grad = "AvgPoolGrad";
    264     csinfo_.bias_add = "BiasAdd";
    265     csinfo_.bias_add_grad = "BiasAddGrad";
    266     csinfo_.concat = "Concat";
    267     csinfo_.concatv2 = "ConcatV2";
    268     csinfo_.conv2d = "Conv2D";
    269     csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
    270     csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
    271     csinfo_.fused_batch_norm = "FusedBatchNorm";
    272     csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
    273     csinfo_.identity = "Identity";
    274     csinfo_.lrn = "LRN";
    275     csinfo_.lrn_grad = "LRNGrad";
    276     csinfo_.matmul = "MatMul";
    277     csinfo_.max_pool = "MaxPool";
    278     csinfo_.max_pool_grad = "MaxPoolGrad";
    279     csinfo_.mkl_conv2d = "_MklConv2D";
    280     csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
    281     csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
    282     csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
    283     csinfo_.mkl_conv2d_with_bias_backprop_bias =
    284         "_MklConv2DWithBiasBackpropBias";
    285     csinfo_.relu = "Relu";
    286     csinfo_.relu_grad = "ReluGrad";
    287     csinfo_.reshape = "Reshape";
    288     csinfo_.split = "Split";
    289     // Element-wise ops. Ensure you also add any new ops to IsOpElementWise
    290     // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
    291     // MklInputConversion op is added before it.
    292     csinfo_.add = "Add";
    293     csinfo_.maximum = "Maximum";
    294     csinfo_.mul = "Mul";
    295     csinfo_.squared_difference = "SquaredDifference";
    296     csinfo_.sub = "Sub";
    297     // End - element-wise ops. See note above.
    298 
    299     // NOTE: names are alphabetically sorted.
    300     rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
    301                       CopyAttrsAddN, AddNRewrite, nullptr});
    302     rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
    303                       CopyAttrsDataType, AlwaysRewrite, nullptr});
    304     rinfo_.push_back({csinfo_.avg_pool,
    305                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
    306                       CopyAttrsPooling, AlwaysRewrite, nullptr});
    307     rinfo_.push_back({csinfo_.avg_pool_grad,
    308                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
    309                       CopyAttrsPooling, AlwaysRewrite, nullptr});
    310     // BiasAddGrad gets written into Conv2DWithBiasBackpropBias depending
    311     // on if context contains Conv2D.
    312     rinfo_.push_back({csinfo_.bias_add_grad,
    313                       csinfo_.mkl_conv2d_with_bias_backprop_bias,
    314                       CopyAttrsBiasAddGrad, ContextMatchRewrite,
    315                       &biasaddgrad_conv2dwithbias_context_});
    316     // BiasAddGrad gets written into BiasAddGrad depending on if context
    317     // contains MatMul.
    318     rinfo_.push_back({csinfo_.bias_add_grad, csinfo_.matmul,
    319                       CopyAttrsBiasAddGrad, ContextMatchRewrite,
    320                       &biasaddgrad_matmul_context_});
    321     rinfo_.push_back({csinfo_.concat,
    322                       mkl_op_registry::GetMklOpName(csinfo_.concat),
    323                       CopyAttrsConcat, AlwaysRewrite, nullptr});
    324     rinfo_.push_back({csinfo_.concatv2,
    325                       mkl_op_registry::GetMklOpName(csinfo_.concatv2),
    326                       CopyAttrsConcatV2, AlwaysRewrite, nullptr});
    327     rinfo_.push_back({csinfo_.conv2d,
    328                       mkl_op_registry::GetMklOpName(csinfo_.conv2d),
    329                       CopyAttrsConv2D, AlwaysRewrite, nullptr});
    330     rinfo_.push_back({csinfo_.conv2d_grad_filter,
    331                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
    332                       CopyAttrsConv2D, AlwaysRewrite, nullptr});
    333     rinfo_.push_back({csinfo_.conv2d_grad_input,
    334                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
    335                       CopyAttrsConv2D, AlwaysRewrite, nullptr});
    336     rinfo_.push_back({csinfo_.fused_batch_norm,
    337                       mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
    338                       CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
    339     rinfo_.push_back(
    340         {csinfo_.fused_batch_norm_grad,
    341          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
    342          CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
    343     rinfo_.push_back({csinfo_.identity,
    344                       mkl_op_registry::GetMklOpName(csinfo_.identity),
    345                       CopyAttrsIdentity, AlwaysRewrite, nullptr});
    346     rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
    347                       CopyAttrsLRN, AlwaysRewrite, nullptr});
    348     rinfo_.push_back({csinfo_.lrn_grad,
    349                       mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
    350                       CopyAttrsLRN, AlwaysRewrite, nullptr});
    351     rinfo_.push_back({csinfo_.max_pool,
    352                       mkl_op_registry::GetMklOpName(csinfo_.max_pool),
    353                       CopyAttrsPooling, NonDepthBatchWisePoolRewrite, nullptr});
    354     rinfo_.push_back({csinfo_.max_pool_grad,
    355                       mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
    356                       CopyAttrsPooling, AlwaysRewrite, nullptr});
    357     rinfo_.push_back({csinfo_.maximum,
    358                       mkl_op_registry::GetMklOpName(csinfo_.maximum),
    359                       CopyAttrsDataType, AlwaysRewrite, nullptr});
    360     rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
    361                       CopyAttrsDataType, AlwaysRewrite, nullptr});
    362     rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
    363                       CopyAttrsDataType, AlwaysRewrite, nullptr});
    364     rinfo_.push_back({csinfo_.relu_grad,
    365                       mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
    366                       CopyAttrsDataType, AlwaysRewrite, nullptr});
    367     rinfo_.push_back({csinfo_.reshape,
    368                       mkl_op_registry::GetMklOpName(csinfo_.reshape),
    369                       CopyAttrsReshape, AlwaysRewrite, nullptr});
    370     rinfo_.push_back({csinfo_.squared_difference,
    371                       mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
    372                       CopyAttrsDataType, AlwaysRewrite, nullptr});
    373     rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
    374                       CopyAttrsDataType, AlwaysRewrite, nullptr});
    375 
    376     // Add info about which ops to add workspace edge to and the slots.
    377     wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
    378     wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
    379 
    380     // Add a rule for merging nodes
    381     minfo_.push_back({csinfo_.mkl_conv2d, csinfo_.bias_add, 0,
    382                       csinfo_.mkl_conv2d_with_bias});
    383 
    384     biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul,
    385                                    IsBiasAddGradInMatMulContext};
    386 
    387     biasaddgrad_conv2dwithbias_context_ = {
    388         csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias,
    389         IsBiasAddGradInConv2DWithBiasContext};
    390 
    391     cinfo_.push_back(&biasaddgrad_matmul_context_);
    392     cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_);
    393   }
    394 
    395   // Standard interface to run pass
    396   Status Run(const GraphOptimizationPassOptions& options);
    397 
    398   // Helper function which does most of heavy lifting for rewriting
    399   // Mkl nodes to propagate Mkl tensor as additional output
    400   //
    401   // Extracts common functionality between Run public interface and
    402   // test interface.
    403   //
    404   // @return true, if and only if graph is mutated; false otherwise.
    405   bool RunPass(std::unique_ptr<Graph>* g);
    406 
    407   /// Structure to specify the context information used in a node rewrite rule
    408   typedef struct {
    409     string node;  // Name of the node to be rewritten
    410     string fwd;   // Name of the node in the forward pass that this node
    411                   // corresponds to
    412     std::function<bool(const Node*, const Node**, void* c)> context_match_fn;
    413   } ContextInfo;
    414 
    415   /// Structure to specify the name of an original node, its new name after
    416   /// rewrite, the number of inputs to the original node, the function to
    417   /// be used to copy attributes for the op, and the rule (if any) which
    418   /// must hold for rewriting the node
    419   typedef struct {
    420     string name;      // Original name of op of the node in the graph
    421     string new_name;  // New name of the op of the node in the graph
    422     // A function handler to copy attributes from an old node to a new node.
    423     std::function<void(const Node*, NodeBuilder*)> copy_attrs;
    424     // A rule under which to rewrite this node
    425     std::function<bool(const Node*, const ContextInfo* c)> rewrite_rule;
    426     // ContextInfo, if any, to be used for rewrite
    427     ContextInfo* context;
    428   } RewriteInfo;
    429 
    430   /// Structure to specify a forward op, a backward op, and the slot numbers
    431   /// in the forward and backward ops where we will add a workspace edge.
    432   typedef struct {
    433     string fwd_op;    // Name of a forward op in the graph
    434     string bwd_op;    // Name of a backward op in the graph
    435     int fwd_slot;     // Output slot in the forward op node where actual
    436                       // output tensor resides
    437     int bwd_slot;     // Input slot in the backward op node where actual
    438                       // input tensor resides
    439     int ws_fwd_slot;  // Output slot in the forward op node where workspace
    440                       // edge is added
    441     int ws_bwd_slot;  // Input slot in the backward op node where workspace
    442                       // edge is added
    443   } WorkSpaceInfo;
    444 
    445   /// Structure to specify information used in node merge
    446   typedef struct {
    447     string pred;      // Predecessor node string
    448     string succ;      // Successor node string
    449     int op;           // The operand no the predecessor node corresponds
    450                       // to the successor node
    451     string new_node;  // Name of the node after merge
    452   } MergeInfo;
    453 
    454   /// Structure to store all constant strings
    455   /// NOTE: names are alphabetically sorted.
    456   typedef struct {
    457     string addn;
    458     string add;
    459     string avg_pool;
    460     string avg_pool_grad;
    461     string bias_add;
    462     string bias_add_grad;
    463     string concat;
    464     string concatv2;
    465     string conv2d;
    466     string conv2d_grad_input;
    467     string conv2d_grad_filter;
    468     string fused_batch_norm;
    469     string fused_batch_norm_grad;
    470     string identity;
    471     string lrn;
    472     string lrn_grad;
    473     string matmul;
    474     string max_pool;
    475     string max_pool_grad;
    476     string maximum;
    477     string mkl_conv2d;
    478     string mkl_conv2d_grad_input;
    479     string mkl_conv2d_grad_filter;
    480     string mkl_conv2d_with_bias;
    481     string mkl_conv2d_with_bias_backprop_bias;
    482     string mul;
    483     string relu;
    484     string relu_grad;
    485     string reshape;
    486     string split;
    487     string squared_difference;
    488     string sub;
    489   } ConstStringsInfo;
    490 
    491  private:
    492   /// Maintain info about nodes to rewrite
    493   std::vector<RewriteInfo> rinfo_;
    494 
    495   /// Maintain info about nodes to add workspace edge
    496   std::vector<WorkSpaceInfo> wsinfo_;
    497 
    498   /// Maintain info about nodes to be merged
    499   std::vector<MergeInfo> minfo_;
    500 
    501   /// Maintain info about nodes to rewrite
    502   static std::vector<ContextInfo*> cinfo_;
    503 
    504   /// Maintain structure of constant strings
    505   static ConstStringsInfo csinfo_;
    506 
    507   /// Context variables used in referencing rules
    508   static ContextInfo biasaddgrad_matmul_context_;
    509   static ContextInfo biasaddgrad_conv2dwithbias_context_;
    510 
    511  private:
    512   // Is OpDef::ArgDef a list type? It could be N * T or list(type).
    513   // Refer to opdef.proto for details of list type.
    514   inline bool ArgIsList(const OpDef::ArgDef& arg) const {
    515     return !arg.type_list_attr().empty() || !arg.number_attr().empty();
    516   }
    517 
    518   // Get length of a list in 'n' if 'arg' is of list type. Refer to
    519   // description of ArgIsList for definition of list type.
    520   inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) {
    521     CHECK_EQ(ArgIsList(arg), true);
    522     int N = 0;
    523     const string attr_name = !arg.type_list_attr().empty()
    524                                  ? arg.type_list_attr()
    525                                  : arg.number_attr();
    526     if (!arg.type_list_attr().empty()) {
    527       std::vector<DataType> value;
    528       TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value));
    529       N = value.size();
    530     } else {
    531       TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N));
    532     }
    533     return N;
    534   }
    535 
    536   // Can op represented by node 'n' run on DEVICE_CPU?
    537   // Op can run on CPU with MKL if the runtime assigned device or the
    538   // user requested device contains device CPU, or both are empty.
    539   bool CanOpRunOnCPUDevice(const Node* n) {
    540     bool result = true;
    541     string reason;
    542 
    543     // Substring that should be checked for in device name for CPU device.
    544     const char* const kCPUDeviceSubStr = "CPU";
    545 
    546     // If Op has been specifically assigned to a non-CPU device, then No.
    547     if (!n->assigned_device_name().empty() &&
    548         !StringPiece(n->assigned_device_name()).contains(kCPUDeviceSubStr)) {
    549       result = false;
    550       reason = "Op has been assigned a runtime device that is not CPU.";
    551     }
    552 
    553     // If user has specifically assigned this op to a non-CPU device, then No.
    554     if (!n->def().device().empty() &&
    555         !StringPiece(n->def().device()).contains(kCPUDeviceSubStr)) {
    556       result = false;
    557       reason = "User has assigned a device that is not CPU.";
    558     }
    559 
    560     if (result == false) {
    561       VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node "
    562               << n->type_string() << ", reason: " << reason;
    563     }
    564 
    565     // Otherwise Yes.
    566     return result;
    567   }
    568 
    569   // Return a node that can be merged with input node 'n'
    570   //
    571   // @return pointer to the node if we can find such a
    572   // node. Otherwise, it returns nullptr.
    573   Node* CheckForNodeMerge(const Node* n) const;
    574 
    575   // Merge predecessor node with its successor.
    576   // Currently, we merge Conv2D with BiasAdd only.
    577   //
    578   // Input nodes succ and pred may be deleted if the call to
    579   // this function is successful. Attempt to use the pointers
    580   // after the call to function may result in undefined behaviors.
    581   //
    582   // @input g - input graph, succ - successor node, pred - predecessor node
    583   // @return Status::OK(), if merging is successful and supported.
    584   //         Returns appropriate Status error code otherwise.
    585   //         Graph is updated in case nodes are merged. Otherwise, it is
    586   //         not updated.
    587   Status MergeNode(std::unique_ptr<Graph>* g, Node* succ, Node* pred);
    588 
    589   // Check if the node 'n' has any applicable rewrite rule
    590   // We check for 2 scenarios for rewrite.
    591   //
    592   // @return RewriteInfo* for the applicable rewrite rule
    593   const RewriteInfo* CheckForNodeRewrite(const Node* n) const;
    594 
    595   // Default rewrite rule to be used in scenario 1 for rewrite.
    596   // @return - true (since we want to always rewrite)
    597   static bool AlwaysRewrite(const Node* n, const ContextInfo* c = nullptr) {
    598     return true;
    599   }
    600 
    601   // Check if we are performing pooling on depth or batch. If it is, then we
    602   // do not rewrite MaxPool node to Mkl version.
    603   // @return - true (if it is not a depth/batch wise pooling case);
    604   //           false otherwise.
    605   static bool NonDepthBatchWisePoolRewrite(const Node* n,
    606                                            const ContextInfo* c) {
    607     CHECK_NOTNULL(n);
    608 
    609     string data_format_str;
    610     TensorFormat data_format;
    611     std::vector<int32> ksize, strides;
    612     CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
    613     CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
    614     CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true);
    615     CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
    616 
    617     // Condition that specifies non-batch-wise and non-depth-wise pooling.
    618     if (GetTensorDim(ksize, data_format, 'N') == 1 &&
    619         GetTensorDim(strides, data_format, 'N') == 1 &&
    620         GetTensorDim(ksize, data_format, 'C') == 1 &&
    621         GetTensorDim(strides, data_format, 'C') == 1) {
    622       return true;
    623     }
    624 
    625     return false;
    626   }
    627 
    628   static bool AddNRewrite(const Node* n, const ContextInfo* c) {
    629     CHECK_NOTNULL(n);
    630 
    631     int num;
    632     CHECK_EQ(GetNodeAttr(n->def(), "N", &num).ok(), true);
    633 
    634     // Condition that specifies non-batch-wise and non-depth-wise pooling.
    635     if (num == 2) {
    636       return true;
    637     }
    638 
    639     return false;
    640   }
    641   // Is BiasAddGrad node in 'n' is associated with Conv2DWithBias node
    642   // specified in contextinfo 'ci'. Function updates fwd_node to point
    643   // to Conv2DWithBias node if 'n' is associated with Conv2DWithBias.
    644   //
    645   // Association checks for one of the following graphs:
    646   //
    647   // Graph A:
    648   //
    649   // _ = Conv2DWithBias(F, I, _)
    650   // ..
    651   // _ = Conv2DBackpropFilter(F, _, G)
    652   // _ = Conv2DBackpropInput(_, I, G)
    653   // _ = BiasAddGrad(G)
    654   //
    655   // OR
    656   //
    657   // Graph B:
    658   //
    659   // _ = Conv2DWithBias(F, _, _)
    660   // ..
    661   // _ = Conv2DBackpropFilter(F, _, G)
    662   // _ = BiasAddGrad(G)
    663   //
    664   // Here F, G, and I are graph nodes; _ represents graph nodes that we
    665   // don't care here.
    666   //
    667   // @return - true (if BiasAddGrad is associated with Conv2DWithBias);
    668   //           false otherwise.
    669   static bool IsBiasAddGradInConv2DWithBiasContext(const Node* n,
    670                                                    const Node** fwd_node,
    671                                                    void* ci) {
    672     CHECK_NOTNULL(n);
    673     CHECK_NOTNULL(fwd_node);
    674     CHECK_NOTNULL(ci);
    675     *fwd_node = nullptr;
    676 
    677     CHECK_EQ(n->type_string(), csinfo_.bias_add_grad);
    678 
    679     // Get the only 1 input of BiasAddGrad.
    680     CHECK_EQ(n->num_inputs(), 1);
    681     const Node* bias_add_grad_inp = nullptr;
    682     TF_CHECK_OK(n->input_node(0, &bias_add_grad_inp));
    683     CHECK_NOTNULL(bias_add_grad_inp);
    684 
    685     // Check if this input also goes to BackpropFilter and BackpropInput
    686     // as 3rd input.
    687     bool found_backprop_input = false;
    688     bool found_backprop_filter = false;
    689     Node* backprop_filter_node = nullptr;
    690     Node* backprop_input_node = nullptr;
    691 
    692     for (const Edge* e : bias_add_grad_inp->out_edges()) {
    693       Node* third_input = nullptr;
    694       if (e->dst()->type_string() == csinfo_.conv2d_grad_input ||
    695           e->dst()->type_string() == csinfo_.mkl_conv2d_grad_input) {
    696         // Third input (index 2) of BackpropInput
    697         TF_CHECK_OK(e->dst()->input_node(2, &third_input));
    698         // Third input (index 2) of BackpropInput must be same as the input
    699         // of BiasAddGrad.
    700         if (third_input == bias_add_grad_inp) {
    701           found_backprop_input = true;
    702           backprop_input_node = e->dst();
    703         }
    704       }
    705 
    706       if (e->dst()->type_string() == csinfo_.conv2d_grad_filter ||
    707           e->dst()->type_string() == csinfo_.mkl_conv2d_grad_filter) {
    708         // Third input (index 2) of BackpropFilter
    709         TF_CHECK_OK(e->dst()->input_node(2, &third_input));
    710         // Third input (index 2) of BackpropFilter must be same as the input
    711         // of BiasAddGrad.
    712         if (third_input == bias_add_grad_inp) {
    713           found_backprop_filter = true;
    714           backprop_filter_node = e->dst();
    715         }
    716       }
    717 
    718       // If we found both the nodes, then we can stop the search.
    719       if (found_backprop_input && found_backprop_filter) {
    720         break;
    721       }
    722     }
    723 
    724     // If BackpropFilter node is not found, then this is not
    725     // Conv2DWithBias context. For 2nd graph in the example above, only
    726     // BackpropFilter would be present.
    727     if (!found_backprop_filter) {
    728       return false;
    729     }
    730 
    731     // Otherwise, we found the nodes.
    732     CHECK_NOTNULL(backprop_filter_node);
    733     if (found_backprop_input) {
    734       CHECK_NOTNULL(backprop_input_node);
    735     }
    736 
    737     // Now that we confirmed that this is Conv2DWithBias context, we need to
    738     // get access to the forward node (Conv2DWithBias). 2nd input of
    739     // Conv2DWithBias is same as the 2nd input of Conv2DBackpropInput; 1st
    740     // input of Conv2DWithBias is same as the 1st input of Conv2DBackpropFilter
    741     // (This comes from definition of gradient computation for Conv2D).
    742     if (found_backprop_input) {
    743       // Graph A in the example.
    744       Node* second_inp_of_input = nullptr;
    745       Node* first_inp_of_filter = nullptr;
    746       TF_CHECK_OK(backprop_input_node->input_node(1, &second_inp_of_input));
    747       TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter));
    748       CHECK_NOTNULL(second_inp_of_input);
    749       CHECK_NOTNULL(first_inp_of_filter);
    750 
    751       // Now we need to find out Conv2DWithBias node from these input nodes.
    752       // Conv2DWithBias node is the node that accepts both the nodes
    753       // second_inp_of_input and first_inp_of_filter in 2nd and 1st input slots.
    754       for (const Edge* fe : first_inp_of_filter->out_edges()) {
    755         if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
    756             fe->dst_input() == 0) {
    757           for (const Edge* ie : second_inp_of_input->out_edges()) {
    758             if (ie->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
    759                 ie->dst_input() == 1 && fe->dst() == ie->dst()) {
    760               VLOG(1) << "MklLayoutRewritePass: found "
    761                       << fe->dst()->DebugString()
    762                       << " as the forward node for matching context, backward"
    763                       << " node is: " << n->DebugString();
    764               *fwd_node = fe->dst();
    765               return true;
    766             }
    767           }
    768         }
    769       }
    770     } else {
    771       // We did not find BackpropInput, so we work with BackpropFilter only.
    772       // Graph B in the example.
    773       Node* first_inp_of_filter = nullptr;
    774       TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter));
    775       CHECK_NOTNULL(first_inp_of_filter);
    776 
    777       // Now we need to find out Conv2DWithBias node from first input of
    778       // BackpropFIlter. Conv2DWithBias node is the node that accepts
    779       // first_inp_of_filter in 1st input slot.
    780       for (const Edge* fe : first_inp_of_filter->out_edges()) {
    781         if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
    782             fe->dst_input() == 0) {
    783           VLOG(1) << "MklLayoutRewritePass: found " << fe->dst()->DebugString()
    784                   << " as the forward node for matching context, backward"
    785                   << " node is: " << n->DebugString();
    786           *fwd_node = fe->dst();
    787           return true;
    788         }
    789       }
    790     }
    791 
    792     return false;
    793   }
    794 
    795   // Is BiasAddGrad node in 'n' is associated with MatMul node
    796   // specified in contextinfo 'ci'. Function does not update fwd_node.
    797   //
    798   // @return - true (if BiasAddGrad is associated with MatMul);
    799   //           false otherwise.
    800   static bool IsBiasAddGradInMatMulContext(const Node* n, const Node** fwd_node,
    801                                            void* ci) {
    802     return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci));
    803   }
    804 
    805   // Rewrite rule that uses context-information for matching,
    806   // used in scenario 2.
    807   //
    808   // @input - Node 'n' for which to search for matching context
    809   // @input - The context 'c' under which to rewrite
    810   // @return - true if we can rewrite node under context 'c';
    811   //           false otherwise.
    812   static bool ContextMatchRewrite(const Node* n, const ContextInfo* c);
    813 
    814   // Helper function that searches the matching contextinfo for the node.
    815   //
    816   // @input n - Node (gradient op) whose contextinfo is to be searched,
    817   //        fwd_node - pointer to node from the forward pass that this node
    818   //        belongs to. fwd_node cannot be NULL.
    819   // @return Matching contextinfo in case a match is found; null otherwise.
    820   //         Also updates *fwd_node with pointer to forward node that this
    821   //         context matches.
    822   static const ContextInfo* SearchMatchingContext(const Node* n,
    823                                                   const Node** fwd_node);
    824 
    825   // Rewrites input node to a new node specified by its matching rewrite info.
    826   //
    827   // Method first searches matching rewrite info for input node and then
    828   // uses that info to rewrite.
    829   //
    830   // Input node may be deleted in case of rewrite. Attempt to use the node
    831   // after the call can result in undefined behaviors.
    832   //
    833   // @input  g - input graph, n - Node to be rewritten,
    834   //         ri - matching rewriteinfo
    835   // @return Status::OK(), if the input node is rewritten;
    836   //         Returns appropriate Status error code otherwise.
    837   //         Graph is updated in case the input node is rewritten.
    838   //         Otherwise, it is not updated.
    839   Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri);
    840 
    841   // Get nodes that will feed a list of TF tensors to the new
    842   // node that we are constructing.
    843   //
    844   // @input g - input graph,
    845   // @input inputs - inputs to old node that we are using for constructing
    846   //                 new inputs,
    847   // @input input_idx - the index in the 'inputs' vector pointing to the
    848   //                    current input that we have processed so far
    849   // @output input_idx - index will be incremented by the number of nodes
    850   //                     from 'inputs' that are processed
    851   // @input list_length - The expected length of list of TF tensors
    852   // @output output_nodes - the list of new nodes creating TF tensors
    853   //
    854   // @return None
    855   void GetNodesProducingTFTensorList(
    856       const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
    857       int* input_idx, int list_length,
    858       std::vector<NodeBuilder::NodeOut>* output_nodes);
    859 
    860   // Get nodes that will feed a list of Mkl tensors to the new
    861   // node that we are constructing.
    862   //
    863   // @input g - input graph,
    864   // @input orig_node - Original node that we are rewriting
    865   // @input inputs - inputs to old node that we are using for constructing
    866   //                 new inputs,
    867   // @input input_idx - the index in the 'inputs' vector pointing to the
    868   //                    current input that we have processed so far
    869   // @output input_idx - index will be incremented by the number of nodes
    870   //                     from 'inputs' that are processed
    871   // @input list_length - The expected length of list of Mkl tensors
    872   // @output output_nodes - the list of new nodes creating Mkl tensors
    873   //
    874   // @return None
    875   void GetNodesProducingMklTensorList(
    876       std::unique_ptr<Graph>* g, Node* orig_node,
    877       const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
    878       int* input_idx, int list_length,
    879       std::vector<NodeBuilder::NodeOut>* output_nodes);
    880 
    881   // Get a node that will feed an Mkl tensor to the new
    882   // node that we are constructing. The output node could be (1) 'n'
    883   // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
    884   // if 'n' is not an Mkl layer.
    885   //
    886   // @input g - input graph,
    887   // @input orig_node - Original node that we are rewriting,
    888   // @input n - Node based on which we are creating Mkl node,
    889   // @input n_output_slot - the output slot of node 'n'
    890   //            which is feeding to the node that we are constructing
    891   // @output mkl_node - the new node that will feed Mkl tensor
    892   // @output mkl_node_output_slot - the slot number of mkl_node that
    893   //                                will feed the tensor
    894   // @return None
    895   void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
    896                                  Node* n, int n_output_slot, Node** mkl_node,
    897                                  int* mkl_node_output_slot);
    898 
    899   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
    900   // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
    901   // set up in contiguous fashion. 'workspace_tensors' carry graph nodes
    902   // producing workspace edges if 'are_workspace_tensors_available' is true.
    903   // Otherwise, 'workspace_tensors' is empty vector.
    904   //
    905   // For details, refer to 'Ordering of inputs after rewriting' section in the
    906   // documentation above.
    907   //
    908   // Returns Status::OK() if setting up inputs is successful, otherwise
    909   // returns appropriate status code.
    910   int SetUpContiguousInputs(
    911       std::unique_ptr<Graph>* g,
    912       const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
    913       NodeBuilder* nb, Node* old_node,
    914       std::vector<NodeBuilder::NodeOut>* workspace_tensors,
    915       bool are_workspace_tensors_available);
    916 
    917   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
    918   // in graph 'g'. Original node is input in 'orig_node'.
    919   //
    920   // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors'
    921   // section in the documentation above.
    922   //
    923   // Returns Status::OK() if setting up inputs is successful, otherwise
    924   // returns appropriate status code.
    925   Status SetUpInputs(std::unique_ptr<Graph>* g,
    926                      const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
    927                      NodeBuilder* nb, Node* orig_node);
    928 
    929   // Add workspace edge on the input or output side of Node 'orig_node' by using
    930   // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate
    931   // adding workspace edge then do not add it. Workspace Tensorflow and Mkl
    932   // tensors, if they need to be added, will be set into these tensors.
    933   // If we set workspace tensors, then are_ws_tensors_added should be true.
    934   void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orig_node,
    935                                 NodeBuilder* nb,
    936                                 std::vector<NodeBuilder::NodeOut>* ws_tensors,
    937                                 bool* are_ws_tensors_added);
    938 
    939   // Functions specific to operators to copy attributes
    940   // We need operator-specific function to copy attributes because the framework
    941   // does not provide any generic function for it.
    942   // NOTE: names are alphabetically sorted.
    943   static void CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb);
    944   static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb);
    945   static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb);
    946   static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb);
    947   static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb);
    948   static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb);
    949   static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb);
    950   static void CopyAttrsIdentity(const Node* orig_node, NodeBuilder* nb);
    951   static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
    952   static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb);
    953   static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb);
    954   static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb);
    955 
    956   // Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
    957   // using node for original node 'orig_node' and return it in '*out'.
    958   // TODO(nhasabni) We should move this to mkl_util.h
    959   void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out,
    960                              Node* orig_node);
    961   void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out,
    962                                    Node* orig_node);
    963 };
    964 
    965 MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
    966 MklLayoutRewritePass::ContextInfo
    967     MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_;
    968 MklLayoutRewritePass::ContextInfo
    969     MklLayoutRewritePass::biasaddgrad_matmul_context_;
    970 std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_;
    971 
    972 // We register Mkl rewrite pass for phase 1 in post partitioning group.
    973 // We register it here so that we get a complete picture of all users of Mkl
    974 // nodes. Do not change the ordering of the Mkl passes.
    975 const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
    976     OptimizationPassRegistry::POST_PARTITIONING;
    977 REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
    978 
    979 //////////////////////////////////////////////////////////////////////////
    980 //           Helper functions for creating new node
    981 //////////////////////////////////////////////////////////////////////////
    982 
    983 static void FillInputs(const Node* n,
    984                        gtl::InlinedVector<Node*, 4>* control_edges,
    985                        gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
    986   control_edges->clear();
    987   for (const Edge* e : n->in_edges()) {
    988     if (e->IsControlEdge()) {
    989       control_edges->push_back(e->src());
    990     } else {
    991       (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
    992     }
    993   }
    994   std::sort(control_edges->begin(), control_edges->end());
    995   if (n->op_def().is_commutative()) {
    996     // For commutative inputs, we sort the input by the input Node*
    997     // to get a canonical ordering (so that add(a,b) and add(b, a) will
    998     // hash to the same value if is_commutative is true for 'add').
    999     std::sort(in->begin(), in->end());
   1000   }
   1001 }
   1002 
   1003 void MklLayoutRewritePass::GetNodesProducingTFTensorList(
   1004     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
   1005     int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
   1006   CHECK_LT(*input_idx, inputs.size());
   1007   CHECK_GT(list_length, 0);
   1008   CHECK_NOTNULL(output_nodes);
   1009   output_nodes->reserve(list_length);
   1010 
   1011   while (list_length != 0) {
   1012     CHECK_GT(list_length, 0);
   1013     CHECK_LT(*input_idx, inputs.size());
   1014     Node* n = inputs[*input_idx].first;
   1015     int slot = inputs[*input_idx].second;
   1016     // If input node 'n' is just producing a single tensor at
   1017     // output slot 'slot' then we just add that single node.
   1018     output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
   1019     (*input_idx)++;
   1020     list_length--;
   1021   }
   1022 }
   1023 
   1024 // TODO(nhasabni) We should move this to mkl_util.h.
   1025 void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
   1026                                                  Node** out, Node* orig_node) {
   1027   // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent
   1028   // dummy Mkl tensor. 8 = 2*size_t.
   1029   const DataType dt = DataTypeToEnum<uint8>::v();
   1030   TensorProto proto;
   1031   proto.set_dtype(dt);
   1032   uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
   1033   proto.set_tensor_content(const_cast<const void*>(static_cast<void*>(&zero)),
   1034                            8);
   1035   TensorShape dummy_shape({8});
   1036   dummy_shape.AsProto(proto.mutable_tensor_shape());
   1037   TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
   1038                   .Attr("value", proto)
   1039                   .Attr("dtype", dt)
   1040                   .Device(orig_node->def().device())  // We place this node on
   1041                                                       // the same device as the
   1042                                                       // device of the original
   1043                                                       // node.
   1044                   .Finalize(&**g, out));
   1045 
   1046   // If number of inputs to the original node is > 0, then we add
   1047   // control dependency between 1st input (index 0) of the original node and
   1048   // the dummy Mkl node. This is needed because control-flow ops such as Enter,
   1049   // Merge, etc, require frame_name of the dummy Mkl node to be same as the
   1050   // rewritten node. Adding control edge between 1st input of the original node
   1051   // and the dummy Mkl node ensures that the dummy node is in the same frame
   1052   // as the original node. Choosing 1st input is not necessary - any input of
   1053   // the original node is fine because all the inputs of a node are always in
   1054   // the same frame.
   1055   if (orig_node->num_inputs() > 0) {
   1056     Node* orig_input0 = nullptr;
   1057     TF_CHECK_OK(
   1058         orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
   1059     CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
   1060   }
   1061 
   1062   (*out)->set_assigned_device_name(orig_node->assigned_device_name());
   1063 }
   1064 
   1065 void MklLayoutRewritePass::GetNodesProducingMklTensorList(
   1066     std::unique_ptr<Graph>* g, Node* orig_node,
   1067     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
   1068     int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
   1069   CHECK_LT(*input_idx, inputs.size());
   1070   CHECK_GT(list_length, 0);
   1071   CHECK_NOTNULL(output_nodes);
   1072   output_nodes->reserve(list_length);
   1073 
   1074   while (list_length != 0) {
   1075     CHECK_GT(list_length, 0);
   1076     CHECK_LT(*input_idx, inputs.size());
   1077     Node* n = inputs[*input_idx].first;
   1078     int slot = inputs[*input_idx].second;
   1079     // If 'n' is producing a single tensor, then create a single Mkl tensor
   1080     // node.
   1081     Node* mkl_node = nullptr;
   1082     int mkl_node_output_slot = 0;
   1083     GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
   1084                               &mkl_node_output_slot);
   1085     output_nodes->push_back(
   1086         NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
   1087     (*input_idx)++;
   1088     list_length--;
   1089   }
   1090 }
   1091 
   1092 // Get an input node that will feed Mkl tensor to the new
   1093 // node that we are constructing. An input node could be (1) 'n'
   1094 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
   1095 // if 'n' is not an Mkl layer.
   1096 void MklLayoutRewritePass::GetNodeProducingMklTensor(
   1097     std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot,
   1098     Node** mkl_node, int* mkl_node_output_slot) {
   1099   CHECK_NOTNULL(n);
   1100   CHECK_NOTNULL(mkl_node);
   1101   CHECK_NOTNULL(mkl_node_output_slot);
   1102 
   1103   // If this is an MKL op, then it will create extra output for MKL layout.
   1104   DataType T;
   1105   if (GetNodeAttr(n->def(), "T", &T).ok() &&
   1106       mkl_op_registry::IsMklOp(n->type_string(), T)) {
   1107     // If this is an MKL op, then it will generate an edge that will receive
   1108     // Mkl tensor from a node.
   1109     // output slot number for Mkl tensor would be N+slot number of TensorFlow
   1110     // tensor, where N is total number of TensorFlow tensors.
   1111     *mkl_node = n;
   1112     *mkl_node_output_slot =
   1113         GetTensorMetaDataIndex(n_output_slot, n->num_outputs());
   1114   } else {
   1115     // If we have not visited the node and rewritten it, then we need
   1116     // to create a dummy node that will feed a dummy Mkl tensor to this node.
   1117     // DummyMklTensor node has no input and generates only 1 output
   1118     // (dummy Mkl tensor) as output slot number 0.
   1119     GetDummyMklTensorNode(g, mkl_node, orig_node);
   1120     CHECK_NOTNULL(*mkl_node);
   1121     *mkl_node_output_slot = 0;
   1122   }
   1123 }
   1124 
   1125 int MklLayoutRewritePass::SetUpContiguousInputs(
   1126     std::unique_ptr<Graph>* g,
   1127     const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
   1128     NodeBuilder* nb, Node* old_node,
   1129     std::vector<NodeBuilder::NodeOut>* workspace_tensors,
   1130     bool are_workspace_tensors_available) {
   1131   CHECK_NOTNULL(workspace_tensors);
   1132   CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
   1133 
   1134   // TODO(nhasabni): Temporary solution to connect filter input of
   1135   // BackpropInput with the converted filter from Conv2D.
   1136   bool do_connect_conv2d_backprop_input_filter = false;
   1137   Node* conv2d_node = nullptr;
   1138   // Filter node is 2nd input (slot index 1) of Conv2D.
   1139   int kConv2DFilterInputSlotIdx = 1;
   1140   int kConv2DBackpropInputFilterInputSlotIdx = 1;
   1141   int kConv2DFilterOutputSlotIdx = 1;
   1142   if (old_node->type_string() == csinfo_.conv2d_grad_input) {
   1143     // We need to find Conv2D node from Conv2DBackpropInput.
   1144     // For that let's first find filter node that is 2nd input (slot 1)
   1145     // of BackpropInput.
   1146     Node* filter_node = nullptr;
   1147     old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx, &filter_node);
   1148     CHECK_NOTNULL(filter_node);
   1149 
   1150     // Now check which nodes receive from filter_node. Filter feeds as
   1151     // 2nd input (slot 1) of _MklConv2D and _MklConv2DWithBias.
   1152     for (const Edge* e : filter_node->out_edges()) {
   1153       if (e->dst()->type_string() == csinfo_.mkl_conv2d &&
   1154           e->dst_input() == kConv2DFilterInputSlotIdx
   1155           /* filter is 2nd input of Conv2D and _MklConv2D. */) {
   1156         if (conv2d_node != nullptr) {
   1157           VLOG(1) << "MklLayoutRewritePass: unusual case of same filter"
   1158                   << " feeding multiple Conv2D nodes: "
   1159                   << filter_node->DebugString();
   1160           // We will not connect filter input of Conv2DBackpropInput
   1161           // to be safe here.
   1162           do_connect_conv2d_backprop_input_filter = false;
   1163           break;
   1164         } else {
   1165           conv2d_node = e->dst();
   1166           do_connect_conv2d_backprop_input_filter = true;
   1167         }
   1168       }
   1169     }
   1170   }
   1171 
   1172   // Number of input slots to original op
   1173   // Input slots are represented by .Input() calls in REGISTER_OP.
   1174   int old_node_input_slots = old_node->op_def().input_arg_size();
   1175   // Actual number of inputs can be greater than or equal to number
   1176   // of Input slots because inputs of type list could be unfolded.
   1177   CHECK_GE(old_node_inputs.size(), old_node_input_slots);
   1178   int nn_slot_idx = 0;  // slot index for inputs of new node
   1179 
   1180   // Let's copy all inputs (TF tensors) of original node to new node.
   1181   int iidx = 0;
   1182   for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
   1183     // An input slot could be a single tensor or a list. We need
   1184     // to handle this case accordingly.
   1185     CHECK_LT(iidx, old_node_inputs.size());
   1186     const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
   1187     if (ArgIsList(arg)) {
   1188       std::vector<NodeBuilder::NodeOut> new_node_inputs;
   1189       int N = GetTensorListLength(arg, old_node);
   1190       GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
   1191                                     &new_node_inputs);
   1192       nb->Input(new_node_inputs);
   1193       nn_slot_idx++;
   1194     } else {
   1195       // Special case for connecting filter input of Conv2DBackpropInput
   1196       if (do_connect_conv2d_backprop_input_filter &&
   1197           iidx == kConv2DBackpropInputFilterInputSlotIdx) {
   1198         nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx);
   1199       } else {
   1200         nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
   1201       }
   1202       iidx++;
   1203       nn_slot_idx++;
   1204     }
   1205   }
   1206 
   1207   // If workspace tensors are available for this op and we are using
   1208   // contiguous ordering then we need to add Tensorflow tensor for
   1209   // workspace here because Tensorflow tensor for workspace is the
   1210   // last tensor in the list of Tensorflow tensors.
   1211   if (are_workspace_tensors_available) {
   1212     CHECK_EQ(workspace_tensors->size(), 2);
   1213     // Tensorflow tensor
   1214     nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index);
   1215     nn_slot_idx++;
   1216   }
   1217 
   1218   // Let's now setup all Mkl inputs to new node.
   1219   // Number of Mkl inputs must be same as number of TF inputs.
   1220   iidx = 0;
   1221   for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
   1222     // An input slot could be a single tensor or a list. We need
   1223     // to handle this case accordingly.
   1224     CHECK_LT(iidx, old_node_inputs.size());
   1225     const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
   1226     if (ArgIsList(arg)) {
   1227       std::vector<NodeBuilder::NodeOut> new_node_inputs;
   1228       int N = GetTensorListLength(arg, old_node);
   1229       GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N,
   1230                                      &new_node_inputs);
   1231       nb->Input(new_node_inputs);
   1232       nn_slot_idx++;
   1233     } else {
   1234       Node* mkl_node = nullptr;
   1235       int mkl_node_output_slot = 0;
   1236       // Special case for connecting filter input of Conv2DBackpropInput
   1237       if (do_connect_conv2d_backprop_input_filter &&
   1238           iidx == kConv2DBackpropInputFilterInputSlotIdx) {
   1239         GetNodeProducingMklTensor(g, old_node, conv2d_node,
   1240                                   kConv2DFilterOutputSlotIdx, &mkl_node,
   1241                                   &mkl_node_output_slot);
   1242       } else {
   1243         GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
   1244                                   old_node_inputs[iidx].second, &mkl_node,
   1245                                   &mkl_node_output_slot);
   1246       }
   1247       nb->Input(mkl_node, mkl_node_output_slot);
   1248       iidx++;
   1249       nn_slot_idx++;
   1250     }
   1251   }
   1252 
   1253   // If workspace tensors are available for this op and we are using
   1254   // contiguous ordering then we need to add Mkl tensor for
   1255   // workspace here because Mkl tensor for workspace is the
   1256   // last tensor in the list of Mkl tensors.
   1257   if (are_workspace_tensors_available) {
   1258     CHECK_EQ(workspace_tensors->size(), 2);
   1259     // Mkl tensor
   1260     nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index);
   1261     nn_slot_idx++;
   1262   }
   1263 
   1264   return nn_slot_idx;
   1265 }
   1266 
   1267 Status MklLayoutRewritePass::SetUpInputs(
   1268     std::unique_ptr<Graph>* g,
   1269     const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
   1270     NodeBuilder* nb, Node* old_node) {
   1271   // Let's check if we need to add workspace tensors for this node.
   1272   // We add workspace edge only for MaxPool, LRN and BatchNorm.
   1273   std::vector<NodeBuilder::NodeOut> workspace_tensors;
   1274   bool are_workspace_tensors_available = false;
   1275   AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors,
   1276                            &are_workspace_tensors_available);
   1277 
   1278   int new_node_input_slots = 0;
   1279   if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
   1280     // TODO(nhasabni): implement this function just for same of completion.
   1281     // We do not use interleaved ordering right now.
   1282     return Status(
   1283         error::Code::UNIMPLEMENTED,
   1284         "Interleaved ordering of tensors is currently not supported.");
   1285   } else {
   1286     CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
   1287     new_node_input_slots = SetUpContiguousInputs(
   1288         g, old_node_inputs, nb, old_node, &workspace_tensors,
   1289         are_workspace_tensors_available);
   1290   }
   1291 
   1292   // Sanity check
   1293   int old_node_input_slots = old_node->op_def().input_arg_size();
   1294   if (!are_workspace_tensors_available) {
   1295     // If we are not adding workspace tensors for this op, then the total
   1296     // number of input slots to the new node _must_ be 2 times the number
   1297     // of input slots to the original node: N original Tensorflow tensors and
   1298     // N for Mkl tensors corresponding to each Tensorflow tensors.
   1299     CHECK_EQ(new_node_input_slots, old_node_input_slots * 2);
   1300   } else {
   1301     // If we are adding workspace tensors for this op, then the total
   1302     // The total number of input slots to new node _must_ be 2 times the number
   1303     // of input slots to the original node: N original Tensorflow tensors and
   1304     // N for Mkl tensors corresponding to each Tensorflow tensors plus 2
   1305     // (for workspace Tensorflow tensor and workspace Mkl tensor).
   1306     CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2);
   1307   }
   1308 
   1309   return Status::OK();
   1310 }
   1311 
   1312 //////////////////////////////////////////////////////////////////////////
   1313 //           Helper functions related to workspace pass
   1314 //////////////////////////////////////////////////////////////////////////
   1315 
   1316 // TODO(nhasabni) We should move this to mkl_util.h.
   1317 void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
   1318     std::unique_ptr<Graph>* g, Node** out, Node* orig_node) {
   1319   // We use a tensor of shape {1} and value 0 to represent
   1320   // dummy float tensor. We need this as a dummy workspace tensor.
   1321   // Workspace tensor has type float.
   1322   const DataType dt = DataTypeToEnum<float>::v();
   1323   TensorProto proto;
   1324   proto.set_dtype(dt);
   1325   float zero[1] = {0};
   1326   proto.set_tensor_content(const_cast<const void*>(static_cast<void*>(&zero)),
   1327                            4);
   1328   TensorShape dummy_shape({1});
   1329   dummy_shape.AsProto(proto.mutable_tensor_shape());
   1330   TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
   1331                   .Attr("value", proto)
   1332                   .Attr("dtype", dt)
   1333                   .Device(orig_node->def().device())  // We place this node on
   1334                                                       // same the device as the
   1335                                                       // device of the original
   1336                                                       // node.
   1337                   .Finalize(&**g, out));
   1338 
   1339   // If number of inputs to the original node is > 0, then we add
   1340   // control dependency between 1st input (index 0) of the original node and
   1341   // the dummy Mkl node. This is needed because control-flow ops such as Enter,
   1342   // Merge, etc, require frame_name of the dummy Mkl node to be same as the
   1343   // rewritten node. Adding control edge between 1st input of the original node
   1344   // and the dummy Mkl node ensures that the dummy node is in the same frame
   1345   // as the original node. Choosing 1st input is not necessary - any input of
   1346   // the original node is fine because all the inputs of a node are always in
   1347   // the same frame.
   1348   if (orig_node->num_inputs() > 0) {
   1349     Node* orig_input0 = nullptr;
   1350     TF_CHECK_OK(
   1351         orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
   1352     CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
   1353   }
   1354 
   1355   (*out)->set_assigned_device_name(orig_node->assigned_device_name());
   1356 }
   1357 
   1358 void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
   1359     std::unique_ptr<Graph>* g, Node* orig_node, NodeBuilder* nb,
   1360     std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) {
   1361   bool workspace_edge_added = false;  // Default initializer
   1362   CHECK_NOTNULL(are_ws_tensors_added);
   1363   *are_ws_tensors_added = false;  // Default initializer
   1364 
   1365   DataType T;
   1366   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   1367   for (auto ws : wsinfo_) {
   1368     if (orig_node->type_string() == ws.fwd_op &&
   1369         mkl_op_registry::IsMklOp(
   1370             mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
   1371       // If this op is a fwd op, then we need to check if there is an
   1372       // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
   1373       // an edge, then we just add an attribute on this node for setting
   1374       // workspace_passed to true. We don't add actual workspace edge
   1375       // in this node. Actual workspace edge gets added in the backward
   1376       // op for this node.
   1377       for (const Edge* e : orig_node->out_edges()) {
   1378         if (e->src_output() == ws.fwd_slot &&
   1379             e->dst()->type_string() == ws.bwd_op &&
   1380             e->dst_input() == ws.bwd_slot) {
   1381           nb->Attr("workspace_enabled", true);
   1382           VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
   1383                   << orig_node->type_string();
   1384           workspace_edge_added = true;
   1385           // We found the edge that we were looking for, so break.
   1386           break;
   1387         }
   1388       }
   1389 
   1390       if (!workspace_edge_added) {
   1391         // If we are here, then we did not find backward operator for this
   1392         // node.
   1393         nb->Attr("workspace_enabled", false);
   1394       }
   1395     } else if (orig_node->type_string() == ws.bwd_op &&
   1396                mkl_op_registry::IsMklOp(
   1397                    mkl_op_registry::GetMklOpName(orig_node->type_string()),
   1398                    T)) {
   1399       // If this op is a bwd op, then we need to add workspace edge and
   1400       // it's Mkl tensor edge between its corresponding fwd op and this
   1401       // op. Corresponding fwd op is specified in 'fwd_op' field of
   1402       // workspace info. fwd_slot and bwd_slot in workspace info specify
   1403       // an edge between which slots connect forward and backward op.
   1404       // Once all these criteria match, we add a workspace edge between
   1405       // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is
   1406       // determined by interleaved/contiguous ordering. Function
   1407       // DataIndexToMetaDataIndex tells us the location of Mkl tensor
   1408       // from the location of the Tensorflow tensor.
   1409       for (const Edge* e : orig_node->in_edges()) {
   1410         if (e->src_output() == ws.fwd_slot &&
   1411             // We would have rewritten the forward op, so we need to use
   1412             // GetMklOpName call to get its Mkl name.
   1413             e->src()->type_string() ==
   1414                 mkl_op_registry::GetMklOpName(ws.fwd_op) &&
   1415             e->dst_input() == ws.bwd_slot) {
   1416           nb->Attr("workspace_enabled", true);
   1417           CHECK_NOTNULL(ws_tensors);
   1418           // Add workspace edge between fwd op and bwd op.
   1419           ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
   1420           // Add Mkl tensor edge for workspace edge between fwd op and bwd op.
   1421           ws_tensors->push_back(NodeBuilder::NodeOut(
   1422               e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
   1423                                                  e->src()->num_outputs())));
   1424           *are_ws_tensors_added = true;
   1425           // In terms of input ordering, we add these calls to add Input
   1426           // here because workspace edge (and its Mkl tensor) is the last
   1427           // edge in the fwdop and bwdop. So all inputs before workspace
   1428           // tensor have been added by SetUpInputs function.
   1429           VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
   1430                   << orig_node->type_string();
   1431           workspace_edge_added = true;
   1432           // We found the edge that we were looking for, so break.
   1433           break;
   1434         }
   1435       }
   1436 
   1437       // If we are here means we did not find fwd op that feeds to this
   1438       // bwd op. So in this case, we need to generate dummy tensors for
   1439       // workspace input and Mkl tensor for workspace, and set
   1440       // workspace_enabled to false.
   1441       if (!workspace_edge_added) {
   1442         nb->Attr("workspace_enabled", false);
   1443         Node* dmt_ws = nullptr;      // Dummy tensor for workspace
   1444         Node* dmt_mkl_ws = nullptr;  // Dummy Mkl tensor for workspace
   1445         GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node);
   1446         GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node);
   1447         CHECK_NOTNULL(dmt_ws);
   1448         CHECK_NOTNULL(dmt_mkl_ws);
   1449         CHECK_NOTNULL(ws_tensors);
   1450         // We add dummy tensor as workspace tensor.
   1451         ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0));
   1452         // We add dummy tensor as Mkl tensor for workspace tensor.
   1453         ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0));
   1454         *are_ws_tensors_added = true;
   1455         VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for "
   1456                 << orig_node->type_string();
   1457       }
   1458     } else {
   1459       // If this node does not match any workspace info, then we do not
   1460       // do anything special for workspace propagation for it.
   1461     }
   1462   }
   1463 }
   1464 
   1465 //////////////////////////////////////////////////////////////////////////
   1466 // Op-specific functions to copy attributes from old node to new node
   1467 //////////////////////////////////////////////////////////////////////////
   1468 
   1469 void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
   1470                                            NodeBuilder* nb) {
   1471   DataType T;
   1472   string data_format;
   1473   string padding;
   1474   std::vector<int32> strides;
   1475   bool use_cudnn_on_gpu;
   1476 
   1477   // Get all attributes from old node.
   1478   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   1479   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
   1480   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
   1481   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   1482   TF_CHECK_OK(
   1483       GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
   1484 
   1485   // Add attributes to new node.
   1486   nb->Attr("T", T);
   1487   nb->Attr("strides", strides);
   1488   nb->Attr("padding", padding);
   1489   nb->Attr("data_format", data_format);
   1490   nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
   1491 }
   1492 
   1493 void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node,
   1494                                          NodeBuilder* nb) {
   1495   DataType T;
   1496   int N;
   1497 
   1498   // Get all attributes from old node.
   1499   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   1500   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
   1501 
   1502   // Add attributes to new node.
   1503   nb->Attr("T", T);
   1504   nb->Attr("N", N);
   1505 }
   1506 
   1507 void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node,
   1508                                                 NodeBuilder* nb) {
   1509   DataType T;
   1510   string data_format;
   1511   std::vector<int32> strides;
   1512 
   1513   // Get all attributes from old node.
   1514   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   1515   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
   1516   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   1517 
   1518   // Add attributes to new node.
   1519   nb->Attr("T", T);
   1520   nb->Attr("strides", strides);
   1521   nb->Attr("data_format", data_format);
   1522 }
   1523 
   1524 void MklLayoutRewritePass::CopyAttrsIdentity(const Node* orig_node,
   1525                                              NodeBuilder* nb) {
   1526   DataType T;
   1527 
   1528   // Get all attributes from old node.
   1529   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   1530   // Add attributes to new node.
   1531   nb->Attr("T", T);
   1532 }
   1533 
   1534 void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node,
   1535                                         NodeBuilder* nb) {
   1536   DataType T;
   1537   int depth_radius;
   1538   float bias;
   1539   float alpha;
   1540   float beta;
   1541 
   1542   // Get all attributes from old node.
   1543   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   1544   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "depth_radius", &depth_radius));
   1545   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "bias", &bias));
   1546   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha));
   1547   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "beta", &beta));
   1548 
   1549   // Add attributes to new node.
   1550   nb->Attr("T", T);
   1551   nb->Attr("depth_radius", depth_radius);
   1552   nb->Attr("bias", bias);
   1553   nb->Attr("alpha", alpha);
   1554   nb->Attr("beta", beta);
   1555 }
   1556 
   1557 void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
   1558                                             NodeBuilder* nb) {
   1559   DataType T;
   1560   string data_format;
   1561   string padding;
   1562   std::vector<int32> ksize, strides;
   1563 
   1564   // Get all attributes from old node.
   1565   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   1566   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
   1567   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
   1568   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
   1569   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   1570 
   1571   // Add attributes to new node.
   1572   nb->Attr("T", T);
   1573   nb->Attr("ksize", ksize);
   1574   nb->Attr("strides", strides);
   1575   nb->Attr("padding", padding);
   1576   nb->Attr("data_format", data_format);
   1577 }
   1578 
   1579 void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node,
   1580                                              NodeBuilder* nb) {
   1581   DataType T;
   1582 
   1583   // Get all attributes from old node.
   1584   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   1585 
   1586   // Add attributes to new node.
   1587   nb->Attr("T", T);
   1588 }
   1589 
   1590 void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
   1591                                             NodeBuilder* nb) {
   1592   DataType T;
   1593   DataType Tshape;
   1594 
   1595   // Get all attributes from old node.
   1596   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   1597   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
   1598   // Add attributes to new node.
   1599   nb->Attr("T", T);
   1600   nb->Attr("Tshape", Tshape);
   1601 }
   1602 
   1603 void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
   1604                                           NodeBuilder* nb) {
   1605   DataType T;
   1606   string data_format;
   1607   int num_split;
   1608 
   1609   // Get all attributes from old node.
   1610   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   1611   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_split", &num_split));
   1612   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   1613 
   1614   // Add attributes to new node.
   1615   nb->Attr("T", T);
   1616   nb->Attr("num_split", num_split);
   1617   nb->Attr("data_format", data_format);
   1618 }
   1619 
   1620 void MklLayoutRewritePass::CopyAttrsConcat(const Node* orig_node,
   1621                                            NodeBuilder* nb) {
   1622   DataType T;
   1623   int N;
   1624 
   1625   // Get all attributes from old node.
   1626   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   1627   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
   1628 
   1629   // Add attributes to new node.
   1630   nb->Attr("T", T);
   1631   nb->Attr("N", N);
   1632 }
   1633 
   1634 void MklLayoutRewritePass::CopyAttrsConcatV2(const Node* orig_node,
   1635                                              NodeBuilder* nb) {
   1636   DataType T;
   1637   int N;
   1638   DataType tidx;
   1639 
   1640   // Get all attributes from old node.
   1641   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   1642   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
   1643   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tidx", &tidx));
   1644 
   1645   // Add attributes to new node.
   1646   nb->Attr("T", T);
   1647   nb->Attr("N", N);
   1648   nb->Attr("Tidx", tidx);
   1649 }
   1650 
   1651 void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node,
   1652                                                    NodeBuilder* nb) {
   1653   DataType T;
   1654   float epsilon;
   1655   string data_format;
   1656   bool is_training;
   1657 
   1658   // Get all attributes from old node.
   1659   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   1660   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon));
   1661   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   1662   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "is_training", &is_training));
   1663 
   1664   // Add attributes to new node.
   1665   nb->Attr("T", T);
   1666   nb->Attr("epsilon", epsilon);
   1667   nb->Attr("data_format", data_format);
   1668   nb->Attr("is_training", is_training);
   1669 }
   1670 
   1671 //////////////////////////////////////////////////////////////////////////
   1672 //           Helper functions related to node merge pass
   1673 //////////////////////////////////////////////////////////////////////////
   1674 
   1675 Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
   1676   // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite
   1677   // once we support BiasAddGrad as Mkl layer.
   1678 
   1679   // Search for all matching mergeinfo.
   1680   // We allow more than one match for extensibility.
   1681   std::vector<const MergeInfo*> matching_mi;
   1682   for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
   1683     if (a->type_string() == mi->succ) {
   1684       matching_mi.push_back(&*mi);
   1685     }
   1686   }
   1687 
   1688   for (const MergeInfo* mi : matching_mi) {
   1689     const int N_in = a->num_inputs();
   1690     if (mi->op >= N_in) {
   1691       continue;
   1692     }
   1693 
   1694     // Get the control edges and input of node
   1695     gtl::InlinedVector<Node*, 4> a_control_edges;
   1696     gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
   1697     FillInputs(a, &a_control_edges, &a_in);
   1698 
   1699     // Get operand op of the operator
   1700     Node* b = nullptr;
   1701     b = a_in[mi->op].first;
   1702     if (b == nullptr || (b->type_string() != mi->pred)) {
   1703       // NOTE: Should the first check be assert?
   1704       continue;
   1705     }
   1706 
   1707     const int B_in = b->num_inputs();
   1708     gtl::InlinedVector<Node*, 4> b_control_edges;
   1709     gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in);
   1710     FillInputs(b, &b_control_edges, &b_in);
   1711 
   1712     // Shouldn't merge if a and b have different control edges.
   1713     if (a_control_edges != b_control_edges) {
   1714       continue;
   1715     } else {
   1716       // We found a match.
   1717       return b;
   1718     }
   1719   }
   1720 
   1721   return nullptr;
   1722 }
   1723 
   1724 Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
   1725                                        Node* pred) {
   1726   CHECK_NOTNULL(succ);
   1727   CHECK_NOTNULL(pred);
   1728 
   1729   if (succ->type_string() == csinfo_.bias_add &&
   1730       pred->type_string() == csinfo_.mkl_conv2d) {
   1731     // 1. Get all attributes from input nodes.
   1732     DataType T_pred, T_succ;
   1733     string padding;
   1734     std::vector<int32> strides;
   1735     string data_format_pred, data_format_succ;
   1736     bool use_cudnn_on_gnu;
   1737     TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
   1738     TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
   1739     TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
   1740     TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
   1741     TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
   1742     TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
   1743     TF_CHECK_OK(
   1744         GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu));
   1745     // We check to ensure that data formats of both succ and pred are same.
   1746     // We expect them to be same, so we can enforce this as assert.
   1747     // But assert can be too strict, so we enforce this as a check.
   1748     // If the check fails, then we do not merge two nodes.
   1749     // We also do same check for devices.
   1750     if (data_format_pred != data_format_succ || T_pred != T_succ ||
   1751         pred->assigned_device_name() != succ->assigned_device_name() ||
   1752         pred->def().device() != succ->def().device()) {
   1753       return Status(error::Code::INVALID_ARGUMENT,
   1754                     "data_format or T attribute or devices of Conv2D and "
   1755                     "BiasAdd do not match. Will skip node merge optimization");
   1756     }
   1757 
   1758     const int succ_num = succ->num_inputs();
   1759     gtl::InlinedVector<Node*, 4> succ_control_edges;
   1760     gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
   1761     FillInputs(succ, &succ_control_edges, &succ_in);
   1762 
   1763     const int pred_num = pred->num_inputs();
   1764     gtl::InlinedVector<Node*, 4> pred_control_edges;
   1765     gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
   1766     FillInputs(pred, &pred_control_edges, &pred_in);
   1767 
   1768     // We need to ensure that there is only 1 edge between Conv2D and AddBias.
   1769     // Otherwise, merging is semantically incorrect.
   1770     if (pred->out_edges().size() != 1) {
   1771       return Status(error::Code::INVALID_ARGUMENT,
   1772                     "Conv2D has multiple outputs."
   1773                     "Will skip node merge optimization");
   1774     }
   1775 
   1776     for (const Edge* e : pred->out_edges()) {
   1777       if (e->dst() != succ) {
   1778         return Status(error::Code::INVALID_ARGUMENT,
   1779                       "Conv2D does not feed to BiasAdd."
   1780                       "Will skip node merge optimization");
   1781       }
   1782     }
   1783 
   1784     // 2. Get inputs from both the nodes.
   1785     // Find the 2 inputs from the conv and the bias from the add Bias.
   1786     // Get operand 0, 1 of conv2D and their Mkl tensors.
   1787     CHECK_EQ(pred->in_edges().size(), 4);  // _MklConv2D must have 4 inputs.
   1788     // Get operand 1 of add_bias
   1789     // BiasAdd must have 2 inputs: Conv, bias
   1790     CHECK_EQ(succ->in_edges().size(), 2);
   1791     Node* oper3_mkl = nullptr;  // Mkl tensor corresponding to oper3
   1792     int oper3_mkl_slot = 0;     // For dummy MKL tensor node, output slot is 0.
   1793     GetDummyMklTensorNode(g, &oper3_mkl, pred);  // Get dummy Mkl tensor node
   1794     // as BiasAdd does not have Mkl tensor as input.
   1795     CHECK_NOTNULL(oper3_mkl);
   1796 
   1797     // We will use the node name of BiasAdd as the name of new node
   1798     // Build new node. We use same name as original node, but change the op
   1799     // name.
   1800     NodeBuilder nb(succ->name(), csinfo_.mkl_conv2d_with_bias);
   1801     if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
   1802       nb.Input(pred_in[0].first, pred_in[0].second);  // In1 of Conv2D
   1803       // pred_in[1] will be Mkl tensor for In1 if we follow interleaved
   1804       // ordering, and it will be 2nd Tensorflow tensor for Conv2D if
   1805       // we follow contiguous ordering.
   1806       nb.Input(pred_in[1].first, pred_in[1].second);  // Mkl for In1
   1807       nb.Input(pred_in[2].first, pred_in[2].second);  // In2 of Conv2D
   1808       nb.Input(pred_in[3].first, pred_in[3].second);  // Mkl for In2
   1809       nb.Input(succ_in[1].first, succ_in[1].second);  // In2 of BiasAdd
   1810       nb.Input(oper3_mkl, oper3_mkl_slot);            // Mkl for In2 of BiasAdd
   1811     } else {
   1812       CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
   1813       nb.Input(pred_in[0].first, pred_in[0].second);  // In1 of Conv2D
   1814       // pred_in[1] will be Mkl tensor for In1 if we follow interleaved
   1815       // ordering, and it will be 2nd Tensorflow tensor for Conv2D if
   1816       // we follow contiguous ordering.
   1817       nb.Input(pred_in[1].first, pred_in[1].second);  // In2 of Conv2D
   1818       nb.Input(succ_in[1].first, succ_in[1].second);  // In2 of BiasAdd
   1819       nb.Input(pred_in[2].first, pred_in[2].second);  // Mkl for In1 of Conv2D
   1820       nb.Input(pred_in[3].first, pred_in[3].second);  // Mkl for In2 of Conv2D
   1821       nb.Input(oper3_mkl, oper3_mkl_slot);            // Mkl for In2 of BiasAdd
   1822     }
   1823 
   1824     // Copy attributes from Conv2D to Conv2DWithBias.
   1825     CopyAttrsConv2D(const_cast<const Node*>(pred), &nb);
   1826 
   1827     // Copy the device assigned to old node to new node.
   1828     nb.Device(succ->def().device());
   1829 
   1830     // Create node.
   1831     Node* new_node;
   1832     nb.Finalize(&**g, &new_node);
   1833     CHECK_NOTNULL(new_node);
   1834 
   1835     // Set the Mkl layer label for this op.
   1836     new_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel);
   1837 
   1838     // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
   1839     // node are already copied in BuildNode. We handle control edges now.
   1840     for (const Edge* e : pred->in_edges()) {
   1841       if (e->IsControlEdge()) {
   1842         CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
   1843       }
   1844     }
   1845     for (const Edge* e : succ->in_edges()) {
   1846       if (e->IsControlEdge()) {
   1847         CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
   1848       }
   1849     }
   1850 
   1851     // Incoming edges are fixed, we will fix the outgoing edges now.
   1852     // First, we will fix outgoing control edges from 'pred' node.
   1853     // We don't need to handle outgoing data edges from 'pred' node
   1854     // because pred has only 1 output going to succ node (we enforced
   1855     // this check for merge already).
   1856     for (const Edge* e : pred->out_edges()) {
   1857       if (e->IsControlEdge()) {
   1858         CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
   1859       }
   1860     }
   1861 
   1862     // Second, we will fix outgoing control and data edges from 'succ' node.
   1863     for (const Edge* e : succ->out_edges()) {
   1864       if (e->IsControlEdge()) {
   1865         CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
   1866       } else {
   1867         CHECK_NOTNULL(
   1868             (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input()));
   1869       }
   1870     }
   1871 
   1872     // Copy device assigned to old node to new node.
   1873     // It's ok to use pred or succ as we have enforced a check that
   1874     // both have same device assigned.
   1875     new_node->set_assigned_device_name(pred->assigned_device_name());
   1876 
   1877     VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
   1878             << ", and node: " << succ->DebugString()
   1879             << ", into node:" << new_node->DebugString();
   1880 
   1881     (*g)->RemoveNode(succ);
   1882     (*g)->RemoveNode(pred);
   1883 
   1884     return Status::OK();
   1885   }
   1886 
   1887   return Status(error::Code::UNIMPLEMENTED,
   1888                 "Unimplemented case for node merge optimization.");
   1889 }
   1890 
   1891 //////////////////////////////////////////////////////////////////////////
   1892 //           Helper functions for node rewrite
   1893 //////////////////////////////////////////////////////////////////////////
   1894 
   1895 Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
   1896                                          Node* orig_node,
   1897                                          const RewriteInfo* ri) {
   1898   CHECK_NOTNULL(ri);
   1899   CHECK_NOTNULL(orig_node);
   1900 
   1901   VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString();
   1902 
   1903   // Check if this is scenario 2 (context-based rewrite).
   1904   // Get the matching ContextInfo if it is.
   1905   const Node* fwd_node = nullptr;
   1906   const ContextInfo* ci = nullptr;
   1907   bool is_context_based_rewrite = false;
   1908   if ((ci = SearchMatchingContext(orig_node, &fwd_node)) != nullptr) {
   1909     is_context_based_rewrite = true;
   1910 
   1911     // Sanity checks for context-based rewrite (if any)
   1912     if (orig_node->type_string() == csinfo_.bias_add_grad &&
   1913         ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) {
   1914       CHECK_NOTNULL(fwd_node);
   1915       DataType orig_T, ctx_T;
   1916       string orig_data_format, ctx_data_format;
   1917       TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &orig_T));
   1918       TF_CHECK_OK(
   1919           GetNodeAttr(orig_node->def(), "data_format", &orig_data_format));
   1920       TF_CHECK_OK(GetNodeAttr(fwd_node->def(), "T", &ctx_T));
   1921       TF_CHECK_OK(
   1922           GetNodeAttr(fwd_node->def(), "data_format", &ctx_data_format));
   1923 
   1924       if (orig_data_format != ctx_data_format || orig_T != ctx_T ||
   1925           orig_node->assigned_device_name() !=
   1926               fwd_node->assigned_device_name() ||
   1927           orig_node->def().device() != fwd_node->def().device()) {
   1928         return Status(
   1929             error::Code::INVALID_ARGUMENT,
   1930             "data_format or T attribute or devices of BiasAddGrad and "
   1931             "Conv2D do not match. Will skip node rewrite optimization");
   1932       }
   1933     } else if (orig_node->type_string() == csinfo_.bias_add_grad &&
   1934                ri->new_name == csinfo_.matmul) {
   1935       // When BiasAddGrad has MatMul in context, we do not do any rewrite
   1936       // and leave BiasAddGrad as it is. But we check for this condition
   1937       // when we check for node rewrite rule. So we should not even come
   1938       // here for MatMul. So we will fail now.
   1939       return Status(
   1940           error::Code::INVALID_ARGUMENT,
   1941           "No rewrite is required for BiasAddGrad for MatMul context.");
   1942     }
   1943   }
   1944 
   1945   // Get all inputs.
   1946   int num_inputs = orig_node->in_edges().size();
   1947 
   1948   // Drop count for control edges from inputs
   1949   for (const Edge* e : orig_node->in_edges()) {
   1950     if (e->IsControlEdge()) {
   1951       num_inputs--;
   1952     }
   1953   }
   1954 
   1955   gtl::InlinedVector<Node*, 4> control_edges;
   1956   gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs);
   1957   FillInputs(orig_node, &control_edges, &inputs);
   1958 
   1959   // Build new node. We use same name as original node, but change the op name.
   1960   NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
   1961   // Copy user-specified device assigned to original node to new node.
   1962   nb.Device(orig_node->def().device());
   1963   // Set up new inputs to the rewritten node.
   1964   Status s = SetUpInputs(g, inputs, &nb, orig_node);
   1965   if (s != Status::OK()) {
   1966     return s;
   1967   }
   1968 
   1969   // Copy attributes from original node to new node (for scenario 1).
   1970   // For context-based rewrite, we use context to copy the attributes.
   1971   if (is_context_based_rewrite) {
   1972     if (orig_node->type_string() == csinfo_.bias_add_grad &&
   1973         ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) {
   1974       CHECK_NOTNULL(fwd_node);
   1975       ri->copy_attrs(fwd_node, &nb);
   1976     } else {
   1977       return Status(error::Code::UNIMPLEMENTED,
   1978                     "Unimplemented case for node rewrite optimization.");
   1979     }
   1980   } else {
   1981     ri->copy_attrs(const_cast<const Node*>(orig_node), &nb);
   1982   }
   1983   // Set the Mkl layer label for this op.
   1984   nb.Attr("_kernel", mkl_op_registry::kMklOpLabel);
   1985 
   1986   // Finalize graph and get new node.
   1987   Node* new_node = nullptr;
   1988   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
   1989   CHECK_NOTNULL(new_node);
   1990 
   1991   // Incoming data edges from 'orig_node' node to new 'new_node' node are
   1992   // already copied in BuildNode. We need to handle control edges now.
   1993   for (const Edge* e : orig_node->in_edges()) {
   1994     if (e->IsControlEdge()) {
   1995       CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
   1996     }
   1997   }
   1998 
   1999   // Copy outgoing edges from 'orig_node' node to new
   2000   // 'new_node' node, since the output also follows same ordering among
   2001   // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
   2002   // tensors appropriately. Specifically, nth output of the original node
   2003   // will become 2*nth output of the Mkl node for the interleaved ordering
   2004   // of the tensors. For the contiguous ordering of the tensors, it will be n.
   2005   // GetTensorDataIndex provides this mapping function.
   2006   for (const Edge* e : orig_node->out_edges()) {
   2007     if (e->IsControlEdge()) {
   2008       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
   2009     } else {
   2010       CHECK_NOTNULL((*g)->AddEdge(
   2011           new_node,
   2012           GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
   2013           e->dst(), e->dst_input()));
   2014     }
   2015   }
   2016 
   2017   // Copy the runtime device assigned from original code to new node.
   2018   new_node->set_assigned_device_name(orig_node->assigned_device_name());
   2019 
   2020   // Delete original node and mark new node as rewritten.
   2021   (*g)->RemoveNode(orig_node);
   2022 
   2023   VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString();
   2024   return Status::OK();
   2025 }
   2026 
   2027 const MklLayoutRewritePass::ContextInfo*
   2028 MklLayoutRewritePass::SearchMatchingContext(const Node* n,
   2029                                             const Node** fwd_node) {
   2030   CHECK_NOTNULL(n);
   2031   CHECK_NOTNULL(fwd_node);
   2032   *fwd_node = nullptr;
   2033 
   2034   // Search for matching contextinfo based on node name and call
   2035   // callback function using matching contextinfo.
   2036   // There could be more than one matching contextinfos but whichever
   2037   // matches first is returned.
   2038   for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) {
   2039     if (n->type_string() == (*ci)->node &&
   2040         (*ci)->context_match_fn(n, fwd_node, *ci)) {
   2041       VLOG(1) << "Found context as matching: " << (*ci)->fwd;
   2042       return *ci;
   2043     }
   2044   }
   2045   return nullptr;
   2046 }
   2047 
   2048 bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n,
   2049                                                const ContextInfo* c) {
   2050   const Node* fwd_node = nullptr;
   2051   return SearchMatchingContext(n, &fwd_node) == c;
   2052 }
   2053 
   2054 const MklLayoutRewritePass::RewriteInfo*
   2055 MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
   2056   CHECK_NOTNULL(n);
   2057 
   2058   // First check if node along with its type is supported by MKL layer.
   2059   // We do not want to rewrite an op into Mkl op if types are not supported.
   2060   // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to
   2061   // MklRelu if type is INT32.
   2062   DataType T;
   2063   if (!GetNodeAttr(n->def(), "T", &T).ok()) {
   2064     return nullptr;
   2065   }
   2066 
   2067   // BiasAddGrad is not an Mkl layer, so we make an exception for it.
   2068   if (n->type_string() != csinfo_.bias_add_grad) {
   2069     if (!mkl_op_registry::IsMklOp(
   2070             mkl_op_registry::GetMklOpName(n->type_string()), T)) {
   2071       return nullptr;
   2072     }
   2073   }
   2074 
   2075   // For elementwise node, we reuse the Eigen implementation and pass the MKL
   2076   // metadata tensor through so we can avoid conversions. However, if all
   2077   // incoming edges are in TF format, we don't need all this overhead, so
   2078   // replace the elementwise node only if at least one of its parents is a MKL
   2079   // node.
   2080   //
   2081   // TODO(vrane): Add implementation for element-wise ops that doesn't reuse
   2082   // eigen code to reduce cross-library dependency.
   2083   if (mkl_op_registry::IsMklElementWiseOp(
   2084           mkl_op_registry::GetMklOpName(n->type_string()), T)) {
   2085     bool incoming_mkl_edge = false;
   2086     for (auto parent : n->in_edges()) {
   2087       if (mkl_op_registry::IsMklOp(
   2088               mkl_op_registry::GetMklOpName(parent->src()->type_string()), T)) {
   2089         incoming_mkl_edge = true;
   2090         break;
   2091       } else {
   2092         VLOG(1) << "Non-MKL parent is: " << parent->src()->type_string();
   2093       }
   2094     }
   2095     if (incoming_mkl_edge == false) {
   2096       VLOG(1) << "Skipping replacement of elementwise node which has no MKL "
   2097                  "parents.";
   2098       return nullptr;
   2099     }
   2100   }
   2101 
   2102   // We support 2 types of node rewrites:
   2103   // 1. Rewriting BiasAddGrad depending on its MklConv2DWithBias context.
   2104   // 2. Rewriting an op to Mkl op always
   2105   // We return true if any of these 2 conditions is met.
   2106 
   2107   // Find matching RewriteInfo and then check that rewrite rule applies.
   2108   for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
   2109     if (n->type_string().compare(ri->name) == 0 &&
   2110         ri->rewrite_rule(n, ri->context)) {
   2111       // If we are rewriting BiasAddGrad into BiasAddGrad for MatMul context,
   2112       // then we just return directly.
   2113       if (n->type_string() == csinfo_.bias_add_grad &&
   2114           ri->context->fwd == csinfo_.matmul &&
   2115           ri->new_name == csinfo_.bias_add_grad) {
   2116         return nullptr;
   2117       }
   2118       return &*ri;
   2119     }
   2120   }
   2121 
   2122   // Else return not found.
   2123   return nullptr;
   2124 }
   2125 
   2126 ///////////////////////////////////////////////////////////////////////////////
   2127 //              Run function for the pass
   2128 ///////////////////////////////////////////////////////////////////////////////
   2129 
   2130 bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
   2131   bool result = false;
   2132   CHECK_NOTNULL(g);
   2133 
   2134   DumpGraph("Before running MklLayoutRewritePass", &**g);
   2135 
   2136   std::vector<Node*> order;
   2137   GetReversePostOrder(**g, &order);  // This will give us topological sort.
   2138 
   2139   for (Node* n : order) {
   2140     // If node is not an op or it cannot run on CPU device, then skip.
   2141     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
   2142       continue;
   2143     }
   2144 
   2145     const RewriteInfo* ri = nullptr;
   2146     Node* predn = nullptr;
   2147     // We will first search if node is to be rewritten
   2148     if ((ri = CheckForNodeRewrite(n)) != nullptr) {
   2149       string node_name = n->name();
   2150       string op_name = n->type_string();
   2151 
   2152       VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name
   2153               << " with op " << op_name << " for rewrite using"
   2154               << " layout optimization.";
   2155 
   2156       if (RewriteNode(g, n, ri) == Status::OK()) {
   2157         VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name
   2158                 << " with op " << op_name << " for Mkl layout optimization.";
   2159         result = true;
   2160       }
   2161     } else if ((predn = CheckForNodeMerge(n)) != nullptr) {
   2162       // Otherwise, we will check if the node is to be merged.
   2163       string n1_name = n->name();
   2164       string n2_name = predn->name();
   2165 
   2166       VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and "
   2167               << n2_name << " for merging";
   2168 
   2169       if (MergeNode(g, n, predn) == Status::OK()) {
   2170         VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and "
   2171                 << n2_name;
   2172         result = true;
   2173       }
   2174     }
   2175   }
   2176 
   2177   DumpGraph("After running MklLayoutRewritePass", &**g);
   2178 
   2179   return result;
   2180 }
   2181 
   2182 bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
   2183   return MklLayoutRewritePass().RunPass(g);
   2184 }
   2185 
   2186 Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
   2187   if (options.graph == nullptr && options.partition_graphs == nullptr) {
   2188     return Status::OK();
   2189   }
   2190 
   2191   auto process_graph = [&](std::unique_ptr<Graph>* g) {
   2192     // Get the ownership of a graph
   2193     std::unique_ptr<Graph>* ng = std::move(g);
   2194     RunPass(ng);
   2195     // Return the ownership of a graph back
   2196     g->reset(ng->release());
   2197   };
   2198 
   2199   if (kMklLayoutRewritePassGroup !=
   2200       OptimizationPassRegistry::POST_PARTITIONING) {
   2201     // For any pre-partitioning phase, a graph is stored in options.graph.
   2202     process_graph(options.graph);
   2203   } else {
   2204     // For post partitioning phase, graphs are stored in
   2205     // options.partition_graphs.
   2206     for (auto& pg : *options.partition_graphs) {
   2207       process_graph(&pg.second);
   2208     }
   2209   }
   2210 
   2211   return Status::OK();
   2212 }
   2213 
   2214 #else   // INTEL_MKL_ML
   2215 
   2216 // This pass implements rewriting of graph to support following scenarios:
   2217 // (A) Merging nodes in the graph
   2218 // (B) Rewriting a node in the graph to a new node
   2219 //     Rewrite happens under following scenario:
   2220 //     - Propagating Mkl layout as an additional output tensor
   2221 //        (we will loosely call a tensor that carries Mkl layout as Mkl tensor
   2222 //         henceforth.) from every Mkl supported NN layer.
   2223 //
   2224 // Example of A : Merging nodes in the graph
   2225 // -----------------------------------------
   2226 // Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as:
   2227 //
   2228 //           O = Conv2D(A, B)
   2229 //           P = BiasAdd(O, C)
   2230 //
   2231 // We merge them into Conv2DWithBias as:
   2232 //           P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
   2233 //
   2234 // The meaning of A_m, B_m and C_m is explained in B.1.
   2235 //
   2236 // Merge rules:
   2237 //  - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_
   2238 //    goes to BiasAdd.
   2239 //  - Also, the intersection of attributes of both the nodes must have same
   2240 //    values.
   2241 //  - Both the nodes must have been assigned to same device (if any).
   2242 //
   2243 // Example of B.1 : Rewriting nodes to Mkl nodes
   2244 // ---------------------------------------------
   2245 // Consider a Relu node. Current definition of Relu node looks like:
   2246 //
   2247 //           O = Relu(A)
   2248 //
   2249 // Relu has 1 input (A), and 1 output (O).
   2250 //
   2251 // This rewrite pass will generate a new graph node for Relu (new node is
   2252 // called MklRelu) as:
   2253 //
   2254 //          O, O_m = MklRelu(A, A_m)
   2255 //
   2256 // MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is
   2257 // same as input A of Relu; output O is same as output O of Relu. O_m is the
   2258 // additional output tensor that will be set by MklRelu, and it represents
   2259 // Mkl tensor corresponding to O -- in other words, O_m is some kind of
   2260 // metadata for O. A_m is additional input of Relu, and it represents metadata
   2261 // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives
   2262 // this metadata from previous node in the graph.
   2263 //
   2264 // When a previous node in the graph is an Mkl node, A_m will represent a valid
   2265 // Mkl tensor. But when a previous node is not an Mkl node, A_m will represent
   2266 // a dummy Mkl tensor.
   2267 //
   2268 // Rewriting rules:
   2269 //  - Selection of a node for rewriting happens by registering the op type of
   2270 //    the node with the rewriting pass. If the op type is not registered, then
   2271 //    all nodes of this op type will not be rewritten.
   2272 //  - Number of inputs after rewriting:
   2273 //      Since for every input Tensorflow tensor, the rewritten node gets Mkl
   2274 //      tensor(s), rewritten node gets 2*N inputs, where N is the number of
   2275 //      inputs for the original node.
   2276 //  - Number of outputs after rewriting:
   2277 //      Since for every output Tensorflow tensor, the rewritten node generates
   2278 //      Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the
   2279 //      number of outputs of the original node.
   2280 //  - Ordering of Tensorflow tensors and Mkl tensors:
   2281 //      Since every rewritten node generates twice the number of inputs and
   2282 //      outputs, one could imagine various orderings among Tensorflow tensors
   2283 //      and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as
   2284 //      inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m
   2285 //      in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m
   2286 //      order. Among N inputs one can get N! permutations.
   2287 //
   2288 //      So the question is: which order do we follow? We support 2 types of
   2289 //      orderings: (1) interleaved, and (2) contiguous. Interleaved ordering
   2290 //      follows an intuitive order where an Mkl tensor follows the
   2291 //      corresponding Tensorflow tensor immediately. In the context of the
   2292 //      above example, it will be: A, A_m, B, B_m. Note that the ordering rule
   2293 //      applies to both the inputs and outputs. Contiguous ordering means
   2294 //      all the Tensorflow tensors are contiguous followed by all the Mkl
   2295 //      tensors. We use contiguous ordering as default.
   2296 //
   2297 // Graph rewrite algorithm:
   2298 //      Algorithm: Graph Rewrite
   2299 //      Input: Graph G, Names of the nodes to rewrite and their new names
   2300 //      Output: Modified Graph G' if the nodes are modified, G otherwise.
   2301 //      Start:
   2302 //        N = Topological_Sort(G) // N is a set of nodes in toposort order.
   2303 //        foreach node n in N
   2304 //        do
   2305 //          if (Is_MKL_Op(n))  // Can this node accept an Mkl layout as input.
   2306 //          then
   2307 //            E = set of <incoming edge and its src_output slot> of n
   2308 //            E' = {}   // a new set of edges for rewritten node
   2309 //            foreach <e,s> in E
   2310 //            do
   2311 //              E' U {<e,s>}  // First copy edge which generates Tensorflow
   2312 //                            // tensor as it is
   2313 //              m = Source node of edge e
   2314 //              if Is_Rewritten(m)  // Did we rewrite this node in this pass?
   2315 //              then
   2316 //                E' U {<m,s+1>}    // If yes, then m will generate an Mkl
   2317 //                                  // tensor as an additional output.
   2318 //              else
   2319 //                d = Generate_Dummy_Mkl_Tensor()  // If not, generate a dummy
   2320 //                                                 // Mkl tensor.
   2321 //                E' U {<d,0>}  // The dummy Mkl tensor has only 1 output slot.
   2322 //              fi
   2323 //            done
   2324 //            n' = Build_New_Node(G,new_name,E')
   2325 //            Mark_Rewritten(n')  // Mark the new node as being rewritten.
   2326 //          fi
   2327 //        done
   2328 //
   2329 //      Explanation:
   2330 //        For graph rewrite, we visit nodes of the input graph in the
   2331 //        topological sort order. With this ordering, we visit nodes in the
   2332 //        top-to-bottom fashion. We need this order because while visiting a
   2333 //        node we want that all of its input nodes are visited and rewritten if
   2334 //        applicable. This is because if we need to rewrite a given node
   2335 //        then all of its input nodes need to be fixed (in other words they
   2336 //        cannot be deleted later.)
   2337 //
   2338 //        While visiting a node, we first check if the op type of the node is
   2339 //        an Mkl op. If it is, then we rewrite that node after constructing
   2340 //        new inputs to the node. If the op type of the node is not Mkl op,
   2341 //        then we do not rewrite that node.
   2342 //
   2343 // Handling workspace propagation for certain ops:
   2344 //
   2345 //        Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require
   2346 //        passing of a workspace from their respective forward ops. Workspace
   2347 //        tensors provide memory for storing results of intermediate operations
   2348 //        which are helpful in backward propagation. TensorFlow does not have
   2349 //        a notion of a workspace and as a result does not allow producing
   2350 //        additional outputs from these forward ops. For these ops, we need
   2351 //        to add 2 extra edges between forward ops and their corresponding
   2352 //        backward ops - the first extra edge carries a workspace tensor and
   2353 //        the second one carries an Mkl tensor for the workspace tensor.
   2354 //
   2355 //        Example:
   2356 //
   2357 //        Typical graph for MaxPool and its gradient looks like:
   2358 //
   2359 //        A = MaxPool(T)
   2360 //        B = MaxPoolGrad(X, A, Y)
   2361 //
   2362 //        We will transform this graph to propagate the workspace as:
   2363 //        (with the contiguous ordering)
   2364 //
   2365 //        A, W, A_m, W_m = MklMaxPool(T, T_m)
   2366 //        B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m)
   2367 //
   2368 //        Here W is the workspace tensor. Transformed tensor names with the
   2369 //        suffix _m are Mkl tensors, and this transformation has been done
   2370 //        using the algorithm discussed earlier. The transformation for
   2371 //        workspace propagation only adds extra outputs (W, W_m) for a forward
   2372 //        op and connects them to the corresponding backward ops.
   2373 //
   2374 //        Terms:
   2375 //
   2376 //        Forward op name = name of the op in the forward pass
   2377 //          where a workspace tensor originates (MaxPool in this example)
   2378 //        Backward op name = name of the op in the backward pass that receives
   2379 //          a workspace tensor from the forward op (MaxPoolGrad in the example)
   2380 //        Slot = Position of the output or input slot that will be
   2381 //               used by the workspace tensor (1 for MklMaxPool as W is the 2nd
   2382 //               output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad)
   2383 //
   2384 //        Question:
   2385 //
   2386 //        How do we associate a backward op to a forward op? There can be more
   2387 //        than one op with the exact same name.
   2388 //
   2389 //        In this example, we associate MaxPoolGrad with MaxPool. But there
   2390 //        could be more than one MaxPool ops. To solve this problem, we look
   2391 //        for _direct_ edge between a forward op and a backward op (tensor A is
   2392 //        flowing along this edge in the example).
   2393 //
   2394 //        How do we transform forward and backward ops when there is no direct
   2395 //        edge between them? In such a case, we generate dummy tensors for
   2396 //        workspace tensors. For the example, transformation of MaxPool will
   2397 //        be exactly same as it would be when there is a direct edge between
   2398 //        the forward and the backward op --- it is just that MaxPool won't
   2399 //        generate any workspace tensor. For MaxPoolGrad, the transformation
   2400 //        will also be same, but instead of connecting W and W_m with the
   2401 //        outputs of MaxPool, we will produce dummy tensors for them, and we
   2402 //        will set workspace_enabled attribute to false.
   2403 //
   2404 class MklLayoutRewritePass : public GraphOptimizationPass {
   2405  public:
   2406   MklLayoutRewritePass() {
   2407     // NOTE: names are alphabetically sorted.
   2408     csinfo_.addn = "AddN";
   2409     csinfo_.avg_pool = "AvgPool";
   2410     csinfo_.avg_pool_grad = "AvgPoolGrad";
   2411     csinfo_.bias_add = "BiasAdd";
   2412     csinfo_.bias_add_grad = "BiasAddGrad";
   2413     csinfo_.concat = "Concat";
   2414     csinfo_.concatv2 = "ConcatV2";
   2415     csinfo_.conv2d = "Conv2D";
   2416     csinfo_.conv2d_with_bias = "__MklDummyConv2DWithBias";
   2417     csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
   2418     csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
   2419     csinfo_.conv2d_grad_filter_with_bias =
   2420         "__MklDummyConv2DBackpropFilterWithBias";
   2421     csinfo_.fused_batch_norm = "FusedBatchNorm";
   2422     csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
   2423     csinfo_.identity = "Identity";
   2424     csinfo_.lrn = "LRN";
   2425     csinfo_.lrn_grad = "LRNGrad";
   2426     csinfo_.matmul = "MatMul";
   2427     csinfo_.max_pool = "MaxPool";
   2428     csinfo_.max_pool_grad = "MaxPoolGrad";
   2429     csinfo_.mkl_conv2d = "_MklConv2D";
   2430     csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
   2431     csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
   2432     csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
   2433     csinfo_.mkl_conv2d_grad_filter_with_bias =
   2434         "_MklConv2DBackpropFilterWithBias";
   2435     csinfo_.relu = "Relu";
   2436     csinfo_.relu_grad = "ReluGrad";
   2437     csinfo_.tanh = "Tanh";
   2438     csinfo_.tanh_grad = "TanhGrad";
   2439     csinfo_.reshape = "Reshape";
   2440     csinfo_.softmax = "Softmax";
   2441     csinfo_.split = "Split";
   2442     // Element-wise ops. Ensure you also add any new ops to IsOpElementWise
   2443     // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
   2444     // MklInputConversion op is added before it.
   2445     csinfo_.add = "Add";
   2446     csinfo_.maximum = "Maximum";
   2447     csinfo_.mul = "Mul";
   2448     csinfo_.squared_difference = "SquaredDifference";
   2449     csinfo_.sub = "Sub";
   2450     // End - element-wise ops. See note above.
   2451 
   2452     // NOTE: names are alphabetically sorted.
   2453     rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
   2454                       CopyAttrsAddN, AddNRewrite});
   2455     rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
   2456                       CopyAttrsDataType, AlwaysRewrite});
   2457     rinfo_.push_back({csinfo_.avg_pool,
   2458                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
   2459                       CopyAttrsPooling, AlwaysRewrite});
   2460     rinfo_.push_back({csinfo_.avg_pool_grad,
   2461                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
   2462                       CopyAttrsPooling, AlwaysRewrite});
   2463     rinfo_.push_back({csinfo_.concat,
   2464                       mkl_op_registry::GetMklOpName(csinfo_.concat),
   2465                       CopyAttrsConcat, AlwaysRewrite});
   2466     rinfo_.push_back({csinfo_.concatv2,
   2467                       mkl_op_registry::GetMklOpName(csinfo_.concatv2),
   2468                       CopyAttrsConcatV2, AlwaysRewrite});
   2469     rinfo_.push_back({csinfo_.conv2d,
   2470                       mkl_op_registry::GetMklOpName(csinfo_.conv2d),
   2471                       CopyAttrsConv2D, AlwaysRewrite});
   2472     rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
   2473                       CopyAttrsConv2D, AlwaysRewrite});
   2474     rinfo_.push_back({csinfo_.conv2d_grad_filter,
   2475                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
   2476                       CopyAttrsConv2D, AlwaysRewrite});
   2477     rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
   2478                       csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv2D,
   2479                       AlwaysRewrite});
   2480     rinfo_.push_back({csinfo_.conv2d_grad_input,
   2481                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
   2482                       CopyAttrsConv2D, AlwaysRewrite});
   2483     rinfo_.push_back({csinfo_.fused_batch_norm,
   2484                       mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
   2485                       CopyAttrsFusedBatchNorm, AlwaysRewrite});
   2486     rinfo_.push_back(
   2487         {csinfo_.fused_batch_norm_grad,
   2488          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
   2489          CopyAttrsFusedBatchNorm, AlwaysRewrite});
   2490     rinfo_.push_back({csinfo_.identity,
   2491                       mkl_op_registry::GetMklOpName(csinfo_.identity),
   2492                       CopyAttrsDataType, AlwaysRewrite});
   2493     rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
   2494                       CopyAttrsLRN, AlwaysRewrite});
   2495     rinfo_.push_back({csinfo_.lrn_grad,
   2496                       mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
   2497                       CopyAttrsLRN, AlwaysRewrite});
   2498     rinfo_.push_back({csinfo_.max_pool,
   2499                       mkl_op_registry::GetMklOpName(csinfo_.max_pool),
   2500                       CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
   2501     rinfo_.push_back({csinfo_.max_pool_grad,
   2502                       mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
   2503                       CopyAttrsPooling, AlwaysRewrite});
   2504 
   2505     rinfo_.push_back({csinfo_.maximum,
   2506                       mkl_op_registry::GetMklOpName(csinfo_.maximum),
   2507                       CopyAttrsDataType, AlwaysRewrite});
   2508     rinfo_.push_back({csinfo_.mul,
   2509                       mkl_op_registry::GetMklOpName(csinfo_.mul),
   2510                       CopyAttrsDataType, AlwaysRewrite});
   2511     rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
   2512                       CopyAttrsDataType, AlwaysRewrite});
   2513     rinfo_.push_back({csinfo_.relu_grad,
   2514                       mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
   2515                       CopyAttrsDataType, AlwaysRewrite});
   2516     /*
   2517     rinfo_.push_back({csinfo_.tanh,
   2518                       mkl_op_registry::GetMklOpName(csinfo_.tanh),
   2519                       CopyAttrsDataType, AlwaysRewrite});
   2520     rinfo_.push_back({csinfo_.tanh_grad,
   2521                       mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
   2522                       CopyAttrsDataType, AlwaysRewrite});
   2523     */
   2524     rinfo_.push_back({csinfo_.reshape,
   2525                       mkl_op_registry::GetMklOpName(csinfo_.reshape),
   2526                       CopyAttrsReshape, AlwaysRewrite});
   2527     rinfo_.push_back({csinfo_.softmax,
   2528                       mkl_op_registry::GetMklOpName(csinfo_.softmax),
   2529                       CopyAttrsDataType, AlwaysRewrite});
   2530 
   2531     rinfo_.push_back({csinfo_.squared_difference,
   2532                       mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
   2533                       CopyAttrsDataType, AlwaysRewrite});
   2534     rinfo_.push_back({csinfo_.sub,
   2535                       mkl_op_registry::GetMklOpName(csinfo_.sub),
   2536                       CopyAttrsDataType, AlwaysRewrite});
   2537 
   2538     // Add info about which ops to add workspace edge to and the slots.
   2539     wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
   2540     wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
   2541 
   2542     // Add a rule for merging nodes
   2543     minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
   2544                       csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd});
   2545 
   2546     minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
   2547                       csinfo_.conv2d_grad_filter_with_bias,
   2548                       GetConv2DBackpropFilterOrBiasAddGrad});
   2549   }
   2550 
   2551   // Standard interface to run pass
   2552   Status Run(const GraphOptimizationPassOptions& options);
   2553 
   2554   // Helper function which does most of heavy lifting for rewriting
   2555   // Mkl nodes to propagate Mkl tensor as additional output
   2556   //
   2557   // Extracts common functionality between Run public interface and
   2558   // test interface.
   2559   //
   2560   // @return true, if and only if graph is mutated; false otherwise.
   2561   bool RunPass(std::unique_ptr<Graph>* g);
   2562 
   2563   /// Structure to specify the name of an original node, its new name after
   2564   /// rewrite, the number of inputs to the original node, the function to
   2565   /// be used to copy attributes for the op, and the rule (if any) which
   2566   /// must hold for rewriting the node
   2567   typedef struct {
   2568     string name;      // Original name of op of the node in the graph
   2569     string new_name;  // New name of the op of the node in the graph
   2570     // A function handler to copy attributes from an old node to a new node.
   2571     std::function<void(const Node*, NodeBuilder*)> copy_attrs;
   2572     // A rule under which to rewrite this node
   2573     std::function<bool(const Node*)> rewrite_rule;
   2574   } RewriteInfo;
   2575 
   2576   /// Structure to specify a forward op, a backward op, and the slot numbers
   2577   /// in the forward and backward ops where we will add a workspace edge.
   2578   typedef struct {
   2579     string fwd_op;    // Name of a forward op in the graph
   2580     string bwd_op;    // Name of a backward op in the graph
   2581     int fwd_slot;     // Output slot in the forward op node where actual
   2582                       // output tensor resides
   2583     int bwd_slot;     // Input slot in the backward op node where actual
   2584                       // input tensor resides
   2585     int ws_fwd_slot;  // Output slot in the forward op node where workspace
   2586                       // edge is added
   2587     int ws_bwd_slot;  // Input slot in the backward op node where workspace
   2588                       // edge is added
   2589   } WorkSpaceInfo;
   2590 
   2591   /// Structure to specify information used in node merge of 2 operators
   2592   typedef struct {
   2593     string op1;       // Node string for one operator.
   2594     string op2;       // Node string for second operator.
   2595     string new_node;  // Name of the node after merge
   2596     // Function that enables user of the node merger to specify how to find
   2597     // second operator given the first operator.
   2598     std::function<Node*(const Node*)> get_node_to_be_merged;
   2599   } MergeInfo;
   2600 
   2601   /// Structure to store all constant strings
   2602   /// NOTE: names are alphabetically sorted.
   2603   typedef struct {
   2604     string addn;
   2605     string add;
   2606     string avg_pool;
   2607     string avg_pool_grad;
   2608     string bias_add;
   2609     string bias_add_grad;
   2610     string concat;
   2611     string concatv2;
   2612     string conv2d;
   2613     string conv2d_with_bias;
   2614     string conv2d_grad_input;
   2615     string conv2d_grad_filter;
   2616     string conv2d_grad_filter_with_bias;
   2617     string fused_batch_norm;
   2618     string fused_batch_norm_grad;
   2619     string identity;
   2620     string lrn;
   2621     string lrn_grad;
   2622     string matmul;
   2623     string max_pool;
   2624     string max_pool_grad;
   2625     string maximum;
   2626     string mkl_conv2d;
   2627     string mkl_conv2d_grad_input;
   2628     string mkl_conv2d_grad_filter;
   2629     string mkl_conv2d_grad_filter_with_bias;
   2630     string mkl_conv2d_with_bias;
   2631     string mul;
   2632     string relu;
   2633     string relu_grad;
   2634     string tanh;
   2635     string tanh_grad;
   2636     string reshape;
   2637     string softmax;
   2638     string split;
   2639     string squared_difference;
   2640     string sub;
   2641   } ConstStringsInfo;
   2642 
   2643  private:
   2644   /// Maintain info about nodes to rewrite
   2645   std::vector<RewriteInfo> rinfo_;
   2646 
   2647   /// Maintain info about nodes to add workspace edge
   2648   std::vector<WorkSpaceInfo> wsinfo_;
   2649 
   2650   /// Maintain info about nodes to be merged
   2651   std::vector<MergeInfo> minfo_;
   2652 
   2653   /// Maintain structure of constant strings
   2654   static ConstStringsInfo csinfo_;
   2655 
   2656  private:
   2657   // Is OpDef::ArgDef a list type? It could be N * T or list(type).
   2658   // Refer to opdef.proto for details of list type.
   2659   inline bool ArgIsList(const OpDef::ArgDef& arg) const {
   2660     return !arg.type_list_attr().empty() || !arg.number_attr().empty();
   2661   }
   2662 
   2663   // Get length of a list in 'n' if 'arg' is of list type. Refer to
   2664   // description of ArgIsList for definition of list type.
   2665   inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) {
   2666     CHECK_EQ(ArgIsList(arg), true);
   2667     int N = 0;
   2668     const string attr_name = !arg.type_list_attr().empty()
   2669                                  ? arg.type_list_attr()
   2670                                  : arg.number_attr();
   2671     if (!arg.type_list_attr().empty()) {
   2672       std::vector<DataType> value;
   2673       TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value));
   2674       N = value.size();
   2675     } else {
   2676       TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N));
   2677     }
   2678     return N;
   2679   }
   2680 
   2681   // Can op represented by node 'n' run on DEVICE_CPU?
   2682   // Op can run on CPU with MKL if the runtime assigned device or the
   2683   // user requested device contains device CPU, or both are empty.
   2684   bool CanOpRunOnCPUDevice(const Node* n) {
   2685     bool result = true;
   2686     string reason;
   2687 
   2688     // Substring that should be checked for in device name for CPU device.
   2689     const char* const kCPUDeviceSubStr = "CPU";
   2690 
   2691     // If Op has been specifically assigned to a non-CPU device, then No.
   2692     if (!n->assigned_device_name().empty() &&
   2693         !StringPiece(n->assigned_device_name()).contains(kCPUDeviceSubStr)) {
   2694       result = false;
   2695       reason = "Op has been assigned a runtime device that is not CPU.";
   2696     }
   2697 
   2698     // If user has specifically assigned this op to a non-CPU device, then No.
   2699     if (!n->def().device().empty() &&
   2700         !StringPiece(n->def().device()).contains(kCPUDeviceSubStr)) {
   2701       result = false;
   2702       reason = "User has assigned a device that is not CPU.";
   2703     }
   2704 
   2705     if (result == false) {
   2706       VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node "
   2707               << n->type_string() << ", reason: " << reason;
   2708     }
   2709 
   2710     // Otherwise Yes.
   2711     return result;
   2712   }
   2713 
   2714   // Return a node that can be merged with input node 'n'
   2715   //
   2716   // @return pointer to the node if we can find such a
   2717   // node. Otherwise, it returns nullptr.
   2718   Node* CheckForNodeMerge(const Node* n) const;
   2719 
   2720   // Merge node 'm' with node 'n'.
   2721   // Currently, we merge (1) Conv2D with BiasAdd, and (2) BiasAddGrad with
   2722   // Conv2DBackpropFilter.
   2723   //
   2724   // Input nodes m and n may be deleted if the call to
   2725   // this function is successful. Attempt to use the pointers
   2726   // after the call to function may result in undefined behaviors.
   2727   //
   2728   // @input g - input graph, m - graph node, n - graph node to be merged with m
   2729   // @return Status::OK(), if merging is successful and supported.
   2730   //         Returns appropriate Status error code otherwise.
   2731   //         Graph is updated in case nodes are merged. Otherwise, it is
   2732   //         not updated.
   2733   Status MergeNode(std::unique_ptr<Graph>* g, Node* m, Node* n);
   2734 
   2735   // Helper function to merge different nodes
   2736   Status MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, Node* m, Node* n);
   2737   Status MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph>* g,
   2738                                                   Node* m, Node* n);
   2739 
   2740   // Find BiasAdd or Conv2D node that can be merged with input node 'm'.
   2741   // If input 'm' is BiasAdd, then check if there exists Conv2D node that can be
   2742   // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd
   2743   // node that can be merged with 'm'.
   2744   static Node* GetConv2DOrBiasAdd(const Node* m) {
   2745     CHECK_NOTNULL(m);
   2746     Node* n = nullptr;
   2747 
   2748     if (m->type_string() == csinfo_.bias_add) {
   2749       // If a is BiasAdd, then Conv2D is 0th input of BiasAdd.
   2750       TF_CHECK_OK(m->input_node(0, &n));
   2751     } else {
   2752       CHECK_EQ(m->type_string(), csinfo_.conv2d);
   2753       // Go over all output edges and search for BiasAdd Node.
   2754       // 0th input of BiasAdd is Conv2D.
   2755       for (const Edge* e : m->out_edges()) {
   2756         if (!e->IsControlEdge() &&
   2757             e->dst()->type_string() == csinfo_.bias_add &&
   2758             e->dst_input() == 0) {
   2759           n = e->dst();
   2760           break;
   2761         }
   2762       }
   2763     }
   2764 
   2765     if (n == nullptr) {
   2766       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
   2767               << "Conv2D and BiasAdd node for merging. Input node: "
   2768               << m->DebugString();
   2769     }
   2770 
   2771     return n;
   2772   }
   2773 
   2774   // Find Conv2DBackpropFilter or BiasAddGrad node that can be merged with input
   2775   // node 'm'. If input 'm' is Conv2DBackpropFilter, then check if there exists
   2776   // BiasAddGrad node that can be merged with 'm'. If input 'm' is BiasAddGrad,
   2777   // then check if there exists Conv2DBackpropFilter node that can be merged
   2778   // with 'm'.
   2779   //
   2780   // Graph that will allow us to connect Conv2DBackpropFilter with BiasAddGrad
   2781   // would look like:
   2782   //
   2783   // _ = Conv2DBackpropFilter(F, _, G)
   2784   // _ = BiasAddGrad(G)
   2785   //
   2786   // So 1st input of BiasAddGrad connects with 3rd input of
   2787   // Conv2DBackpropFilter and vice versa.
   2788   static Node* GetConv2DBackpropFilterOrBiasAddGrad(const Node* m) {
   2789     CHECK_NOTNULL(m);
   2790     Node* n = nullptr;
   2791 
   2792     if (m->type_string() == csinfo_.bias_add_grad) {
   2793       // Get 1st input 'g' of BiasAddGrad.
   2794       Node* g = nullptr;
   2795       TF_CHECK_OK(m->input_node(0, &g));
   2796       // Now traverse all outgoing edges from g that have destination node as
   2797       // Conv2DBackpropFilter.
   2798       for (const Edge* e : g->out_edges()) {
   2799         if (!e->IsControlEdge() &&
   2800             e->dst()->type_string() == csinfo_.conv2d_grad_filter &&
   2801             e->dst_input() == 2 /* 3rd input of BackpropFilter */) {
   2802           n = e->dst();
   2803           break;
   2804         }
   2805       }
   2806     } else {
   2807       CHECK_EQ(m->type_string(), csinfo_.conv2d_grad_filter);
   2808       // Get 3rd input 'g' of Conv2DBackpropFilter.
   2809       Node* g = nullptr;
   2810       TF_CHECK_OK(m->input_node(2, &g));
   2811       // Now traverse all outgoing edges from g that have destination node as
   2812       // BiasAddGrad.
   2813       for (const Edge* e : g->out_edges()) {
   2814         if (!e->IsControlEdge() &&
   2815             e->dst()->type_string() == csinfo_.bias_add_grad &&
   2816             e->dst_input() == 0 /* 1st input of BiasAddGrad */) {
   2817           n = e->dst();
   2818           break;
   2819         }
   2820       }
   2821     }
   2822 
   2823     if (n == nullptr) {
   2824       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
   2825               << "Conv2DBackpropFilter and BiasAddGrad node for merging. "
   2826               << "Input node: " << m->DebugString();
   2827     }
   2828     return n;
   2829   }
   2830 
   2831   // Check if the node 'n' has any applicable rewrite rule
   2832   // We check for 2 scenarios for rewrite.
   2833   //
   2834   // @return RewriteInfo* for the applicable rewrite rule
   2835   const RewriteInfo* CheckForNodeRewrite(const Node* n) const;
   2836 
   2837   // Default rewrite rule to be used in scenario 1 for rewrite.
   2838   // @return - true (since we want to always rewrite)
   2839   static bool AlwaysRewrite(const Node* n) { return true; }
   2840 
   2841   // Check if we are performing pooling on depth or batch. If it is, then we
   2842   // do not rewrite MaxPool node to Mkl version.
   2843   // @return - true (if it is not a depth/batch wise pooling case);
   2844   //           false otherwise.
   2845   static bool NonDepthBatchWisePoolRewrite(const Node* n) {
   2846     CHECK_NOTNULL(n);
   2847 
   2848     string data_format_str;
   2849     TensorFormat data_format;
   2850     std::vector<int32> ksize, strides;
   2851     CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
   2852     CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
   2853     CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true);
   2854     CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
   2855 
   2856     // Condition that specifies non-batch-wise and non-depth-wise pooling.
   2857     if (GetTensorDim(ksize, data_format, 'N') == 1 &&
   2858         GetTensorDim(strides, data_format, 'N') == 1 &&
   2859         GetTensorDim(ksize, data_format, 'C') == 1 &&
   2860         GetTensorDim(strides, data_format, 'C') == 1) {
   2861       return true;
   2862     }
   2863 
   2864     return false;
   2865   }
   2866 
   2867   static bool AddNRewrite(const Node* n) {
   2868     CHECK_NOTNULL(n);
   2869 
   2870     int num;
   2871     CHECK_EQ(GetNodeAttr(n->def(), "N", &num).ok(), true);
   2872 
   2873     // Condition that specifies non-batch-wise and non-depth-wise pooling.
   2874     if (num == 2) {
   2875       return true;
   2876     }
   2877 
   2878     return false;
   2879   }
   2880 
   2881   // Rewrites input node to a new node specified by its matching rewrite info.
   2882   //
   2883   // Method first searches matching rewrite info for input node and then
   2884   // uses that info to rewrite.
   2885   //
   2886   // Input node may be deleted in case of rewrite. Attempt to use the node
   2887   // after the call can result in undefined behaviors.
   2888   //
   2889   // @input  g - input graph, n - Node to be rewritten,
   2890   //         ri - matching rewriteinfo
   2891   // @return Status::OK(), if the input node is rewritten;
   2892   //         Returns appropriate Status error code otherwise.
   2893   //         Graph is updated in case the input node is rewritten.
   2894   //         Otherwise, it is not updated.
   2895   Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri);
   2896 
   2897   // Get nodes that will feed a list of TF tensors to the new
   2898   // node that we are constructing.
   2899   //
   2900   // @input g - input graph,
   2901   // @input inputs - inputs to old node that we are using for constructing
   2902   //                 new inputs,
   2903   // @input input_idx - the index in the 'inputs' vector pointing to the
   2904   //                    current input that we have processed so far
   2905   // @output input_idx - index will be incremented by the number of nodes
   2906   //                     from 'inputs' that are processed
   2907   // @input list_length - The expected length of list of TF tensors
   2908   // @output output_nodes - the list of new nodes creating TF tensors
   2909   //
   2910   // @return None
   2911   void GetNodesProducingTFTensorList(
   2912       const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
   2913       int* input_idx, int list_length,
   2914       std::vector<NodeBuilder::NodeOut>* output_nodes);
   2915 
   2916   // Get nodes that will feed a list of Mkl tensors to the new
   2917   // node that we are constructing.
   2918   //
   2919   // @input g - input graph,
   2920   // @input orig_node - Original node that we are rewriting
   2921   // @input inputs - inputs to old node that we are using for constructing
   2922   //                 new inputs,
   2923   // @input input_idx - the index in the 'inputs' vector pointing to the
   2924   //                    current input that we have processed so far
   2925   // @output input_idx - index will be incremented by the number of nodes
   2926   //                     from 'inputs' that are processed
   2927   // @input list_length - The expected length of list of Mkl tensors
   2928   // @output output_nodes - the list of new nodes creating Mkl tensors
   2929   //
   2930   // @return None
   2931   void GetNodesProducingMklTensorList(
   2932       std::unique_ptr<Graph>* g, Node* orig_node,
   2933       const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
   2934       int* input_idx, int list_length,
   2935       std::vector<NodeBuilder::NodeOut>* output_nodes);
   2936 
   2937   // Get a node that will feed an Mkl tensor to the new
   2938   // node that we are constructing. The output node could be (1) 'n'
   2939   // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
   2940   // if 'n' is not an Mkl layer.
   2941   //
   2942   // @input g - input graph,
   2943   // @input orig_node - Original node that we are rewriting,
   2944   // @input n - Node based on which we are creating Mkl node,
   2945   // @input n_output_slot - the output slot of node 'n'
   2946   //            which is feeding to the node that we are constructing
   2947   // @output mkl_node - the new node that will feed Mkl tensor
   2948   // @output mkl_node_output_slot - the slot number of mkl_node that
   2949   //                                will feed the tensor
   2950   // @return None
   2951   void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
   2952                                  Node* n, int n_output_slot, Node** mkl_node,
   2953                                  int* mkl_node_output_slot);
   2954 
   2955   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
   2956   // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
   2957   // set up in contiguous fashion. 'workspace_tensors' carry graph nodes
   2958   // producing workspace edges if 'are_workspace_tensors_available' is true.
   2959   // Otherwise, 'workspace_tensors' is empty vector.
   2960   //
   2961   // For details, refer to 'Ordering of inputs after rewriting' section in the
   2962   // documentation above.
   2963   //
   2964   // Returns Status::OK() if setting up inputs is successful, otherwise
   2965   // returns appropriate status code.
   2966   int SetUpContiguousInputs(
   2967       std::unique_ptr<Graph>* g,
   2968       const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
   2969       NodeBuilder* nb, Node* old_node,
   2970       std::vector<NodeBuilder::NodeOut>* workspace_tensors,
   2971       bool are_workspace_tensors_available);
   2972 
   2973   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
   2974   // in graph 'g'. Original node is input in 'orig_node'.
   2975   //
   2976   // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors'
   2977   // section in the documentation above.
   2978   //
   2979   // Returns Status::OK() if setting up inputs is successful, otherwise
   2980   // returns appropriate status code.
   2981   Status SetUpInputs(std::unique_ptr<Graph>* g,
   2982                      const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
   2983                      NodeBuilder* nb, Node* orig_node);
   2984 
   2985   // Add workspace edge on the input or output side of Node 'orig_node' by using
   2986   // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate
   2987   // adding workspace edge then do not add it. Workspace Tensorflow and Mkl
   2988   // tensors, if they need to be added, will be set into these tensors.
   2989   // If we set workspace tensors, then are_ws_tensors_added should be true.
   2990   void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orig_node,
   2991                                 NodeBuilder* nb,
   2992                                 std::vector<NodeBuilder::NodeOut>* ws_tensors,
   2993                                 bool* are_ws_tensors_added);
   2994 
   2995   // Functions specific to operators to copy attributes
   2996   // We need operator-specific function to copy attributes because the framework
   2997   // does not provide any generic function for it.
   2998   // NOTE: names are alphabetically sorted.
   2999   static void CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb);
   3000   static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb);
   3001   static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb);
   3002   static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb);
   3003   static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb);
   3004   static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb);
   3005   static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb);
   3006   static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
   3007   static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb);
   3008   static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb);
   3009   static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb);
   3010 
   3011   // Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
   3012   // using node for original node 'orig_node' and return it in '*out'.
   3013   // TODO(nhasabni) We should move this to mkl_util.h
   3014   void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out,
   3015                              Node* orig_node);
   3016   void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out,
   3017                                    Node* orig_node);
   3018 };
   3019 
   3020 MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
   3021 
   3022 // We register Mkl rewrite pass for phase 1 in post partitioning group.
   3023 // We register it here so that we get a complete picture of all users of Mkl
   3024 // nodes. Do not change the ordering of the Mkl passes.
   3025 const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
   3026     OptimizationPassRegistry::POST_PARTITIONING;
   3027 REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
   3028 
   3029 //////////////////////////////////////////////////////////////////////////
   3030 //           Helper functions for creating new node
   3031 //////////////////////////////////////////////////////////////////////////
   3032 
   3033 static void FillInputs(const Node* n,
   3034                        gtl::InlinedVector<Node*, 4>* control_edges,
   3035                        gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
   3036   control_edges->clear();
   3037   for (const Edge* e : n->in_edges()) {
   3038     if (e->IsControlEdge()) {
   3039       control_edges->push_back(e->src());
   3040     } else {
   3041       (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
   3042     }
   3043   }
   3044   std::sort(control_edges->begin(), control_edges->end());
   3045   if (n->op_def().is_commutative()) {
   3046     // For commutative inputs, we sort the input by the input Node*
   3047     // to get a canonical ordering (so that add(a,b) and add(b, a) will
   3048     // hash to the same value if is_commutative is true for 'add').
   3049     std::sort(in->begin(), in->end());
   3050   }
   3051 }
   3052 
   3053 void MklLayoutRewritePass::GetNodesProducingTFTensorList(
   3054     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
   3055     int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
   3056   CHECK_LT(*input_idx, inputs.size());
   3057   CHECK_GT(list_length, 0);
   3058   CHECK_NOTNULL(output_nodes);
   3059   output_nodes->reserve(list_length);
   3060 
   3061   while (list_length != 0) {
   3062     CHECK_GT(list_length, 0);
   3063     CHECK_LT(*input_idx, inputs.size());
   3064     Node* n = inputs[*input_idx].first;
   3065     int slot = inputs[*input_idx].second;
   3066     // If input node 'n' is just producing a single tensor at
   3067     // output slot 'slot' then we just add that single node.
   3068     output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
   3069     (*input_idx)++;
   3070     list_length--;
   3071   }
   3072 }
   3073 
   3074 // TODO(nhasabni) We should move this to mkl_util.h.
   3075 void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
   3076                                                  Node** out, Node* orig_node) {
   3077   // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent
   3078   // dummy Mkl tensor. 8 = 2*size_t.
   3079   const DataType dt = DataTypeToEnum<uint8>::v();
   3080   TensorProto proto;
   3081   proto.set_dtype(dt);
   3082   uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
   3083   proto.set_tensor_content(const_cast<const void*>(static_cast<void*>(&zero)),
   3084                            8);
   3085   TensorShape dummy_shape({8});
   3086   dummy_shape.AsProto(proto.mutable_tensor_shape());
   3087   TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
   3088                   .Attr("value", proto)
   3089                   .Attr("dtype", dt)
   3090                   .Device(orig_node->def().device())  // We place this node on
   3091                                                       // the same device as the
   3092                                                       // device of the original
   3093                                                       // node.
   3094                   .Finalize(&**g, out));
   3095 
   3096   // If number of inputs to the original node is > 0, then we add
   3097   // control dependency between 1st input (index 0) of the original node and
   3098   // the dummy Mkl node. This is needed because control-flow ops such as Enter,
   3099   // Merge, etc, require frame_name of the dummy Mkl node to be same as the
   3100   // rewritten node. Adding control edge between 1st input of the original node
   3101   // and the dummy Mkl node ensures that the dummy node is in the same frame
   3102   // as the original node. Choosing 1st input is not necessary - any input of
   3103   // the original node is fine because all the inputs of a node are always in
   3104   // the same frame.
   3105   if (orig_node->num_inputs() > 0) {
   3106     Node* orig_input0 = nullptr;
   3107     TF_CHECK_OK(
   3108         orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
   3109     // Allow duplicate while adding control edge as it would fail (return
   3110     // NULL) if we try to add duplicate edge.
   3111     CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
   3112   }
   3113 
   3114   (*out)->set_assigned_device_name(orig_node->assigned_device_name());
   3115 }
   3116 
   3117 void MklLayoutRewritePass::GetNodesProducingMklTensorList(
   3118     std::unique_ptr<Graph>* g, Node* orig_node,
   3119     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
   3120     int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
   3121   CHECK_LT(*input_idx, inputs.size());
   3122   CHECK_GT(list_length, 0);
   3123   CHECK_NOTNULL(output_nodes);
   3124   output_nodes->reserve(list_length);
   3125 
   3126   while (list_length != 0) {
   3127     CHECK_GT(list_length, 0);
   3128     CHECK_LT(*input_idx, inputs.size());
   3129     Node* n = inputs[*input_idx].first;
   3130     int slot = inputs[*input_idx].second;
   3131     // If 'n' is producing a single tensor, then create a single Mkl tensor
   3132     // node.
   3133     Node* mkl_node = nullptr;
   3134     int mkl_node_output_slot = 0;
   3135     GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
   3136                               &mkl_node_output_slot);
   3137     output_nodes->push_back(
   3138         NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
   3139     (*input_idx)++;
   3140     list_length--;
   3141   }
   3142 }
   3143 
   3144 // Get an input node that will feed Mkl tensor to the new
   3145 // node that we are constructing. An input node could be (1) 'n'
   3146 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
   3147 // if 'n' is not an Mkl layer.
   3148 void MklLayoutRewritePass::GetNodeProducingMklTensor(
   3149     std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot,
   3150     Node** mkl_node, int* mkl_node_output_slot) {
   3151   CHECK_NOTNULL(n);
   3152   CHECK_NOTNULL(mkl_node);
   3153   CHECK_NOTNULL(mkl_node_output_slot);
   3154 
   3155   // If this is an MKL op, then it will create extra output for MKL layout.
   3156   DataType T;
   3157   if (GetNodeAttr(n->def(), "T", &T).ok() &&
   3158       mkl_op_registry::IsMklOp(n->type_string(), T)) {
   3159     // If this is an MKL op, then it will generate an edge that will receive
   3160     // Mkl tensor from a node.
   3161     // output slot number for Mkl tensor would be N+slot number of TensorFlow
   3162     // tensor, where N is total number of TensorFlow tensors.
   3163     *mkl_node = n;
   3164     *mkl_node_output_slot =
   3165         GetTensorMetaDataIndex(n_output_slot, n->num_outputs());
   3166   } else {
   3167     // If we have not visited the node and rewritten it, then we need
   3168     // to create a dummy node that will feed a dummy Mkl tensor to this node.
   3169     // DummyMklTensor node has no input and generates only 1 output
   3170     // (dummy Mkl tensor) as output slot number 0.
   3171     GetDummyMklTensorNode(g, mkl_node, orig_node);
   3172     CHECK_NOTNULL(*mkl_node);
   3173     *mkl_node_output_slot = 0;
   3174   }
   3175 }
   3176 
   3177 int MklLayoutRewritePass::SetUpContiguousInputs(
   3178     std::unique_ptr<Graph>* g,
   3179     const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
   3180     NodeBuilder* nb, Node* old_node,
   3181     std::vector<NodeBuilder::NodeOut>* workspace_tensors,
   3182     bool are_workspace_tensors_available) {
   3183   CHECK_NOTNULL(workspace_tensors);
   3184   CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
   3185 
   3186   // TODO(nhasabni): Temporary solution to connect filter input of
   3187   // BackpropInput with the converted filter from Conv2D.
   3188   bool do_connect_conv2d_backprop_input_filter = false;
   3189   Node* conv2d_node = nullptr;
   3190   // Filter node is 2nd input (slot index 1) of Conv2D.
   3191   int kConv2DFilterInputSlotIdx = 1;
   3192   int kConv2DBackpropInputFilterInputSlotIdx = 1;
   3193   int kConv2DFilterOutputSlotIdx = 1;
   3194   if (old_node->type_string() == csinfo_.conv2d_grad_input) {
   3195     // We need to find Conv2D node from Conv2DBackpropInput.
   3196     // For that let's first find filter node that is 2nd input (slot 1)
   3197     // of BackpropInput.
   3198     Node* filter_node = nullptr;
   3199     old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx, &filter_node);
   3200     CHECK_NOTNULL(filter_node);
   3201 
   3202     // Now check which nodes receive from filter_node. Filter feeds as
   3203     // 2nd input (slot 1) of _MklConv2D and _MklConv2DWithBias.
   3204     for (const Edge* e : filter_node->out_edges()) {
   3205       if ((e->dst()->type_string() == csinfo_.mkl_conv2d ||
   3206            e->dst()->type_string() == csinfo_.mkl_conv2d_with_bias) &&
   3207           e->dst_input() == kConv2DFilterInputSlotIdx
   3208           /* filter is 2nd input of Conv2D and _MklConv2D. */) {
   3209         if (conv2d_node != nullptr) {
   3210           VLOG(1) << "MklLayoutRewritePass: unusual case of same filter"
   3211                   << " feeding multiple Conv2D nodes: "
   3212                   << filter_node->DebugString();
   3213           // We will not connect filter input of Conv2DBackpropInput
   3214           // to be safe here.
   3215           do_connect_conv2d_backprop_input_filter = false;
   3216           break;
   3217         } else {
   3218           conv2d_node = e->dst();
   3219           do_connect_conv2d_backprop_input_filter = true;
   3220         }
   3221       }
   3222     }
   3223   }
   3224 
   3225   // Number of input slots to original op
   3226   // Input slots are represented by .Input() calls in REGISTER_OP.
   3227   int old_node_input_slots = old_node->op_def().input_arg_size();
   3228   // Actual number of inputs can be greater than or equal to number
   3229   // of Input slots because inputs of type list could be unfolded.
   3230   CHECK_GE(old_node_inputs.size(), old_node_input_slots);
   3231   int nn_slot_idx = 0;  // slot index for inputs of new node
   3232 
   3233   // Let's copy all inputs (TF tensors) of original node to new node.
   3234   int iidx = 0;
   3235   for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
   3236     // An input slot could be a single tensor or a list. We need
   3237     // to handle this case accordingly.
   3238     CHECK_LT(iidx, old_node_inputs.size());
   3239     const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
   3240     if (ArgIsList(arg)) {
   3241       std::vector<NodeBuilder::NodeOut> new_node_inputs;
   3242       int N = GetTensorListLength(arg, old_node);
   3243       GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
   3244                                     &new_node_inputs);
   3245       nb->Input(new_node_inputs);
   3246       nn_slot_idx++;
   3247     } else {
   3248       // Special case for connecting filter input of Conv2DBackpropInput
   3249       if (do_connect_conv2d_backprop_input_filter &&
   3250           iidx == kConv2DBackpropInputFilterInputSlotIdx) {
   3251         nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx);
   3252       } else {
   3253         nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
   3254       }
   3255       iidx++;
   3256       nn_slot_idx++;
   3257     }
   3258   }
   3259 
   3260   // If workspace tensors are available for this op and we are using
   3261   // contiguous ordering then we need to add Tensorflow tensor for
   3262   // workspace here because Tensorflow tensor for workspace is the
   3263   // last tensor in the list of Tensorflow tensors.
   3264   if (are_workspace_tensors_available) {
   3265     CHECK_EQ(workspace_tensors->size(), 2);
   3266     // Tensorflow tensor
   3267     nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index);
   3268     nn_slot_idx++;
   3269   }
   3270 
   3271   // Let's now setup all Mkl inputs to a new node.
   3272   // Number of Mkl inputs must be same as number of TF inputs.
   3273   iidx = 0;
   3274   for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
   3275     // An input slot could be a single tensor or a list. We need
   3276     // to handle this case accordingly.
   3277     CHECK_LT(iidx, old_node_inputs.size());
   3278     const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
   3279     if (ArgIsList(arg)) {
   3280       std::vector<NodeBuilder::NodeOut> new_node_inputs;
   3281       int N = GetTensorListLength(arg, old_node);
   3282       GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N,
   3283                                      &new_node_inputs);
   3284       nb->Input(new_node_inputs);
   3285       nn_slot_idx++;
   3286     } else {
   3287       Node* mkl_node = nullptr;
   3288       int mkl_node_output_slot = 0;
   3289       // Special case for connecting filter input of Conv2DBackpropInput
   3290       if (do_connect_conv2d_backprop_input_filter &&
   3291           iidx == kConv2DBackpropInputFilterInputSlotIdx) {
   3292         GetNodeProducingMklTensor(g, old_node, conv2d_node,
   3293                                   kConv2DFilterOutputSlotIdx, &mkl_node,
   3294                                   &mkl_node_output_slot);
   3295       } else {
   3296         GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
   3297                                   old_node_inputs[iidx].second, &mkl_node,
   3298                                   &mkl_node_output_slot);
   3299       }
   3300       nb->Input(mkl_node, mkl_node_output_slot);
   3301       iidx++;
   3302       nn_slot_idx++;
   3303     }
   3304   }
   3305 
   3306   // If workspace tensors are available for this op and we are using
   3307   // contiguous ordering then we need to add Mkl tensor for
   3308   // workspace here because Mkl tensor for workspace is the
   3309   // last tensor in the list of Mkl tensors.
   3310   if (are_workspace_tensors_available) {
   3311     CHECK_EQ(workspace_tensors->size(), 2);
   3312     // Mkl tensor
   3313     nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index);
   3314     nn_slot_idx++;
   3315   }
   3316 
   3317   return nn_slot_idx;
   3318 }
   3319 
   3320 Status MklLayoutRewritePass::SetUpInputs(
   3321     std::unique_ptr<Graph>* g,
   3322     const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
   3323     NodeBuilder* nb, Node* old_node) {
   3324   // Let's check if we need to add workspace tensors for this node.
   3325   // We add workspace edge only for MaxPool, LRN and BatchNorm.
   3326   std::vector<NodeBuilder::NodeOut> workspace_tensors;
   3327   bool are_workspace_tensors_available = false;
   3328   AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors,
   3329                            &are_workspace_tensors_available);
   3330 
   3331   int new_node_input_slots = 0;
   3332   if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
   3333     // TODO(nhasabni): implement this function just for same of completion.
   3334     // We do not use interleaved ordering right now.
   3335     return Status(
   3336         error::Code::UNIMPLEMENTED,
   3337         "Interleaved ordering of tensors is currently not supported.");
   3338   } else {
   3339     CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
   3340     new_node_input_slots = SetUpContiguousInputs(
   3341         g, old_node_inputs, nb, old_node, &workspace_tensors,
   3342         are_workspace_tensors_available);
   3343   }
   3344 
   3345   // Sanity check
   3346   int old_node_input_slots = old_node->op_def().input_arg_size();
   3347   if (!are_workspace_tensors_available) {
   3348     // If we are not adding workspace tensors for this op, then the total
   3349     // number of input slots to the new node _must_ be 2 times the number
   3350     // of input slots to the original node: N original Tensorflow tensors and
   3351     // N for Mkl tensors corresponding to each Tensorflow tensors.
   3352     CHECK_EQ(new_node_input_slots, old_node_input_slots * 2);
   3353   } else {
   3354     // If we are adding workspace tensors for this op, then the total
   3355     // The total number of input slots to new node _must_ be 2 times the number
   3356     // of input slots to the original node: N original Tensorflow tensors and
   3357     // N for Mkl tensors corresponding to each Tensorflow tensors plus 2
   3358     // (for workspace Tensorflow tensor and workspace Mkl tensor).
   3359     CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2);
   3360   }
   3361 
   3362   return Status::OK();
   3363 }
   3364 
   3365 //////////////////////////////////////////////////////////////////////////
   3366 //           Helper functions related to workspace pass
   3367 //////////////////////////////////////////////////////////////////////////
   3368 
   3369 // TODO(nhasabni) We should move this to mkl_util.h.
   3370 void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
   3371     std::unique_ptr<Graph>* g, Node** out, Node* orig_node) {
   3372   // We use a tensor of shape {1} and value 0 to represent
   3373   // dummy float tensor. We need this as a dummy workspace tensor.
   3374   // Workspace tensor has type uint8.
   3375   const DataType dt = DataTypeToEnum<uint8>::v();
   3376   TensorProto proto;
   3377   proto.set_dtype(dt);
   3378   float zero[1] = {0};
   3379   proto.set_tensor_content(const_cast<const void*>(static_cast<void*>(&zero)),
   3380                            4);
   3381   TensorShape dummy_shape({1});
   3382   dummy_shape.AsProto(proto.mutable_tensor_shape());
   3383   TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
   3384                   .Attr("value", proto)
   3385                   .Attr("dtype", dt)
   3386                   .Device(orig_node->def().device())  // We place this node on
   3387                                                       // same the device as the
   3388                                                       // device of the original
   3389                                                       // node.
   3390                   .Finalize(&**g, out));
   3391 
   3392   // If number of inputs to the original node is > 0, then we add
   3393   // control dependency between 1st input (index 0) of the original node and
   3394   // the dummy Mkl node. This is needed because control-flow ops such as Enter,
   3395   // Merge, etc, require frame_name of the dummy Mkl node to be same as the
   3396   // rewritten node. Adding control edge between 1st input of the original node
   3397   // and the dummy Mkl node ensures that the dummy node is in the same frame
   3398   // as the original node. Choosing 1st input is not necessary - any input of
   3399   // the original node is fine because all the inputs of a node are always in
   3400   // the same frame.
   3401   if (orig_node->num_inputs() > 0) {
   3402     Node* orig_input0 = nullptr;
   3403     TF_CHECK_OK(
   3404         orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
   3405     // Allow duplicate while adding control edge as it would fail (return
   3406     // NULL) if we try to add duplicate edge.
   3407     CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
   3408   }
   3409 
   3410   (*out)->set_assigned_device_name(orig_node->assigned_device_name());
   3411 }
   3412 
   3413 void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
   3414     std::unique_ptr<Graph>* g, Node* orig_node, NodeBuilder* nb,
   3415     std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) {
   3416   bool workspace_edge_added = false;  // Default initializer
   3417   CHECK_NOTNULL(are_ws_tensors_added);
   3418   *are_ws_tensors_added = false;  // Default initializer
   3419 
   3420   DataType T;
   3421   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   3422   for (auto ws : wsinfo_) {
   3423     if (orig_node->type_string() == ws.fwd_op &&
   3424         mkl_op_registry::IsMklOp(
   3425             mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
   3426       // If this op is a fwd op, then we need to check if there is an
   3427       // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
   3428       // an edge, then we just add an attribute on this node for setting
   3429       // workspace_passed to true. We don't add actual workspace edge
   3430       // in this node. Actual workspace edge gets added in the backward
   3431       // op for this node.
   3432       for (const Edge* e : orig_node->out_edges()) {
   3433         if (e->src_output() == ws.fwd_slot &&
   3434             e->dst()->type_string() == ws.bwd_op &&
   3435             e->dst_input() == ws.bwd_slot) {
   3436           nb->Attr("workspace_enabled", true);
   3437           VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
   3438                   << orig_node->type_string();
   3439           workspace_edge_added = true;
   3440           // We found the edge that we were looking for, so break.
   3441           break;
   3442         }
   3443       }
   3444 
   3445       if (!workspace_edge_added) {
   3446         // If we are here, then we did not find backward operator for this
   3447         // node.
   3448         nb->Attr("workspace_enabled", false);
   3449       }
   3450     } else if (orig_node->type_string() == ws.bwd_op &&
   3451                mkl_op_registry::IsMklOp(
   3452                    mkl_op_registry::GetMklOpName(orig_node->type_string()),
   3453                    T)) {
   3454       // If this op is a bwd op, then we need to add workspace edge and
   3455       // it's Mkl tensor edge between its corresponding fwd op and this
   3456       // op. Corresponding fwd op is specified in 'fwd_op' field of
   3457       // workspace info. fwd_slot and bwd_slot in workspace info specify
   3458       // an edge between which slots connect forward and backward op.
   3459       // Once all these criteria match, we add a workspace edge between
   3460       // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is
   3461       // determined by interleaved/contiguous ordering. Function
   3462       // DataIndexToMetaDataIndex tells us the location of Mkl tensor
   3463       // from the location of the Tensorflow tensor.
   3464       for (const Edge* e : orig_node->in_edges()) {
   3465         if (e->src_output() == ws.fwd_slot &&
   3466             // We would have rewritten the forward op, so we need to use
   3467             // GetMklOpName call to get its Mkl name.
   3468             e->src()->type_string() ==
   3469                 mkl_op_registry::GetMklOpName(ws.fwd_op) &&
   3470             e->dst_input() == ws.bwd_slot) {
   3471           nb->Attr("workspace_enabled", true);
   3472           CHECK_NOTNULL(ws_tensors);
   3473           // Add workspace edge between fwd op and bwd op.
   3474           ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
   3475           // Add Mkl tensor edge for workspace edge between fwd op and bwd op.
   3476           ws_tensors->push_back(NodeBuilder::NodeOut(
   3477               e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
   3478                                                  e->src()->num_outputs())));
   3479           *are_ws_tensors_added = true;
   3480           // In terms of input ordering, we add these calls to add Input
   3481           // here because workspace edge (and its Mkl tensor) is the last
   3482           // edge in the fwdop and bwdop. So all inputs before workspace
   3483           // tensor have been added by SetUpInputs function.
   3484           VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
   3485                   << orig_node->type_string();
   3486           workspace_edge_added = true;
   3487           // We found the edge that we were looking for, so break.
   3488           break;
   3489         }
   3490       }
   3491 
   3492       // If we are here means we did not find fwd op that feeds to this
   3493       // bwd op. So in this case, we need to generate dummy tensors for
   3494       // workspace input and Mkl tensor for workspace, and set
   3495       // workspace_enabled to false.
   3496       if (!workspace_edge_added) {
   3497         nb->Attr("workspace_enabled", false);
   3498         Node* dmt_ws = nullptr;      // Dummy tensor for workspace
   3499         Node* dmt_mkl_ws = nullptr;  // Dummy Mkl tensor for workspace
   3500         GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node);
   3501         GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node);
   3502         CHECK_NOTNULL(dmt_ws);
   3503         CHECK_NOTNULL(dmt_mkl_ws);
   3504         CHECK_NOTNULL(ws_tensors);
   3505         // We add dummy tensor as workspace tensor.
   3506         ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0));
   3507         // We add dummy tensor as Mkl tensor for workspace tensor.
   3508         ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0));
   3509         *are_ws_tensors_added = true;
   3510         VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for "
   3511                 << orig_node->type_string();
   3512       }
   3513     } else {
   3514       // If this node does not match any workspace info, then we do not
   3515       // do anything special for workspace propagation for it.
   3516     }
   3517   }
   3518 }
   3519 
   3520 //////////////////////////////////////////////////////////////////////////
   3521 // Op-specific functions to copy attributes from old node to new node
   3522 //////////////////////////////////////////////////////////////////////////
   3523 
   3524 void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
   3525                                            NodeBuilder* nb) {
   3526   DataType T;
   3527   string data_format;
   3528   string padding;
   3529   std::vector<int32> strides;
   3530   bool use_cudnn_on_gpu;
   3531 
   3532   // Get all attributes from old node.
   3533   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   3534   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
   3535   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
   3536   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   3537   TF_CHECK_OK(
   3538       GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
   3539 
   3540   // Add attributes to new node.
   3541   nb->Attr("T", T);
   3542   nb->Attr("strides", strides);
   3543   nb->Attr("padding", padding);
   3544   nb->Attr("data_format", data_format);
   3545   nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
   3546 }
   3547 
   3548 void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node,
   3549                                          NodeBuilder* nb) {
   3550   DataType T;
   3551   int N;
   3552 
   3553   // Get all attributes from old node.
   3554   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   3555   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
   3556 
   3557   // Add attributes to new node.
   3558   nb->Attr("T", T);
   3559   nb->Attr("N", N);
   3560 }
   3561 
   3562 void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node,
   3563                                                 NodeBuilder* nb) {
   3564   DataType T;
   3565   string data_format;
   3566   std::vector<int32> strides;
   3567 
   3568   // Get all attributes from old node.
   3569   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   3570   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
   3571   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   3572 
   3573   // Add attributes to new node.
   3574   nb->Attr("T", T);
   3575   nb->Attr("strides", strides);
   3576   nb->Attr("data_format", data_format);
   3577 }
   3578 
   3579 void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node,
   3580                                         NodeBuilder* nb) {
   3581   DataType T;
   3582   int depth_radius;
   3583   float bias;
   3584   float alpha;
   3585   float beta;
   3586 
   3587   // Get all attributes from old node.
   3588   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   3589   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "depth_radius", &depth_radius));
   3590   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "bias", &bias));
   3591   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha));
   3592   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "beta", &beta));
   3593 
   3594   // Add attributes to new node.
   3595   nb->Attr("T", T);
   3596   nb->Attr("depth_radius", depth_radius);
   3597   nb->Attr("bias", bias);
   3598   nb->Attr("alpha", alpha);
   3599   nb->Attr("beta", beta);
   3600 }
   3601 
   3602 void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
   3603                                             NodeBuilder* nb) {
   3604   DataType T;
   3605   string data_format;
   3606   string padding;
   3607   std::vector<int32> ksize, strides;
   3608 
   3609   // Get all attributes from old node.
   3610   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   3611   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
   3612   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
   3613   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
   3614   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   3615 
   3616   // Add attributes to new node.
   3617   nb->Attr("T", T);
   3618   nb->Attr("ksize", ksize);
   3619   nb->Attr("strides", strides);
   3620   nb->Attr("padding", padding);
   3621   nb->Attr("data_format", data_format);
   3622 }
   3623 
   3624 void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node,
   3625                                              NodeBuilder* nb) {
   3626   DataType T;
   3627 
   3628   // Get all attributes from old node.
   3629   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   3630 
   3631   // Add attributes to new node.
   3632   nb->Attr("T", T);
   3633 }
   3634 
   3635 void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
   3636                                             NodeBuilder* nb) {
   3637   DataType T;
   3638   DataType Tshape;
   3639 
   3640   // Get all attributes from old node.
   3641   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   3642   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
   3643   // Add attributes to new node.
   3644   nb->Attr("T", T);
   3645   nb->Attr("Tshape", Tshape);
   3646 }
   3647 
   3648 void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
   3649                                           NodeBuilder* nb) {
   3650   DataType T;
   3651   string data_format;
   3652   int num_split;
   3653 
   3654   // Get all attributes from old node.
   3655   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   3656   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_split", &num_split));
   3657   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   3658 
   3659   // Add attributes to new node.
   3660   nb->Attr("T", T);
   3661   nb->Attr("num_split", num_split);
   3662   nb->Attr("data_format", data_format);
   3663 }
   3664 
   3665 void MklLayoutRewritePass::CopyAttrsConcat(const Node* orig_node,
   3666                                            NodeBuilder* nb) {
   3667   DataType T;
   3668   int N;
   3669 
   3670   // Get all attributes from old node.
   3671   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   3672   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
   3673 
   3674   // Add attributes to new node.
   3675   nb->Attr("T", T);
   3676   nb->Attr("N", N);
   3677 }
   3678 
   3679 void MklLayoutRewritePass::CopyAttrsConcatV2(const Node* orig_node,
   3680                                              NodeBuilder* nb) {
   3681   DataType T;
   3682   int N;
   3683   DataType tidx;
   3684 
   3685   // Get all attributes from old node.
   3686   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   3687   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
   3688   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tidx", &tidx));
   3689 
   3690   // Add attributes to new node.
   3691   nb->Attr("T", T);
   3692   nb->Attr("N", N);
   3693   nb->Attr("Tidx", tidx);
   3694 }
   3695 
   3696 void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node,
   3697                                                    NodeBuilder* nb) {
   3698   DataType T;
   3699   float epsilon;
   3700   string data_format;
   3701   bool is_training;
   3702 
   3703   // Get all attributes from old node.
   3704   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   3705   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon));
   3706   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   3707   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "is_training", &is_training));
   3708 
   3709   // Add attributes to new node.
   3710   nb->Attr("T", T);
   3711   nb->Attr("epsilon", epsilon);
   3712   nb->Attr("data_format", data_format);
   3713   nb->Attr("is_training", is_training);
   3714 }
   3715 
   3716 //////////////////////////////////////////////////////////////////////////
   3717 //           Helper functions related to node merge pass
   3718 //////////////////////////////////////////////////////////////////////////
   3719 
   3720 Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
   3721   // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite
   3722   // once we support BiasAddGrad as Mkl layer.
   3723 
   3724   // Search for all matching mergeinfo.
   3725   // We allow more than one match for extensibility.
   3726   std::vector<const MergeInfo*> matching_mi;
   3727   for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
   3728     if (a->type_string() == mi->op1 || a->type_string() == mi->op2) {
   3729       matching_mi.push_back(&*mi);
   3730     }
   3731   }
   3732 
   3733   for (const MergeInfo* mi : matching_mi) {
   3734     // Get the operand with which 'a' can be merged.
   3735     Node* b = nullptr;
   3736     if ((b = mi->get_node_to_be_merged(a)) == nullptr) {
   3737       continue;
   3738     }
   3739 
   3740     // Get the control edges and input of node
   3741     const int N_in = a->num_inputs();
   3742     gtl::InlinedVector<Node*, 4> a_control_edges;
   3743     gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
   3744     FillInputs(a, &a_control_edges, &a_in);
   3745 
   3746     const int B_in = b->num_inputs();
   3747     gtl::InlinedVector<Node*, 4> b_control_edges;
   3748     gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in);
   3749     FillInputs(b, &b_control_edges, &b_in);
   3750 
   3751     // Shouldn't merge if a and b have different control edges.
   3752     if (a_control_edges != b_control_edges) {
   3753       continue;
   3754     } else {
   3755       // We found a match.
   3756       return b;
   3757     }
   3758   }
   3759 
   3760   return nullptr;
   3761 }
   3762 
   3763 Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
   3764                                                     Node* m, Node* n) {
   3765   CHECK_EQ(((m->type_string() == csinfo_.bias_add &&
   3766              n->type_string() == csinfo_.conv2d)) ||
   3767                ((n->type_string() == csinfo_.bias_add &&
   3768                  m->type_string() == csinfo_.conv2d)),
   3769            true);
   3770 
   3771   // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd,
   3772   // BiasAdd is successor node, and Conv2D predecessor node.
   3773   Node* pred = m->type_string() == csinfo_.bias_add ? n : m;
   3774   Node* succ = m->type_string() == csinfo_.bias_add ? m : n;
   3775 
   3776   // 1. Get all attributes from input nodes.
   3777   DataType T_pred, T_succ;
   3778   string padding;
   3779   std::vector<int32> strides;
   3780   string data_format_pred, data_format_succ;
   3781   bool use_cudnn_on_gnu;
   3782   TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
   3783   TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
   3784   TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
   3785   TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
   3786   TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
   3787   TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
   3788   TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu));
   3789   // We check to ensure that data formats of both succ and pred are same.
   3790   // We expect them to be same, so we can enforce this as assert.
   3791   // But assert can be too strict, so we enforce this as a check.
   3792   // If the check fails, then we do not merge two nodes.
   3793   // We also do same check for devices.
   3794   if (data_format_pred != data_format_succ || T_pred != T_succ ||
   3795       pred->assigned_device_name() != succ->assigned_device_name() ||
   3796       pred->def().device() != succ->def().device()) {
   3797     return Status(error::Code::INVALID_ARGUMENT,
   3798                   "data_format or T attribute or devices of Conv2D and "
   3799                   "BiasAdd do not match. Will skip node merge optimization");
   3800   }
   3801 
   3802   const int succ_num = succ->num_inputs();
   3803   gtl::InlinedVector<Node*, 4> succ_control_edges;
   3804   gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
   3805   FillInputs(succ, &succ_control_edges, &succ_in);
   3806 
   3807   const int pred_num = pred->num_inputs();
   3808   gtl::InlinedVector<Node*, 4> pred_control_edges;
   3809   gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
   3810   FillInputs(pred, &pred_control_edges, &pred_in);
   3811 
   3812   // We need to ensure that Conv2D only feeds to BiasAdd (some other operator is
   3813   // not expecting output of Conv2D). If this is not the case, then we cannot
   3814   // merge Conv2D with BiasAdd.
   3815   const int kFirstOutputSlot = 0;
   3816   for (const Edge* e : pred->out_edges()) {
   3817     if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
   3818       return Status(error::Code::INVALID_ARGUMENT,
   3819                     "Conv2D does not feed to BiasAdd, or "
   3820                     "it feeds BiasAdd but has multiple outputs. "
   3821                     "Will skip node merge optimization");
   3822     }
   3823   }
   3824 
   3825   // 2. Get inputs from both the nodes.
   3826   // Find the 2 inputs from the conv and the bias from the add Bias.
   3827   // Get operand 0, 1 of conv2D.
   3828   CHECK_EQ(pred->in_edges().size(), 2);  // Conv2D must have 2 inputs.
   3829   // Get operand 1 of add_bias
   3830   // BiasAdd must have 2 inputs: Conv, bias
   3831   CHECK_EQ(succ->in_edges().size(), 2);
   3832 
   3833   // We will use the node name of BiasAdd as the name of new node
   3834   // Build new node. We use same name as original node, but change the op
   3835   // name.
   3836   NodeBuilder nb(succ->name(), csinfo_.conv2d_with_bias);
   3837   nb.Input(pred_in[0].first, pred_in[0].second);  // In1 of Conv2D
   3838   // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
   3839   nb.Input(pred_in[1].first, pred_in[1].second);  // In2 of Conv2D
   3840   // In1 of BiasAdd is same as output of Conv2D.
   3841   nb.Input(succ_in[1].first, succ_in[1].second);  // In2 of BiasAdd
   3842 
   3843   // Copy attributes from Conv2D to Conv2DWithBias.
   3844   CopyAttrsConv2D(const_cast<const Node*>(pred), &nb);
   3845 
   3846   // Copy the device assigned to old node to new node.
   3847   nb.Device(succ->def().device());
   3848 
   3849   // Create node.
   3850   Node* new_node;
   3851   nb.Finalize(&**g, &new_node);
   3852   CHECK_NOTNULL(new_node);
   3853 
   3854   // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
   3855   // node are already copied in BuildNode. We handle control edges now.
   3856   for (const Edge* e : pred->in_edges()) {
   3857     if (e->IsControlEdge()) {
   3858       // Allow duplicate while adding control edge as it would fail (return
   3859       // NULL) if we try to add duplicate edge.
   3860       CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
   3861     }
   3862   }
   3863   for (const Edge* e : succ->in_edges()) {
   3864     if (e->IsControlEdge()) {
   3865       // Allow duplicate while adding control edge as it would fail (return
   3866       // NULL) if we try to add duplicate edge.
   3867       CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
   3868     }
   3869   }
   3870 
   3871   // Incoming edges are fixed, we will fix the outgoing edges now.
   3872   // First, we will fix outgoing control edges from 'pred' node.
   3873   for (const Edge* e : pred->out_edges()) {
   3874     if (e->IsControlEdge()) {
   3875       // Allow duplicate while adding control edge as it would fail (return
   3876       // NULL) if we try to add duplicate edge.
   3877       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
   3878     }
   3879   }
   3880 
   3881   // Second, we will fix outgoing control and data edges from 'succ' node.
   3882   for (const Edge* e : succ->out_edges()) {
   3883     if (e->IsControlEdge()) {
   3884       // Allow duplicate while adding control edge as it would fail (return
   3885       // NULL) if we try to add duplicate edge.
   3886       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
   3887     } else {
   3888       // BiasAdd has only 1 output (at slot 0) and merged node also has only 1
   3889       // output (at slot 0).
   3890       const int kConv2DWithBiasOutputSlot = 0;
   3891       CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, e->dst(),
   3892                                   e->dst_input()));
   3893     }
   3894   }
   3895 
   3896   // Copy device assigned to old node to new node.
   3897   // It's ok to use pred or succ as we have enforced a check that
   3898   // both have same device assigned.
   3899   new_node->set_assigned_device_name(pred->assigned_device_name());
   3900 
   3901   VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
   3902           << ", and node: " << succ->DebugString()
   3903           << ", into node:" << new_node->DebugString();
   3904 
   3905   (*g)->RemoveNode(succ);
   3906   (*g)->RemoveNode(pred);
   3907 
   3908   return Status::OK();
   3909 }
   3910 
   3911 Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
   3912     std::unique_ptr<Graph>* g, Node* m, Node* n) {
   3913   CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad &&
   3914              n->type_string() == csinfo_.conv2d_grad_filter)) ||
   3915                ((n->type_string() == csinfo_.bias_add_grad &&
   3916                  m->type_string() == csinfo_.conv2d_grad_filter)),
   3917            true);
   3918 
   3919   // If 'm' is BiasAddGrad, then 'n' is BackpropFilter.
   3920   Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n;
   3921   Node* fltr = m->type_string() == csinfo_.bias_add_grad ? n : m;
   3922 
   3923   // Sanity check for attributes from input nodes.
   3924   DataType T_b, T_f;
   3925   string data_format_b, data_format_f;
   3926   TF_CHECK_OK(GetNodeAttr(badd->def(), "T", &T_b));
   3927   TF_CHECK_OK(GetNodeAttr(fltr->def(), "T", &T_f));
   3928   TF_CHECK_OK(GetNodeAttr(badd->def(), "data_format", &data_format_b));
   3929   TF_CHECK_OK(GetNodeAttr(fltr->def(), "data_format", &data_format_f));
   3930   if (data_format_b != data_format_f || T_b != T_f ||
   3931       badd->assigned_device_name() != fltr->assigned_device_name() ||
   3932       badd->def().device() != fltr->def().device()) {
   3933     return Status(error::Code::INVALID_ARGUMENT,
   3934                   "data_format or T attribute or devices of "
   3935                   "Conv2DBackpropFilter and BiasAddGrad do not match. "
   3936                   "Will skip node merge optimization");
   3937   }
   3938 
   3939   // We will use the node name of Conv2DBackpropFilter as the name of new node.
   3940   // This is because BackpropFilterWithBias is going to emit bias output also.
   3941   NodeBuilder nb(fltr->name(), csinfo_.conv2d_grad_filter_with_bias);
   3942   // Since Conv2DBackpropFilterWithBias has same number of inputs as
   3943   // Conv2DBackpropFilter, we can just copy input edges directly. We dont need
   3944   // to copy any data input of BiasAddGrad because that input also goes to
   3945   // Conv2DBackpropFilter.
   3946   const int fltr_ins = fltr->num_inputs();
   3947   gtl::InlinedVector<Node*, 4> fltr_control_edges;
   3948   gtl::InlinedVector<std::pair<Node*, int>, 4> fltr_in_edges(fltr_ins);
   3949   FillInputs(fltr, &fltr_control_edges, &fltr_in_edges);
   3950   for (int idx = 0; idx < fltr_ins; idx++) {
   3951     nb.Input(fltr_in_edges[idx].first, fltr_in_edges[idx].second);
   3952   }
   3953 
   3954   // Copy attributes from Conv2DBackpropFilter.
   3955   CopyAttrsConv2D(const_cast<const Node*>(fltr), &nb);
   3956 
   3957   // Copy the device assigned to old node to new node.
   3958   nb.Device(fltr->def().device());
   3959 
   3960   // Create node.
   3961   Node* new_node;
   3962   nb.Finalize(&**g, &new_node);
   3963   CHECK_NOTNULL(new_node);
   3964 
   3965   // Incoming data edges from BiasAddGrad node and Conv2DBackpropFilter node to
   3966   // new 'new_node' node are already copied in BuildNode. We handle control
   3967   // edges now.
   3968   for (const Edge* e : badd->in_edges()) {
   3969     if (e->IsControlEdge()) {
   3970       // Allow duplicate while adding control edge as it would fail (return
   3971       // NULL) if we try to add duplicate edge.
   3972       CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
   3973     }
   3974   }
   3975   for (const Edge* e : fltr->in_edges()) {
   3976     if (e->IsControlEdge()) {
   3977       // Allow duplicate while adding control edge as it would fail (return
   3978       // NULL) if we try to add duplicate edge.
   3979       CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
   3980     }
   3981   }
   3982 
   3983   // Incoming edges are fixed, we will fix the outgoing edges now.
   3984   // First, we will fix outgoing control edges from 'badd' node.
   3985   // Conv2DBackpropFilter has 1 output -- filter_grad.
   3986   // Conv2DBackpropFilterWithBias has 2 outputs -- filter_grad and
   3987   // bias_grad. But filter_grad is at same slot number (0) in both the
   3988   // nodes. bias_grad is at slot number 1 in Conv2DBackpropFilterWithBias, while
   3989   // it is at slot number 0 in BiasAddGrad.
   3990   const int kMergedNodeFilterGradOutputIdx = 0;
   3991   const int kMergedNodeBiasGradOutputIdx = 1;
   3992 
   3993   for (const Edge* e : badd->out_edges()) {
   3994     if (e->IsControlEdge()) {
   3995       // Allow duplicate while adding control edge as it would fail (return
   3996       // NULL) if we try to add duplicate edge.
   3997       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
   3998     } else {
   3999       CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx,
   4000                                   e->dst(), e->dst_input()));
   4001     }
   4002   }
   4003 
   4004   // Second, we will fix outgoing control and data edges from 'fltr' node.
   4005   for (const Edge* e : fltr->out_edges()) {
   4006     if (e->IsControlEdge()) {
   4007       // We allow duplicate edge for this case since we already add control
   4008       // edge from new_node in line 3990. Line below could be adding same
   4009       // edge to same destination again. In such case, if we do not allow
   4010       // duplicate edge, then this call will fail.
   4011       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
   4012     } else {
   4013       CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx,
   4014                                   e->dst(), e->dst_input()));
   4015     }
   4016   }
   4017 
   4018   // Copy device assigned to old node to new node.
   4019   // It's ok to use badd or fltr as we have enforced a check that
   4020   // both have same device assigned.
   4021   new_node->set_assigned_device_name(badd->assigned_device_name());
   4022 
   4023   VLOG(1) << "MklLayoutRewritePass: Merged old node:" << badd->DebugString()
   4024           << ", and node: " << fltr->DebugString()
   4025           << ", into node:" << new_node->DebugString();
   4026 
   4027   (*g)->RemoveNode(badd);
   4028   (*g)->RemoveNode(fltr);
   4029 
   4030   return Status::OK();
   4031 }
   4032 
   4033 Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* m,
   4034                                        Node* n) {
   4035   CHECK_NOTNULL(m);
   4036   CHECK_NOTNULL(n);
   4037 
   4038   if (((m->type_string() == csinfo_.bias_add &&
   4039         n->type_string() == csinfo_.conv2d)) ||
   4040       ((n->type_string() == csinfo_.bias_add &&
   4041         m->type_string() == csinfo_.conv2d))) {
   4042     return this->MergeConv2DWithBiasAdd(g, m, n);
   4043   }
   4044 
   4045   if (((m->type_string() == csinfo_.bias_add_grad &&
   4046         n->type_string() == csinfo_.conv2d_grad_filter)) ||
   4047       ((n->type_string() == csinfo_.bias_add_grad &&
   4048         m->type_string() == csinfo_.conv2d_grad_filter))) {
   4049     return this->MergeConv2DBackpropFilterWithBiasAddGrad(g, m, n);
   4050   }
   4051 
   4052   return Status(error::Code::UNIMPLEMENTED,
   4053                 "Unimplemented case for node merge optimization.");
   4054 }
   4055 
   4056 //////////////////////////////////////////////////////////////////////////
   4057 //           Helper functions for node rewrite
   4058 //////////////////////////////////////////////////////////////////////////
   4059 
   4060 Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
   4061                                          Node* orig_node,
   4062                                          const RewriteInfo* ri) {
   4063   CHECK_NOTNULL(ri);
   4064   CHECK_NOTNULL(orig_node);
   4065 
   4066   VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString();
   4067 
   4068   // Get all inputs.
   4069   int num_inputs = orig_node->in_edges().size();
   4070 
   4071   // Drop count for control edges from inputs
   4072   for (const Edge* e : orig_node->in_edges()) {
   4073     if (e->IsControlEdge()) {
   4074       num_inputs--;
   4075     }
   4076   }
   4077 
   4078   gtl::InlinedVector<Node*, 4> control_edges;
   4079   gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs);
   4080   FillInputs(orig_node, &control_edges, &inputs);
   4081 
   4082   // Build new node. We use same name as original node, but change the op name.
   4083   NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
   4084   // Copy user-specified device assigned to original node to new node.
   4085   nb.Device(orig_node->def().device());
   4086   // Set up new inputs to the rewritten node.
   4087   Status s = SetUpInputs(g, inputs, &nb, orig_node);
   4088   if (s != Status::OK()) {
   4089     return s;
   4090   }
   4091 
   4092   ri->copy_attrs(const_cast<const Node*>(orig_node), &nb);
   4093   // Set the Mkl layer label for this op.
   4094   nb.Attr("_kernel", mkl_op_registry::kMklOpLabel);
   4095 
   4096   // Finalize graph and get new node.
   4097   Node* new_node = nullptr;
   4098   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
   4099   CHECK_NOTNULL(new_node);
   4100 
   4101   // Incoming data edges from 'orig_node' node to new 'new_node' node are
   4102   // already copied in BuildNode. We need to handle control edges now.
   4103   for (const Edge* e : orig_node->in_edges()) {
   4104     if (e->IsControlEdge()) {
   4105       // Allow duplicate while adding control edge as it would fail (return
   4106       // NULL) if we try to add duplicate edge.
   4107       CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
   4108     }
   4109   }
   4110 
   4111   // Copy outgoing edges from 'orig_node' node to new
   4112   // 'new_node' node, since the output also follows same ordering among
   4113   // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
   4114   // tensors appropriately. Specifically, nth output of the original node
   4115   // will become 2*nth output of the Mkl node for the interleaved ordering
   4116   // of the tensors. For the contiguous ordering of the tensors, it will be n.
   4117   // GetTensorDataIndex provides this mapping function.
   4118   for (const Edge* e : orig_node->out_edges()) {
   4119     if (e->IsControlEdge()) {
   4120       // Allow duplicate while adding control edge as it would fail (return
   4121       // NULL) if we try to add duplicate edge.
   4122       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
   4123     } else {
   4124       CHECK_NOTNULL((*g)->AddEdge(
   4125           new_node,
   4126           GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
   4127           e->dst(), e->dst_input()));
   4128     }
   4129   }
   4130 
   4131   // Copy the runtime device assigned from original code to new node.
   4132   new_node->set_assigned_device_name(orig_node->assigned_device_name());
   4133 
   4134   // Delete original node and mark new node as rewritten.
   4135   (*g)->RemoveNode(orig_node);
   4136 
   4137   VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString();
   4138   return Status::OK();
   4139 }
   4140 
   4141 const MklLayoutRewritePass::RewriteInfo*
   4142 MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
   4143   CHECK_NOTNULL(n);
   4144 
   4145   // First check if node along with its type is supported by MKL layer.
   4146   // We do not want to rewrite an op into Mkl op if types are not supported.
   4147   // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to
   4148   // MklRelu if type is INT32.
   4149   DataType T;
   4150   if (!GetNodeAttr(n->def(), "T", &T).ok()) {
   4151     return nullptr;
   4152   }
   4153 
   4154   // We make an exception for __MklDummyConv2DWithBias and
   4155   // __MklConv2DBackpropFilterWithBias since their names do not match Mkl node
   4156   // names.
   4157   if (n->type_string() != csinfo_.conv2d_with_bias &&
   4158       n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
   4159       !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
   4160                                 T)) {
   4161     return nullptr;
   4162   }
   4163 
   4164   // For elementwise node, we reuse the Eigen implementation and pass the MKL
   4165   // metadata tensor through so we can avoid conversions. However, if all
   4166   // incoming edges are in TF format, we don't need all this overhead, so
   4167   // replace the elementwise node only if at least one of its parents is a MKL
   4168   // node.
   4169   //
   4170   // Identity nodes can also skip replacement if they are not being served by
   4171   // any MKL nodes.
   4172   //
   4173   // TODO(vrane): Add implementation for element-wise ops that doesn't reuse
   4174   // eigen code to reduce cross-library dependency.
   4175   VLOG(1) << "ELEMENTWISE: checking op: " << n->type_string();
   4176   if (mkl_op_registry::IsMklElementWiseOp(
   4177           mkl_op_registry::GetMklOpName(n->type_string()), T) ||
   4178       n->type_string().find("Identity") != string::npos) {
   4179     VLOG(1) << "ELEMENTWISE: op is elementwise: " << n->type_string();
   4180     bool incoming_mkl_edge = false;
   4181     int num_parent = 0;
   4182     for (auto parent : n->in_edges()) {
   4183       if (mkl_op_registry::IsMklOp(parent->src()->type_string(), T)) {
   4184         VLOG(1) << "ELEMENTWISE: parent " << num_parent++
   4185                 << " is MKL op: " << parent->src()->type_string();
   4186         incoming_mkl_edge = true;
   4187         break;
   4188       } else {
   4189         VLOG(1) << "ELEMENTWISE: parent " << num_parent++
   4190                 << " is NON-MKL op: " << parent->src()->type_string();
   4191       }
   4192     }
   4193     if (incoming_mkl_edge == false) {
   4194       VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which "
   4195                  "has no MKL "
   4196                  "parents.";
   4197       return nullptr;
   4198     } else {
   4199       VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string()
   4200               << " which has MKL parents";
   4201     }
   4202   }
   4203 
   4204   // We now check if rewrite rule applies for this op. If rewrite rule passes
   4205   // for this op, then we rewrite it to Mkl op.
   4206   // Find matching RewriteInfo and then check that rewrite rule applies.
   4207   for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
   4208     if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
   4209       return &*ri;
   4210     }
   4211   }
   4212 
   4213   // Else return not found.
   4214   return nullptr;
   4215 }
   4216 
   4217 ///////////////////////////////////////////////////////////////////////////////
   4218 //              Run function for the pass
   4219 ///////////////////////////////////////////////////////////////////////////////
   4220 
   4221 bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
   4222   bool result = false;
   4223   CHECK_NOTNULL(g);
   4224 
   4225   DumpGraph("Before running MklLayoutRewritePass", &**g);
   4226 
   4227   std::vector<Node*> order;
   4228   GetReversePostOrder(**g, &order);  // This will give us topological sort.
   4229   for (Node* n : order) {
   4230     // If node is not an op or it cannot run on CPU device, then skip.
   4231     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
   4232       continue;
   4233     }
   4234 
   4235     Node* m = nullptr;
   4236     if ((m = CheckForNodeMerge(n)) != nullptr && CanOpRunOnCPUDevice(m)) {
   4237       // Check if the node 'n' can be merged with any other node. If it can
   4238       // be 'm' contains the node with which it can be merged.
   4239       string n1_name = n->name();
   4240       string n2_name = m->name();
   4241 
   4242       VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and "
   4243               << n2_name << " for merging";
   4244 
   4245       if (MergeNode(g, n, m) == Status::OK()) {
   4246         VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and "
   4247                 << n2_name;
   4248         result = true;
   4249       }
   4250     }
   4251   }
   4252 
   4253   DumpGraph("After running MklLayoutRewritePass(NodeMerge)", &**g);
   4254 
   4255   order.clear();
   4256   GetReversePostOrder(**g, &order);  // This will give us topological sort.
   4257   for (Node* n : order) {
   4258     // If node is not an op or it cannot run on CPU device, then skip.
   4259     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
   4260       continue;
   4261     }
   4262 
   4263     const RewriteInfo* ri = nullptr;
   4264     // We will first search if node is to be rewritten.
   4265     if ((ri = CheckForNodeRewrite(n)) != nullptr) {
   4266       string node_name = n->name();
   4267       string op_name = n->type_string();
   4268 
   4269       VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name
   4270               << " with op " << op_name << " for rewrite using"
   4271               << " layout optimization.";
   4272 
   4273       if (RewriteNode(g, n, ri) == Status::OK()) {
   4274         VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name
   4275                 << " with op " << op_name << " for Mkl layout optimization.";
   4276         result = true;
   4277       }
   4278     }
   4279   }
   4280 
   4281   DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g);
   4282 
   4283   return result;
   4284 }
   4285 
   4286 bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
   4287   return MklLayoutRewritePass().RunPass(g);
   4288 }
   4289 
   4290 Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
   4291   if (options.graph == nullptr && options.partition_graphs == nullptr) {
   4292     return Status::OK();
   4293   }
   4294 
   4295   auto process_graph = [&](std::unique_ptr<Graph>* g) {
   4296     // Get the ownership of a graph
   4297     std::unique_ptr<Graph>* ng = std::move(g);
   4298     RunPass(ng);
   4299     // Return the ownership of a graph back
   4300     g->reset(ng->release());
   4301   };
   4302 
   4303   if (kMklLayoutRewritePassGroup !=
   4304       OptimizationPassRegistry::POST_PARTITIONING) {
   4305     // For any pre-partitioning phase, a graph is stored in options.graph.
   4306     process_graph(options.graph);
   4307   } else {
   4308     // For post partitioning phase, graphs are stored in
   4309     // options.partition_graphs.
   4310     for (auto& pg : *options.partition_graphs) {
   4311       process_graph(&pg.second);
   4312     }
   4313   }
   4314 
   4315   return Status::OK();
   4316 }
   4317 #endif  // INTEL_MKL_ML
   4318 }  // namespace tensorflow
   4319 
   4320 #endif
   4321