Home | History | Annotate | Download | only in graph
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // TODO(intel): Improve error handling in this file; instead of CHECK failing
     17 // all over the place, we should log an error and execute the original graph.
     18 #ifdef INTEL_MKL
     19 
     20 #include <algorithm>
     21 #include <functional>
     22 #include <memory>
     23 #include <queue>
     24 #include <set>
     25 #include <stack>
     26 #include <tuple>
     27 #include <unordered_set>
     28 #include <utility>
     29 #include <vector>
     30 
     31 #include "tensorflow/core/common_runtime/function.h"
     32 #include "tensorflow/core/common_runtime/optimization_registry.h"
     33 #include "tensorflow/core/framework/node_def_util.h"
     34 #include "tensorflow/core/framework/tensor.pb.h"
     35 #include "tensorflow/core/graph/algorithm.h"
     36 #include "tensorflow/core/graph/graph.h"
     37 #include "tensorflow/core/graph/node_builder.h"
     38 #include "tensorflow/core/lib/core/status.h"
     39 #include "tensorflow/core/lib/gtl/array_slice.h"
     40 #include "tensorflow/core/lib/gtl/map_util.h"
     41 #include "tensorflow/core/lib/hash/hash.h"
     42 #include "tensorflow/core/platform/logging.h"
     43 #include "tensorflow/core/util/tensor_format.h"
     44 #include "tensorflow/core/util/util.h"
     45 
     46 #include "tensorflow/core/graph/mkl_graph_util.h"
     47 #include "tensorflow/core/graph/mkl_layout_pass.h"
     48 
     49 namespace tensorflow {
     50 
     51 // This pass implements rewriting of graph to support following scenarios:
     52 // (A) Merging nodes in the graph
     53 // (B) Rewriting a node in the graph to a new node
     54 //     Rewrite happens under following scenario:
     55 //     - Propagating Mkl layout as an additional output tensor
     56 //        (we will loosely call a tensor that carries Mkl layout as Mkl tensor
     57 //         henceforth.) from every Mkl supported NN layer.
     58 //
     59 // Example of A : Merging nodes in the graph
     60 // -----------------------------------------
     61 // Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as:
     62 //
     63 //           O = Conv2D(A, B)
     64 //           P = BiasAdd(O, C)
     65 //
     66 // We merge them into Conv2DWithBias as:
     67 //           P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
     68 //
     69 // The meaning of A_m, B_m and C_m is explained in B.1.
     70 //
     71 // Merge rules:
     72 //  - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_
     73 //    goes to BiasAdd.
     74 //  - Also, the intersection of attributes of both the nodes must have same
     75 //    values.
     76 //  - Both the nodes must have been assigned to same device (if any).
     77 //
     78 // Example of B.1 : Rewriting nodes to Mkl nodes
     79 // ---------------------------------------------
     80 // Consider a Relu node. Current definition of Relu node looks like:
     81 //
     82 //           O = Relu(A)
     83 //
     84 // Relu has 1 input (A), and 1 output (O).
     85 //
     86 // This rewrite pass will generate a new graph node for Relu (new node is
     87 // called MklRelu) as:
     88 //
     89 //          O, O_m = MklRelu(A, A_m)
     90 //
     91 // MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is
     92 // same as input A of Relu; output O is same as output O of Relu. O_m is the
     93 // additional output tensor that will be set by MklRelu, and it represents
     94 // Mkl tensor corresponding to O -- in other words, O_m is some kind of
     95 // metadata for O. A_m is additional input of Relu, and it represents metadata
     96 // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives
     97 // this metadata from previous node in the graph.
     98 //
     99 // When a previous node in the graph is an Mkl node, A_m will represent a valid
    100 // Mkl tensor. But when a previous node is not an Mkl node, A_m will represent
    101 // a dummy Mkl tensor.
    102 //
    103 // Rewriting rules:
    104 //  - Selection of a node for rewriting happens by registering the op type of
    105 //    the node with the rewriting pass. If the op type is not registered, then
    106 //    all nodes of this op type will not be rewritten.
    107 //  - Number of inputs after rewriting:
    108 //      Since for every input Tensorflow tensor, the rewritten node gets Mkl
    109 //      tensor(s), rewritten node gets 2*N inputs, where N is the number of
    110 //      inputs for the original node.
    111 //  - Number of outputs after rewriting:
    112 //      Since for every output Tensorflow tensor, the rewritten node generates
    113 //      Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the
    114 //      number of outputs of the original node.
    115 //  - Ordering of Tensorflow tensors and Mkl tensors:
    116 //      Since every rewritten node generates twice the number of inputs and
    117 //      outputs, one could imagine various orderings among Tensorflow tensors
    118 //      and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as
    119 //      inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m
    120 //      in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m
    121 //      order. Among N inputs one can get N! permutations.
    122 //
    123 //      So the question is: which order do we follow? We support 2 types of
    124 //      orderings: (1) interleaved, and (2) contiguous. Interleaved ordering
    125 //      follows an intuitive order where an Mkl tensor follows the
    126 //      corresponding Tensorflow tensor immediately. In the context of the
    127 //      above example, it will be: A, A_m, B, B_m. Note that the ordering rule
    128 //      applies to both the inputs and outputs. Contiguous ordering means
    129 //      all the Tensorflow tensors are contiguous followed by all the Mkl
    130 //      tensors. We use contiguous ordering as default.
    131 //
    132 // Graph rewrite algorithm:
    133 //      Algorithm: Graph Rewrite
    134 //      Input: Graph G, Names of the nodes to rewrite and their new names
    135 //      Output: Modified Graph G' if the nodes are modified, G otherwise.
    136 //      Start:
    137 //        N = Topological_Sort(G) // N is a set of nodes in toposort order.
    138 //        foreach node n in N
    139 //        do
    140 //          if (Is_MKL_Op(n))  // Can this node accept an Mkl layout as input.
    141 //          then
    142 //            E = set of <incoming edge and its src_output slot> of n
    143 //            E' = {}   // a new set of edges for rewritten node
    144 //            foreach <e,s> in E
    145 //            do
    146 //              E' U {<e,s>}  // First copy edge which generates Tensorflow
    147 //                            // tensor as it is
    148 //              m = Source node of edge e
    149 //              if Is_Rewritten(m)  // Did we rewrite this node in this pass?
    150 //              then
    151 //                E' U {<m,s+1>}    // If yes, then m will generate an Mkl
    152 //                                  // tensor as an additional output.
    153 //              else
    154 //                d = Generate_Dummy_Mkl_Tensor()  // If not, generate a dummy
    155 //                                                 // Mkl tensor.
    156 //                E' U {<d,0>}  // The dummy Mkl tensor has only 1 output slot.
    157 //              fi
    158 //            done
    159 //            n' = Build_New_Node(G,new_name,E')
    160 //            Mark_Rewritten(n')  // Mark the new node as being rewritten.
    161 //          fi
    162 //        done
    163 //
    164 //      Explanation:
    165 //        For graph rewrite, we visit nodes of the input graph in the
    166 //        topological sort order. With this ordering, we visit nodes in the
    167 //        top-to-bottom fashion. We need this order because while visiting a
    168 //        node we want that all of its input nodes are visited and rewritten if
    169 //        applicable. This is because if we need to rewrite a given node
    170 //        then all of its input nodes need to be fixed (in other words they
    171 //        cannot be deleted later.)
    172 //
    173 //        While visiting a node, we first check if the op type of the node is
    174 //        an Mkl op. If it is, then we rewrite that node after constructing
    175 //        new inputs to the node. If the op type of the node is not Mkl op,
    176 //        then we do not rewrite that node.
    177 //
    178 // Handling workspace propagation for certain ops:
    179 //
    180 //        Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require
    181 //        passing of a workspace from their respective forward ops. Workspace
    182 //        tensors provide memory for storing results of intermediate operations
    183 //        which are helpful in backward propagation. TensorFlow does not have
    184 //        a notion of a workspace and as a result does not allow producing
    185 //        additional outputs from these forward ops. For these ops, we need
    186 //        to add 2 extra edges between forward ops and their corresponding
    187 //        backward ops - the first extra edge carries a workspace tensor and
    188 //        the second one carries an Mkl tensor for the workspace tensor.
    189 //
    190 //        Example:
    191 //
    192 //        Typical graph for MaxPool and its gradient looks like:
    193 //
    194 //        A = MaxPool(T)
    195 //        B = MaxPoolGrad(X, A, Y)
    196 //
    197 //        We will transform this graph to propagate the workspace as:
    198 //        (with the contiguous ordering)
    199 //
    200 //        A, W, A_m, W_m = MklMaxPool(T, T_m)
    201 //        B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m)
    202 //
    203 //        Here W is the workspace tensor. Transformed tensor names with the
    204 //        suffix _m are Mkl tensors, and this transformation has been done
    205 //        using the algorithm discussed earlier. The transformation for
    206 //        workspace propagation only adds extra outputs (W, W_m) for a forward
    207 //        op and connects them to the corresponding backward ops.
    208 //
    209 //        Terms:
    210 //
    211 //        Forward op name = name of the op in the forward pass
    212 //          where a workspace tensor originates (MaxPool in this example)
    213 //        Backward op name = name of the op in the backward pass that receives
    214 //          a workspace tensor from the forward op (MaxPoolGrad in the example)
    215 //        Slot = Position of the output or input slot that will be
    216 //               used by the workspace tensor (1 for MklMaxPool as W is the 2nd
    217 //               output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad)
    218 //
    219 //        Question:
    220 //
    221 //        How do we associate a backward op to a forward op? There can be more
    222 //        than one op with the exact same name.
    223 //
    224 //        In this example, we associate MaxPoolGrad with MaxPool. But there
    225 //        could be more than one MaxPool ops. To solve this problem, we look
    226 //        for _direct_ edge between a forward op and a backward op (tensor A is
    227 //        flowing along this edge in the example).
    228 //
    229 //        How do we transform forward and backward ops when there is no direct
    230 //        edge between them? In such a case, we generate dummy tensors for
    231 //        workspace tensors. For the example, transformation of MaxPool will
    232 //        be exactly same as it would be when there is a direct edge between
    233 //        the forward and the backward op --- it is just that MaxPool won't
    234 //        generate any workspace tensor. For MaxPoolGrad, the transformation
    235 //        will also be same, but instead of connecting W and W_m with the
    236 //        outputs of MaxPool, we will produce dummy tensors for them, and we
    237 //        will set workspace_enabled attribute to false.
    238 //
    239 class MklLayoutRewritePass : public GraphOptimizationPass {
    240  public:
    241   MklLayoutRewritePass() {
    242     // NOTE: names are alphabetically sorted.
    243     csinfo_.addn = "AddN";
    244     csinfo_.avg_pool = "AvgPool";
    245     csinfo_.avg_pool_grad = "AvgPoolGrad";
    246     csinfo_.avg_pool3d = "AvgPool3D";
    247     csinfo_.avg_pool3d_grad = "AvgPool3DGrad";
    248     csinfo_.bias_add = "BiasAdd";
    249     csinfo_.bias_add_grad = "BiasAddGrad";
    250     csinfo_.concat = "Concat";
    251     csinfo_.concatv2 = "ConcatV2";
    252     csinfo_.conv2d = "Conv2D";
    253     csinfo_.conv2d_with_bias = "__MklDummyConv2DWithBias";
    254     csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
    255     csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
    256     csinfo_.conv2d_grad_filter_with_bias =
    257         "__MklDummyConv2DBackpropFilterWithBias";
    258     csinfo_.conv3d = "Conv3D";
    259     csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2";
    260     csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2";
    261     csinfo_.depthwise_conv2d = "DepthwiseConv2dNative";
    262     csinfo_.depthwise_conv2d_grad_input = "DepthwiseConv2dNativeBackpropInput";
    263     csinfo_.depthwise_conv2d_grad_filter =
    264         "DepthwiseConv2dNativeBackpropFilter";
    265     csinfo_.fused_batch_norm = "FusedBatchNorm";
    266     csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
    267     csinfo_.fused_conv2d = "_FusedConv2D";
    268     csinfo_.identity = "Identity";
    269     csinfo_.leakyrelu = "LeakyRelu";
    270     csinfo_.leakyrelu_grad = "LeakyReluGrad";
    271     csinfo_.lrn = "LRN";
    272     csinfo_.lrn_grad = "LRNGrad";
    273     csinfo_.matmul = "MatMul";
    274     csinfo_.max_pool = "MaxPool";
    275     csinfo_.max_pool_grad = "MaxPoolGrad";
    276     csinfo_.max_pool3d = "MaxPool3D";
    277     csinfo_.max_pool3d_grad = "MaxPool3DGrad";
    278     csinfo_.mkl_conv2d = "_MklConv2D";
    279     csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
    280     csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
    281     csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
    282     csinfo_.mkl_conv2d_grad_filter_with_bias =
    283         "_MklConv2DBackpropFilterWithBias";
    284     csinfo_.mkl_depthwise_conv2d_grad_input =
    285         "_MklDepthwiseConv2dNativeBackpropInput";
    286     csinfo_.mkl_depthwise_conv2d_grad_filter =
    287         "_MklDepthwiseConv2dNativeBackpropFilter";
    288     csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
    289     csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
    290     csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
    291     csinfo_.pad = "Pad";
    292     csinfo_.pad_with_conv2d = "__MklDummyPadWithConv2D";
    293     csinfo_.pad_with_fused_conv2d = "__MklDummyPadWithFusedConv2D";
    294     csinfo_.quantized_avg_pool = "QuantizedAvgPool";
    295     csinfo_.quantized_concatv2 = "QuantizedConcatV2";
    296     csinfo_.quantized_conv2d = "QuantizedConv2D";
    297     csinfo_.quantized_conv2d_with_requantize = "QuantizedConv2DAndRequantize";
    298     csinfo_.quantized_conv2d_with_bias = "QuantizedConv2DWithBias";
    299     csinfo_.quantized_conv2d_with_bias_and_requantize =
    300         "QuantizedConv2DWithBiasAndRequantize";
    301     csinfo_.quantized_conv2d_and_relu = "QuantizedConv2DAndRelu";
    302     csinfo_.quantized_conv2d_and_relu_and_requantize =
    303         "QuantizedConv2DAndReluAndRequantize";
    304     csinfo_.quantized_conv2d_with_bias_and_relu =
    305         "QuantizedConv2DWithBiasAndRelu";
    306     csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize =
    307         "QuantizedConv2DWithBiasAndReluAndRequantize";
    308     csinfo_.quantized_max_pool = "QuantizedMaxPool";
    309     csinfo_.quantized_conv2d_with_bias_sum_and_relu =
    310         "QuantizedConv2DWithBiasSumAndRelu";
    311     csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize =
    312         "QuantizedConv2DWithBiasSumAndReluAndRequantize";
    313     csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize =
    314         "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize";
    315     csinfo_.relu = "Relu";
    316     csinfo_.relu_grad = "ReluGrad";
    317     csinfo_.relu6 = "Relu6";
    318     csinfo_.relu6_grad = "Relu6Grad";
    319     csinfo_.requantize = "Requantize";
    320     csinfo_.tanh = "Tanh";
    321     csinfo_.tanh_grad = "TanhGrad";
    322     csinfo_.reshape = "Reshape";
    323     csinfo_.slice = "Slice";
    324     csinfo_.softmax = "Softmax";
    325     csinfo_.split = "Split";
    326     csinfo_.transpose = "Transpose";
    327     // Element-wise ops. Ensure you also add any new ops to IsOpElementWise
    328     // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
    329     // MklInputConversion op is added before it.
    330     csinfo_.add = "Add";
    331     csinfo_.maximum = "Maximum";
    332     csinfo_.mul = "Mul";
    333     csinfo_.squared_difference = "SquaredDifference";
    334     csinfo_.sub = "Sub";
    335     // End - element-wise ops. See note above.
    336 
    337     // NOTE: names are alphabetically sorted.
    338     rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
    339                       CopyAttrsAddN, AddNRewrite});
    340     rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
    341                       CopyAttrsDataType, AlwaysRewrite});
    342     rinfo_.push_back({csinfo_.avg_pool,
    343                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
    344                       CopyAttrsPooling, AlwaysRewrite});
    345     rinfo_.push_back({csinfo_.avg_pool_grad,
    346                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
    347                       CopyAttrsPooling, AlwaysRewrite});
    348     rinfo_.push_back({csinfo_.avg_pool3d,
    349                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
    350                       CopyAttrsPooling, AlwaysRewrite});
    351     rinfo_.push_back({csinfo_.avg_pool3d_grad,
    352                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
    353                       CopyAttrsPooling, AlwaysRewrite});
    354     rinfo_.push_back({csinfo_.concat,
    355                       mkl_op_registry::GetMklOpName(csinfo_.concat),
    356                       CopyAttrsConcat, AlwaysRewrite});
    357     rinfo_.push_back({csinfo_.concatv2,
    358                       mkl_op_registry::GetMklOpName(csinfo_.concatv2),
    359                       CopyAttrsConcatV2, AlwaysRewrite});
    360     rinfo_.push_back({csinfo_.conv2d,
    361                       mkl_op_registry::GetMklOpName(csinfo_.conv2d),
    362                       CopyAttrsConvCheckConstFilter, AlwaysRewrite});
    363     rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
    364                       CopyAttrsConvCheckConstFilter, AlwaysRewrite});
    365     rinfo_.push_back({csinfo_.conv2d_grad_filter,
    366                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
    367                       CopyAttrsConv, AlwaysRewrite});
    368     rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
    369                       csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv,
    370                       AlwaysRewrite});
    371     rinfo_.push_back({csinfo_.conv2d_grad_input,
    372                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
    373                       CopyAttrsConv, AlwaysRewrite});
    374     rinfo_.push_back({csinfo_.conv3d,
    375                       mkl_op_registry::GetMklOpName(csinfo_.conv3d),
    376                       CopyAttrsConvCheckConstFilter, AlwaysRewrite});
    377     rinfo_.push_back({csinfo_.conv3d_grad_filter,
    378                       mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter),
    379                       CopyAttrsConv, AlwaysRewrite});
    380     rinfo_.push_back({csinfo_.conv3d_grad_input,
    381                       mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input),
    382                       CopyAttrsConv, AlwaysRewrite});
    383     rinfo_.push_back({csinfo_.depthwise_conv2d,
    384                       mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d),
    385                       CopyAttrsConv2DDepthwiseCheckConstFilter, AlwaysRewrite});
    386     rinfo_.push_back(
    387         {csinfo_.depthwise_conv2d_grad_input,
    388          mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_input),
    389          CopyAttrsConv2DDepthwise, AlwaysRewrite});
    390     rinfo_.push_back(
    391         {csinfo_.depthwise_conv2d_grad_filter,
    392          mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_filter),
    393          CopyAttrsConv2DDepthwise, AlwaysRewrite});
    394     rinfo_.push_back({csinfo_.fused_batch_norm,
    395                       mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
    396                       CopyAttrsFusedBatchNorm, AlwaysRewrite});
    397     rinfo_.push_back(
    398         {csinfo_.fused_batch_norm_grad,
    399          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
    400          CopyAttrsFusedBatchNorm, AlwaysRewrite});
    401     rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
    402                       CopyAttrsFusedConv2D, FusedConv2DRewrite});
    403     rinfo_.push_back({csinfo_.identity,
    404                       mkl_op_registry::GetMklOpName(csinfo_.identity),
    405                       CopyAttrsDataType, AlwaysRewrite});
    406     rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
    407                       CopyAttrsLRN, LrnRewrite});
    408     rinfo_.push_back({csinfo_.lrn_grad,
    409                       mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
    410                       CopyAttrsLRN, LrnGradRewrite});
    411     rinfo_.push_back({csinfo_.leakyrelu,
    412                       mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
    413                       CopyAttrsLeakyRelu, LeakyReluRewrite});
    414     rinfo_.push_back({csinfo_.leakyrelu_grad,
    415                       mkl_op_registry::GetMklOpName(csinfo_.leakyrelu_grad),
    416                       CopyAttrsLeakyRelu, LeakyReluRewrite});
    417     rinfo_.push_back({csinfo_.max_pool,
    418                       mkl_op_registry::GetMklOpName(csinfo_.max_pool),
    419                       CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
    420     rinfo_.push_back({csinfo_.max_pool_grad,
    421                       mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
    422                       CopyAttrsPooling, MaxpoolGradRewrite});
    423     rinfo_.push_back({csinfo_.max_pool3d,
    424                       mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
    425                       CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
    426     rinfo_.push_back({csinfo_.max_pool3d_grad,
    427                       mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
    428                       CopyAttrsPooling, AlwaysRewrite});
    429     rinfo_.push_back({csinfo_.maximum,
    430                       mkl_op_registry::GetMklOpName(csinfo_.maximum),
    431                       CopyAttrsDataType, AlwaysRewrite});
    432     rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
    433                       CopyAttrsDataType, AlwaysRewrite});
    434     rinfo_.push_back({csinfo_.pad_with_conv2d, csinfo_.mkl_pad_with_conv2d,
    435                       CopyAttrsPadWithConv2D, AlwaysRewrite});
    436     rinfo_.push_back({csinfo_.pad_with_fused_conv2d,
    437                       csinfo_.mkl_pad_with_fused_conv2d,
    438                       CopyAttrsPadWithFusedConv2D, AlwaysRewrite});
    439     rinfo_.push_back({csinfo_.quantized_avg_pool,
    440                       mkl_op_registry::GetMklOpName(csinfo_.quantized_avg_pool),
    441                       CopyAttrsQuantizedPooling, AlwaysRewrite});
    442     rinfo_.push_back({csinfo_.quantized_concatv2,
    443                       mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2),
    444                       CopyAttrsConcatV2, AlwaysRewrite});
    445     rinfo_.push_back({csinfo_.quantized_conv2d,
    446                       mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d),
    447                       CopyAttrsQuantizedConv2D, AlwaysRewrite});
    448     rinfo_.push_back({csinfo_.quantized_conv2d_with_requantize,
    449                       mkl_op_registry::GetMklOpName(
    450                           csinfo_.quantized_conv2d_with_requantize),
    451                       CopyAttrsQuantizedConv2D, AlwaysRewrite});
    452     rinfo_.push_back(
    453         {csinfo_.quantized_conv2d_with_bias,
    454          mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_with_bias),
    455          CopyAttrsQuantizedConv2D, AlwaysRewrite});
    456     rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_requantize,
    457                       mkl_op_registry::GetMklOpName(
    458                           csinfo_.quantized_conv2d_with_bias_and_requantize),
    459                       CopyAttrsQuantizedConv2D, AlwaysRewrite});
    460     rinfo_.push_back(
    461         {csinfo_.quantized_conv2d_and_relu,
    462          mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_and_relu),
    463          CopyAttrsQuantizedConv2D, AlwaysRewrite});
    464     rinfo_.push_back({csinfo_.quantized_conv2d_and_relu_and_requantize,
    465                       mkl_op_registry::GetMklOpName(
    466                           csinfo_.quantized_conv2d_and_relu_and_requantize),
    467                       CopyAttrsQuantizedConv2D, AlwaysRewrite});
    468     rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_relu,
    469                       mkl_op_registry::GetMklOpName(
    470                           csinfo_.quantized_conv2d_with_bias_and_relu),
    471                       CopyAttrsQuantizedConv2D, AlwaysRewrite});
    472     rinfo_.push_back(
    473         {csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize,
    474          mkl_op_registry::GetMklOpName(
    475              csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize),
    476          CopyAttrsQuantizedConv2D, AlwaysRewrite});
    477     rinfo_.push_back({csinfo_.quantized_max_pool,
    478                       mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool),
    479                       CopyAttrsQuantizedPooling, AlwaysRewrite});
    480     rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu,
    481                       mkl_op_registry::GetMklOpName(
    482                           csinfo_.quantized_conv2d_with_bias_sum_and_relu),
    483                       CopyAttrsQuantizedConv2D, AlwaysRewrite});
    484     rinfo_.push_back(
    485         {csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize,
    486          mkl_op_registry::GetMklOpName(
    487              csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize),
    488          CopyAttrsQuantizedConv2D, AlwaysRewrite});
    489     rinfo_.push_back(
    490         {csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize,
    491          mkl_op_registry::GetMklOpName(
    492              csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize),
    493          CopyAttrsQuantizedConv2D, AlwaysRewrite});
    494     rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
    495                       CopyAttrsDataType, AlwaysRewrite});
    496     rinfo_.push_back({csinfo_.relu_grad,
    497                       mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
    498                       CopyAttrsDataType, AlwaysRewrite});
    499     rinfo_.push_back({csinfo_.relu6,
    500                       mkl_op_registry::GetMklOpName(csinfo_.relu6),
    501                       CopyAttrsDataType, AlwaysRewrite});
    502     rinfo_.push_back({csinfo_.relu6_grad,
    503                       mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
    504                       CopyAttrsDataType, AlwaysRewrite});
    505     rinfo_.push_back({csinfo_.requantize,
    506                       mkl_op_registry::GetMklOpName(csinfo_.requantize),
    507                       CopyAttrsRequantize, AlwaysRewrite});
    508     /*
    509     rinfo_.push_back({csinfo_.tanh,
    510                       mkl_op_registry::GetMklOpName(csinfo_.tanh),
    511                       CopyAttrsDataType, AlwaysRewrite});
    512     rinfo_.push_back({csinfo_.tanh_grad,
    513                       mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
    514                       CopyAttrsDataType, AlwaysRewrite});
    515     */
    516     rinfo_.push_back({csinfo_.reshape,
    517                       mkl_op_registry::GetMklOpName(csinfo_.reshape),
    518                       CopyAttrsReshape, AlwaysRewrite});
    519     rinfo_.push_back({csinfo_.slice,
    520                       mkl_op_registry::GetMklOpName(csinfo_.slice),
    521                       CopyAttrsSlice, AlwaysRewrite});
    522     rinfo_.push_back({csinfo_.softmax,
    523                       mkl_op_registry::GetMklOpName(csinfo_.softmax),
    524                       CopyAttrsDataType, AlwaysRewrite});
    525 
    526     rinfo_.push_back({csinfo_.squared_difference,
    527                       mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
    528                       CopyAttrsDataType, AlwaysRewrite});
    529     rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
    530                       CopyAttrsDataType, AlwaysRewrite});
    531 
    532     // Add info about which ops to add workspace edge to and the slots.
    533     wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
    534     wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
    535     wsinfo_.push_back(
    536         {csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3});
    537 
    538     // Add a rule for merging nodes
    539     minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
    540                       csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd});
    541 
    542     minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
    543                       csinfo_.conv2d_grad_filter_with_bias,
    544                       GetConv2DBackpropFilterOrBiasAddGrad});
    545     // Merge Pad and Conv2d, only if the pad op is "Pad"
    546     // Doesn't merge if pad op is "PadV2" or "MirrorPad"
    547     minfo_.push_back(
    548         {csinfo_.pad, csinfo_.conv2d, csinfo_.pad_with_conv2d, GetPadOrConv2D});
    549 
    550     minfo_.push_back({csinfo_.pad, csinfo_.fused_conv2d,
    551                       csinfo_.pad_with_fused_conv2d, GetPadOrFusedConv2D});
    552 
    553     // The fusion patterns in "finfo_" that show up first will get applied
    554     // first, for example, graph "A->B->C-D" and finfo_ is {A->B->C to ABC,
    555     // A->B->C->D to ABCD}, since the first gets applied first, the final
    556     // graph will be ABC->D.
    557 
    558     //
    559     // Add rules to fuse sequences such as "Transpose (NCHW -> NHWC) + Conv2D
    560     // (NHWC) + Transpose (NHWC->
    561     // NCHW)" into "Conv2D (NCHW)". Such patterns occur frequently in Keras.
    562     // Note: we use the term "merge" to combine (exactly) 2 nodes into one,
    563     // while "fusion" is for 3+ nodes situation.
    564     //
    565 
    566     // Transpose + Conv2d + Transpose:
    567     std::vector<int> transpose_to_nhwc = {NCHW::dim::N, NCHW::dim::H,
    568                                           NCHW::dim::W, NCHW::dim::C};
    569     std::vector<int> transpose_to_nchw = {NHWC::dim::N, NHWC::dim::C,
    570                                           NHWC::dim::H, NHWC::dim::W};
    571     auto CheckForTransposeToNHWC =
    572         std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_nhwc);
    573     auto CheckForConv2dOp =
    574         std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv2d);
    575     auto CheckForTransposeToNCHW =
    576         std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_nchw);
    577     auto FuseConv2D =
    578         std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
    579                   std::placeholders::_2, std::placeholders::_3, "NCHW");
    580     finfo_.push_back(
    581         {"transpose-elimination for Conv2D",
    582          {CheckForTransposeToNHWC, CheckForConv2dOp, CheckForTransposeToNCHW},
    583          // CheckForMklOp
    584          FuseConv2D,
    585          CopyAttrsConv});
    586   }
    587 
    588   // Standard interface to run pass
    589   Status Run(const GraphOptimizationPassOptions& options);
    590 
    591   // Helper function which does most of heavy lifting for rewriting
    592   // Mkl nodes to propagate Mkl tensor as additional output
    593   //
    594   // Extracts common functionality between Run public interface and
    595   // test interface.
    596   //
    597   // @return true, if and only if graph is mutated; false otherwise.
    598   bool RunPass(std::unique_ptr<Graph>* g);
    599 
    600   /// Structure to specify the name of an original node, its new name after
    601   /// rewrite, the number of inputs to the original node, the function to
    602   /// be used to copy attributes for the op, and the rule (if any) which
    603   /// must hold for rewriting the node
    604   typedef struct {
    605     string name;      // Original name of op of the node in the graph
    606     string new_name;  // New name of the op of the node in the graph
    607     // A function handler to copy attributes from an old node to a new node.
    608     std::function<void(const Node*, NodeBuilder*, bool)> copy_attrs;
    609     // A rule under which to rewrite this node
    610     std::function<bool(const Node*)> rewrite_rule;
    611   } RewriteInfo;
    612 
    613   /// Structure to specify a forward op, a backward op, and the slot numbers
    614   /// in the forward and backward ops where we will add a workspace edge.
    615   typedef struct {
    616     string fwd_op;    // Name of a forward op in the graph
    617     string bwd_op;    // Name of a backward op in the graph
    618     int fwd_slot;     // Output slot in the forward op node where actual
    619                       // output tensor resides
    620     int bwd_slot;     // Input slot in the backward op node where actual
    621                       // input tensor resides
    622     int ws_fwd_slot;  // Output slot in the forward op node where workspace
    623                       // edge is added
    624     int ws_bwd_slot;  // Input slot in the backward op node where workspace
    625                       // edge is added
    626   } WorkSpaceInfo;
    627 
    628   /// Structure to specify information used in node merge of 2 operators
    629   typedef struct {
    630     string op1;       // Node string for one operator.
    631     string op2;       // Node string for second operator.
    632     string new_node;  // Name of the node after merge
    633     // Function that enables user of the node merger to specify how to find
    634     // second operator given the first operator.
    635     std::function<Node*(const Node*)> get_node_to_be_merged;
    636   } MergeInfo;
    637 
    638   // Structure to specify information used in node fusion of 3+ operators
    639   typedef struct {
    640     std::string pattern_name;  // Name to describe this pattern, such as
    641                                // "Transpose_Mklop_Transpose".
    642     std::vector<std::function<bool(const Node*)> >
    643         node_checkers;  // Extra restriction checker for these ops
    644     std::function<Status(
    645         std::unique_ptr<Graph>*, std::vector<Node*>&,
    646         std::function<void(const Node*, NodeBuilder* nb, bool)>)>
    647         fuse_func;
    648     std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs;
    649   } FusionInfo;
    650 
    651   //
    652   // Dimension indices for 2D tensor.
    653   //
    654   struct NCHW {
    655     enum dim { N = 0, C = 1, H = 2, W = 3 };
    656   };
    657 
    658   struct NHWC {
    659     enum dim { N = 0, H = 1, W = 2, C = 3 };
    660   };
    661 
    662   //
    663   // dimension indices for 3D tensor.
    664   //
    665   struct NCDHW {
    666     enum dim { N = 0, C = 1, D = 2, H = 3, W = 4 };
    667   };
    668 
    669   struct NDHWC {
    670     enum dim { N = 0, D = 1, H = 2, W = 3, C = 4 };
    671   };
    672 
    673   /// Structure to store all constant strings
    674   /// NOTE: names are alphabetically sorted.
    675   typedef struct {
    676     string addn;
    677     string add;
    678     string avg_pool;
    679     string avg_pool_grad;
    680     string avg_pool3d;
    681     string avg_pool3d_grad;
    682     string bias_add;
    683     string bias_add_grad;
    684     string concat;
    685     string concatv2;
    686     string conv2d;
    687     string conv2d_with_bias;
    688     string conv2d_grad_input;
    689     string conv2d_grad_filter;
    690     string conv2d_grad_filter_with_bias;
    691     string conv3d;
    692     string conv3d_grad_input;
    693     string conv3d_grad_filter;
    694     string depthwise_conv2d;
    695     string depthwise_conv2d_grad_input;
    696     string depthwise_conv2d_grad_filter;
    697     string fused_batch_norm;
    698     string fused_batch_norm_grad;
    699     string fused_conv2d;
    700     string identity;
    701     string leakyrelu;
    702     string leakyrelu_grad;
    703     string lrn;
    704     string lrn_grad;
    705     string matmul;
    706     string max_pool;
    707     string max_pool_grad;
    708     string max_pool3d;
    709     string max_pool3d_grad;
    710     string maximum;
    711     string mkl_conv2d;
    712     string mkl_conv2d_grad_input;
    713     string mkl_conv2d_grad_filter;
    714     string mkl_conv2d_grad_filter_with_bias;
    715     string mkl_conv2d_with_bias;
    716     string mkl_depthwise_conv2d_grad_input;
    717     string mkl_depthwise_conv2d_grad_filter;
    718     string mkl_fused_conv2d;
    719     string mkl_pad_with_conv2d;
    720     string mkl_pad_with_fused_conv2d;
    721     string mul;
    722     string pad;
    723     string pad_with_conv2d;
    724     string pad_with_fused_conv2d;
    725     string quantized_avg_pool;
    726     string quantized_conv2d;
    727     string quantized_conv2d_with_requantize;
    728     string quantized_conv2d_with_bias;
    729     string quantized_conv2d_with_bias_and_requantize;
    730     string quantized_conv2d_and_relu;
    731     string quantized_conv2d_and_relu_and_requantize;
    732     string quantized_conv2d_with_bias_and_relu;
    733     string quantized_conv2d_with_bias_and_relu_and_requantize;
    734     string quantized_concatv2;
    735     string quantized_max_pool;
    736     string quantized_conv2d_with_bias_sum_and_relu;
    737     string quantized_conv2d_with_bias_sum_and_relu_and_requantize;
    738     string quant_conv2d_with_bias_signed_sum_and_relu_and_requantize;
    739     string relu;
    740     string relu_grad;
    741     string relu6;
    742     string relu6_grad;
    743     string requantize;
    744     string tanh;
    745     string tanh_grad;
    746     string transpose;
    747     string reshape;
    748     string slice;
    749     string softmax;
    750     string split;
    751     string squared_difference;
    752     string sub;
    753   } ConstStringsInfo;
    754 
    755  private:
    756   /// Maintain info about nodes to rewrite
    757   std::vector<RewriteInfo> rinfo_;
    758 
    759   /// Maintain info about nodes to add workspace edge
    760   std::vector<WorkSpaceInfo> wsinfo_;
    761 
    762   /// Maintain info about nodes to be merged
    763   std::vector<MergeInfo> minfo_;
    764 
    765   /// Maintain info about nodes to be fused
    766   std::vector<FusionInfo> finfo_;
    767 
    768   /// Maintain structure of constant strings
    769   static ConstStringsInfo csinfo_;
    770 
    771  private:
    772   // Is OpDef::ArgDef a list type? It could be N * T or list(type).
    773   // Refer to opdef.proto for details of list type.
    774   inline bool ArgIsList(const OpDef::ArgDef& arg) const {
    775     return !arg.type_list_attr().empty() || !arg.number_attr().empty();
    776   }
    777 
    778   // Get length of a list in 'n' if 'arg' is of list type. Refer to
    779   // description of ArgIsList for definition of list type.
    780   inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) {
    781     CHECK_EQ(ArgIsList(arg), true);
    782     int N = 0;
    783     const string attr_name = !arg.type_list_attr().empty()
    784                                  ? arg.type_list_attr()
    785                                  : arg.number_attr();
    786     if (!arg.type_list_attr().empty()) {
    787       std::vector<DataType> value;
    788       TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value));
    789       N = value.size();
    790     } else {
    791       TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N));
    792     }
    793     return N;
    794   }
    795 
    796   // Can op represented by node 'n' run on DEVICE_CPU?
    797   // Op can run on CPU with MKL if the runtime assigned device or the
    798   // user requested device contains device CPU, or both are empty.
    799   bool CanOpRunOnCPUDevice(const Node* n) {
    800     bool result = true;
    801     string reason;
    802 
    803     // Substring that should be checked for in device name for CPU device.
    804     const char* const kCPUDeviceSubStr = "CPU";
    805 
    806     // If Op has been specifically assigned to a non-CPU device, then No.
    807     if (!n->assigned_device_name().empty() &&
    808         !str_util::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) {
    809       result = false;
    810       reason = "Op has been assigned a runtime device that is not CPU.";
    811     }
    812 
    813     // If user has specifically assigned this op to a non-CPU device, then No.
    814     if (!n->def().device().empty() &&
    815         !str_util::StrContains(n->def().device(), kCPUDeviceSubStr)) {
    816       result = false;
    817       reason = "User has assigned a device that is not CPU.";
    818     }
    819 
    820     if (result == false) {
    821       VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node "
    822               << n->type_string() << ", reason: " << reason;
    823     }
    824 
    825     // Otherwise Yes.
    826     return result;
    827   }
    828 
    829   // Return a node that can be merged with input node 'n'
    830   //
    831   // @return pointer to the node if we can find such a
    832   // node. Otherwise, it returns nullptr.
    833   Node* CheckForNodeMerge(const Node* n) const;
    834 
    835   // Merge node 'm' with node 'n'.
    836   // Currently, we merge (1) Conv2D with BiasAdd, and (2) BiasAddGrad with
    837   // Conv2DBackpropFilter.
    838   //
    839   // Input nodes m and n may be deleted if the call to
    840   // this function is successful. Attempt to use the pointers
    841   // after the call to function may result in undefined behaviors.
    842   //
    843   // @input g - input graph, m - graph node, n - graph node to be merged with m
    844   // @return Status::OK(), if merging is successful and supported.
    845   //         Returns appropriate Status error code otherwise.
    846   //         Graph is updated in case nodes are merged. Otherwise, it is
    847   //         not updated.
    848   Status MergeNode(std::unique_ptr<Graph>* g, Node* m, Node* n);
    849 
    850   // Helper function to merge different nodes
    851   Status MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, Node* m, Node* n);
    852   Status MergePadWithConv2D(std::unique_ptr<Graph>* g, Node* m, Node* n);
    853   Status MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph>* g,
    854                                                   Node* m, Node* n);
    855 
    856   // Find BiasAdd or Conv2D node that can be merged with input node 'm'.
    857   // If input 'm' is BiasAdd, then check if there exists Conv2D node that can be
    858   // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd
    859   // node that can be merged with 'm'.
    860   static Node* GetConv2DOrBiasAdd(const Node* m) {
    861     CHECK_NOTNULL(m);
    862     Node* n = nullptr;
    863 
    864     DataType T_m;
    865     TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
    866 
    867     // Don't try to merge if datatype is not DT_FLOAT
    868     if (T_m != DT_FLOAT) return n;
    869 
    870     if (m->type_string() == csinfo_.bias_add) {
    871       // If a is BiasAdd, then Conv2D is 0th input of BiasAdd.
    872       TF_CHECK_OK(m->input_node(0, &n));
    873     } else {
    874       CHECK_EQ(m->type_string(), csinfo_.conv2d);
    875       // Go over all output edges and search for BiasAdd Node.
    876       // 0th input of BiasAdd is Conv2D.
    877       for (const Edge* e : m->out_edges()) {
    878         if (!e->IsControlEdge() &&
    879             e->dst()->type_string() == csinfo_.bias_add &&
    880             e->dst_input() == 0) {
    881           n = e->dst();
    882           break;
    883         }
    884       }
    885     }
    886 
    887     if (n == nullptr) {
    888       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
    889               << "Conv2D and BiasAdd node for merging. Input node: "
    890               << m->DebugString();
    891     }
    892 
    893     return n;
    894   }
    895 
    896   // Find Pad or Conv2D node that can be merged with input node 'm'.
    897   // If input 'm' is Pad, then check if there exists Conv2D node that can be
    898   // merged with 'm'. If input 'm' is Conv2D, then check if there exists Pad
    899   // node that can be merged with 'm'.
    900   static Node* GetPadOrConv2D(const Node* m) {
    901     DCHECK(m);
    902     Node* n = nullptr;
    903 
    904     DataType T_m;
    905     TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
    906 
    907     // Don't try to merge if datatype is not DT_FLOAT
    908     if (T_m != DT_FLOAT) return n;
    909 
    910     const Node* conv_node;
    911     if (m->type_string() == csinfo_.pad) {
    912       // If m is Pad, then Conv2D is the output of Pad.
    913       for (const Edge* e : m->out_edges()) {
    914         if (!e->IsControlEdge() && e->dst()->type_string() == csinfo_.conv2d) {
    915           n = e->dst();
    916           conv_node = n;
    917           break;
    918         }
    919       }
    920     } else {
    921       DCHECK_EQ(m->type_string(), csinfo_.conv2d);
    922       // If m is conv2D, Go over all input edges
    923       // and search for Pad  Node.
    924       for (const Edge* e : m->in_edges()) {
    925         if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
    926           n = e->src();
    927           conv_node = m;
    928           break;
    929         }
    930       }
    931     }
    932     // Check if only VALID type of padding is used
    933     // or not.
    934     if (n != nullptr) {
    935       string padding;
    936       TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
    937       if (padding != "VALID")
    938         // Then do not merge.
    939         // Only VALID type of padding in conv op can be
    940         // merged with Pad op.
    941         n = nullptr;
    942     } else {
    943       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
    944               << "Pad and Conv2D node for merging. Input node: "
    945               << m->DebugString();
    946     }
    947 
    948     return n;
    949   }
    950 
    951   // Find Pad or _FusedConv2D node that can be merged with input node 'm'.
    952   // If input 'm' is Pad, then check if there exists _FusedConv2D node that can
    953   // be merged with 'm'. If input 'm' is _FusedConv2D, then check if there
    954   // exists Pad node that can be merged with 'm'.
    955   static Node* GetPadOrFusedConv2D(const Node* m) {
    956     DCHECK(m);
    957     Node* n = nullptr;
    958 
    959     const Node* conv_node;
    960     if (m->type_string() == csinfo_.pad) {
    961       // If m is Pad, then _FusedConv2D is the output of Pad.
    962       for (const Edge* e : m->out_edges()) {
    963         if (!e->IsControlEdge() &&
    964             e->dst()->type_string() == csinfo_.fused_conv2d) {
    965           n = e->dst();
    966           conv_node = n;
    967           break;
    968         }
    969       }
    970     } else {
    971       DCHECK_EQ(m->type_string(), csinfo_.fused_conv2d);
    972       // If m is _FusedConv2D, Go over all input edges
    973       // and search for Pad node.
    974       for (const Edge* e : m->in_edges()) {
    975         if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
    976           n = e->src();
    977           conv_node = m;
    978           break;
    979         }
    980       }
    981     }
    982     // Check if only VALID type of padding is used or not.
    983     if (n != nullptr) {
    984       string padding;
    985       TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
    986       if (padding != "VALID") {
    987         // Then do not merge.
    988         n = nullptr;
    989         VLOG(1) << "MklLayoutRewritePass: Could match Pad and _FusedConv2D "
    990                 << "nodes but cannot merge them. Only conv ops with padding "
    991                 << "type VALID can be merged with Pad op Input node: "
    992                 << m->DebugString();
    993       }
    994     } else {
    995       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
    996               << "Pad and _FusedConv2D node for merging. Input node: "
    997               << m->DebugString();
    998     }
    999 
   1000     return n;
   1001   }
   1002 
   1003   // Find Conv2DBackpropFilter or BiasAddGrad node that can be merged with input
   1004   // node 'm'. If input 'm' is Conv2DBackpropFilter, then check if there exists
   1005   // BiasAddGrad node that can be merged with 'm'. If input 'm' is BiasAddGrad,
   1006   // then check if there exists Conv2DBackpropFilter node that can be merged
   1007   // with 'm'.
   1008   //
   1009   // Graph that will allow us to connect Conv2DBackpropFilter with BiasAddGrad
   1010   // would look like:
   1011   //
   1012   // _ = Conv2DBackpropFilter(F, _, G)
   1013   // _ = BiasAddGrad(G)
   1014   //
   1015   // So 1st input of BiasAddGrad connects with 3rd input of
   1016   // Conv2DBackpropFilter and vice versa.
   1017   static Node* GetConv2DBackpropFilterOrBiasAddGrad(const Node* m) {
   1018     CHECK_NOTNULL(m);
   1019     Node* n = nullptr;
   1020 
   1021     DataType T_m;
   1022     TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
   1023 
   1024     // Don't try to merge if datatype is not DT_FLOAT
   1025     if (T_m != DT_FLOAT) return n;
   1026 
   1027     if (m->type_string() == csinfo_.bias_add_grad) {
   1028       // Get 1st input 'g' of BiasAddGrad.
   1029       Node* g = nullptr;
   1030       TF_CHECK_OK(m->input_node(0, &g));
   1031       // Now traverse all outgoing edges from g that have destination node as
   1032       // Conv2DBackpropFilter.
   1033       for (const Edge* e : g->out_edges()) {
   1034         if (!e->IsControlEdge() &&
   1035             e->dst()->type_string() == csinfo_.conv2d_grad_filter &&
   1036             e->dst_input() == 2 /* 3rd input of BackpropFilter */) {
   1037           n = e->dst();
   1038           break;
   1039         }
   1040       }
   1041     } else {
   1042       CHECK_EQ(m->type_string(), csinfo_.conv2d_grad_filter);
   1043       // Get 3rd input 'g' of Conv2DBackpropFilter.
   1044       Node* g = nullptr;
   1045       TF_CHECK_OK(m->input_node(2, &g));
   1046       // Now traverse all outgoing edges from g that have destination node as
   1047       // BiasAddGrad.
   1048       for (const Edge* e : g->out_edges()) {
   1049         if (!e->IsControlEdge() &&
   1050             e->dst()->type_string() == csinfo_.bias_add_grad &&
   1051             e->dst_input() == 0 /* 1st input of BiasAddGrad */) {
   1052           n = e->dst();
   1053           break;
   1054         }
   1055       }
   1056     }
   1057 
   1058     if (n == nullptr) {
   1059       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
   1060               << "Conv2DBackpropFilter and BiasAddGrad node for merging. "
   1061               << "Input node: " << m->DebugString();
   1062     }
   1063     return n;
   1064   }
   1065 
   1066   // Return a node that can be fused with input node 'n'
   1067   //
   1068   // @return tuple. If we can find such nodes, the first
   1069   // element of the tuple is a true. Otherwise, it's false.
   1070   std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
   1071   CheckForNodeFusion(Node* n) const;
   1072 
   1073   // Fuse nodes in the vector "nodes"
   1074   Status FuseNode(std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
   1075                   const MklLayoutRewritePass::FusionInfo fi);
   1076 
   1077   // Fuse tranpose(to "NHWC") + mklop("NHWC") + transpose(to "NCHW") into
   1078   // mklop("NCHW").
   1079   // Here "mklop" can be any MKL-DNN supported op, such as Conv2D.
   1080   static Status FuseTransposeMklOpTranspose(
   1081       std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
   1082       std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
   1083       string data_format);
   1084 
   1085   static bool CheckForTranspose(const Node* node, std::vector<int> perm) {
   1086     // Check if node's type is "Transpose"
   1087     if (node->type_string() != "Transpose") return false;
   1088 
   1089     // If "Transpose" has multiple output data edges, also don't fuse it.
   1090     if (node->num_outputs() > 1 || node->out_edges().size() > 1) return false;
   1091 
   1092     // Check if has out control edge. If true, this is a training graph.
   1093     // Currently we focus on inference and do no fusion in training.
   1094     // Note: this constraint will eventually be removed, if we enabled this
   1095     // fusion for training
   1096     // in the future.
   1097     for (const Edge* e : node->out_edges()) {
   1098       if (e->IsControlEdge()) {
   1099         return false;
   1100       }
   1101     }
   1102 
   1103     // If "Transpose" has input control edges, don't fuse on it.
   1104     for (const Edge* e : node->in_edges()) {
   1105       if (e->IsControlEdge()) {
   1106         return false;
   1107       }
   1108     }
   1109 
   1110     // We compared the tensor containing the permutation order ("perm_node")
   1111     // with our desired order ("perm"). If they're exactly match, this check
   1112     // succeed and returns true.
   1113     for (const Edge* e : node->in_edges()) {
   1114       if (!e->IsControlEdge()) {
   1115         const Node* perm_node = e->src();
   1116 
   1117         const int kPermTensorIndex = 1;
   1118         if (perm_node->type_string() == "Const" &&
   1119             e->dst_input() == kPermTensorIndex) {
   1120           // we find the "perm" node, now try to retrieve its value.
   1121           const TensorProto* proto = nullptr;
   1122           TF_CHECK_OK(GetNodeAttr(perm_node->def(), "value", &proto));
   1123 
   1124           DataType type;
   1125           GetNodeAttr(perm_node->def(), "dtype", &type);
   1126 
   1127           // Here we directly access to the "tensor_content", rather than
   1128           // "int_val". This is because we find "int_val" is
   1129           // not set properly under some circumstances.
   1130           if (type == DT_INT32) {
   1131             const int type_size = 4;
   1132             const int* tensor_content =
   1133                 reinterpret_cast<const int*>(proto->tensor_content().c_str());
   1134             const int tensor_content_size =
   1135                 proto->tensor_content().size() / type_size;
   1136 
   1137             std::vector<int> perm_value(tensor_content,
   1138                                         tensor_content + tensor_content_size);
   1139 
   1140             return perm_value == perm;
   1141           } else if (type == DT_INT64) {
   1142             const int type_size = 8;
   1143             const long* tensor_content =
   1144                 reinterpret_cast<const long*>(proto->tensor_content().c_str());
   1145             const int tensor_content_size =
   1146                 proto->tensor_content().size() / type_size;
   1147 
   1148             std::vector<long> perm_value(tensor_content,
   1149                                          tensor_content + tensor_content_size);
   1150             std::vector<long> long_perm(perm.cbegin(), perm.cend());
   1151 
   1152             return perm_value == long_perm;
   1153           }
   1154           return false;
   1155         }
   1156       }
   1157     }
   1158     return false;
   1159   }
   1160 
   1161   static bool CheckForMklOp(const Node* node, string name = "") {
   1162     if (node == nullptr) return false;
   1163 
   1164     if (!name.empty() && node->type_string() != name) {
   1165       return false;
   1166     }
   1167 
   1168     // if mklop has multiple outputs, don't fuse it.
   1169     if (node->num_outputs() > 1) return false;
   1170 
   1171     if (node->out_edges().size() > 1) return false;
   1172 
   1173     DataType T;
   1174     TF_CHECK_OK(GetNodeAttr(node->def(), "T", &T));
   1175     return mkl_op_registry::IsMklOp(
   1176         mkl_op_registry::GetMklOpName(node->type_string()), T);
   1177   }
   1178 
   1179   // Check if the node 'n' has any applicable rewrite rule
   1180   // We check for 2 scenarios for rewrite.
   1181   //
   1182   // @return RewriteInfo* for the applicable rewrite rule
   1183   const RewriteInfo* CheckForNodeRewrite(const Node* n) const;
   1184   const RewriteInfo* CheckForQuantizedNodeRewrite(const Node* n) const;
   1185 
   1186   // Default rewrite rule to be used in scenario 1 for rewrite.
   1187   // @return - true (since we want to always rewrite)
   1188   static bool AlwaysRewrite(const Node* n) { return true; }
   1189 
   1190   // Check if we are performing pooling on depth or batch. If it is, then we
   1191   // do not rewrite MaxPool node to Mkl version.
   1192   // @return - true (if it is not a depth/batch wise pooling case);
   1193   //           false otherwise.
   1194   static bool NonDepthBatchWisePoolRewrite(const Node* n) {
   1195     CHECK_NOTNULL(n);
   1196 
   1197     string data_format_str;
   1198     TensorFormat data_format;
   1199     std::vector<int32> ksize, strides;
   1200     CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
   1201     CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
   1202     CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true);
   1203     CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
   1204 
   1205     // Condition that specifies non-batch-wise and non-depth-wise pooling.
   1206     if (GetTensorDim(ksize, data_format, 'N') == 1 &&
   1207         GetTensorDim(strides, data_format, 'N') == 1 &&
   1208         GetTensorDim(ksize, data_format, 'C') == 1 &&
   1209         GetTensorDim(strides, data_format, 'C') == 1) {
   1210       return true;
   1211     }
   1212 
   1213     return false;
   1214   }
   1215 
   1216   // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized
   1217   // path. The unoptimized path is slow. Thus we dont rewrite the node
   1218   // and use default Eigen. But for depth_radius=2, MKL DNN optimized
   1219   // path is taken, i.e., eigen node is rewritten by MKl DNN node.
   1220   static bool LrnRewrite(const Node* n) {
   1221     CHECK_NOTNULL(n);
   1222 
   1223     int depth_radius;
   1224     CHECK_EQ(GetNodeAttr(n->def(), "depth_radius", &depth_radius).ok(), true);
   1225 
   1226     // if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN
   1227     // and use eigen node instead
   1228     if (depth_radius == 2) {
   1229       return true;
   1230     }
   1231     VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which"
   1232             << "case is not optimized by Intel MKL, thus using Eigen op"
   1233             << "for LRN ";
   1234 
   1235     return false;
   1236   }
   1237 
   1238   static bool LrnGradRewrite(const Node* n) {
   1239     CHECK_NOTNULL(n);
   1240     bool do_rewrite = false;
   1241 
   1242     for (const Edge* e : n->in_edges()) {
   1243       // Rewrite only if there is corresponding LRN, i.e workspace is available
   1244       if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 &&
   1245           e->src()->type_string() ==
   1246               mkl_op_registry::GetMklOpName(csinfo_.lrn) &&
   1247           e->src_output() == 0) {
   1248         do_rewrite = true;
   1249         break;
   1250       }
   1251     }
   1252     return do_rewrite;
   1253   }
   1254 
   1255   // MKL-DNN's LeakyRelu(feature) = feature          (if feature > 0), or
   1256   //                                feature * alpha  (otherwise),
   1257   // while TensorFlow's LeakyRelu(feature) = max(feature, feature * alpha).
   1258   // These two algorithms are not consistent when alpha > 1,
   1259   // so we only rewrite LeakyRelu to MKL OP when alpha <= 1.
   1260   static bool LeakyReluRewrite(const Node* n) {
   1261     DCHECK(n);
   1262 
   1263     float alpha;
   1264     bool has_attr = GetNodeAttr(n->def(), "alpha", &alpha).ok();
   1265     DCHECK(has_attr);
   1266 
   1267     // If the alpha of LeakyRelu is less than 1, rewrite the node.
   1268     // Otherwise eigen node is used instead.
   1269     if (alpha <= 1) {
   1270       return true;
   1271     }
   1272     VLOG(1) << "LeakyReluRewrite: The model sets alpha is greater than 1 "
   1273             << "which case is not optimized by Intel MKL, thus using Eigen op"
   1274             << "for LeakyRelu ";
   1275 
   1276     return false;
   1277   }
   1278 
   1279   static bool MaxpoolGradRewrite(const Node* n) {
   1280     CHECK_NOTNULL(n);
   1281     bool do_rewrite = false;
   1282     for (const Edge* e : n->in_edges()) {
   1283       // Rewrite only if there is corresponding Maxpool, i.e workspace is
   1284       // available
   1285       if (e->dst()->type_string() == csinfo_.max_pool_grad &&
   1286           e->dst_input() == 1 &&
   1287           e->src()->type_string() ==
   1288               mkl_op_registry::GetMklOpName(csinfo_.max_pool) &&
   1289           e->src_output() == 0) {
   1290         do_rewrite = true;
   1291         break;
   1292       }
   1293     }
   1294     return do_rewrite;
   1295   }
   1296 
   1297   static bool AddNRewrite(const Node* n) {
   1298     CHECK_NOTNULL(n);
   1299 
   1300     int num;
   1301     CHECK_EQ(GetNodeAttr(n->def(), "N", &num).ok(), true);
   1302 
   1303     // Condition that specifies non-batch-wise and non-depth-wise pooling.
   1304     if (num == 2) {
   1305       return true;
   1306     }
   1307 
   1308     return false;
   1309   }
   1310 
   1311   static bool FusedConv2DRewrite(const Node* n) {
   1312     // MKL DNN currently doesn't support all fusions that grappler fuses
   1313     // together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
   1314     // it includes those we support.
   1315     DataType T;
   1316     if (!GetNodeAttr(n->def(), "T", &T).ok() ||
   1317         !mkl_op_registry::IsMklOp(csinfo_.mkl_fused_conv2d, T)) {
   1318       return false;
   1319     }
   1320 
   1321     std::vector<string> fused_ops;
   1322     TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
   1323     return (fused_ops == std::vector<string>{"BiasAdd"} ||
   1324             fused_ops == std::vector<string>{"Relu"} ||
   1325             fused_ops == std::vector<string>{"BiasAdd", "Relu"});
   1326   }
   1327 
   1328   // Rewrites input node to a new node specified by its matching rewrite info.
   1329   //
   1330   // Method first searches matching rewrite info for input node and then
   1331   // uses that info to rewrite.
   1332   //
   1333   // Input node may be deleted in case of rewrite. Attempt to use the node
   1334   // after the call can result in undefined behaviors.
   1335   //
   1336   // @input  g - input graph, n - Node to be rewritten,
   1337   //         ri - matching rewriteinfo
   1338   // @return Status::OK(), if the input node is rewritten;
   1339   //         Returns appropriate Status error code otherwise.
   1340   //         Graph is updated in case the input node is rewritten.
   1341   //         Otherwise, it is not updated.
   1342   Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri);
   1343 
   1344   // Get nodes that will feed a list of TF tensors to the new
   1345   // node that we are constructing.
   1346   //
   1347   // @input g - input graph,
   1348   // @input inputs - inputs to old node that we are using for constructing
   1349   //                 new inputs,
   1350   // @input input_idx - the index in the 'inputs' vector pointing to the
   1351   //                    current input that we have processed so far
   1352   // @output input_idx - index will be incremented by the number of nodes
   1353   //                     from 'inputs' that are processed
   1354   // @input list_length - The expected length of list of TF tensors
   1355   // @output output_nodes - the list of new nodes creating TF tensors
   1356   //
   1357   // @return None
   1358   void GetNodesProducingTFTensorList(
   1359       const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
   1360       int* input_idx, int list_length,
   1361       std::vector<NodeBuilder::NodeOut>* output_nodes);
   1362 
   1363   // Get nodes that will feed a list of Mkl tensors to the new
   1364   // node that we are constructing.
   1365   //
   1366   // @input g - input graph,
   1367   // @input orig_node - Original node that we are rewriting
   1368   // @input inputs - inputs to old node that we are using for constructing
   1369   //                 new inputs,
   1370   // @input input_idx - the index in the 'inputs' vector pointing to the
   1371   //                    current input that we have processed so far
   1372   // @output input_idx - index will be incremented by the number of nodes
   1373   //                     from 'inputs' that are processed
   1374   // @input list_length - The expected length of list of Mkl tensors
   1375   // @output output_nodes - the list of new nodes creating Mkl tensors
   1376   //
   1377   // @return None
   1378   void GetNodesProducingMklTensorList(
   1379       std::unique_ptr<Graph>* g, Node* orig_node,
   1380       const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
   1381       int* input_idx, int list_length,
   1382       std::vector<NodeBuilder::NodeOut>* output_nodes);
   1383 
   1384   // Get a node that will feed an Mkl tensor to the new
   1385   // node that we are constructing. The output node could be (1) 'n'
   1386   // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
   1387   // if 'n' is not an Mkl layer.
   1388   //
   1389   // @input g - input graph,
   1390   // @input orig_node - Original node that we are rewriting,
   1391   // @input n - Node based on which we are creating Mkl node,
   1392   // @input n_output_slot - the output slot of node 'n'
   1393   //            which is feeding to the node that we are constructing
   1394   // @output mkl_node - the new node that will feed Mkl tensor
   1395   // @output mkl_node_output_slot - the slot number of mkl_node that
   1396   //                                will feed the tensor
   1397   // @return None
   1398   void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
   1399                                  Node* n, int n_output_slot, Node** mkl_node,
   1400                                  int* mkl_node_output_slot);
   1401 
   1402   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
   1403   // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
   1404   // set up in contiguous fashion. 'workspace_tensors' carry graph nodes
   1405   // producing workspace edges if 'are_workspace_tensors_available' is true.
   1406   // Otherwise, 'workspace_tensors' is empty vector.
   1407   //
   1408   // For details, refer to 'Ordering of inputs after rewriting' section in the
   1409   // documentation above.
   1410   //
   1411   // Returns Status::OK() if setting up inputs is successful, otherwise
   1412   // returns appropriate status code.
   1413   int SetUpContiguousInputs(
   1414       std::unique_ptr<Graph>* g,
   1415       const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
   1416       NodeBuilder* nb, Node* old_node,
   1417       std::vector<NodeBuilder::NodeOut>* workspace_tensors,
   1418       bool are_workspace_tensors_available);
   1419 
   1420   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
   1421   // in graph 'g'. Original node is input in 'orig_node'.
   1422   //
   1423   // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors'
   1424   // section in the documentation above.
   1425   //
   1426   // Returns Status::OK() if setting up inputs is successful, otherwise
   1427   // returns appropriate status code.
   1428   Status SetUpInputs(std::unique_ptr<Graph>* g,
   1429                      const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
   1430                      NodeBuilder* nb, Node* orig_node);
   1431 
   1432   // Add workspace edge on the input or output side of Node 'orig_node' by using
   1433   // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate
   1434   // adding workspace edge then do not add it. Workspace Tensorflow and Mkl
   1435   // tensors, if they need to be added, will be set into these tensors.
   1436   // If we set workspace tensors, then are_ws_tensors_added should be true.
   1437   void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orig_node,
   1438                                 NodeBuilder* nb,
   1439                                 std::vector<NodeBuilder::NodeOut>* ws_tensors,
   1440                                 bool* are_ws_tensors_added);
   1441 
   1442   // Helper function used by FixMklMetaDataEdges. Fixes the metadata edge
   1443   // pointed by 'e_metadata' corresponding to the data edge 'e_data' in graph
   1444   // 'g'. Returns true is fixup was done; otherwise, it returns false.
   1445   bool FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, const Edge* e_data,
   1446                                   const Edge* e_metadata);
   1447 
   1448   // Are the input Mkl metadata edges for node 'n' in graph 'g' correctly
   1449   // connected? If not, then fix them. This is needed because a graph may have
   1450   // some input Mkl metadata edges incorrectly setup after node merge and
   1451   // rewrite passes. This could happen because GetReversePostOrder function may
   1452   // not provide topologically sorted order if a graph contains cycles. The
   1453   // function returns true if at least one Mkl metadata edge for node 'n' was
   1454   // fixed. Otherwise, it returns false.
   1455   //
   1456   // Example:
   1457   //
   1458   // X = MklConv2D(_, _, _)
   1459   // Y = MklConv2DWithBias(_, _, _, _, _, _)
   1460   // Z = MklAdd(X, Y, DummyMklTensor, Y:1)
   1461   //
   1462   // For a graph such as shown above, note that 3rd argument of MklAdd contains
   1463   // DummyMklTensor. Actually, it should be getting the Mkl metadata from
   1464   // MklConv2D op (specifically, X:2). This incorrect plumbing could be possible
   1465   // (although rare) if the Mkl NodeMerge + NodeRewrite passes visit Z before X
   1466   // (possible if X, Y, Z are part of a loop.) This function fixes the Mkl
   1467   // metadata edges only - it does not rewrite nodes nor does it modify the Mkl
   1468   // data edges (1st and 2nd arguments of MklAdd).
   1469   bool FixMklMetaDataEdges(std::unique_ptr<Graph>* g, Node* n);
   1470 
   1471   // Functions specific to operators to copy attributes
   1472   // We need operator-specific function to copy attributes because the framework
   1473   // does not provide any generic function for it.
   1474   // NOTE: names are alphabetically sorted.
   1475   static void CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb,
   1476                             bool change_format = false);
   1477   static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb,
   1478                                    bool change_format = false);
   1479   static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb,
   1480                               bool change_format = false);
   1481   static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb,
   1482                                 bool change_format = false);
   1483   static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
   1484                             bool change_format = false);
   1485   static void CopyAttrsConv2DDepthwise(const Node* orig_node, NodeBuilder* nb,
   1486                                        bool change_format = false);
   1487   static void CopyAttrsConv2DDepthwiseCheckConstFilter(
   1488       const Node* orig_node, NodeBuilder* nb, bool change_format = false);
   1489   static void CopyAttrsConvCheckConstFilter(const Node* orig_node,
   1490                                             NodeBuilder* nb,
   1491                                             bool change_format = false);
   1492   static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb,
   1493                                 bool change_format = false);
   1494   static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb,
   1495                                       bool change_format = false);
   1496   static void CopyAttrsLeakyRelu(const Node* orig_node, NodeBuilder* nb,
   1497                                  bool change_format = false);
   1498   static void CopyAttrsFusedConv2D(const Node* orig_node, NodeBuilder* nb,
   1499                                    bool change_format = false);
   1500   static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb,
   1501                            bool change_format = false);
   1502   static void CopyAttrsPadWithConv2D(const Node* orig_node, NodeBuilder* nb,
   1503                                      bool change_format = false);
   1504   static void CopyAttrsPadWithFusedConv2D(const Node* orig_node,
   1505                                           NodeBuilder* nb,
   1506                                           bool change_format = false);
   1507   static void CopyAttrsFromPadAndConv2D(const Node* orig_node1,
   1508                                         const Node* orig_node2, NodeBuilder* nb,
   1509                                         bool change_format = false);
   1510   static void CopyAttrsFromPadAndFusedConv2D(const Node* orig_node1,
   1511                                              const Node* orig_node2,
   1512                                              NodeBuilder* nb,
   1513                                              bool change_format = false);
   1514   static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb,
   1515                                bool change_format = false);
   1516   static void CopyAttrsQuantizedPooling(const Node* orig_node, NodeBuilder* nb,
   1517                                         bool change_format = false);
   1518   static void CopyAttrsQuantizedConv2D(const Node* orig_node, NodeBuilder* nb,
   1519                                        bool change_format = false);
   1520   static void CopyAttrsQuantizedConcat(const Node* orig_node, NodeBuilder* nb,
   1521                                        bool change_format = false);
   1522   static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb,
   1523                                bool change_format = false);
   1524   static void CopyAttrsRequantize(const Node* orig_node, NodeBuilder* nb,
   1525                                   bool change_format = false);
   1526   static void CopyAttrsSlice(const Node* orig_node, NodeBuilder* nb,
   1527                              bool change_format = false);
   1528   static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb,
   1529                              bool change_format = false);
   1530   static void CopyFormatAttrsConv(const Node* orig_node, NodeBuilder* nb,
   1531                                   const std::vector<int32>& strides,
   1532                                   const std::vector<int32>& dilations,
   1533                                   bool change_format = false);
   1534 
   1535   // Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
   1536   // using node for original node 'orig_node' and return it in '*out'.
   1537   // TODO(nhasabni) We should move this to mkl_util.h
   1538   void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out,
   1539                              Node* orig_node);
   1540   void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out,
   1541                                    Node* orig_node);
   1542 };
   1543 
   1544 MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
   1545 
   1546 // We register Mkl rewrite pass for phase 1 in post partitioning group.
   1547 // We register it here so that we get a complete picture of all users of Mkl
   1548 // nodes. Do not change the ordering of the Mkl passes.
   1549 const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
   1550     OptimizationPassRegistry::POST_PARTITIONING;
   1551 #ifdef ENABLE_MKL
   1552 REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
   1553 #endif  // ENABLE_MKL
   1554 
   1555 //////////////////////////////////////////////////////////////////////////
   1556 //           Helper functions for creating new node
   1557 //////////////////////////////////////////////////////////////////////////
   1558 
   1559 static void FillInputs(const Node* n,
   1560                        gtl::InlinedVector<Node*, 4>* control_edges,
   1561                        gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
   1562   control_edges->clear();
   1563   for (const Edge* e : n->in_edges()) {
   1564     if (e->IsControlEdge()) {
   1565       control_edges->push_back(e->src());
   1566     } else {
   1567       (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
   1568     }
   1569   }
   1570   std::sort(control_edges->begin(), control_edges->end());
   1571   if (n->op_def().is_commutative()) {
   1572     // For commutative inputs, we sort the input by the input Node*
   1573     // to get a canonical ordering (so that add(a,b) and add(b, a) will
   1574     // hash to the same value if is_commutative is true for 'add').
   1575     std::sort(in->begin(), in->end());
   1576   }
   1577 }
   1578 
   1579 void MklLayoutRewritePass::GetNodesProducingTFTensorList(
   1580     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
   1581     int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
   1582   CHECK_LT(*input_idx, inputs.size());
   1583   CHECK_GT(list_length, 0);
   1584   CHECK_NOTNULL(output_nodes);
   1585   output_nodes->reserve(list_length);
   1586 
   1587   while (list_length != 0) {
   1588     CHECK_GT(list_length, 0);
   1589     CHECK_LT(*input_idx, inputs.size());
   1590     Node* n = inputs[*input_idx].first;
   1591     int slot = inputs[*input_idx].second;
   1592     // If input node 'n' is just producing a single tensor at
   1593     // output slot 'slot' then we just add that single node.
   1594     output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
   1595     (*input_idx)++;
   1596     list_length--;
   1597   }
   1598 }
   1599 
   1600 // TODO(nhasabni) We should move this to mkl_util.h.
   1601 void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
   1602                                                  Node** out, Node* orig_node) {
   1603   // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent
   1604   // dummy Mkl tensor. 8 = 2*size_t.
   1605   const DataType dt = DataTypeToEnum<uint8>::v();
   1606   TensorProto proto;
   1607   proto.set_dtype(dt);
   1608   uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
   1609   proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 8));
   1610   TensorShape dummy_shape({8});
   1611   dummy_shape.AsProto(proto.mutable_tensor_shape());
   1612   TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
   1613                   .Attr("value", proto)
   1614                   .Attr("dtype", dt)
   1615                   .Device(orig_node->def().device())  // We place this node on
   1616                                                       // the same device as the
   1617                                                       // device of the original
   1618                                                       // node.
   1619                   .Finalize(&**g, out));
   1620   CHECK_NOTNULL(*out);  // Make sure we got a valid object before using it
   1621 
   1622   // If number of inputs to the original node is > 0, then we add
   1623   // control dependency between 1st input (index 0) of the original node and
   1624   // the dummy Mkl node. This is needed because control-flow ops such as Enter,
   1625   // Merge, etc, require frame_name of the dummy Mkl node to be same as the
   1626   // rewritten node. Adding control edge between 1st input of the original node
   1627   // and the dummy Mkl node ensures that the dummy node is in the same frame
   1628   // as the original node. Choosing 1st input is not necessary - any input of
   1629   // the original node is fine because all the inputs of a node are always in
   1630   // the same frame.
   1631   if (orig_node->num_inputs() > 0) {
   1632     Node* orig_input0 = nullptr;
   1633     TF_CHECK_OK(
   1634         orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
   1635     // Allow duplicate while adding control edge as it would fail (return
   1636     // NULL) if we try to add duplicate edge.
   1637     CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
   1638   }
   1639 
   1640   (*out)->set_assigned_device_name(orig_node->assigned_device_name());
   1641 }
   1642 
   1643 void MklLayoutRewritePass::GetNodesProducingMklTensorList(
   1644     std::unique_ptr<Graph>* g, Node* orig_node,
   1645     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
   1646     int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
   1647   CHECK_LT(*input_idx, inputs.size());
   1648   CHECK_GT(list_length, 0);
   1649   CHECK_NOTNULL(output_nodes);
   1650   output_nodes->reserve(list_length);
   1651 
   1652   while (list_length != 0) {
   1653     CHECK_GT(list_length, 0);
   1654     CHECK_LT(*input_idx, inputs.size());
   1655     Node* n = inputs[*input_idx].first;
   1656     int slot = inputs[*input_idx].second;
   1657     // If 'n' is producing a single tensor, then create a single Mkl tensor
   1658     // node.
   1659     Node* mkl_node = nullptr;
   1660     int mkl_node_output_slot = 0;
   1661     GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
   1662                               &mkl_node_output_slot);
   1663     output_nodes->push_back(
   1664         NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
   1665     (*input_idx)++;
   1666     list_length--;
   1667   }
   1668 }
   1669 
   1670 // Get an input node that will feed Mkl tensor to the new
   1671 // node that we are constructing. An input node could be (1) 'n'
   1672 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
   1673 // if 'n' is not an Mkl layer.
   1674 void MklLayoutRewritePass::GetNodeProducingMklTensor(
   1675     std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot,
   1676     Node** mkl_node, int* mkl_node_output_slot) {
   1677   CHECK_NOTNULL(n);
   1678   CHECK_NOTNULL(mkl_node);
   1679   CHECK_NOTNULL(mkl_node_output_slot);
   1680 
   1681   // If this is an MKL op, then it will create extra output for MKL layout.
   1682   DataType T;
   1683   if (GetNodeAttr(n->def(), "T", &T).ok() &&
   1684       mkl_op_registry::IsMklOp(n->type_string(), T)) {
   1685     // If this is an MKL op, then it will generate an edge that will receive
   1686     // Mkl tensor from a node.
   1687     // output slot number for Mkl tensor would be N+slot number of TensorFlow
   1688     // tensor, where N is total number of TensorFlow tensors.
   1689     *mkl_node = n;
   1690     *mkl_node_output_slot =
   1691         GetTensorMetaDataIndex(n_output_slot, n->num_outputs());
   1692   } else {
   1693     // If we have not visited the node and rewritten it, then we need
   1694     // to create a dummy node that will feed a dummy Mkl tensor to this node.
   1695     // DummyMklTensor node has no input and generates only 1 output
   1696     // (dummy Mkl tensor) as output slot number 0.
   1697     GetDummyMklTensorNode(g, mkl_node, orig_node);
   1698     CHECK_NOTNULL(*mkl_node);
   1699     *mkl_node_output_slot = 0;
   1700   }
   1701 }
   1702 
   1703 int MklLayoutRewritePass::SetUpContiguousInputs(
   1704     std::unique_ptr<Graph>* g,
   1705     const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
   1706     NodeBuilder* nb, Node* old_node,
   1707     std::vector<NodeBuilder::NodeOut>* workspace_tensors,
   1708     bool are_workspace_tensors_available) {
   1709   CHECK_NOTNULL(workspace_tensors);
   1710   CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
   1711 
   1712   // TODO(nhasabni): Temporary solution to connect filter input of
   1713   // BackpropInput with the converted filter from Conv2D.
   1714   bool do_connect_conv2d_backprop_input_filter = false;
   1715   Node* conv2d_node = nullptr;
   1716   // Filter node is 2nd input (slot index 1) of Conv2D.
   1717   int kConv2DFilterInputSlotIdx = 1;
   1718   int kConv2DBackpropInputFilterInputSlotIdx = 1;
   1719   int kConv2DFilterOutputSlotIdx = 1;
   1720   if (old_node->type_string() == csinfo_.conv2d_grad_input) {
   1721     // We need to find Conv2D node from Conv2DBackpropInput.
   1722     // For that let's first find filter node that is 2nd input (slot 1)
   1723     // of BackpropInput.
   1724     Node* filter_node = nullptr;
   1725     TF_CHECK_OK(old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx,
   1726                                      &filter_node));
   1727     CHECK_NOTNULL(filter_node);
   1728 
   1729     // Now check which nodes receive from filter_node. Filter feeds as
   1730     // 2nd input (slot 1) of _MklConv2D, _MklConv2DWithBias, and
   1731     // _MklFusedConv2D.
   1732     for (const Edge* e : filter_node->out_edges()) {
   1733       if ((e->dst()->type_string() == csinfo_.mkl_conv2d ||
   1734            e->dst()->type_string() == csinfo_.mkl_pad_with_conv2d ||
   1735            e->dst()->type_string() == csinfo_.mkl_pad_with_fused_conv2d ||
   1736            e->dst()->type_string() == csinfo_.mkl_conv2d_with_bias ||
   1737            e->dst()->type_string() == csinfo_.mkl_fused_conv2d) &&
   1738           e->dst_input() == kConv2DFilterInputSlotIdx
   1739           /* filter is 2nd input of Conv2D and _MklConv2D. */) {
   1740         if (conv2d_node != nullptr) {
   1741           VLOG(1) << "MklLayoutRewritePass: unusual case of same filter"
   1742                   << " feeding multiple Conv2D nodes: "
   1743                   << filter_node->DebugString();
   1744           // We will not connect filter input of Conv2DBackpropInput
   1745           // to be safe here.
   1746           do_connect_conv2d_backprop_input_filter = false;
   1747           break;
   1748         } else {
   1749           conv2d_node = e->dst();
   1750           do_connect_conv2d_backprop_input_filter = true;
   1751         }
   1752       }
   1753     }
   1754   }
   1755 
   1756   // Number of input slots to original op
   1757   // Input slots are represented by .Input() calls in REGISTER_OP.
   1758   int old_node_input_slots = old_node->op_def().input_arg_size();
   1759   // Actual number of inputs can be greater than or equal to number
   1760   // of Input slots because inputs of type list could be unfolded.
   1761   CHECK_GE(old_node_inputs.size(), old_node_input_slots);
   1762   int nn_slot_idx = 0;  // slot index for inputs of new node
   1763 
   1764   // Let's copy all inputs (TF tensors) of original node to new node.
   1765   int iidx = 0;
   1766   for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
   1767     // An input slot could be a single tensor or a list. We need
   1768     // to handle this case accordingly.
   1769     CHECK_LT(iidx, old_node_inputs.size());
   1770     const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
   1771     if (ArgIsList(arg)) {
   1772       std::vector<NodeBuilder::NodeOut> new_node_inputs;
   1773       int N = GetTensorListLength(arg, old_node);
   1774       GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
   1775                                     &new_node_inputs);
   1776       nb->Input(new_node_inputs);
   1777       nn_slot_idx++;
   1778     } else {
   1779       // Special case for connecting filter input of Conv2DBackpropInput
   1780       if (do_connect_conv2d_backprop_input_filter &&
   1781           iidx == kConv2DBackpropInputFilterInputSlotIdx) {
   1782         nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx);
   1783       } else {
   1784         nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
   1785       }
   1786       iidx++;
   1787       nn_slot_idx++;
   1788     }
   1789   }
   1790 
   1791   // If workspace tensors are available for this op and we are using
   1792   // contiguous ordering then we need to add Tensorflow tensor for
   1793   // workspace here because Tensorflow tensor for workspace is the
   1794   // last tensor in the list of Tensorflow tensors.
   1795   if (are_workspace_tensors_available) {
   1796     CHECK_EQ(workspace_tensors->size(), 2);
   1797     // Tensorflow tensor
   1798     nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index);
   1799     nn_slot_idx++;
   1800   }
   1801 
   1802   // Let's now setup all Mkl inputs to a new node.
   1803   // Number of Mkl inputs must be same as number of TF inputs.
   1804   iidx = 0;
   1805   for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
   1806     // An input slot could be a single tensor or a list. We need
   1807     // to handle this case accordingly.
   1808     CHECK_LT(iidx, old_node_inputs.size());
   1809     const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
   1810     if (ArgIsList(arg)) {
   1811       std::vector<NodeBuilder::NodeOut> new_node_inputs;
   1812       int N = GetTensorListLength(arg, old_node);
   1813       GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N,
   1814                                      &new_node_inputs);
   1815       nb->Input(new_node_inputs);
   1816       nn_slot_idx++;
   1817     } else {
   1818       Node* mkl_node = nullptr;
   1819       int mkl_node_output_slot = 0;
   1820       // Special case for connecting filter input of Conv2DBackpropInput
   1821       if (do_connect_conv2d_backprop_input_filter &&
   1822           iidx == kConv2DBackpropInputFilterInputSlotIdx) {
   1823         GetNodeProducingMklTensor(g, old_node, conv2d_node,
   1824                                   kConv2DFilterOutputSlotIdx, &mkl_node,
   1825                                   &mkl_node_output_slot);
   1826       } else {
   1827         GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
   1828                                   old_node_inputs[iidx].second, &mkl_node,
   1829                                   &mkl_node_output_slot);
   1830       }
   1831       nb->Input(mkl_node, mkl_node_output_slot);
   1832       iidx++;
   1833       nn_slot_idx++;
   1834     }
   1835   }
   1836 
   1837   // If workspace tensors are available for this op and we are using
   1838   // contiguous ordering then we need to add Mkl tensor for
   1839   // workspace here because Mkl tensor for workspace is the
   1840   // last tensor in the list of Mkl tensors.
   1841   if (are_workspace_tensors_available) {
   1842     CHECK_EQ(workspace_tensors->size(), 2);
   1843     // Mkl tensor
   1844     nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index);
   1845     nn_slot_idx++;
   1846   }
   1847 
   1848   return nn_slot_idx;
   1849 }
   1850 
   1851 Status MklLayoutRewritePass::SetUpInputs(
   1852     std::unique_ptr<Graph>* g,
   1853     const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
   1854     NodeBuilder* nb, Node* old_node) {
   1855   // Let's check if we need to add workspace tensors for this node.
   1856   // We add workspace edge only for MaxPool, LRN and BatchNorm.
   1857   std::vector<NodeBuilder::NodeOut> workspace_tensors;
   1858   bool are_workspace_tensors_available = false;
   1859 
   1860   // Avoid workspace check for QuantizedConv2D and the fused
   1861   // Ops as they don't have attribute: "T".
   1862   std::vector<string> quant_ops{
   1863       "QuantizedConv2D",
   1864       "QuantizedConv2DWithBias",
   1865       "QuantizedConv2DAndRelu",
   1866       "QuantizedConv2DWithBiasAndRelu",
   1867       "QuantizedConv2DWithBiasSumAndRelu",
   1868       "QuantizedConv2DAndRequantize",
   1869       "QuantizedConv2DWithBiasAndRequantize",
   1870       "QuantizedConv2DAndReluAndRequantize",
   1871       "QuantizedConv2DWithBiasAndReluAndRequantize",
   1872       "QuantizedConv2DWithBiasSumAndReluAndRequantize",
   1873       "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize"};
   1874   bool should_check_workspace =
   1875       std::find(std::begin(quant_ops), std::end(quant_ops),
   1876                 old_node->type_string()) == std::end(quant_ops);
   1877   if (should_check_workspace)
   1878     AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors,
   1879                              &are_workspace_tensors_available);
   1880 
   1881   int new_node_input_slots = 0;
   1882   if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
   1883     // TODO(nhasabni): implement this function just for same of completion.
   1884     // We do not use interleaved ordering right now.
   1885     return Status(
   1886         error::Code::UNIMPLEMENTED,
   1887         "Interleaved ordering of tensors is currently not supported.");
   1888   } else {
   1889     CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
   1890     new_node_input_slots = SetUpContiguousInputs(
   1891         g, old_node_inputs, nb, old_node, &workspace_tensors,
   1892         are_workspace_tensors_available);
   1893   }
   1894 
   1895   // Sanity check
   1896   int old_node_input_slots = old_node->op_def().input_arg_size();
   1897   if (!are_workspace_tensors_available) {
   1898     // If we are not adding workspace tensors for this op, then the total
   1899     // number of input slots to the new node _must_ be 2 times the number
   1900     // of input slots to the original node: N original Tensorflow tensors and
   1901     // N for Mkl tensors corresponding to each Tensorflow tensors.
   1902     CHECK_EQ(new_node_input_slots, old_node_input_slots * 2);
   1903   } else {
   1904     // If we are adding workspace tensors for this op, then the total
   1905     // The total number of input slots to new node _must_ be 2 times the number
   1906     // of input slots to the original node: N original Tensorflow tensors and
   1907     // N for Mkl tensors corresponding to each Tensorflow tensors plus 2
   1908     // (for workspace Tensorflow tensor and workspace Mkl tensor).
   1909     CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2);
   1910   }
   1911 
   1912   return Status::OK();
   1913 }
   1914 
   1915 //////////////////////////////////////////////////////////////////////////
   1916 //           Helper functions related to workspace pass
   1917 //////////////////////////////////////////////////////////////////////////
   1918 
   1919 // TODO(nhasabni) We should move this to mkl_util.h.
   1920 void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
   1921     std::unique_ptr<Graph>* g, Node** out, Node* orig_node) {
   1922   // We use uint8 tensor of shape 8 with content {0,0,0,0,0,0,0,0} to represent
   1923   // workspace tensor.
   1924   GetDummyMklTensorNode(g, out, orig_node);
   1925 }
   1926 
   1927 void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
   1928     std::unique_ptr<Graph>* g, Node* orig_node, NodeBuilder* nb,
   1929     std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) {
   1930   bool workspace_edge_added = false;  // Default initializer
   1931   CHECK_NOTNULL(are_ws_tensors_added);
   1932   *are_ws_tensors_added = false;  // Default initializer
   1933 
   1934   DataType T;
   1935   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   1936   for (auto ws : wsinfo_) {
   1937     if (orig_node->type_string() == ws.fwd_op &&
   1938         mkl_op_registry::IsMklOp(
   1939             mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
   1940       // If this op is a fwd op, then we need to check if there is an
   1941       // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
   1942       // an edge, then we just add an attribute on this node for setting
   1943       // workspace_passed to true. We don't add actual workspace edge
   1944       // in this node. Actual workspace edge gets added in the backward
   1945       // op for this node.
   1946       for (const Edge* e : orig_node->out_edges()) {
   1947         if (e->src_output() == ws.fwd_slot &&
   1948             e->dst()->type_string() == ws.bwd_op &&
   1949             e->dst_input() == ws.bwd_slot) {
   1950           nb->Attr("workspace_enabled", true);
   1951           VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
   1952                   << orig_node->type_string();
   1953           workspace_edge_added = true;
   1954           // We found the edge that we were looking for, so break.
   1955           break;
   1956         }
   1957       }
   1958 
   1959       if (!workspace_edge_added) {
   1960         // If we are here, then we did not find backward operator for this
   1961         // node.
   1962         nb->Attr("workspace_enabled", false);
   1963       }
   1964     } else if (orig_node->type_string() == ws.bwd_op &&
   1965                mkl_op_registry::IsMklOp(
   1966                    mkl_op_registry::GetMklOpName(orig_node->type_string()),
   1967                    T)) {
   1968       // If this op is a bwd op, then we need to add workspace edge and
   1969       // it's Mkl tensor edge between its corresponding fwd op and this
   1970       // op. Corresponding fwd op is specified in 'fwd_op' field of
   1971       // workspace info. fwd_slot and bwd_slot in workspace info specify
   1972       // an edge between which slots connect forward and backward op.
   1973       // Once all these criteria match, we add a workspace edge between
   1974       // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is
   1975       // determined by interleaved/contiguous ordering. Function
   1976       // DataIndexToMetaDataIndex tells us the location of Mkl tensor
   1977       // from the location of the Tensorflow tensor.
   1978       for (const Edge* e : orig_node->in_edges()) {
   1979         if (e->src_output() == ws.fwd_slot &&
   1980             // We would have rewritten the forward op, so we need to use
   1981             // GetMklOpName call to get its Mkl name.
   1982             e->src()->type_string() ==
   1983                 mkl_op_registry::GetMklOpName(ws.fwd_op) &&
   1984             e->dst_input() == ws.bwd_slot) {
   1985           nb->Attr("workspace_enabled", true);
   1986           CHECK_NOTNULL(ws_tensors);
   1987           // Add workspace edge between fwd op and bwd op.
   1988           ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
   1989           // Add Mkl tensor edge for workspace edge between fwd op and bwd op.
   1990           ws_tensors->push_back(NodeBuilder::NodeOut(
   1991               e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
   1992                                                  e->src()->num_outputs())));
   1993           *are_ws_tensors_added = true;
   1994           // In terms of input ordering, we add these calls to add Input
   1995           // here because workspace edge (and its Mkl tensor) is the last
   1996           // edge in the fwdop and bwdop. So all inputs before workspace
   1997           // tensor have been added by SetUpInputs function.
   1998           VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
   1999                   << orig_node->type_string();
   2000           workspace_edge_added = true;
   2001           // We found the edge that we were looking for, so break.
   2002           break;
   2003         }
   2004       }
   2005 
   2006       // If we are here means we did not find fwd op that feeds to this
   2007       // bwd op. So in this case, we need to generate dummy tensors for
   2008       // workspace input and Mkl tensor for workspace, and set
   2009       // workspace_enabled to false.
   2010       if (!workspace_edge_added) {
   2011         nb->Attr("workspace_enabled", false);
   2012         Node* dmt_ws = nullptr;      // Dummy tensor for workspace
   2013         Node* dmt_mkl_ws = nullptr;  // Dummy Mkl tensor for workspace
   2014         GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node);
   2015         GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node);
   2016         CHECK_NOTNULL(dmt_ws);
   2017         CHECK_NOTNULL(dmt_mkl_ws);
   2018         CHECK_NOTNULL(ws_tensors);
   2019         // We add dummy tensor as workspace tensor.
   2020         ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0));
   2021         // We add dummy tensor as Mkl tensor for workspace tensor.
   2022         ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0));
   2023         *are_ws_tensors_added = true;
   2024         VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for "
   2025                 << orig_node->type_string();
   2026       }
   2027     } else {
   2028       // If this node does not match any workspace info, then we do not
   2029       // do anything special for workspace propagation for it.
   2030     }
   2031   }
   2032 }
   2033 
   2034 //////////////////////////////////////////////////////////////////////////
   2035 // Op-specific functions to copy attributes from old node to new node
   2036 //////////////////////////////////////////////////////////////////////////
   2037 
   2038 void MklLayoutRewritePass::CopyAttrsConvCheckConstFilter(const Node* orig_node,
   2039                                                          NodeBuilder* nb,
   2040                                                          bool change_format) {
   2041   DataType T;
   2042   string padding;
   2043   std::vector<int32> strides;
   2044   std::vector<int32> dilations;
   2045 
   2046   // Get all attributes from old node.
   2047   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2048   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
   2049   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
   2050   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
   2051 
   2052   Node* filter_node = nullptr;
   2053   orig_node->input_node(1, &filter_node);
   2054 
   2055   // Add attributes to new node.
   2056   nb->Attr("T", T);
   2057   nb->Attr("padding", padding);
   2058   nb->Attr("is_filter_const", filter_node->IsConstant());
   2059 
   2060   // Add attributes related to `data_format`.
   2061   CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format);
   2062 }
   2063 
   2064 void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
   2065                                          bool change_format) {
   2066   DataType T;
   2067   string padding;
   2068   std::vector<int32> strides;
   2069   std::vector<int32> dilations;
   2070 
   2071   // Get all attributes from old node.
   2072   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2073   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
   2074   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
   2075   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
   2076 
   2077   // Add attributes to new node.
   2078   nb->Attr("T", T);
   2079   nb->Attr("padding", padding);
   2080 
   2081   // Add attributes related to `data_format`.
   2082   CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format);
   2083 }
   2084 
   2085 // Used in rinfo when replacing __MklDummyPadWithConv2D by _MklPadWithConv2D
   2086 void MklLayoutRewritePass::CopyAttrsPadWithConv2D(const Node* orig_node,
   2087                                                   NodeBuilder* nb,
   2088                                                   bool change_format) {
   2089   DataType Tpaddings;
   2090   DataType T;
   2091   string data_format;
   2092   string padding;
   2093   std::vector<int32> strides;
   2094   std::vector<int32> dilations;
   2095   bool use_cudnn_on_gpu;
   2096 
   2097   // Get all attributes from old node.
   2098   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2099   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
   2100   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
   2101   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
   2102   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   2103   TF_CHECK_OK(
   2104       GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
   2105   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tpaddings", &Tpaddings));
   2106 
   2107   Node* filter_node = nullptr;
   2108   orig_node->input_node(1, &filter_node);
   2109 
   2110   // Add attributes to new node.
   2111   nb->Attr("T", T);
   2112   nb->Attr("strides", strides);
   2113   nb->Attr("dilations", dilations);
   2114   nb->Attr("padding", padding);
   2115   nb->Attr("is_filter_const", filter_node->IsConstant());
   2116   nb->Attr("data_format", data_format);
   2117   nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
   2118   nb->Attr("Tpaddings", Tpaddings);
   2119 }
   2120 
   2121 void MklLayoutRewritePass::CopyAttrsPadWithFusedConv2D(const Node* orig_node,
   2122                                                        NodeBuilder* nb,
   2123                                                        bool change_format) {
   2124   DataType Tpaddings;
   2125 
   2126   CopyAttrsFusedConv2D(orig_node, nb, change_format);
   2127 
   2128   // Get attributes from old node.
   2129   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tpaddings", &Tpaddings));
   2130   // Check if filter is a constant.
   2131   Node* filter_node = nullptr;
   2132   orig_node->input_node(1, &filter_node);
   2133 
   2134   // Add attributes to new node.
   2135   nb->Attr("Tpaddings", Tpaddings);
   2136   nb->Attr("is_filter_const", filter_node->IsConstant());
   2137 }
   2138 
   2139 // Used with MergePadWithConv2D
   2140 void MklLayoutRewritePass::CopyAttrsFromPadAndConv2D(const Node* orig_node1,
   2141                                                      const Node* orig_node2,
   2142                                                      NodeBuilder* nb,
   2143                                                      bool change_format) {
   2144   DataType Tpaddings;
   2145   DataType T;
   2146   string data_format;
   2147   string padding;
   2148   std::vector<int32> strides;
   2149   std::vector<int32> dilations;
   2150   bool use_cudnn_on_gpu;
   2151 
   2152   // Get all attributes from old node 1.
   2153   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "T", &T));
   2154   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "strides", &strides));
   2155   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "dilations", &dilations));
   2156   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "padding", &padding));
   2157   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "data_format", &data_format));
   2158   TF_CHECK_OK(
   2159       GetNodeAttr(orig_node1->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
   2160   // Get all attributes from old node 2.
   2161   TF_CHECK_OK(GetNodeAttr(orig_node2->def(), "Tpaddings", &Tpaddings));
   2162 
   2163   // Add attributes to new node.
   2164   nb->Attr("T", T);
   2165   nb->Attr("strides", strides);
   2166   nb->Attr("dilations", dilations);
   2167   nb->Attr("padding", padding);
   2168   nb->Attr("data_format", data_format);
   2169   nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
   2170   nb->Attr("Tpaddings", Tpaddings);
   2171 }
   2172 
   2173 void MklLayoutRewritePass::CopyAttrsFromPadAndFusedConv2D(
   2174     const Node* fused_conv2d, const Node* pad, NodeBuilder* nb,
   2175     bool change_format) {
   2176   DataType T;
   2177   int num_args;
   2178   string data_format;
   2179   string padding;
   2180   std::vector<int32> strides;
   2181   std::vector<int32> dilations;
   2182   float epsilon;
   2183   std::vector<string> fused_ops;
   2184   DataType Tpaddings;
   2185 
   2186   // Get all attributes from old node.
   2187   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "T", &T));
   2188   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "num_args", &num_args));
   2189   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "strides", &strides));
   2190   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "padding", &padding));
   2191   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "data_format", &data_format));
   2192   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "dilations", &dilations));
   2193   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "fused_ops", &fused_ops));
   2194   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "epsilon", &epsilon));
   2195   TF_CHECK_OK(GetNodeAttr(pad->def(), "Tpaddings", &Tpaddings));
   2196 
   2197   // Add attributes to new node.
   2198   nb->Attr("T", T);
   2199   nb->Attr("num_args", num_args);
   2200   nb->Attr("strides", strides);
   2201   nb->Attr("padding", padding);
   2202   nb->Attr("data_format", data_format);
   2203   nb->Attr("dilations", dilations);
   2204   nb->Attr("epsilon", epsilon);
   2205   nb->Attr("Tpaddings", Tpaddings);
   2206   nb->Attr("fused_ops", fused_ops);
   2207 }
   2208 
   2209 void MklLayoutRewritePass::CopyAttrsConv2DDepthwise(const Node* orig_node,
   2210                                                     NodeBuilder* nb,
   2211                                                     bool change_format) {
   2212   DataType T;
   2213   string data_format;
   2214   string padding;
   2215   std::vector<int32> strides;
   2216   std::vector<int32> dilations;
   2217 
   2218   // Get all attributes from old node.
   2219   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2220   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
   2221   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
   2222   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
   2223   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   2224 
   2225   // Add attributes to new node.
   2226   nb->Attr("T", T);
   2227   nb->Attr("strides", strides);
   2228   nb->Attr("dilations", dilations);
   2229   nb->Attr("padding", padding);
   2230   nb->Attr("data_format", data_format);
   2231 }
   2232 
   2233 void MklLayoutRewritePass::CopyAttrsConv2DDepthwiseCheckConstFilter(
   2234     const Node* orig_node, NodeBuilder* nb, bool change_format) {
   2235   DataType T;
   2236   string data_format;
   2237   string padding;
   2238   std::vector<int32> strides;
   2239   std::vector<int32> dilations;
   2240 
   2241   // Get all attributes from old node.
   2242   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2243   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
   2244   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
   2245   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
   2246   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   2247 
   2248   Node* filter_node = nullptr;
   2249   orig_node->input_node(1, &filter_node);
   2250 
   2251   // Add attributes to new node.
   2252   nb->Attr("T", T);
   2253   nb->Attr("strides", strides);
   2254   nb->Attr("dilations", dilations);
   2255   nb->Attr("padding", padding);
   2256   nb->Attr("is_filter_const", filter_node->IsConstant());
   2257   nb->Attr("data_format", data_format);
   2258 }
   2259 
   2260 void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb,
   2261                                          bool change_format) {
   2262   DataType T;
   2263   int N;
   2264 
   2265   // Get all attributes from old node.
   2266   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2267   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
   2268 
   2269   // Add attributes to new node.
   2270   nb->Attr("T", T);
   2271   nb->Attr("N", N);
   2272 }
   2273 
   2274 void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node,
   2275                                                 NodeBuilder* nb,
   2276                                                 bool change_format) {
   2277   DataType T;
   2278   string data_format;
   2279   std::vector<int32> strides;
   2280 
   2281   // Get all attributes from old node.
   2282   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2283   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
   2284   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   2285 
   2286   // Add attributes to new node.
   2287   nb->Attr("T", T);
   2288   nb->Attr("strides", strides);
   2289   nb->Attr("data_format", data_format);
   2290 }
   2291 
   2292 void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb,
   2293                                         bool change_format) {
   2294   DataType T;
   2295   int depth_radius;
   2296   float bias;
   2297   float alpha;
   2298   float beta;
   2299 
   2300   // Get all attributes from old node.
   2301   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2302   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "depth_radius", &depth_radius));
   2303   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "bias", &bias));
   2304   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha));
   2305   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "beta", &beta));
   2306 
   2307   // Add attributes to new node.
   2308   nb->Attr("T", T);
   2309   nb->Attr("depth_radius", depth_radius);
   2310   nb->Attr("bias", bias);
   2311   nb->Attr("alpha", alpha);
   2312   nb->Attr("beta", beta);
   2313 }
   2314 
   2315 void MklLayoutRewritePass::CopyAttrsLeakyRelu(const Node* orig_node,
   2316                                               NodeBuilder* nb,
   2317                                               bool change_format) {
   2318   DataType T;
   2319   float alpha;
   2320 
   2321   // Get all attributes from old node.
   2322   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2323   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha));
   2324 
   2325   // Add attributes to new node.
   2326   nb->Attr("T", T);
   2327   nb->Attr("alpha", alpha);
   2328 }
   2329 
   2330 void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
   2331                                             NodeBuilder* nb,
   2332                                             bool change_format) {
   2333   DataType T;
   2334   string data_format;
   2335   string padding;
   2336   std::vector<int32> ksize, strides;
   2337 
   2338   // Get all attributes from old node.
   2339   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2340   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
   2341   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
   2342   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
   2343   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   2344 
   2345   // Add attributes to new node.
   2346   nb->Attr("T", T);
   2347   nb->Attr("ksize", ksize);
   2348   nb->Attr("strides", strides);
   2349   nb->Attr("padding", padding);
   2350   nb->Attr("data_format", data_format);
   2351 }
   2352 
   2353 void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node,
   2354                                              NodeBuilder* nb,
   2355                                              bool change_format) {
   2356   DataType T;
   2357 
   2358   // Get all attributes from old node.
   2359   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2360 
   2361   // Add attributes to new node.
   2362   nb->Attr("T", T);
   2363 }
   2364 
   2365 void MklLayoutRewritePass::CopyAttrsQuantizedPooling(const Node* orig_node,
   2366                                                      NodeBuilder* nb,
   2367                                                      bool change_format) {
   2368   DataType T;
   2369   string padding;
   2370   std::vector<int32> ksize, strides;
   2371 
   2372   // Get all attributes from old node.
   2373   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2374   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
   2375   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
   2376   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
   2377 
   2378   // Add attributes to new node.
   2379   nb->Attr("T", T);
   2380   nb->Attr("ksize", ksize);
   2381   nb->Attr("strides", strides);
   2382   nb->Attr("padding", padding);
   2383 }
   2384 
   2385 void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node,
   2386                                                     NodeBuilder* nb,
   2387                                                     bool change_format) {
   2388   DataType Tinput, Tfilter, out_type;
   2389   string padding;
   2390   string data_format("NHWC");
   2391   std::vector<int32> strides, dilations, padding_list;
   2392   bool has_padding_list = HasNodeAttr(orig_node->def(), "padding_list");
   2393 
   2394   // Get all attributes from old node.
   2395   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput));
   2396   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tfilter", &Tfilter));
   2397   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "out_type", &out_type));
   2398   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
   2399   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
   2400   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
   2401   if (has_padding_list) {
   2402     TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding_list", &padding_list));
   2403   }
   2404 
   2405   Node* filter_node = nullptr;
   2406   orig_node->input_node(1, &filter_node);
   2407 
   2408   // Add attributes to new node.
   2409   nb->Attr("Tinput", Tinput);
   2410   nb->Attr("Tfilter", Tfilter);
   2411   nb->Attr("out_type", out_type);
   2412   nb->Attr("padding", padding);
   2413   nb->Attr("is_filter_const", filter_node->IsConstant());
   2414   nb->Attr("strides", strides);
   2415   nb->Attr("dilations", dilations);
   2416   nb->Attr("T", out_type);  // added "T" for facilitating MklToTf conversion.
   2417   nb->Attr("data_format", data_format);
   2418   if (has_padding_list) {
   2419     nb->Attr("padding_list", padding_list);
   2420   }
   2421 
   2422   // Requantization attr Tbias.
   2423   DataType Tbias;
   2424   Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
   2425   if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
   2426 }
   2427 
   2428 void MklLayoutRewritePass::CopyAttrsRequantize(const Node* orig_node,
   2429                                                NodeBuilder* nb,
   2430                                                bool change_format) {
   2431   DataType Tinput, out_type;
   2432 
   2433   // Get all attributes from old node.
   2434   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput));
   2435   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "out_type", &out_type));
   2436 
   2437   // Add attributes to new node.
   2438   nb->Attr("Tinput", Tinput);
   2439   nb->Attr("out_type", out_type);
   2440 }
   2441 
   2442 void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
   2443                                             NodeBuilder* nb,
   2444                                             bool change_format) {
   2445   DataType T;
   2446   DataType Tshape;
   2447 
   2448   // Get all attributes from old node.
   2449   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2450   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
   2451 
   2452   // Add attributes to new node.
   2453   nb->Attr("T", T);
   2454   nb->Attr("Tshape", Tshape);
   2455 }
   2456 
   2457 void MklLayoutRewritePass::CopyAttrsSlice(const Node* orig_node,
   2458                                           NodeBuilder* nb, bool change_format) {
   2459   DataType T;
   2460   DataType Index;
   2461 
   2462   // Get all attributes from old node.
   2463   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2464   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Index", &Index));
   2465 
   2466   // Add attributes to new node.
   2467   nb->Attr("T", T);
   2468   nb->Attr("Index", Index);
   2469 }
   2470 
   2471 void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
   2472                                           NodeBuilder* nb, bool change_format) {
   2473   DataType T;
   2474   string data_format;
   2475   int num_split;
   2476 
   2477   // Get all attributes from old node.
   2478   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2479   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_split", &num_split));
   2480   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   2481 
   2482   // Add attributes to new node.
   2483   nb->Attr("T", T);
   2484   nb->Attr("num_split", num_split);
   2485   nb->Attr("data_format", data_format);
   2486 }
   2487 
   2488 void MklLayoutRewritePass::CopyFormatAttrsConv(
   2489     const Node* orig_node, NodeBuilder* nb, const std::vector<int32>& strides,
   2490     const std::vector<int32>& dilations, bool change_format) {
   2491   string data_format;
   2492 
   2493   if (!change_format) {
   2494     nb->Attr("strides", strides);
   2495     nb->Attr("dilations", dilations);
   2496 
   2497     TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   2498     nb->Attr("data_format", data_format);
   2499   } else {
   2500     std::vector<int32> new_strides;
   2501     std::vector<int32> new_dilations;
   2502     if (strides.size() == 5) {
   2503       // `strides` and `dilations` also need to be changed according to
   2504       // `data_format`. In this case, from `NDHWC` to `NCDHW`.
   2505       new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C],
   2506                      strides[NDHWC::dim::D], strides[NDHWC::dim::H],
   2507                      strides[NDHWC::dim::W]};
   2508 
   2509       new_dilations = {dilations[NDHWC::dim::N], dilations[NDHWC::dim::C],
   2510                        dilations[NDHWC::dim::D], dilations[NDHWC::dim::H],
   2511                        dilations[NDHWC::dim::W]};
   2512     } else {
   2513       // `strides` and `dilations` also need to be changed according to
   2514       // `data_format`. In this case, from `NHWC` to `NCHW`.
   2515 
   2516       new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C],
   2517                      strides[NHWC::dim::H], strides[NHWC::dim::W]};
   2518 
   2519       new_dilations = {dilations[NHWC::dim::N], dilations[NHWC::dim::C],
   2520                        dilations[NHWC::dim::H], dilations[NHWC::dim::W]};
   2521     }
   2522     nb->Attr("strides", new_strides);
   2523     nb->Attr("dilations", new_dilations);
   2524   }
   2525 }
   2526 
   2527 void MklLayoutRewritePass::CopyAttrsConcat(const Node* orig_node,
   2528                                            NodeBuilder* nb,
   2529                                            bool change_format) {
   2530   DataType T;
   2531   int N;
   2532 
   2533   // Get all attributes from old node.
   2534   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2535   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
   2536 
   2537   // Add attributes to new node.
   2538   nb->Attr("T", T);
   2539   nb->Attr("N", N);
   2540 }
   2541 
   2542 void MklLayoutRewritePass::CopyAttrsConcatV2(const Node* orig_node,
   2543                                              NodeBuilder* nb,
   2544                                              bool change_format) {
   2545   DataType T;
   2546   int N;
   2547   DataType tidx;
   2548 
   2549   // Get all attributes from old node.
   2550   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2551   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
   2552   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tidx", &tidx));
   2553 
   2554   // Add attributes to new node.
   2555   nb->Attr("T", T);
   2556   nb->Attr("N", N);
   2557   nb->Attr("Tidx", tidx);
   2558 }
   2559 
   2560 void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node,
   2561                                                    NodeBuilder* nb,
   2562                                                    bool change_format) {
   2563   DataType T;
   2564   float epsilon;
   2565   string data_format;
   2566   bool is_training;
   2567 
   2568   // Get all attributes from old node.
   2569   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2570   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon));
   2571   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   2572   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "is_training", &is_training));
   2573 
   2574   // Add attributes to new node.
   2575   nb->Attr("T", T);
   2576   nb->Attr("epsilon", epsilon);
   2577   nb->Attr("data_format", data_format);
   2578   nb->Attr("is_training", is_training);
   2579 }
   2580 
   2581 void MklLayoutRewritePass::CopyAttrsFusedConv2D(const Node* orig_node,
   2582                                                 NodeBuilder* nb,
   2583                                                 bool change_format) {
   2584   DataType T;
   2585   int num_args;
   2586   float epsilon;
   2587   string data_format;
   2588   string padding;
   2589   std::vector<int32> strides;
   2590   std::vector<int32> dilations;
   2591   std::vector<string> fused_ops;
   2592 
   2593   // Get all attributes from old node.
   2594   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   2595   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_args", &num_args));
   2596   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
   2597   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
   2598   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
   2599   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
   2600   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "fused_ops", &fused_ops));
   2601   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon));
   2602 
   2603   Node* filter_node = nullptr;
   2604   orig_node->input_node(1, &filter_node);
   2605 
   2606   // Add attributes to new node.
   2607   nb->Attr("T", T);
   2608   nb->Attr("num_args", num_args);
   2609   nb->Attr("strides", strides);
   2610   nb->Attr("padding", padding);
   2611   nb->Attr("is_filter_const", filter_node->IsConstant());
   2612   nb->Attr("data_format", data_format);
   2613   nb->Attr("dilations", dilations);
   2614   nb->Attr("fused_ops", fused_ops);
   2615   nb->Attr("epsilon", epsilon);
   2616 }
   2617 
   2618 //////////////////////////////////////////////////////////////////////////
   2619 //           Helper functions related to node merge pass
   2620 //////////////////////////////////////////////////////////////////////////
   2621 
   2622 Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
   2623   // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite
   2624   // once we support BiasAddGrad as Mkl layer.
   2625 
   2626   // Search for all matching mergeinfo.
   2627   // We allow more than one match for extensibility.
   2628   std::vector<const MergeInfo*> matching_mi;
   2629   for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
   2630     if (a->type_string() == mi->op1 || a->type_string() == mi->op2) {
   2631       matching_mi.push_back(&*mi);
   2632     }
   2633   }
   2634 
   2635   for (const MergeInfo* mi : matching_mi) {
   2636     // Get the operand with which 'a' can be merged.
   2637     Node* b = nullptr;
   2638     if ((b = mi->get_node_to_be_merged(a)) == nullptr) {
   2639       continue;
   2640     }
   2641 
   2642     // Get the control edges and input of node
   2643     const int N_in = a->num_inputs();
   2644     gtl::InlinedVector<Node*, 4> a_control_edges;
   2645     gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
   2646     FillInputs(a, &a_control_edges, &a_in);
   2647 
   2648     const int B_in = b->num_inputs();
   2649     gtl::InlinedVector<Node*, 4> b_control_edges;
   2650     gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in);
   2651     FillInputs(b, &b_control_edges, &b_in);
   2652 
   2653     // Shouldn't merge if a and b have different control edges.
   2654     if (a_control_edges != b_control_edges) {
   2655       continue;
   2656     } else {
   2657       // We found a match.
   2658       return b;
   2659     }
   2660   }
   2661 
   2662   return nullptr;
   2663 }
   2664 
   2665 Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
   2666                                                     Node* m, Node* n) {
   2667   CHECK_EQ(((m->type_string() == csinfo_.bias_add &&
   2668              n->type_string() == csinfo_.conv2d)) ||
   2669                ((n->type_string() == csinfo_.bias_add &&
   2670                  m->type_string() == csinfo_.conv2d)),
   2671            true);
   2672 
   2673   // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd,
   2674   // BiasAdd is successor node, and Conv2D predecessor node.
   2675   Node* pred = m->type_string() == csinfo_.bias_add ? n : m;
   2676   Node* succ = m->type_string() == csinfo_.bias_add ? m : n;
   2677 
   2678   // 1. Get all attributes from input nodes.
   2679   DataType T_pred, T_succ;
   2680   string padding;
   2681   std::vector<int32> strides;
   2682   std::vector<int32> dilations;
   2683   string data_format_pred, data_format_succ;
   2684   bool use_cudnn_on_gpu;
   2685   TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
   2686   TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
   2687   TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
   2688   TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
   2689   TF_CHECK_OK(GetNodeAttr(pred->def(), "dilations", &dilations));
   2690   TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
   2691   TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
   2692   TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
   2693   // We check to ensure that data formats of both succ and pred are same.
   2694   // We expect them to be same, so we can enforce this as assert.
   2695   // But assert can be too strict, so we enforce this as a check.
   2696   // If the check fails, then we do not merge two nodes.
   2697   // We also do same check for devices.
   2698   if (data_format_pred != data_format_succ || T_pred != T_succ ||
   2699       pred->assigned_device_name() != succ->assigned_device_name() ||
   2700       pred->def().device() != succ->def().device()) {
   2701     return Status(error::Code::INVALID_ARGUMENT,
   2702                   "data_format or T attribute or devices of Conv2D and "
   2703                   "BiasAdd do not match. Will skip node merge optimization");
   2704   }
   2705 
   2706   const int succ_num = succ->num_inputs();
   2707   gtl::InlinedVector<Node*, 4> succ_control_edges;
   2708   gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
   2709   FillInputs(succ, &succ_control_edges, &succ_in);
   2710 
   2711   const int pred_num = pred->num_inputs();
   2712   gtl::InlinedVector<Node*, 4> pred_control_edges;
   2713   gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
   2714   FillInputs(pred, &pred_control_edges, &pred_in);
   2715 
   2716   // We need to ensure that Conv2D only feeds to BiasAdd (some other operator is
   2717   // not expecting output of Conv2D). If this is not the case, then we cannot
   2718   // merge Conv2D with BiasAdd.
   2719   const int kFirstOutputSlot = 0;
   2720   for (const Edge* e : pred->out_edges()) {
   2721     if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
   2722       return Status(error::Code::INVALID_ARGUMENT,
   2723                     "Conv2D does not feed to BiasAdd, or "
   2724                     "it feeds BiasAdd but has multiple outputs. "
   2725                     "Will skip node merge optimization");
   2726     }
   2727   }
   2728 
   2729   // 2. Get inputs from both the nodes.
   2730   // Find the 2 inputs from the conv and the bias from the add Bias.
   2731   // Get operand 0, 1 of conv2D.
   2732   CHECK_EQ(pred->in_edges().size(), 2);  // Conv2D must have 2 inputs.
   2733   // Get operand 1 of add_bias
   2734   // BiasAdd must have 2 inputs: Conv, bias
   2735   CHECK_EQ(succ->in_edges().size(), 2);
   2736 
   2737   // We will use the node name of BiasAdd as the name of new node
   2738   // Build new node. We use same name as original node, but change the op
   2739   // name.
   2740   NodeBuilder nb(succ->name(), csinfo_.conv2d_with_bias);
   2741   nb.Input(pred_in[0].first, pred_in[0].second);  // In1 of Conv2D
   2742   // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
   2743   nb.Input(pred_in[1].first, pred_in[1].second);  // In2 of Conv2D
   2744   // In1 of BiasAdd is same as output of Conv2D.
   2745   nb.Input(succ_in[1].first, succ_in[1].second);  // In2 of BiasAdd
   2746 
   2747   // Copy attributes from Conv2D to Conv2DWithBias.
   2748   CopyAttrsConvCheckConstFilter(const_cast<const Node*>(pred), &nb);
   2749 
   2750   // Copy the device assigned to old node to new node.
   2751   nb.Device(succ->def().device());
   2752 
   2753   // Create node.
   2754   Node* new_node;
   2755   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
   2756   CHECK_NOTNULL(new_node);
   2757 
   2758   // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
   2759   // node are already copied in BuildNode. We handle control edges now.
   2760   for (const Edge* e : pred->in_edges()) {
   2761     if (e->IsControlEdge()) {
   2762       // Allow duplicate while adding control edge as it would fail (return
   2763       // NULL) if we try to add duplicate edge.
   2764       CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
   2765     }
   2766   }
   2767   for (const Edge* e : succ->in_edges()) {
   2768     if (e->IsControlEdge()) {
   2769       // Allow duplicate while adding control edge as it would fail (return
   2770       // NULL) if we try to add duplicate edge.
   2771       CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
   2772     }
   2773   }
   2774 
   2775   // Incoming edges are fixed, we will fix the outgoing edges now.
   2776   // First, we will fix outgoing control edges from 'pred' node.
   2777   for (const Edge* e : pred->out_edges()) {
   2778     if (e->IsControlEdge()) {
   2779       // Allow duplicate while adding control edge as it would fail (return
   2780       // NULL) if we try to add duplicate edge.
   2781       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
   2782     }
   2783   }
   2784 
   2785   // Second, we will fix outgoing control and data edges from 'succ' node.
   2786   for (const Edge* e : succ->out_edges()) {
   2787     if (e->IsControlEdge()) {
   2788       // Allow duplicate while adding control edge as it would fail (return
   2789       // NULL) if we try to add duplicate edge.
   2790       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
   2791     } else {
   2792       // BiasAdd has only 1 output (at slot 0) and merged node also has only 1
   2793       // output (at slot 0).
   2794       const int kConv2DWithBiasOutputSlot = 0;
   2795       CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, e->dst(),
   2796                                   e->dst_input()));
   2797     }
   2798   }
   2799 
   2800   // Copy device assigned to old node to new node.
   2801   // It's ok to use pred or succ as we have enforced a check that
   2802   // both have same device assigned.
   2803   new_node->set_assigned_device_name(pred->assigned_device_name());
   2804 
   2805   VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
   2806           << ", and node: " << succ->DebugString()
   2807           << ", into node:" << new_node->DebugString();
   2808 
   2809   (*g)->RemoveNode(succ);
   2810   (*g)->RemoveNode(pred);
   2811 
   2812   return Status::OK();
   2813 }
   2814 
   2815 Status MklLayoutRewritePass::MergePadWithConv2D(std::unique_ptr<Graph>* g,
   2816                                                 Node* m, Node* n) {
   2817   DCHECK((m->type_string() == csinfo_.pad &&
   2818           (n->type_string() == csinfo_.conv2d ||
   2819            n->type_string() == csinfo_.fused_conv2d)) ||
   2820          (n->type_string() == csinfo_.pad &&
   2821           (m->type_string() == csinfo_.conv2d ||
   2822            m->type_string() == csinfo_.fused_conv2d)));
   2823 
   2824   bool is_fused_conv2d = n->type_string() == csinfo_.fused_conv2d ||
   2825                          m->type_string() == csinfo_.fused_conv2d;
   2826   // Conv2D is successor node, and Pad predecessor node.
   2827   Node* pred = m->type_string() == csinfo_.pad ? m : n;
   2828   Node* succ = m->type_string() == csinfo_.pad ? n : m;
   2829 
   2830   // 1. Get all attributes from input nodes.
   2831   DataType T_pred, T_succ;
   2832   string padding;
   2833   std::vector<int32> strides;
   2834   std::vector<int32> dilations;
   2835   string data_format_pred, data_format_succ;
   2836 
   2837   TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
   2838   TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
   2839   TF_CHECK_OK(GetNodeAttr(succ->def(), "padding", &padding));
   2840   TF_CHECK_OK(GetNodeAttr(succ->def(), "strides", &strides));
   2841   TF_CHECK_OK(GetNodeAttr(succ->def(), "dilations", &dilations));
   2842   // Check if the devices of both succ and pred are the same.
   2843   // Assert is not used because it can be too strict.
   2844   // Don't need to check for data formats because it is not available in Pad.
   2845   if (T_pred != T_succ ||
   2846       pred->assigned_device_name() != succ->assigned_device_name() ||
   2847       pred->def().device() != succ->def().device()) {
   2848     return Status(error::Code::INVALID_ARGUMENT,
   2849                   "T attribute or devices of Conv2D and "
   2850                   "Pad do not match. Will skip node merge optimization");
   2851   }
   2852 
   2853   const int succ_num = succ->num_inputs();
   2854   gtl::InlinedVector<Node*, 4> succ_control_edges;
   2855   gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
   2856   FillInputs(succ, &succ_control_edges, &succ_in);
   2857 
   2858   const int pred_num = pred->num_inputs();
   2859   gtl::InlinedVector<Node*, 4> pred_control_edges;
   2860   gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
   2861   FillInputs(pred, &pred_control_edges, &pred_in);
   2862 
   2863   // We need to ensure that Pad only feeds to Conv2D (some other operator is
   2864   // not expecting output of Pad). If this is not the case, then we cannot
   2865   // merge Conv2D with Pad.
   2866   const int kFirstOutputSlot = 0;
   2867   for (const Edge* e : pred->out_edges()) {
   2868     if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
   2869       return Status(error::Code::INVALID_ARGUMENT,
   2870                     "Pad does not feed to Conv2D, or "
   2871                     "it feeds Conv2D but has multiple outputs. "
   2872                     "Will skip node merge optimization");
   2873     }
   2874   }
   2875 
   2876   // 2. Get inputs from both the nodes.
   2877 
   2878   // Pad must have 2 data inputs: "input" and paddings.
   2879   int PadDataInputEdges = 0;
   2880   for (const Edge* e : pred->in_edges()) {
   2881     if (!e->IsControlEdge()) {
   2882       PadDataInputEdges++;
   2883     }
   2884   }
   2885   DCHECK_EQ(PadDataInputEdges, 2);
   2886 
   2887   // Conv2D must have 2 data inputs: Pad output and Filter
   2888   // FusedConv2D have 3 data inputs: Pad output, Filter and Args;
   2889   int ConvDataInputEdges = 0;
   2890   for (const Edge* e : succ->in_edges()) {
   2891     if (!e->IsControlEdge()) {
   2892       ConvDataInputEdges++;
   2893     }
   2894   }
   2895 
   2896   DCHECK_EQ(ConvDataInputEdges, is_fused_conv2d ? 3 : 2);
   2897 
   2898   // We will use the node name of Conv2D as the name of new node
   2899   // Build new node. We use same name as original node, but change the op
   2900   // name.
   2901 
   2902   NodeBuilder nb(succ->name(), is_fused_conv2d ? csinfo_.pad_with_fused_conv2d
   2903                                                : csinfo_.pad_with_conv2d);
   2904   nb.Input(pred_in[0].first, pred_in[0].second);  // In1 (input data)  of Pad
   2905   // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
   2906   nb.Input(succ_in[1].first, succ_in[1].second);  // In2 (filter) of conv2d
   2907   // In1 of Conv2D is same as output of Pad.
   2908   // Thus, only need to add In2 of Conv2D
   2909 
   2910   if (is_fused_conv2d) {
   2911     // FusedConv2D has one additional input, args
   2912     std::vector<NodeBuilder::NodeOut> args;
   2913     args.emplace_back(succ_in[2].first, succ_in[2].second);
   2914     nb.Input(gtl::ArraySlice<NodeBuilder::NodeOut>{
   2915         args});                                     // In3 (args) of FusedConv2D
   2916     nb.Input(pred_in[1].first, pred_in[1].second);  // In2 (paddings) of Pad
   2917     // Copy attributes from Pad and FusedConv2D to PadWithFusedConv2D.
   2918     CopyAttrsFromPadAndFusedConv2D(const_cast<const Node*>(succ),
   2919                                    const_cast<const Node*>(pred), &nb);
   2920   } else {
   2921     nb.Input(pred_in[1].first, pred_in[1].second);  // In2 (paddings) of Pad
   2922     // Copy attributes from Pad and conv2D to PadWithConv2D.
   2923     CopyAttrsFromPadAndConv2D(const_cast<const Node*>(succ),
   2924                               const_cast<const Node*>(pred), &nb);
   2925   }
   2926 
   2927   // Copy the device assigned to old node to new node.
   2928   nb.Device(succ->def().device());
   2929 
   2930   // Create node.
   2931   Node* new_node;
   2932   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
   2933   DCHECK(new_node);
   2934 
   2935   // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
   2936   // node are already copied in BuildNode.
   2937   // We handle control edges now.
   2938   for (const Edge* e : pred->in_edges()) {
   2939     if (e->IsControlEdge()) {
   2940       // Don't allow duplicate edge
   2941       (*g)->AddControlEdge(e->src(), new_node, false);
   2942     }
   2943   }
   2944   for (const Edge* e : succ->in_edges()) {
   2945     if (e->IsControlEdge()) {
   2946       // Don't allow duplicate edge
   2947       (*g)->AddControlEdge(e->src(), new_node, false);
   2948     }
   2949   }
   2950 
   2951   // Incoming edges are fixed, we will fix the outgoing edges now.
   2952   // First, we will fix outgoing control edges from 'pred' node.
   2953   for (const Edge* e : pred->out_edges()) {
   2954     if (e->IsControlEdge()) {
   2955       // Don't allow duplicate edge
   2956       (*g)->AddControlEdge(new_node, e->dst(), false);
   2957     }
   2958   }
   2959 
   2960   // Second, we will fix outgoing control and data edges from 'succ' node.
   2961   for (const Edge* e : succ->out_edges()) {
   2962     if (e->IsControlEdge()) {
   2963       // Allow duplicate while adding control edge as it would fail (return
   2964       // NULL) if we try to add duplicate edge.
   2965       (*g)->AddControlEdge(new_node, e->dst(), false);
   2966     } else {
   2967       // Conv2D has only 1 output (at slot 0) and merged node also has only 1
   2968       // output (at slot 0).
   2969       const int kPadWithConv2DOutputSlot = 0;
   2970       (*g)->AddEdge(new_node, kPadWithConv2DOutputSlot, e->dst(),
   2971                     e->dst_input());
   2972     }
   2973   }
   2974 
   2975   // Copy device assigned to old node to new node.
   2976   // It's ok to use pred or succ as we have enforced a check that
   2977   // both have same device assigned.
   2978   new_node->set_assigned_device_name(pred->assigned_device_name());
   2979 
   2980   VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
   2981           << ", and node: " << succ->DebugString()
   2982           << ", into node:" << new_node->DebugString();
   2983 
   2984   (*g)->RemoveNode(succ);
   2985   (*g)->RemoveNode(pred);
   2986 
   2987   return Status::OK();
   2988 }
   2989 
   2990 Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
   2991     std::unique_ptr<Graph>* g, Node* m, Node* n) {
   2992   CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad &&
   2993              n->type_string() == csinfo_.conv2d_grad_filter)) ||
   2994                ((n->type_string() == csinfo_.bias_add_grad &&
   2995                  m->type_string() == csinfo_.conv2d_grad_filter)),
   2996            true);
   2997 
   2998   // If 'm' is BiasAddGrad, then 'n' is BackpropFilter.
   2999   Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n;
   3000   Node* fltr = m->type_string() == csinfo_.bias_add_grad ? n : m;
   3001 
   3002   // Sanity check for attributes from input nodes.
   3003   DataType T_b, T_f;
   3004   string data_format_b, data_format_f;
   3005   TF_CHECK_OK(GetNodeAttr(badd->def(), "T", &T_b));
   3006   TF_CHECK_OK(GetNodeAttr(fltr->def(), "T", &T_f));
   3007   TF_CHECK_OK(GetNodeAttr(badd->def(), "data_format", &data_format_b));
   3008   TF_CHECK_OK(GetNodeAttr(fltr->def(), "data_format", &data_format_f));
   3009   if (data_format_b != data_format_f || T_b != T_f ||
   3010       badd->assigned_device_name() != fltr->assigned_device_name() ||
   3011       badd->def().device() != fltr->def().device()) {
   3012     return Status(error::Code::INVALID_ARGUMENT,
   3013                   "data_format or T attribute or devices of "
   3014                   "Conv2DBackpropFilter and BiasAddGrad do not match. "
   3015                   "Will skip node merge optimization");
   3016   }
   3017 
   3018   // We will use the node name of Conv2DBackpropFilter as the name of new node.
   3019   // This is because BackpropFilterWithBias is going to emit bias output also.
   3020   NodeBuilder nb(fltr->name(), csinfo_.conv2d_grad_filter_with_bias);
   3021   // Since Conv2DBackpropFilterWithBias has same number of inputs as
   3022   // Conv2DBackpropFilter, we can just copy input edges directly. We dont need
   3023   // to copy any data input of BiasAddGrad because that input also goes to
   3024   // Conv2DBackpropFilter.
   3025   const int fltr_ins = fltr->num_inputs();
   3026   gtl::InlinedVector<Node*, 4> fltr_control_edges;
   3027   gtl::InlinedVector<std::pair<Node*, int>, 4> fltr_in_edges(fltr_ins);
   3028   FillInputs(fltr, &fltr_control_edges, &fltr_in_edges);
   3029   for (int idx = 0; idx < fltr_ins; idx++) {
   3030     nb.Input(fltr_in_edges[idx].first, fltr_in_edges[idx].second);
   3031   }
   3032 
   3033   // Copy attributes from Conv2DBackpropFilter.
   3034   CopyAttrsConv(const_cast<const Node*>(fltr), &nb);
   3035 
   3036   // Copy the device assigned to old node to new node.
   3037   nb.Device(fltr->def().device());
   3038 
   3039   // Create node.
   3040   Node* new_node;
   3041   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
   3042   CHECK_NOTNULL(new_node);
   3043 
   3044   // Incoming data edges from BiasAddGrad node and Conv2DBackpropFilter node to
   3045   // new 'new_node' node are already copied in BuildNode. We handle control
   3046   // edges now.
   3047   for (const Edge* e : badd->in_edges()) {
   3048     if (e->IsControlEdge()) {
   3049       // Allow duplicate while adding control edge as it would fail (return
   3050       // NULL) if we try to add duplicate edge.
   3051       CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
   3052     }
   3053   }
   3054   for (const Edge* e : fltr->in_edges()) {
   3055     if (e->IsControlEdge()) {
   3056       // Allow duplicate while adding control edge as it would fail (return
   3057       // NULL) if we try to add duplicate edge.
   3058       CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
   3059     }
   3060   }
   3061 
   3062   // Incoming edges are fixed, we will fix the outgoing edges now.
   3063   // First, we will fix outgoing control edges from 'badd' node.
   3064   // Conv2DBackpropFilter has 1 output -- filter_grad.
   3065   // Conv2DBackpropFilterWithBias has 2 outputs -- filter_grad and
   3066   // bias_grad. But filter_grad is at same slot number (0) in both the
   3067   // nodes. bias_grad is at slot number 1 in Conv2DBackpropFilterWithBias, while
   3068   // it is at slot number 0 in BiasAddGrad.
   3069   const int kMergedNodeFilterGradOutputIdx = 0;
   3070   const int kMergedNodeBiasGradOutputIdx = 1;
   3071 
   3072   for (const Edge* e : badd->out_edges()) {
   3073     if (e->IsControlEdge()) {
   3074       // Allow duplicate while adding control edge as it would fail (return
   3075       // NULL) if we try to add duplicate edge.
   3076       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
   3077     } else {
   3078       CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx,
   3079                                   e->dst(), e->dst_input()));
   3080     }
   3081   }
   3082 
   3083   // Second, we will fix outgoing control and data edges from 'fltr' node.
   3084   for (const Edge* e : fltr->out_edges()) {
   3085     if (e->IsControlEdge()) {
   3086       // We allow duplicate edge for this case since we already add control
   3087       // edge from new_node in line 3990. Line below could be adding same
   3088       // edge to same destination again. In such case, if we do not allow
   3089       // duplicate edge, then this call will fail.
   3090       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
   3091     } else {
   3092       CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx,
   3093                                   e->dst(), e->dst_input()));
   3094     }
   3095   }
   3096 
   3097   // Copy device assigned to old node to new node.
   3098   // It's ok to use badd or fltr as we have enforced a check that
   3099   // both have same device assigned.
   3100   new_node->set_assigned_device_name(badd->assigned_device_name());
   3101 
   3102   VLOG(1) << "MklLayoutRewritePass: Merged old node:" << badd->DebugString()
   3103           << ", and node: " << fltr->DebugString()
   3104           << ", into node:" << new_node->DebugString();
   3105 
   3106   (*g)->RemoveNode(badd);
   3107   (*g)->RemoveNode(fltr);
   3108 
   3109   return Status::OK();
   3110 }
   3111 
   3112 Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* m,
   3113                                        Node* n) {
   3114   CHECK_NOTNULL(m);
   3115   CHECK_NOTNULL(n);
   3116 
   3117   if (((m->type_string() == csinfo_.bias_add &&
   3118         n->type_string() == csinfo_.conv2d)) ||
   3119       ((n->type_string() == csinfo_.bias_add &&
   3120         m->type_string() == csinfo_.conv2d))) {
   3121     return this->MergeConv2DWithBiasAdd(g, m, n);
   3122   }
   3123   if ((m->type_string() == csinfo_.pad &&
   3124        (n->type_string() == csinfo_.conv2d ||
   3125         (n->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(n)))) ||
   3126       (n->type_string() == csinfo_.pad &&
   3127        (m->type_string() == csinfo_.conv2d ||
   3128         (m->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(m))))) {
   3129     return this->MergePadWithConv2D(g, m, n);
   3130   }
   3131 
   3132   if (((m->type_string() == csinfo_.bias_add_grad &&
   3133         n->type_string() == csinfo_.conv2d_grad_filter)) ||
   3134       ((n->type_string() == csinfo_.bias_add_grad &&
   3135         m->type_string() == csinfo_.conv2d_grad_filter))) {
   3136     return this->MergeConv2DBackpropFilterWithBiasAddGrad(g, m, n);
   3137   }
   3138 
   3139   return Status(error::Code::UNIMPLEMENTED,
   3140                 "Unimplemented case for node merge optimization.");
   3141 }
   3142 
   3143 //////////////////////////////////////////////////////////////////////////
   3144 //           Helper functions for node rewrite
   3145 //////////////////////////////////////////////////////////////////////////
   3146 
   3147 Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
   3148                                          Node* orig_node,
   3149                                          const RewriteInfo* ri) {
   3150   CHECK_NOTNULL(ri);
   3151   CHECK_NOTNULL(orig_node);
   3152 
   3153   VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString();
   3154 
   3155   // Get all inputs.
   3156   int num_inputs = orig_node->in_edges().size();
   3157 
   3158   // Drop count for control edges from inputs
   3159   for (const Edge* e : orig_node->in_edges()) {
   3160     if (e->IsControlEdge()) {
   3161       num_inputs--;
   3162     }
   3163   }
   3164 
   3165   gtl::InlinedVector<Node*, 4> control_edges;
   3166   gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs);
   3167   FillInputs(orig_node, &control_edges, &inputs);
   3168 
   3169   // Build new node. We use same name as original node, but change the op name.
   3170   NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
   3171   // Copy user-specified device assigned to original node to new node.
   3172   nb.Device(orig_node->def().device());
   3173   // Set up new inputs to the rewritten node.
   3174   Status s = SetUpInputs(g, inputs, &nb, orig_node);
   3175   if (s != Status::OK()) {
   3176     return s;
   3177   }
   3178 
   3179   const bool kPartialCopyAttrs = false;
   3180   ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, kPartialCopyAttrs);
   3181 
   3182   // Set the Mkl layer label for this op.
   3183   if (DataTypeIsQuantized(orig_node->input_type(0)) ||
   3184       DataTypeIsQuantized(orig_node->output_type(0))) {
   3185     nb.Attr("_kernel", mkl_op_registry::kMklQuantizedOpLabel);
   3186   } else {
   3187     nb.Attr("_kernel", mkl_op_registry::kMklOpLabel);
   3188   }
   3189   // Finalize graph and get new node.
   3190   Node* new_node = nullptr;
   3191   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
   3192   CHECK_NOTNULL(new_node);
   3193 
   3194   // Incoming data edges from 'orig_node' node to new 'new_node' node are
   3195   // already copied in BuildNode. We need to handle control edges now.
   3196   for (const Edge* e : orig_node->in_edges()) {
   3197     if (e->IsControlEdge()) {
   3198       // Allow duplicate while adding control edge as it would fail (return
   3199       // NULL) if we try to add duplicate edge.
   3200       CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
   3201     }
   3202   }
   3203 
   3204   // Copy outgoing edges from 'orig_node' node to new
   3205   // 'new_node' node, since the output also follows same ordering among
   3206   // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
   3207   // tensors appropriately. Specifically, nth output of the original node
   3208   // will become 2*nth output of the Mkl node for the interleaved ordering
   3209   // of the tensors. For the contiguous ordering of the tensors, it will be n.
   3210   // GetTensorDataIndex provides this mapping function.
   3211   for (const Edge* e : orig_node->out_edges()) {
   3212     if (e->IsControlEdge()) {
   3213       // Allow duplicate while adding control edge as it would fail (return
   3214       // NULL) if we try to add duplicate edge.
   3215       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
   3216     } else {
   3217       CHECK_NOTNULL((*g)->AddEdge(
   3218           new_node,
   3219           GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
   3220           e->dst(), e->dst_input()));
   3221     }
   3222   }
   3223 
   3224   // Copy the runtime device assigned from original code to new node.
   3225   new_node->set_assigned_device_name(orig_node->assigned_device_name());
   3226 
   3227   // Delete original node and mark new node as rewritten.
   3228   (*g)->RemoveNode(orig_node);
   3229 
   3230   VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString();
   3231   return Status::OK();
   3232 }
   3233 
   3234 // TODO(mdfaijul): Is there any other elegent way to check for quantized ops
   3235 // having attributes other than "T"?
   3236 // Current implementation reflects only QuantizedConv2D and its fused Ops.
   3237 const MklLayoutRewritePass::RewriteInfo*
   3238 MklLayoutRewritePass::CheckForQuantizedNodeRewrite(const Node* n) const {
   3239   DataType Tinput, Tfilter;
   3240   if (!(GetNodeAttr(n->def(), "Tinput", &Tinput).ok() &&
   3241         GetNodeAttr(n->def(), "Tfilter", &Tfilter).ok())) {
   3242     return nullptr;
   3243   }
   3244   if (mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
   3245                                Tinput, Tfilter)) {
   3246     for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
   3247       if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
   3248         return &*ri;
   3249       }
   3250     }
   3251   }
   3252   return nullptr;
   3253 }
   3254 
   3255 const MklLayoutRewritePass::RewriteInfo*
   3256 MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
   3257   CHECK_NOTNULL(n);
   3258 
   3259   // QuntizedOps may have attributes other than "T", so decoupled the check
   3260   // with a function, CheckForQuantizedNodeRewrite(const Node*).
   3261   const RewriteInfo* ri = CheckForQuantizedNodeRewrite(n);
   3262   if (ri != nullptr) return ri;
   3263 
   3264   // First check if node along with its type is supported by MKL layer.
   3265   // We do not want to rewrite an op into Mkl op if types are not supported.
   3266   // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to
   3267   // MklRelu if type is INT32.
   3268   DataType T;
   3269   if (!GetNodeAttr(n->def(), "T", &T).ok()) {
   3270     return nullptr;
   3271   }
   3272 
   3273   // We make an exception for __MklDummyConv2DWithBias,
   3274   // __MklConv2DBackpropFilterWithBias, and __MklDummyPadWithConv2D since their
   3275   // names do not match Mkl node names.
   3276   if (n->type_string() != csinfo_.conv2d_with_bias &&
   3277       n->type_string() != csinfo_.pad_with_conv2d &&
   3278       n->type_string() != csinfo_.pad_with_fused_conv2d &&
   3279       n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
   3280       n->type_string() != csinfo_.fused_conv2d &&
   3281       !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
   3282                                 T)) {
   3283     return nullptr;
   3284   }
   3285 
   3286   // For elementwise node, we reuse the Eigen implementation and pass the MKL
   3287   // metadata tensor through so we can avoid conversions. However, if all
   3288   // incoming edges are in TF format, we don't need all this overhead, so
   3289   // replace the elementwise node only if at least one of its parents is a MKL
   3290   // node.
   3291   //
   3292   // Identity nodes can also skip replacement if they are not being served by
   3293   // any MKL nodes.
   3294   //
   3295   // TODO(vrane): Add implementation for element-wise ops that doesn't reuse
   3296   // eigen code to reduce cross-library dependency.
   3297   VLOG(1) << "ELEMENTWISE: checking op: " << n->type_string();
   3298   if (mkl_op_registry::IsMklElementWiseOp(
   3299           mkl_op_registry::GetMklOpName(n->type_string()), T) ||
   3300       n->type_string().find("Identity") != string::npos) {
   3301     VLOG(1) << "ELEMENTWISE: op is elementwise: " << n->type_string();
   3302     bool incoming_mkl_edge = false;
   3303     int num_parent = 0;
   3304     for (auto parent : n->in_edges()) {
   3305       if (mkl_op_registry::IsMklOp(parent->src()->type_string(), T)) {
   3306         VLOG(1) << "ELEMENTWISE: parent " << num_parent++
   3307                 << " is MKL op: " << parent->src()->type_string();
   3308         incoming_mkl_edge = true;
   3309         break;
   3310       } else {
   3311         VLOG(1) << "ELEMENTWISE: parent " << num_parent++
   3312                 << " is NON-MKL op: " << parent->src()->type_string();
   3313       }
   3314     }
   3315     if (incoming_mkl_edge == false) {
   3316       VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which "
   3317                  "has no MKL "
   3318                  "parents.";
   3319       return nullptr;
   3320     } else {
   3321       VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string()
   3322               << " which has MKL parents";
   3323     }
   3324   }
   3325 
   3326   // We now check if rewrite rule applies for this op. If rewrite rule passes
   3327   // for this op, then we rewrite it to Mkl op.
   3328   // Find matching RewriteInfo and then check that rewrite rule applies.
   3329   for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
   3330     if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
   3331       return &*ri;
   3332     }
   3333   }
   3334 
   3335   // Else return not found.
   3336   return nullptr;
   3337 }
   3338 
   3339 //////////////////////////////////////////////////////////////////////////
   3340 //           Helper functions for node fusion
   3341 //////////////////////////////////////////////////////////////////////////
   3342 Status MklLayoutRewritePass::FuseTransposeMklOpTranspose(
   3343     std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
   3344     std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
   3345     string data_format) {
   3346   Node* transpose_to_nhwc = nodes[0];
   3347   Node* mklop = nodes[1];
   3348   Node* transpose_to_nchw = nodes[2];
   3349 
   3350   const int transpose_nhwc_num_inputs = transpose_to_nhwc->num_inputs();
   3351   gtl::InlinedVector<Node*, 4> transpose_nhwc_control_edges;
   3352   gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nhwc_in(
   3353       transpose_nhwc_num_inputs);
   3354   FillInputs(transpose_to_nhwc, &transpose_nhwc_control_edges,
   3355              &transpose_nhwc_in);
   3356 
   3357   const int mklop_num_inputs = mklop->num_inputs();
   3358   gtl::InlinedVector<Node*, 4> mklop_control_edges;
   3359   gtl::InlinedVector<std::pair<Node*, int>, 4> mklop_in(mklop_num_inputs);
   3360   FillInputs(mklop, &mklop_control_edges, &mklop_in);
   3361 
   3362   const int transpose_nchw_num_inputs = transpose_to_nchw->num_inputs();
   3363   gtl::InlinedVector<Node*, 4> transpose_nchw_control_edges;
   3364   gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nchw_in(
   3365       transpose_nchw_num_inputs);
   3366   FillInputs(transpose_to_nchw, &transpose_nchw_control_edges,
   3367              &transpose_nchw_in);
   3368 
   3369   // We use same name as original node, but change the op
   3370   // type.
   3371   NodeBuilder nb(mklop->name(), mklop->type_string());
   3372 
   3373   // Storing the output slots of the input nodes.
   3374   for (int i = 0; i < mklop_num_inputs; i++) {
   3375     if (mklop_in[i].first == transpose_to_nhwc) {
   3376       // Fill "x":
   3377       nb.Input(transpose_nhwc_in[0].first, transpose_nhwc_in[0].second);
   3378     } else {
   3379       // Fill inputs other than "x":
   3380       nb.Input(mklop_in[i].first, mklop_in[i].second);
   3381     }
   3382   }
   3383 
   3384   copy_attrs(const_cast<const Node*>(mklop), &nb, true);
   3385   nb.Attr("data_format", data_format);
   3386 
   3387   // Copy the device assigned to old node to new node.
   3388   nb.Device(mklop->def().device());
   3389 
   3390   // Create node.
   3391   Node* new_node;
   3392   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
   3393   DCHECK(new_node);
   3394 
   3395   // Fill outputs.
   3396   for (const Edge* e : transpose_to_nchw->out_edges()) {
   3397     if (!e->IsControlEdge()) {
   3398       const int kTransposeWithMklOpOutputSlot = 0;
   3399       auto new_edge = (*g)->AddEdge(new_node, kTransposeWithMklOpOutputSlot,
   3400                                     e->dst(), e->dst_input());
   3401       DCHECK(new_edge);
   3402     }
   3403   }
   3404 
   3405   // Copy device assigned to old node to new node.
   3406   new_node->set_assigned_device_name(mklop->assigned_device_name());
   3407 
   3408   // Copy requested_device and assigned_device_name_index
   3409   new_node->set_requested_device(mklop->requested_device());
   3410   new_node->set_assigned_device_name_index(mklop->assigned_device_name_index());
   3411 
   3412   (*g)->RemoveNode(transpose_to_nhwc);
   3413   (*g)->RemoveNode(mklop);
   3414   (*g)->RemoveNode(transpose_to_nchw);
   3415 
   3416   return Status::OK();
   3417 }
   3418 
   3419 Status MklLayoutRewritePass::FuseNode(
   3420     std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
   3421     const MklLayoutRewritePass::FusionInfo fi) {
   3422   return fi.fuse_func(g, nodes, fi.copy_attrs);
   3423 }
   3424 
   3425 std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
   3426 MklLayoutRewritePass::CheckForNodeFusion(Node* a) const {
   3427   // Stores matched nodes, in the same order as node_checkers.
   3428   std::vector<Node*> nodes;
   3429 
   3430   for (auto fi = finfo_.begin(); fi != finfo_.end(); ++fi) {
   3431     //
   3432     // Make sure node "a" and its succeding nodes (b, c ...), match the pattern
   3433     // defined in fusion info (ops[0], ops[1], ...),
   3434     // a.k.a. "a->b->c" matches "op1->op2->op3"
   3435     //
   3436 
   3437     // Stores the first unvisted outgoing edge of each matched node in "nodes".
   3438     std::stack<EdgeSet::const_iterator> current_neighbor_stack;
   3439     nodes.clear();
   3440 
   3441     auto node_checker = fi->node_checkers.begin();
   3442     if (a != nullptr && (*node_checker)(a)) {
   3443       nodes.push_back(a);
   3444       current_neighbor_stack.push(a->out_edges().begin());
   3445       ++node_checker;
   3446     }
   3447 
   3448     while (!nodes.empty()) {
   3449       auto& current_neighbor_iter = current_neighbor_stack.top();
   3450 
   3451       if (current_neighbor_iter != nodes.back()->out_edges().end()) {
   3452         // Found an unvisited edge. Goes through the edge to get the neighbor.
   3453         Node* neighbor_node = (*current_neighbor_iter)->dst();
   3454         ++current_neighbor_stack.top();  // Retrieves the next unvisited edge.
   3455 
   3456         if ((*node_checker)(neighbor_node)) {
   3457           // Found a match. Stores the node and moves to the next checker.
   3458           nodes.push_back(neighbor_node);
   3459           current_neighbor_stack.push(neighbor_node->out_edges().begin());
   3460           if (++node_checker == fi->node_checkers.end()) {
   3461             return make_tuple(true, nodes, *fi);
   3462           }
   3463         }
   3464       } else {
   3465         // Removes the current node since none of its neighbor leads to a
   3466         // further match.
   3467         nodes.pop_back();
   3468         current_neighbor_stack.pop();
   3469         --node_checker;
   3470       }
   3471     }
   3472   }
   3473 
   3474   return make_tuple(false, std::vector<Node*>(), FusionInfo());
   3475 }
   3476 
   3477 ///////////////////////////////////////////////////////////////////////////////
   3478 //              Post-rewrite Mkl metadata fixup pass
   3479 ///////////////////////////////////////////////////////////////////////////////
   3480 bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g,
   3481                                                       const Edge* e_data,
   3482                                                       const Edge* e_metadata) {
   3483   if (g == nullptr || e_data == nullptr || e_metadata == nullptr) {
   3484     return false;
   3485   }
   3486 
   3487   Node* n_data = e_data->src();
   3488   int n_data_op_slot = e_data->src_output();
   3489   int n_metadata_op_slot =
   3490       GetTensorMetaDataIndex(n_data_op_slot, n_data->num_outputs());
   3491 
   3492   // If the source of meta edge is a constant node (producing dummy Mkl metadata
   3493   // tensor), then we will need to fix.
   3494   if (IsConstant(e_metadata->src())) {
   3495     Node* e_metadata_dst = e_metadata->dst();
   3496     int e_metadata_in_slot = e_metadata->dst_input();
   3497     CHECK_NOTNULL((*g)->AddEdge(n_data, n_metadata_op_slot, e_metadata_dst,
   3498                                 e_metadata_in_slot));
   3499 
   3500     (*g)->RemoveEdge(e_metadata);
   3501     return true;
   3502   }
   3503 
   3504   return false;
   3505 }
   3506 
   3507 bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
   3508                                                Node* n) {
   3509   bool result = false;
   3510 
   3511   // If graph node is not Mkl node, then return.
   3512   DataType T = DT_INVALID;
   3513   if (!GetNodeAttr(n->def(), "T", &T).ok() ||
   3514       !mkl_op_registry::IsMklOp(n->type_string(), T)) {
   3515     return result;
   3516   }
   3517 
   3518   // If it is Mkl node, then check if the input edges to this node that carry
   3519   // Mkl metadata are linked up correctly with the source node.
   3520 
   3521   // For Mkl nodes, we generate twice the number of input tensors (n for Mkl
   3522   // data tensors + n for Mkl metadata tensors). We need to check for correct
   3523   // connection of n metadata tensors only.
   3524   int num_data_inputs = n->num_inputs() / 2;
   3525   for (int idx = 0; idx < num_data_inputs; idx++) {
   3526     // Get the edge connecting input slot with index (idx).
   3527     const Edge* e = nullptr;
   3528     TF_CHECK_OK(n->input_edge(idx, &e));
   3529 
   3530     // If e is control edge, then skip.
   3531     if (e->IsControlEdge()) {
   3532       continue;
   3533     }
   3534 
   3535     // Check that the source node for edge 'e' is Mkl node. If it is not an Mkl
   3536     // node, then we don't need to do anything.
   3537     Node* e_src = e->src();
   3538     if (GetNodeAttr(e_src->def(), "T", &T).ok() &&
   3539         mkl_op_registry::IsMklOp(e_src->type_string(), T)) {
   3540       // Source node for edge 'e' is Mkl node.
   3541       // Destination node and destination input slot of e is node 'n' and 'idx'
   3542       // resp.
   3543       CHECK_EQ(e->dst(), n);
   3544       CHECK_EQ(e->dst_input(), idx);
   3545 
   3546       // Let's get edge that carries Mkl metadata corresponding to Mkl data edge
   3547       // 'e'. For that, let's first get the input slot of 'n' where the meta
   3548       // edge will feed the value.
   3549       int e_meta_in_slot =
   3550           GetTensorMetaDataIndex(e->dst_input(), n->num_inputs());
   3551       const Edge* e_meta = nullptr;
   3552       TF_CHECK_OK(n->input_edge(e_meta_in_slot, &e_meta));
   3553 
   3554       // Let's check if we need to fix this meta edge.
   3555       if (FixMklMetaDataEdgeIfNeeded(g, e, e_meta)) {
   3556         result = true;
   3557       }
   3558     }
   3559   }
   3560 
   3561   return result;
   3562 }
   3563 
   3564 ///////////////////////////////////////////////////////////////////////////////
   3565 //              Run function for the pass
   3566 ///////////////////////////////////////////////////////////////////////////////
   3567 
   3568 bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
   3569   bool result = false;
   3570   CHECK_NOTNULL(g);
   3571 
   3572   DumpGraph("Before running MklLayoutRewritePass", &**g);
   3573 
   3574   std::vector<Node*> order;
   3575   GetReversePostOrder(**g, &order);  // This will give us topological sort.
   3576   for (Node* n : order) {
   3577     // If node is not an op or it cannot run on CPU device, then skip.
   3578     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
   3579       continue;
   3580     }
   3581 
   3582     Node* m = nullptr;
   3583     if ((m = CheckForNodeMerge(n)) != nullptr && CanOpRunOnCPUDevice(m)) {
   3584       // Check if the node 'n' can be merged with any other node. If it can
   3585       // be 'm' contains the node with which it can be merged.
   3586       string n1_name = n->name();
   3587       string n2_name = m->name();
   3588 
   3589       VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and "
   3590               << n2_name << " for merging";
   3591 
   3592       if (MergeNode(g, n, m) == Status::OK()) {
   3593         VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and "
   3594                 << n2_name;
   3595         result = true;
   3596       }
   3597     }
   3598   }
   3599 
   3600   DumpGraph("After running MklLayoutRewritePass(NodeMerge)", &**g);
   3601 
   3602   order.clear();
   3603   GetReversePostOrder(**g, &order);  // This will give us topological sort.
   3604   for (Node* n : order) {
   3605     // If node is not an op or it cannot run on CPU device, then skip.
   3606     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
   3607       continue;
   3608     }
   3609 
   3610     auto check_result = CheckForNodeFusion(n);
   3611     bool found_pattern = std::get<0>(check_result);
   3612     std::vector<Node*> nodes = std::get<1>(check_result);
   3613     const FusionInfo fi = std::get<2>(check_result);
   3614 
   3615     // if "found_pattern" is true, we can do the fusion.
   3616     if (found_pattern) {
   3617       if (FuseNode(g, nodes, fi) == Status::OK()) {
   3618         result = true;
   3619       }
   3620     }
   3621   }
   3622   DumpGraph("After running MklLayoutRewritePass(NodeFusion)", &**g);
   3623 
   3624   order.clear();
   3625   GetReversePostOrder(**g, &order);  // This will give us topological sort.
   3626   for (Node* n : order) {
   3627     // If node is not an op or it cannot run on CPU device, then skip.
   3628     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
   3629       continue;
   3630     }
   3631 
   3632     const RewriteInfo* ri = nullptr;
   3633     // We will first search if node is to be rewritten.
   3634     if ((ri = CheckForNodeRewrite(n)) != nullptr) {
   3635       string node_name = n->name();
   3636       string op_name = n->type_string();
   3637 
   3638       VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name
   3639               << " with op " << op_name << " for rewrite using"
   3640               << " layout optimization.";
   3641 
   3642       if (RewriteNode(g, n, ri) == Status::OK()) {
   3643         VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name
   3644                 << " with op " << op_name << " for Mkl layout optimization.";
   3645         result = true;
   3646       }
   3647     }
   3648   }
   3649 
   3650   DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g);
   3651 
   3652   order.clear();
   3653   GetReversePostOrder(**g, &order);  // This will give us topological sort.
   3654   for (Node* n : order) {
   3655     // If node is not an op or it cannot run on CPU device, then skip.
   3656     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
   3657       continue;
   3658     }
   3659     if (FixMklMetaDataEdges(g, n)) {
   3660       string node_name = n->name();
   3661       string op_name = n->type_string();
   3662 
   3663       VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node "
   3664               << node_name << " with op " << op_name;
   3665       result = true;
   3666     }
   3667   }
   3668   DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)",
   3669             &**g);
   3670 
   3671   return result;
   3672 }
   3673 
   3674 bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
   3675   return MklLayoutRewritePass().RunPass(g);
   3676 }
   3677 
   3678 Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
   3679   if (options.graph == nullptr && options.partition_graphs == nullptr) {
   3680     return Status::OK();
   3681   }
   3682   if (DisableMKL()) {
   3683     VLOG(2) << "TF-MKL: Disabling MKL";
   3684     return Status::OK();
   3685   }
   3686 
   3687   auto process_graph = [&](std::unique_ptr<Graph>* g) {
   3688     // Get the ownership of a graph
   3689     std::unique_ptr<Graph>* ng = std::move(g);
   3690     RunPass(ng);
   3691     // Return the ownership of a graph back
   3692     g->reset(ng->release());
   3693   };
   3694 
   3695   if (kMklLayoutRewritePassGroup !=
   3696       OptimizationPassRegistry::POST_PARTITIONING) {
   3697     // For any pre-partitioning phase, a graph is stored in options.graph.
   3698     process_graph(options.graph);
   3699   } else {
   3700     // For post partitioning phase, graphs are stored in
   3701     // options.partition_graphs.
   3702     for (auto& pg : *options.partition_graphs) {
   3703       process_graph(&pg.second);
   3704     }
   3705   }
   3706 
   3707   return Status::OK();
   3708 }
   3709 
   3710 }  // namespace tensorflow
   3711 
   3712 #endif
   3713