Home | History | Annotate | Download | only in graph
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifdef INTEL_MKL
     17 
     18 #include "tensorflow/core/graph/mkl_layout_pass.h"
     19 #include "tensorflow/core/graph/mkl_graph_util.h"
     20 
     21 #include <algorithm>
     22 #include <string>
     23 #include <vector>
     24 
     25 #include "tensorflow/core/framework/op.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/graph/graph.h"
     28 #include "tensorflow/core/graph/graph_constructor.h"
     29 #include "tensorflow/core/graph/testlib.h"
     30 #include "tensorflow/core/kernels/ops_util.h"
     31 #include "tensorflow/core/lib/random/simple_philox.h"
     32 #include "tensorflow/core/lib/strings/str_util.h"
     33 #include "tensorflow/core/lib/strings/stringprintf.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/protobuf.h"
     36 #include "tensorflow/core/platform/test.h"
     37 #include "tensorflow/core/platform/test_benchmark.h"
     38 
     39 namespace tensorflow {
     40 
     41 #ifdef INTEL_MKL_ML
     42 
     43 namespace {
     44 
     45 const char kCPUDevice[] = "/job:a/replica:0/task:0/device:CPU:0";
     46 const char kGPUDevice[] = "/job:a/replica:0/task:0/device:GPU:0";
     47 
     48 static void InitGraph(const string& s, Graph* graph,
     49                       const string& device = kCPUDevice) {
     50   GraphDef graph_def;
     51 
     52   auto parser = protobuf::TextFormat::Parser();
     53   //  parser.AllowRelaxedWhitespace(true);
     54   CHECK(parser.MergeFromString(s, &graph_def)) << s;
     55   GraphConstructorOptions opts;
     56   TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
     57 
     58   for (Node* node : graph->nodes()) {
     59     node->set_assigned_device_name(device);
     60   }
     61 }
     62 
     63 class MklLayoutPassTest : public ::testing::Test {
     64  public:
     65   MklLayoutPassTest() : graph_(OpRegistry::Global()) {}
     66 
     67   void InitGraph(const string& s, const string& device = kCPUDevice) {
     68     ::tensorflow::InitGraph(s, &graph_, device);
     69     original_ = CanonicalGraphString(&graph_);
     70   }
     71 
     72   static bool IncludeNode(const Node* n) { return n->IsOp(); }
     73 
     74   static string EdgeId(const Node* n, int index) {
     75     if (index == 0) {
     76       return n->name();
     77     } else if (index == Graph::kControlSlot) {
     78       return strings::StrCat(n->name(), ":control");
     79     } else {
     80       return strings::StrCat(n->name(), ":", index);
     81     }
     82   }
     83 
     84   string CanonicalGraphString(Graph* g) {
     85     std::vector<string> nodes;
     86     std::vector<string> edges;
     87     for (const Node* n : g->nodes()) {
     88       if (IncludeNode(n)) {
     89         nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")"));
     90       }
     91     }
     92     for (const Edge* e : g->edges()) {
     93       if (IncludeNode(e->src()) && IncludeNode(e->dst())) {
     94         edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->",
     95                                         EdgeId(e->dst(), e->dst_input())));
     96       }
     97     }
     98     // Canonicalize
     99     std::sort(nodes.begin(), nodes.end());
    100     std::sort(edges.begin(), edges.end());
    101     return strings::StrCat(str_util::Join(nodes, ";"), "|",
    102                            str_util::Join(edges, ";"));
    103   }
    104 
    105   string DoMklLayoutOptimizationPass() {
    106     string before = CanonicalGraphString(&graph_);
    107     LOG(ERROR) << "Before MKL layout rewrite pass: " << before;
    108 
    109     std::unique_ptr<Graph>* ug = new std::unique_ptr<Graph>(&graph_);
    110     RunMklLayoutRewritePass(ug);
    111 
    112     string result = CanonicalGraphString(&graph_);
    113     LOG(ERROR) << "After MKL layout rewrite pass:  " << result;
    114     return result;
    115   }
    116 
    117   const string& OriginalGraph() const { return original_; }
    118 
    119   Graph graph_;
    120   string original_;
    121 };
    122 
    123 REGISTER_OP("Input").Output("o: float").SetIsStateful();
    124 REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
    125 REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
    126 REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
    127 REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
    128 REGISTER_OP("_MklInput2")
    129     .Output("o: uint8")
    130     .Output("o1: uint8")
    131     .SetIsStateful();
    132 
    133 /////////////////////////////////////////////////////////////////////
    134 //  Unit tests related to node merge optiimization
    135 /////////////////////////////////////////////////////////////////////
    136 
    137 TEST_F(MklLayoutPassTest, Basic) {
    138   InitGraph(
    139       "node { name: 'A' op: 'Input'}"
    140       "node { name: 'B' op: 'Input'}"
    141       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
    142       " input: ['A', 'B'] }"
    143       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
    144       " input: ['A', 'B'] }");
    145   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    146             "A(Input);B(Input);C(Zeta);D(Zeta)|"
    147             "A->C;A->D;B->C:1;B->D:1");
    148 }
    149 
    150 // Test set 1: Conv2D + AddBias
    151 
    152 // C=_MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Zeta(E,Y) (for interleaved ordering)
    153 // C=_MklConv2D(A,B,M,N); E=BiasAdd(C,D); Z=Zeta(E,Y) (for contiguous ordering)
    154 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
    155   CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
    156   InitGraph(
    157       "node { name: 'A' op: 'Input'}"
    158       "node { name: 'B' op: 'Input'}"
    159       "node { name: 'M' op: '_MklInput'}"
    160       "node { name: 'N' op: '_MklInput'}"
    161       "node { name: 'C' op: '_MklConv2D'"
    162       " attr { key: 'T'                value { type: DT_FLOAT } }"
    163       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    164       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    165       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    166       " attr { key: 'padding'          value { s: 'SAME' } }"
    167       " input: ['A', 'B', 'M', 'N']}"
    168       "node { name: 'D' op: 'Input'}"
    169       "node { name: 'E' op: 'BiasAdd'"
    170       " attr { key: 'T'                value { type: DT_FLOAT } }"
    171       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    172       " input: ['C', 'D'] }"
    173       "node { name: 'Y' op: 'Input'}"
    174       "node { name: 'Z' op: 'Zeta'"
    175       " attr {key: 'T'                 value { type: DT_FLOAT } }"
    176       " input: ['E', 'Y']}");
    177   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    178             "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
    179             "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->E;"
    180             "A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;M->E:3;"
    181             "N->E:4;Y->Z:1");
    182 }
    183 
    184 // C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Zeta(E,Y) (for interleaved)
    185 // C=_MklConv2D(A,B,M:1,N:1); E=BiasAdd(C,D); Z=Zeta(E,Y) (for contiguous)
    186 // Test for correct output slots selected
    187 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
    188   CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
    189   InitGraph(
    190       "node { name: 'A' op: 'Input'}"
    191       "node { name: 'B' op: 'Input'}"
    192       "node { name: 'M' op: '_MklInput2'}"
    193       "node { name: 'N' op: '_MklInput2'}"
    194       "node { name: 'C' op: '_MklConv2D'"
    195       " attr { key: 'T'                value { type: DT_FLOAT } }"
    196       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    197       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    198       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    199       " attr { key: 'padding'          value { s: 'SAME' } }"
    200       " input: ['A', 'B', 'M:1', 'N:1']}"
    201       "node { name: 'D' op: 'Input'}"
    202       "node { name: 'E' op: 'BiasAdd'"
    203       " attr { key: 'T'                value { type: DT_FLOAT } }"
    204       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    205       " input: ['C', 'D'] }"
    206       "node { name: 'Y' op: 'Input'}"
    207       "node { name: 'Z' op: 'Zeta'"
    208       " attr {key: 'T'                 value { type: DT_FLOAT } }"
    209       " input: ['E', 'Y']}");
    210   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    211             "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
    212             "M(_MklInput2);N(_MklInput2);Y(Input);Z(Zeta)|A->E;"
    213             "A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;"
    214             "M:1->E:3;N:1->E:4;Y->Z:1");
    215 }
    216 
    217 // C=Conv2D(A,B); E=BiasAdd(C,D); Z=Zeta(E,Y);
    218 // This is a case of node rewrite followed by node merge.
    219 // We will first rewrite Conv2D to _MklConv2D, and then merge _MklConv2D
    220 // with BiasAdd to produce _MklConv2DWithBias.
    221 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) {
    222   CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
    223   InitGraph(
    224       "node { name: 'A' op: 'Input'}"
    225       "node { name: 'B' op: 'Input'}"
    226       "node { name: 'C' op: 'Conv2D'"
    227       " attr { key: 'T'                value { type: DT_FLOAT } }"
    228       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    229       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    230       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    231       " attr { key: 'padding'          value { s: 'SAME' } }"
    232       " input: ['A', 'B']}"
    233       "node { name: 'D' op: 'Input'}"
    234       "node { name: 'E' op: 'BiasAdd'"
    235       " attr { key: 'T'                value { type: DT_FLOAT } }"
    236       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    237       " input: ['C', 'D'] }"
    238       "node { name: 'Y' op: 'Input'}"
    239       "node { name: 'Z' op: 'Zeta'"
    240       " attr {key: 'T'                 value { type: DT_FLOAT } }"
    241       " input: ['E', 'Y']}");
    242   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    243             "A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
    244             "DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Zeta)|"
    245             "A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
    246             "A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;"
    247             "DMT/_2->E:5;E->Z;Y->Z:1");
    248 }
    249 
    250 // Graph contains only _MklConv2D, no AddBias.
    251 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) {
    252   InitGraph(
    253       "node { name: 'A' op: 'Input'}"
    254       "node { name: 'B' op: 'Input'}"
    255       "node { name: 'M' op: '_MklInput'}"
    256       "node { name: 'N' op: '_MklInput'}"
    257       "node { name: 'C' op: '_MklConv2D'"
    258       " attr { key: 'T'                value { type: DT_FLOAT } }"
    259       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    260       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    261       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    262       " attr { key: 'padding'          value { s: 'SAME' } }"
    263       " input: ['A', 'B', 'M', 'N']}");
    264   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    265             "A(Input);B(Input);C(_MklConv2D);M(_MklInput);N(_MklInput)|"
    266             "A->C;B->C:1;M->C:2;N->C:3");
    267 }
    268 
    269 // _MklConv2D output does not go to BiasAdd.
    270 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) {
    271   InitGraph(
    272       "node { name: 'A' op: 'Input'}"
    273       "node { name: 'B' op: 'Input'}"
    274       "node { name: 'M' op: '_MklInput'}"
    275       "node { name: 'N' op: '_MklInput'}"
    276       "node { name: 'C' op: '_MklConv2D'"
    277       " attr { key: 'T'                value { type: DT_FLOAT } }"
    278       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    279       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    280       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    281       " attr { key: 'padding'          value { s: 'SAME' } }"
    282       " input: ['A', 'B', 'M', 'N']}"
    283       "node { name: 'D' op: 'Input'}"
    284       "node { name: 'E' op: 'Input'}"
    285       "node { name: 'F' op: 'BiasAdd'"
    286       " attr { key: 'T'                value { type: DT_FLOAT } }"
    287       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    288       " input: ['D', 'E'] }");  // Output of _MklConv2D does not go to BiasAdd.
    289   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    290             "A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);"
    291             "M(_MklInput);N(_MklInput)|A->C;B->C:1;D->F;E->F:1;M->C:2;N->C:3");
    292 }
    293 
    294 // _MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Zeta).
    295 // Merge should not be done in such case.
    296 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
    297   InitGraph(
    298       "node { name: 'A' op: 'Input'}"
    299       "node { name: 'B' op: 'Input'}"
    300       "node { name: 'M' op: '_MklInput'}"
    301       "node { name: 'N' op: '_MklInput'}"
    302       "node { name: 'C' op: '_MklConv2D'"
    303       " attr { key: 'T'                value { type: DT_FLOAT } }"
    304       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    305       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    306       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    307       " attr { key: 'padding'          value { s: 'SAME' } }"
    308       " input: ['A', 'B', 'M', 'N']}"
    309       "node { name: 'D' op: 'Input'}"
    310       "node { name: 'E' op: 'Input'}"
    311       "node { name: 'F' op: 'BiasAdd'"
    312       " attr { key: 'T'                value { type: DT_FLOAT } }"
    313       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    314       " input: ['D', 'E'] }"  // Conv2D has two outputs.
    315                               // No merge should happen.
    316       "node { name: 'G' op: 'Zeta'"
    317       " attr { key: 'T'                value { type: DT_FLOAT } }"
    318       " input: ['C', 'E'] }");
    319   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    320             "A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);"
    321             "G(Zeta);M(_MklInput);N(_MklInput)|A->C;B->C:1;C->G;D->F;"
    322             "E->F:1;E->G:1;M->C:2;N->C:3");
    323 }
    324 
    325 // data_format attribute value mismatch. Merge should not be done
    326 // in such case.
    327 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
    328   InitGraph(
    329       "node { name: 'A' op: 'Input'}"
    330       "node { name: 'B' op: 'Input'}"
    331       "node { name: 'M' op: '_MklInput'}"
    332       "node { name: 'N' op: '_MklInput'}"
    333       "node { name: 'C' op: '_MklConv2D'"
    334       " attr { key: 'T'                value { type: DT_FLOAT } }"
    335       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    336       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    337       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    338       " attr { key: 'padding'          value { s: 'SAME' } }"
    339       " input: ['A', 'B', 'M', 'N']}"
    340       "node { name: 'D' op: 'Input'}"
    341       "node { name: 'E' op: 'BiasAdd'"
    342       " attr { key: 'T'                value { type: DT_FLOAT } }"
    343       " attr { key: 'data_format'      value { s: 'NHCW' } }"
    344       " input: ['C', 'D'] }");
    345   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    346             "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);M(_MklInput);"
    347             "N(_MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3");
    348 }
    349 
    350 // Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias
    351 // rewrite tests
    352 
    353 // BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter
    354 // and BackpropInput
    355 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
    356   InitGraph(
    357       "node { name: 'A' op: 'Input'}"
    358       "node { name: 'B' op: 'Input'}"
    359       "node { name: 'C' op: 'Input'}"
    360       "node { name: 'M' op: '_MklInput'}"
    361       "node { name: 'N' op: '_MklInput'}"
    362       "node { name: 'O' op: '_MklInput'}"
    363       "node { name: 'D' op: '_MklConv2DWithBias'"
    364       " attr { key: 'T'                value { type: DT_FLOAT } }"
    365       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    366       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    367       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    368       " attr { key: 'padding'          value { s: 'SAME' } }"
    369       " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
    370       "node { name: 'E' op: 'Zeta'"
    371       " attr {key: 'T'                 value { type: DT_FLOAT } }"
    372       " input: ['D', 'A']}"
    373       "node { name: 'F' op: 'Int32Input'}"
    374       "node { name: 'G' op: '_MklConv2DBackpropFilter'"
    375       " attr { key: 'T'                value { type: DT_FLOAT } }"
    376       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    377       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    378       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    379       " attr { key: 'padding'          value { s: 'SAME' } }"
    380       " input: ['A', 'F', 'E', 'M', 'N', 'O'] }"
    381       "node { name: 'H' op: 'Int32Input'}"
    382       "node { name: 'I' op: '_MklConv2DBackpropInput'"
    383       " attr { key: 'T'                value { type: DT_FLOAT } }"
    384       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    385       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    386       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    387       " attr { key: 'padding'          value { s: 'SAME' } }"
    388       " input: ['H', 'B', 'E', 'M', 'N', 'O']}"
    389       "node { name: 'J' op: 'BiasAddGrad'"
    390       " attr { key: 'T'                value { type: DT_FLOAT } }"
    391       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    392       " input: ['E'] }");
    393   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    394             "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
    395             "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
    396             "I(_MklConv2DBackpropInput);J(_MklConv2DWithBiasBackpropBias);"
    397             "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G;B->D:1;"
    398             "B->I:1;C->D:2;D->E;DMT/_0->J:1;E->G:2;E->I:2;E->J;"
    399             "E:control->DMT/_0:control;F->G:1;H->I;M->D:3;M->G:3;M->I:3;"
    400             "N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5");
    401 }
    402 
    403 // BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter
    404 // and BackpropInput. But nodes do not match criteria for rewrite. So
    405 // rewrite should not happen.
    406 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative1) {
    407   InitGraph(
    408       "node { name: 'A' op: 'Input'}"
    409       "node { name: 'B' op: 'Input'}"
    410       "node { name: 'C' op: 'Input'}"
    411       "node { name: 'M' op: '_MklInput'}"
    412       "node { name: 'N' op: '_MklInput'}"
    413       "node { name: 'O' op: '_MklInput'}"
    414       "node { name: 'D' op: '_MklConv2DWithBias'"
    415       " attr { key: 'T'                value { type: DT_FLOAT } }"
    416       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    417       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    418       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    419       " attr { key: 'padding'          value { s: 'SAME' } }"
    420       " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
    421       "node { name: 'E' op: 'Zeta'"
    422       " attr {key: 'T'                 value { type: DT_FLOAT } }"
    423       " input: ['D', 'A']}"
    424       "node { name: 'F' op: 'Int32Input'}"
    425       "node { name: 'G' op: '_MklConv2DBackpropFilter'"
    426       " attr { key: 'T'                value { type: DT_FLOAT } }"
    427       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    428       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    429       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    430       " attr { key: 'padding'          value { s: 'SAME' } }"
    431       " input: ['E', 'F', 'A', 'M', 'N', 'O'] }"
    432       "node { name: 'H' op: 'Int32Input'}"
    433       "node { name: 'I' op: '_MklConv2DBackpropInput'"
    434       " attr { key: 'T'                value { type: DT_FLOAT } }"
    435       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    436       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    437       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    438       " attr { key: 'padding'          value { s: 'SAME' } }"
    439       " input: ['H', 'B', 'E', 'M', 'N', 'O']}"
    440       "node { name: 'J' op: 'BiasAddGrad'"
    441       " attr { key: 'T'                value { type: DT_FLOAT } }"
    442       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    443       " input: ['E'] }");
    444   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    445             "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
    446             "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
    447             "I(_MklConv2DBackpropInput);J(BiasAddGrad);"
    448             "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;"
    449             "B->I:1;C->D:2;D->E;E->G;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;"
    450             "M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5");
    451 }
    452 
    453 // BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter
    454 // and BackpropInput. But nodes do not match criteria for rewrite. So
    455 // rewrite should not happen.
    456 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative2) {
    457   InitGraph(
    458       "node { name: 'A' op: 'Input'}"
    459       "node { name: 'B' op: 'Input'}"
    460       "node { name: 'C' op: 'Input'}"
    461       "node { name: 'M' op: '_MklInput'}"
    462       "node { name: 'N' op: '_MklInput'}"
    463       "node { name: 'O' op: '_MklInput'}"
    464       "node { name: 'D' op: '_MklConv2DWithBias'"
    465       " attr { key: 'T'                value { type: DT_FLOAT } }"
    466       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    467       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    468       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    469       " attr { key: 'padding'          value { s: 'SAME' } }"
    470       " input: ['B', 'A', 'C', 'M', 'N', 'O']}"
    471       "node { name: 'E' op: 'Zeta'"
    472       " attr {key: 'T'                 value { type: DT_FLOAT } }"
    473       " input: ['D', 'A']}"
    474       "node { name: 'F' op: 'Int32Input'}"
    475       "node { name: 'G' op: '_MklConv2DBackpropFilter'"
    476       " attr { key: 'T'                value { type: DT_FLOAT } }"
    477       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    478       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    479       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    480       " attr { key: 'padding'          value { s: 'SAME' } }"
    481       " input: ['A', 'F', 'E', 'M', 'N', 'O'] }"
    482       "node { name: 'H' op: 'Int32Input'}"
    483       "node { name: 'I' op: '_MklConv2DBackpropInput'"
    484       " attr { key: 'T'                value { type: DT_FLOAT } }"
    485       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    486       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    487       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    488       " attr { key: 'padding'          value { s: 'SAME' } }"
    489       " input: ['H', 'B', 'E', 'M', 'N', 'O']}"
    490       "node { name: 'J' op: 'BiasAddGrad'"
    491       " attr { key: 'T'                value { type: DT_FLOAT } }"
    492       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    493       " input: ['E'] }");
    494   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    495             "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
    496             "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
    497             "I(_MklConv2DBackpropInput);J(BiasAddGrad);"
    498             "M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;"
    499             "B->I:1;C->D:2;D->E;E->G:2;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;"
    500             "M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5");
    501 }
    502 
    503 // BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only
    504 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) {
    505   InitGraph(
    506       "node { name: 'A' op: 'Input'}"
    507       "node { name: 'B' op: 'Input'}"
    508       "node { name: 'C' op: 'Input'}"
    509       "node { name: 'M' op: '_MklInput'}"
    510       "node { name: 'N' op: '_MklInput'}"
    511       "node { name: 'O' op: '_MklInput'}"
    512       "node { name: 'D' op: '_MklConv2DWithBias'"
    513       " attr { key: 'T'                value { type: DT_FLOAT } }"
    514       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    515       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    516       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    517       " attr { key: 'padding'          value { s: 'SAME' } }"
    518       " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
    519       "node { name: 'E' op: 'Zeta'"
    520       " attr {key: 'T'                 value { type: DT_FLOAT } }"
    521       " input: ['D', 'A']}"
    522       "node { name: 'F' op: 'Int32Input'}"
    523       "node { name: 'G' op: '_MklConv2DBackpropFilter'"
    524       " attr { key: 'T'                value { type: DT_FLOAT } }"
    525       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    526       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    527       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    528       " attr { key: 'padding'          value { s: 'SAME' } }"
    529       " input: ['A', 'F', 'E', 'M', 'N', 'O'] }"
    530       "node { name: 'H' op: 'BiasAddGrad'"
    531       " attr { key: 'T'                value { type: DT_FLOAT } }"
    532       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    533       " input: ['E'] }");
    534   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    535             "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
    536             "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);"
    537             "H(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);"
    538             "O(_MklInput)|A->D;A->E:1;A->G;B->D:1;C->D:2;D->E;DMT/_0->H:1;"
    539             "E->G:2;E->H;E:control->DMT/_0:control;F->G:1;M->D:3;M->G:3;"
    540             "N->D:4;N->G:4;O->D:5;O->G:5");
    541 }
    542 
    543 // BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only
    544 // But BackpropFilter node inputs do not satisfy criteria for rewrite.
    545 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative1) {
    546   InitGraph(
    547       "node { name: 'A' op: 'Input'}"
    548       "node { name: 'B' op: 'Input'}"
    549       "node { name: 'C' op: 'Input'}"
    550       "node { name: 'M' op: '_MklInput'}"
    551       "node { name: 'N' op: '_MklInput'}"
    552       "node { name: 'O' op: '_MklInput'}"
    553       "node { name: 'D' op: '_MklConv2DWithBias'"
    554       " attr { key: 'T'                value { type: DT_FLOAT } }"
    555       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    556       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    557       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    558       " attr { key: 'padding'          value { s: 'SAME' } }"
    559       " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
    560       "node { name: 'E' op: 'Zeta'"
    561       " attr {key: 'T'                 value { type: DT_FLOAT } }"
    562       " input: ['D', 'A']}"
    563       "node { name: 'F' op: 'Int32Input'}"
    564       "node { name: 'G' op: '_MklConv2DBackpropFilter'"
    565       " attr { key: 'T'                value { type: DT_FLOAT } }"
    566       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    567       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    568       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    569       " attr { key: 'padding'          value { s: 'SAME' } }"
    570       " input: ['E', 'F', 'A', 'M', 'N', 'O'] }"
    571       "node { name: 'H' op: 'BiasAddGrad'"
    572       " attr { key: 'T'                value { type: DT_FLOAT } }"
    573       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    574       " input: ['E'] }");
    575   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    576             "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
    577             "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);"
    578             "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;"
    579             "C->D:2;D->E;E->G;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;"
    580             "O->G:5");
    581 }
    582 
    583 // BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only
    584 // But BackpropFilter node inputs do not satisfy criteria for rewrite.
    585 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative2) {
    586   InitGraph(
    587       "node { name: 'A' op: 'Input'}"
    588       "node { name: 'B' op: 'Input'}"
    589       "node { name: 'C' op: 'Input'}"
    590       "node { name: 'M' op: '_MklInput'}"
    591       "node { name: 'N' op: '_MklInput'}"
    592       "node { name: 'O' op: '_MklInput'}"
    593       "node { name: 'D' op: '_MklConv2DWithBias'"
    594       " attr { key: 'T'                value { type: DT_FLOAT } }"
    595       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    596       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    597       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    598       " attr { key: 'padding'          value { s: 'SAME' } }"
    599       " input: ['B', 'A', 'C', 'M', 'N', 'O']}"
    600       "node { name: 'E' op: 'Zeta'"
    601       " attr {key: 'T'                 value { type: DT_FLOAT } }"
    602       " input: ['D', 'A']}"
    603       "node { name: 'F' op: 'Int32Input'}"
    604       "node { name: 'G' op: '_MklConv2DBackpropFilter'"
    605       " attr { key: 'T'                value { type: DT_FLOAT } }"
    606       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    607       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    608       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    609       " attr { key: 'padding'          value { s: 'SAME' } }"
    610       " input: ['A', 'F', 'E', 'M', 'N', 'O'] }"
    611       "node { name: 'H' op: 'BiasAddGrad'"
    612       " attr { key: 'T'                value { type: DT_FLOAT } }"
    613       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    614       " input: ['E'] }");
    615   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    616             "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
    617             "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);"
    618             "M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;"
    619             "C->D:2;D->E;E->G:2;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;"
    620             "O->G:5");
    621 }
    622 
    623 // No _MklConv2DWithBias in context, but _MklConv2D in context.
    624 // No rewrite for BiasAddGrad should happen.
    625 // C=_MklConv2D(A,M,B,N); D=Zeta(C,A); E=BiasAddGrad(D) (for interleaved)
    626 // C=_MklConv2D(A,B,M,N); D=Zeta(C,A); E=BiasAddGrad(D) (for contiguous)
    627 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
    628   InitGraph(
    629       "node { name: 'A' op: 'Input'}"
    630       "node { name: 'B' op: 'Input'}"
    631       "node { name: 'M' op: '_MklInput'}"
    632       "node { name: 'N' op: '_MklInput'}"
    633       "node { name: 'C' op: '_MklConv2D'"
    634       " attr { key: 'T'                value { type: DT_FLOAT } }"
    635       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    636       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    637       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    638       " attr { key: 'padding'          value { s: 'SAME' } }"
    639       " input: ['A', 'B', 'M', 'N']}"
    640       "node { name: 'D' op: 'Zeta'"
    641       " attr {key: 'T'                 value { type: DT_FLOAT } }"
    642       " input: ['C', 'A']}"
    643       "node { name: 'E' op: 'BiasAddGrad'"
    644       " attr { key: 'T'                value { type: DT_FLOAT } }"
    645       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    646       " input: ['D'] }");
    647   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    648             "A(Input);B(Input);C(_MklConv2D);D(Zeta);E(BiasAddGrad);"
    649             "M(_MklInput);N(_MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;"
    650             "M->C:2;N->C:3");
    651 }
    652 
    653 // No Conv2D in the context for BiasAddGrad. No rewrite should happen.
    654 // C=Polygamma(A,B); D=Zeta(C,A); E=BiasAddGrad(D)
    655 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D) {
    656   InitGraph(
    657       "node { name: 'A' op: 'Input'}"
    658       "node { name: 'B' op: 'Input'}"
    659       "node { name: 'C' op: 'Polygamma'"
    660       " attr { key: 'T'                value { type: DT_FLOAT } }"
    661       " input: ['A', 'B']}"
    662       "node { name: 'D' op: 'Zeta'"
    663       " attr {key: 'T'                 value { type: DT_FLOAT } }"
    664       " input: ['C', 'A']}"
    665       "node { name: 'E' op: 'BiasAddGrad'"
    666       " attr { key: 'T'                value { type: DT_FLOAT } }"
    667       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    668       " input: ['D'] }");
    669   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    670             "A(Input);B(Input);C(Polygamma);D(Zeta);E(BiasAddGrad)|"
    671             "A->C;A->D:1;B->C:1;C->D;D->E");
    672 }
    673 
    674 // No Conv2D in the context for BiasAddGrad, but MatMul in context.
    675 // Rewrite should happen, but name of BiasAddGrad does not change.
    676 // C=MatMul(A,B); D=Zeta(C,A); E=BiasAddGrad(D)
    677 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D_MatMul) {
    678   InitGraph(
    679       "node { name: 'A' op: 'Input'}"
    680       "node { name: 'B' op: 'Input'}"
    681       "node { name: 'C' op: 'MatMul'"
    682       " attr { key: 'T'                value { type: DT_FLOAT } }"
    683       " attr { key: 'transpose_a'      value { b: false } }"
    684       " attr { key: 'transpose_b'      value { b: false } }"
    685       " input: ['A', 'B']}"
    686       "node { name: 'D' op: 'Zeta'"
    687       " attr {key: 'T'                 value { type: DT_FLOAT } }"
    688       " input: ['C', 'A']}"
    689       "node { name: 'E' op: 'BiasAddGrad'"
    690       " attr { key: 'T'                value { type: DT_FLOAT } }"
    691       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    692       " input: ['D'] }");
    693   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    694             "A(Input);B(Input);C(MatMul);D(Zeta);E(BiasAddGrad)|"
    695             "A->C;A->D:1;B->C:1;C->D;D->E");
    696 }
    697 
    698 // Test set 3: MatMul..BiasAddGrad -> BiasAddGrad rewrite tests
    699 // C=MatMul(A,B); D=Zeta(C,A); E=BiasAddGrad(D)
    700 TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Positive) {
    701   InitGraph(
    702       "node { name: 'A' op: 'Input'}"
    703       "node { name: 'B' op: 'Input'}"
    704       "node { name: 'C' op: 'MatMul'"
    705       " attr { key: 'T'                value { type: DT_FLOAT } }"
    706       " attr { key: 'transpose_a'      value { b: false } }"
    707       " attr { key: 'transpose_b'      value { b: false } }"
    708       " input: ['A', 'B']}"
    709       "node { name: 'D' op: 'Zeta'"
    710       " attr {key: 'T'                 value { type: DT_FLOAT } }"
    711       " input: ['C', 'A']}"
    712       "node { name: 'E' op: 'BiasAddGrad'"
    713       " attr { key: 'T'                value { type: DT_FLOAT } }"
    714       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    715       " input: ['D'] }");
    716   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    717             "A(Input);B(Input);C(MatMul);D(Zeta);E(BiasAddGrad)|"
    718             "A->C;A->D:1;B->C:1;C->D;D->E");
    719 }
    720 
    721 // No MatMul in the context for BiasAddGrad. No rewrite should happen.
    722 // C=Polygamma(A,B); D=Zeta(C,A); E=BiasAddGrad(D)
    723 TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Negative_NoMatMul) {
    724   InitGraph(
    725       "node { name: 'A' op: 'Input'}"
    726       "node { name: 'B' op: 'Input'}"
    727       "node { name: 'C' op: 'Polygamma'"
    728       " attr { key: 'T'                value { type: DT_FLOAT } }"
    729       " input: ['A', 'B']}"
    730       "node { name: 'D' op: 'Zeta'"
    731       " attr {key: 'T'                 value { type: DT_FLOAT } }"
    732       " input: ['C', 'A']}"
    733       "node { name: 'E' op: 'BiasAddGrad'"
    734       " attr { key: 'T'                value { type: DT_FLOAT } }"
    735       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    736       " input: ['D'] }");
    737   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    738             "A(Input);B(Input);C(Polygamma);D(Zeta);E(BiasAddGrad)|"
    739             "A->C;A->D:1;B->C:1;C->D;D->E");
    740 }
    741 
    742 /////////////////////////////////////////////////////////////////////
    743 //  Unit tests related to rewriting node to Mkl node
    744 /////////////////////////////////////////////////////////////////////
    745 
    746 // Single Conv2D Op; No Mkl layer on the input and on the output.
    747 // We will generate dummy Mkl tensor as 2nd input of Conv2D.
    748 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) {
    749   InitGraph(
    750       "node { name: 'A' op: 'Input'}"
    751       "node { name: 'B' op: 'Input'}"
    752       "node { name: 'C' op: 'Conv2D'"
    753       " attr { key: 'T'                value { type: DT_FLOAT } }"
    754       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    755       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    756       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    757       " attr { key: 'padding'          value { s: 'SAME' } }"
    758       " input: ['A', 'B']}"
    759       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
    760       " input: ['B', 'C'] }");
    761   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    762             "A(Input);B(Input);C(_MklConv2D);D(Zeta);DMT/_0(Const);"
    763             "DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
    764             "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
    765             "DMT/_1->C:3");
    766 }
    767 
    768 // 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will
    769 // have 2 outputs, both of which will be inputs to next Conv2D.
    770 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) {
    771   InitGraph(
    772       "node { name: 'A' op: 'Input'}"
    773       "node { name: 'B' op: 'Input'}"
    774       "node { name: 'C' op: 'Conv2D'"
    775       " attr { key: 'T'                value { type: DT_FLOAT } }"
    776       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    777       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    778       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    779       " attr { key: 'padding'          value { s: 'SAME' } }"
    780       " input: ['A', 'B']}"
    781       "node { name: 'D' op: 'Conv2D'"
    782       " attr { key: 'T'                value { type: DT_FLOAT } }"
    783       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    784       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    785       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    786       " attr { key: 'padding'          value { s: 'SAME' } }"
    787       " input: ['A', 'C']}"
    788       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
    789       " input: ['C', 'D'] }");
    790   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    791             "A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);"
    792             "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->C;A->D;"
    793             "A:control->DMT/_0:control;A:control->DMT/_1:control;"
    794             "A:control->DMT/_2:control;B->C:1;C->D:1;C->E;"
    795             "C:2->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2");
    796 }
    797 
    798 // Conv2D with INT32 which is not supported by Mkl
    799 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) {
    800   InitGraph(
    801       "node { name: 'A' op: 'HalfInput'}"
    802       "node { name: 'B' op: 'HalfInput'}"
    803       "node { name: 'C' op: 'Conv2D'"
    804       " attr { key: 'T'                value { type: DT_HALF } }"
    805       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    806       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    807       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    808       " attr { key: 'padding'          value { s: 'SAME' } }"
    809       " input: ['A', 'B']}"
    810       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_HALF } }"
    811       " input: ['B', 'C'] }");
    812   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    813             "A(HalfInput);B(HalfInput);C(Conv2D);D(Zeta)|"
    814             "A->C;B->C:1;B->D;C->D:1");
    815 }
    816 
    817 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_Positive) {
    818   InitGraph(
    819       "node { name: 'A' op: 'Input'}"
    820       "node { name: 'B' op: 'Int32Input'}"
    821       "node { name: 'C' op: 'Input'}"
    822       "node { name: 'D' op: 'Conv2DBackpropFilter'"
    823       " attr { key: 'T'                value { type: DT_FLOAT } }"
    824       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    825       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    826       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    827       " attr { key: 'padding'          value { s: 'SAME' } }"
    828       " input: ['A', 'B', 'C']}"
    829       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
    830       " input: ['A', 'D'] }");
    831   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    832             "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropFilter);"
    833             "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|"
    834             "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
    835             "A:control->DMT/_2:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
    836             "DMT/_1->D:4;DMT/_2->D:5");
    837 }
    838 
    839 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradInput_Positive) {
    840   InitGraph(
    841       "node { name: 'A' op: 'Input'}"
    842       "node { name: 'B' op: 'Int32Input'}"
    843       "node { name: 'C' op: 'Input'}"
    844       "node { name: 'D' op: 'Conv2DBackpropInput'"
    845       " attr { key: 'T'                value { type: DT_FLOAT } }"
    846       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    847       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    848       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    849       " attr { key: 'padding'          value { s: 'SAME' } }"
    850       " input: ['B', 'A', 'C']}"
    851       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
    852       " input: ['A', 'D'] }");
    853   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    854             "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropInput);"
    855             "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|"
    856             "A->D:1;A->E;B->D;B:control->DMT/_0:control;"
    857             "B:control->DMT/_1:control;B:control->DMT/_2:control;C->D:2;"
    858             "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
    859 }
    860 
    861 // Concat Op test: Concat with no Mkl layer feeding it
    862 TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
    863   InitGraph(
    864       "node { name: 'A' op: 'Const' "
    865       " attr { key: 'dtype' value { type: DT_INT32 } }"
    866       " attr { key: 'value' value { "
    867       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
    868       "    int_val: 0 } } } }"
    869       "node { name: 'B' op: 'InputList'"
    870       " attr { key: 'N'                value { i: 2 } }}"
    871       "node { name: 'C' op: 'Input'}"
    872       "node { name: 'D' op: 'Concat'"
    873       " attr { key: 'T'                value { type: DT_FLOAT } }"
    874       " attr { key: 'N'                value { i: 2 } }"
    875       " input: ['A', 'B:0', 'B:1']}"
    876       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
    877       " input: ['C', 'D'] }");
    878   EXPECT_EQ(
    879       DoMklLayoutOptimizationPass(),
    880       "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
    881       "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;"
    882       "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;"
    883       "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
    884 }
    885 
    886 // Concat with 2 Mkl layers feeding it
    887 TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
    888   InitGraph(
    889       "node { name: 'A' op: 'Input'}"
    890       "node { name: 'B' op: 'Input'}"
    891       "node { name: 'C' op: 'Input'}"
    892       "node { name: 'D' op: 'Input'}"
    893       "node { name: 'E' op: 'Conv2D'"
    894       " attr { key: 'T'                value { type: DT_FLOAT } }"
    895       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    896       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    897       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    898       " attr { key: 'padding'          value { s: 'SAME' } }"
    899       " input: ['A', 'B']}"
    900       "node { name: 'F' op: 'Conv2D'"
    901       " attr { key: 'T'                value { type: DT_FLOAT } }"
    902       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    903       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    904       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    905       " attr { key: 'padding'          value { s: 'SAME' } }"
    906       " input: ['C', 'D']}"
    907       "node { name: 'G' op: 'Const' "
    908       " attr { key: 'dtype' value { type: DT_INT32 } }"
    909       " attr { key: 'value' value { "
    910       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
    911       "    int_val: 0 } } } }"
    912       "node { name: 'H' op: 'Concat'"
    913       " attr { key: 'T'                value { type: DT_FLOAT } }"
    914       " attr { key: 'N'                value { i: 2 } }"
    915       " input: ['G', 'E', 'F']}"
    916       "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
    917       " input: ['A', 'H'] }");
    918   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    919             "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
    920             "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
    921             "F(_MklConv2D);G(Const);H(_MklConcat);I(Zeta)|A->E;A->I;"
    922             "A:control->DMT/_2:control;A:control->DMT/_3:control;"
    923             "B->E:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;"
    924             "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
    925             "DMT/_4->H:3;E->H:1;E:2->H:4;F->H:2;F:2->H:5;G->H;"
    926             "G:control->DMT/_4:control;H->I:1");
    927 }
    928 
    929 // Concat with 1 Mkl and 1 non-Mkl layer feeding it
    930 TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) {
    931   InitGraph(
    932       "node { name: 'A' op: 'Input'}"
    933       "node { name: 'B' op: 'Input'}"
    934       "node { name: 'C' op: 'Input'}"
    935       "node { name: 'D' op: 'Input'}"
    936       "node { name: 'E' op: 'Conv2D'"
    937       " attr { key: 'T'                value { type: DT_FLOAT } }"
    938       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    939       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    940       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    941       " attr { key: 'padding'          value { s: 'SAME' } }"
    942       " input: ['A', 'B']}"
    943       "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
    944       " input: ['C', 'D']}"
    945       "node { name: 'G' op: 'Const' "
    946       " attr { key: 'dtype' value { type: DT_INT32 } }"
    947       " attr { key: 'value' value { "
    948       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
    949       "    int_val: 0 } } } }"
    950       "node { name: 'H' op: 'Concat'"
    951       " attr { key: 'T'                value { type: DT_FLOAT } }"
    952       " attr { key: 'N'                value { i: 2 } }"
    953       " input: ['G', 'E', 'F']}"
    954       "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
    955       " input: ['A', 'H'] }");
    956   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    957             "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
    958             "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);"
    959             "H(_MklConcat);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;"
    960             "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
    961             "DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:2->H:4;F->H:2;"
    962             "G->H;G:control->DMT/_2:control;G:control->DMT/_3:control;H->I:1");
    963 }
    964 
    965 // ConcatV2 Op test: ConcatV2 with no Mkl layer feeding it
    966 TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
    967   InitGraph(
    968       "node { name: 'A' op: 'Const' "
    969       " attr { key: 'dtype' value { type: DT_INT32 } }"
    970       " attr { key: 'value' value { "
    971       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
    972       "    int_val: 0 } } } }"
    973       "node { name: 'B' op: 'InputList'"
    974       " attr { key: 'N'                value { i: 2 } }}"
    975       "node { name: 'C' op: 'Input'}"
    976       "node { name: 'D' op: 'ConcatV2'"
    977       " attr { key: 'T'                value { type: DT_FLOAT } }"
    978       " attr { key: 'Tidx'             value { type: DT_INT32 } }"
    979       " attr { key: 'N'                value { i: 2 } }"
    980       " input: ['B:0', 'B:1', 'A']}"
    981       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
    982       " input: ['C', 'D'] }");
    983   EXPECT_EQ(DoMklLayoutOptimizationPass(),
    984             "A(Const);B(InputList);C(Input);D(_MklConcatV2);DMT/_0(Const);"
    985             "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D:2;B->D;B:1->D:1;"
    986             "B:control->DMT/_0:control;B:control->DMT/_1:control;"
    987             "B:control->DMT/_2:control;C->E;D->E:1;DMT/_0->D:3;"
    988             "DMT/_1->D:4;DMT/_2->D:5");
    989 }
    990 
    991 // ConcatV2 with 2 Mkl layers feeding it
    992 TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
    993   InitGraph(
    994       "node { name: 'A' op: 'Input'}"
    995       "node { name: 'B' op: 'Input'}"
    996       "node { name: 'C' op: 'Input'}"
    997       "node { name: 'D' op: 'Input'}"
    998       "node { name: 'E' op: 'Conv2D'"
    999       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1000       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   1001       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   1002       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   1003       " attr { key: 'padding'          value { s: 'SAME' } }"
   1004       " input: ['A', 'B']}"
   1005       "node { name: 'F' op: 'Conv2D'"
   1006       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1007       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   1008       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   1009       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   1010       " attr { key: 'padding'          value { s: 'SAME' } }"
   1011       " input: ['C', 'D']}"
   1012       "node { name: 'G' op: 'Const' "
   1013       " attr { key: 'dtype' value { type: DT_INT32 } }"
   1014       " attr { key: 'value' value { "
   1015       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
   1016       "    int_val: 0 } } } }"
   1017       "node { name: 'H' op: 'ConcatV2'"
   1018       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1019       " attr { key: 'Tidx'             value { type: DT_INT32 } }"
   1020       " attr { key: 'N'                value { i: 2 } }"
   1021       " input: ['E', 'F', 'G']}"
   1022       "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1023       " input: ['A', 'H'] }");
   1024   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1025             "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
   1026             "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
   1027             "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Zeta)|A->E;A->I;"
   1028             "A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;C->F;"
   1029             "C:control->DMT/_0:control;C:control->DMT/_1:control;"
   1030             "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
   1031             "DMT/_4->H:5;E->H;E:2->H:3;E:control->DMT/_4:control;F->H:1;"
   1032             "F:2->H:4;G->H:2;H->I:1");
   1033 }
   1034 
   1035 // ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it
   1036 TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) {
   1037   InitGraph(
   1038       "node { name: 'A' op: 'Input'}"
   1039       "node { name: 'B' op: 'Input'}"
   1040       "node { name: 'C' op: 'Input'}"
   1041       "node { name: 'D' op: 'Input'}"
   1042       "node { name: 'E' op: 'Conv2D'"
   1043       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1044       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   1045       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   1046       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   1047       " attr { key: 'padding'          value { s: 'SAME' } }"
   1048       " input: ['A', 'B']}"
   1049       "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1050       " input: ['C', 'D']}"
   1051       "node { name: 'G' op: 'Const' "
   1052       " attr { key: 'dtype' value { type: DT_INT32 } }"
   1053       " attr { key: 'value' value { "
   1054       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
   1055       "    int_val: 0 } } } }"
   1056       "node { name: 'H' op: 'ConcatV2'"
   1057       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1058       " attr { key: 'Tidx'             value { type: DT_INT32 } }"
   1059       " attr { key: 'N'                value { i: 2 } }"
   1060       " input: ['E', 'F', 'G']}"
   1061       "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1062       " input: ['A', 'H'] }");
   1063   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1064             "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
   1065             "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);"
   1066             "H(_MklConcatV2);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;"
   1067             "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
   1068             "DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:2->H:3;"
   1069             "E:control->DMT/_2:control;E:control->DMT/_3:control;F->H:1;"
   1070             "G->H:2;H->I:1");
   1071 }
   1072 
   1073 TEST_F(MklLayoutPassTest, NodeRewrite_Relu_Positive) {
   1074   InitGraph(
   1075       "node { name: 'A' op: 'Input'}"
   1076       "node { name: 'B' op: 'Relu'"
   1077       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1078       " input: ['A'] }"
   1079       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1080       " input: ['A', 'B'] }");
   1081   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1082             "A(Input);B(_MklRelu);C(Zeta);DMT/_0(Const)|A->B;A->C;"
   1083             "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
   1084 }
   1085 
   1086 TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_Positive) {
   1087   InitGraph(
   1088       "node { name: 'A' op: 'Input'}"
   1089       "node { name: 'B' op: 'Input'}"
   1090       "node { name: 'C' op: 'ReluGrad'"
   1091       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1092       " input: ['A', 'B'] }"
   1093       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1094       " input: ['A', 'C'] }");
   1095   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1096             "A(Input);B(Input);C(_MklReluGrad);D(Zeta);DMT/_0(Const);"
   1097             "DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;"
   1098             "A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
   1099 }
   1100 
   1101 TEST_F(MklLayoutPassTest, NodeRewrite_ReluReluGrad_Positive) {
   1102   InitGraph(
   1103       "node { name: 'A' op: 'Input'}"
   1104       "node { name: 'B' op: 'Relu'"
   1105       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1106       " input: ['A'] }"
   1107       "node { name: 'C' op: 'ReluGrad'"
   1108       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1109       " input: ['A', 'B'] }"
   1110       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1111       " input: ['A', 'C'] }");
   1112   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1113             "A(Input);B(_MklRelu);C(_MklReluGrad);D(Zeta);DMT/_0(Const);"
   1114             "DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;"
   1115             "A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;"
   1116             "DMT/_1->C:2");
   1117 }
   1118 
   1119 TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) {
   1120   InitGraph(
   1121       "node { name: 'A' op: 'Input'}"
   1122       "node { name: 'B' op: 'AvgPool'"
   1123       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1124       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1125       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   1126       " attr { key: 'padding'      value { s: 'VALID' } }"
   1127       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   1128       " input: ['A'] }"
   1129       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1130       " input: ['A', 'B'] }");
   1131   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1132             "A(Input);B(_MklAvgPool);C(Zeta);DMT/_0(Const)|A->B;A->C;"
   1133             "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
   1134 }
   1135 
   1136 TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolGrad_Positive) {
   1137   InitGraph(
   1138       "node { name: 'A' op: 'Int32Input'}"
   1139       "node { name: 'B' op: 'Input'}"
   1140       "node { name: 'C' op: 'AvgPoolGrad' "
   1141       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1142       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1143       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   1144       " attr { key: 'padding'      value { s: 'VALID' } }"
   1145       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   1146       " input: ['A', 'B'] }"
   1147       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1148       " input: ['B', 'C'] }");
   1149   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1150             "A(Int32Input);B(Input);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);"
   1151             "DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
   1152             "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
   1153             "DMT/_1->C:3");
   1154 }
   1155 
   1156 TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolAvgPoolGrad_Positive) {
   1157   InitGraph(
   1158       "node { name: 'A' op: 'Input'}"
   1159       "node { name: 'I' op: 'Int32Input'}"
   1160       "node { name: 'B' op: 'AvgPool'"
   1161       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1162       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1163       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   1164       " attr { key: 'padding'      value { s: 'VALID' } }"
   1165       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   1166       " input: ['A'] }"
   1167       "node { name: 'C' op: 'AvgPoolGrad' "
   1168       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1169       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1170       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   1171       " attr { key: 'padding'      value { s: 'VALID' } }"
   1172       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   1173       " input: ['I', 'B'] }"
   1174       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1175       " input: ['A', 'C'] }");
   1176   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1177             "A(Input);B(_MklAvgPool);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);"
   1178             "DMT/_1(Const);I(Int32Input)|A->B;A->D;A:control->DMT/_0:control;"
   1179             "B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;DMT/_1->C:2;I->C;"
   1180             "I:control->DMT/_1:control");
   1181 }
   1182 
   1183 TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormGrad_Positive) {
   1184   InitGraph(
   1185       "node { name: 'A' op: 'Input'}"
   1186       "node { name: 'B' op: 'Input'}"
   1187       "node { name: 'C' op: 'Input'}"
   1188       "node { name: 'D' op: 'Input'}"
   1189       "node { name: 'E' op: 'Input'}"
   1190       "node { name: 'F' op: 'FusedBatchNormGrad'"
   1191       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1192       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1193       " attr { key: 'epsilon'      value { f: 0.0001 } }"
   1194       " attr { key: 'is_training'  value { b: true } }"
   1195       " input: ['A', 'B', 'C', 'D', 'E'] }"
   1196       "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1197       " input: ['A', 'F'] }");
   1198   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1199             "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
   1200             "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
   1201             "F(_MklFusedBatchNormGrad);G(Zeta)|A->F;A->G;"
   1202             "A:control->DMT/_0:control;A:control->DMT/_1:control;"
   1203             "A:control->DMT/_2:control;A:control->DMT/_3:control;"
   1204             "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
   1205             "DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
   1206             "E->F:4;F->G:1");
   1207 }
   1208 
   1209 TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_Positive) {
   1210   InitGraph(
   1211       "node { name: 'A' op: 'Input'}"
   1212       "node { name: 'B' op: 'Input'}"
   1213       "node { name: 'C' op: 'Input'}"
   1214       "node { name: 'D' op: 'Input'}"
   1215       "node { name: 'E' op: 'Input'}"
   1216       "node { name: 'F' op: 'FusedBatchNorm'"
   1217       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1218       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1219       " attr { key: 'epsilon'      value { f: 0.0001 } }"
   1220       " attr { key: 'is_training'  value { b: true } }"
   1221       " input: ['A', 'B', 'C', 'D', 'E'] }"
   1222       "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1223       " input: ['A', 'F'] }");
   1224   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1225             "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
   1226             "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
   1227             "F(_MklFusedBatchNorm);G(Zeta)|A->F;A->G;"
   1228             "A:control->DMT/_0:control;A:control->DMT/_1:control;"
   1229             "A:control->DMT/_2:control;A:control->DMT/_3:control;"
   1230             "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
   1231             "DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
   1232             "E->F:4;F->G:1");
   1233 }
   1234 
   1235 /////////////////////////////////////////////////////////////////////
   1236 //  Unit tests related to rewriting node for workspace edges
   1237 /////////////////////////////////////////////////////////////////////
   1238 
   1239 /* Test LRN->MaxPool->MaxPoolGrad->LRNGrad replacement by workspace nodes. */
   1240 TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
   1241   InitGraph(
   1242       "node { name: 'A' op: 'Input'}"
   1243       "node { name: 'B' op: 'LRN'"
   1244       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1245       " attr { key: 'alpha'        value { f: 0.001 } }"
   1246       " attr { key: 'beta'         value { f: 0.75 } }"
   1247       " attr { key: 'bias'         value { f: 1.0 } }"
   1248       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1249       " attr { key: 'depth_radius' value { i: 2 } }"
   1250       " input: ['A'] }"
   1251       "node { name: 'C' op: 'MaxPool'"
   1252       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1253       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1254       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   1255       " attr { key: 'padding'      value { s: 'VALID' } }"
   1256       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   1257       " input: ['B'] }"
   1258       "node { name: 'D' op: 'Input'}"
   1259       "node { name: 'E' op: 'MaxPoolGrad'"
   1260       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1261       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1262       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   1263       " attr { key: 'padding'      value { s: 'VALID' } }"
   1264       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   1265       " input: ['B', 'C', 'D'] }"
   1266       "node { name: 'F' op: 'Input'}"
   1267       "node { name: 'G' op: 'LRNGrad'"
   1268       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1269       " attr { key: 'alpha'        value { f: 0.001 } }"
   1270       " attr { key: 'beta'         value { f: 0.75 } }"
   1271       " attr { key: 'bias'         value { f: 1.0 } }"
   1272       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1273       " attr { key: 'depth_radius' value { i: 2 } }"
   1274       " input: ['E', 'F', 'B'] }"
   1275       "node { name: 'H' op: 'Input'}"
   1276       "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1277       " input: ['H', 'G'] }");
   1278   EXPECT_EQ(
   1279       DoMklLayoutOptimizationPass(),
   1280       "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
   1281       "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);"
   1282       "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;"
   1283       "B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;B:control->DMT/_1:control;C->E:1;"
   1284       "C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;DMT/_2->G:5;"
   1285       "E->G;E:1->G:4;E:control->DMT/_2:control;F->G:1;G->I:1;H->I");
   1286 }
   1287 
   1288 /* Test LRN->LRNGrad replacement by workspace nodes. */
   1289 TEST_F(MklLayoutPassTest, LRN_Positive) {
   1290   InitGraph(
   1291       "node { name: 'A' op: 'Input'}"
   1292       "node { name: 'B' op: 'LRN'"
   1293       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1294       " attr { key: 'alpha'        value { f: 0.001 } }"
   1295       " attr { key: 'beta'         value { f: 0.75 } }"
   1296       " attr { key: 'bias'         value { f: 1.0 } }"
   1297       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1298       " attr { key: 'depth_radius' value { i: 2 } }"
   1299       " input: ['A'] }"
   1300       "node { name: 'C' op: 'Input'}"
   1301       "node { name: 'D' op: 'Input'}"
   1302       "node { name: 'E' op: 'LRNGrad'"
   1303       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1304       " attr { key: 'alpha'        value { f: 0.001 } }"
   1305       " attr { key: 'beta'         value { f: 0.75 } }"
   1306       " attr { key: 'bias'         value { f: 1.0 } }"
   1307       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1308       " attr { key: 'depth_radius' value { i: 2 } }"
   1309       " input: ['C', 'D', 'B'] }"
   1310       "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1311       " input: ['C', 'E'] }");
   1312   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1313             "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
   1314             "DMT/_2(Const);E(_MklLRNGrad);F(Zeta)|"
   1315             "A->B;A:control->DMT/_0:control;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;"
   1316             "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
   1317             "D->E:1;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
   1318 }
   1319 
   1320 /* Test LRN->LRNGrad replacement when only one of them is present. */
   1321 TEST_F(MklLayoutPassTest, LRN_Negative1) {
   1322   InitGraph(
   1323       "node { name: 'A' op: 'Input'}"
   1324       "node { name: 'B' op: 'LRN'"
   1325       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1326       " attr { key: 'alpha'        value { f: 0.001 } }"
   1327       " attr { key: 'beta'         value { f: 0.75 } }"
   1328       " attr { key: 'bias'         value { f: 1.0 } }"
   1329       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1330       " attr { key: 'depth_radius' value { i: 2 } }"
   1331       " input: ['A'] }"
   1332       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1333       " input: ['A', 'B'] }");
   1334   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1335             "A(Input);B(_MklLRN);C(Zeta);DMT/_0(Const)|"
   1336             "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
   1337 }
   1338 
   1339 /* Test LRN->LRNGrad replacement when only one of them is present. */
   1340 TEST_F(MklLayoutPassTest, LRN_Negative2) {
   1341   InitGraph(
   1342       "node { name: 'A' op: 'Input'}"
   1343       "node { name: 'B' op: 'Input'}"
   1344       "node { name: 'C' op: 'Input'}"
   1345       "node { name: 'D' op: 'LRNGrad'"
   1346       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1347       " attr { key: 'alpha'        value { f: 0.001 } }"
   1348       " attr { key: 'beta'         value { f: 0.75 } }"
   1349       " attr { key: 'bias'         value { f: 1.0 } }"
   1350       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1351       " attr { key: 'depth_radius' value { i: 2 } }"
   1352       " input: ['A', 'B', 'C'] }"
   1353       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1354       " input: ['A', 'D'] }");
   1355   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1356             "A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);"
   1357             "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|"
   1358             "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
   1359             "A:control->DMT/_2:control;A:control->DMT/_3:control;"
   1360             "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
   1361             "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
   1362 }
   1363 
   1364 /* Test LRN->LRNGrad negative case, where single LRN feeds
   1365    2 LRNGrad nodes at different slots. */
   1366 TEST_F(MklLayoutPassTest, LRN_Negative3) {
   1367   InitGraph(
   1368       "node { name: 'A' op: 'Input'}"
   1369       "node { name: 'B' op: 'LRN'"
   1370       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1371       " attr { key: 'alpha'        value { f: 0.001 } }"
   1372       " attr { key: 'beta'         value { f: 0.75 } }"
   1373       " attr { key: 'bias'         value { f: 1.0 } }"
   1374       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1375       " attr { key: 'depth_radius' value { i: 2 } }"
   1376       " input: ['A'] }"
   1377       "node { name: 'C' op: 'Input'}"
   1378       "node { name: 'D' op: 'Input'}"
   1379       "node { name: 'E' op: 'LRNGrad'"
   1380       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1381       " attr { key: 'alpha'        value { f: 0.001 } }"
   1382       " attr { key: 'beta'         value { f: 0.75 } }"
   1383       " attr { key: 'bias'         value { f: 1.0 } }"
   1384       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1385       " attr { key: 'depth_radius' value { i: 2 } }"
   1386       " input: ['C', 'D', 'B'] }"
   1387       "node { name: 'F' op: 'LRNGrad'"
   1388       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1389       " attr { key: 'alpha'        value { f: 0.001 } }"
   1390       " attr { key: 'beta'         value { f: 0.75 } }"
   1391       " attr { key: 'bias'         value { f: 1.0 } }"
   1392       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1393       " attr { key: 'depth_radius' value { i: 2 } }"
   1394       " input: ['C', 'B', 'D'] }"
   1395       "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1396       " input: ['E', 'F'] }");
   1397   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1398             "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
   1399             "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);"
   1400             "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Zeta)|A->B;"
   1401             "A:control->DMT/_0:control;B->E:2;"
   1402             "B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;"
   1403             "C:control->DMT/_1:control;C:control->DMT/_2:control;"
   1404             "C:control->DMT/_3:control;C:control->DMT/_4:control;"
   1405             "C:control->DMT/_5:control;C:control->DMT/_6:control;"
   1406             "D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
   1407             "DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1");
   1408 }
   1409 
   1410 /* Test MaxPool->MaxPoolGrad replacement by workspace+rewrite nodes. */
   1411 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) {
   1412   InitGraph(
   1413       "node { name: 'A' op: 'Input'}"
   1414       "node { name: 'B' op: 'MaxPool'"
   1415       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1416       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1417       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   1418       " attr { key: 'padding'      value { s: 'VALID' } }"
   1419       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   1420       " input: ['A'] }"
   1421       "node { name: 'C' op: 'Input'}"
   1422       "node { name: 'D' op: 'Input'}"
   1423       "node { name: 'E' op: 'MaxPoolGrad'"
   1424       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1425       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1426       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   1427       " attr { key: 'padding'      value { s: 'VALID' } }"
   1428       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   1429       " input: ['C', 'B', 'D'] }"
   1430       "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1431       " input: ['C', 'E'] }");
   1432   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1433             "A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
   1434             "DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Zeta)|"
   1435             "A->B;A:control->DMT/_0:control;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;"
   1436             "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
   1437             "D->E:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
   1438 }
   1439 
   1440 // Test MaxPool>MaxPoolGrad replacement when only one of them is present.
   1441 // In this case, we will rewrite MaxPool node but workspace edges will not
   1442 // be present.
   1443 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) {
   1444   InitGraph(
   1445       "node { name: 'A' op: 'Input'}"
   1446       "node { name: 'B' op: 'MaxPool'"
   1447       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1448       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1449       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   1450       " attr { key: 'padding'      value { s: 'VALID' } }"
   1451       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   1452       " input: ['A'] }"
   1453       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1454       " input: ['A', 'B'] }");
   1455   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1456             "A(Input);B(_MklMaxPool);C(Zeta);DMT/_0(Const)|"
   1457             "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
   1458 }
   1459 
   1460 // Test MaxPoolGrad replacement when only one of them is present.
   1461 // In this case, we will rewrite MaxPoolGrad and for workspace tensor and
   1462 // its Mkl part, we will generate dummy tensor.
   1463 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
   1464   InitGraph(
   1465       "node { name: 'A' op: 'Input'}"
   1466       "node { name: 'B' op: 'Input'}"
   1467       "node { name: 'C' op: 'Input'}"
   1468       "node { name: 'D' op: 'MaxPoolGrad'"
   1469       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1470       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1471       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   1472       " attr { key: 'padding'      value { s: 'VALID' } }"
   1473       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   1474       " input: ['A', 'B', 'C'] }"
   1475       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1476       " input: ['A', 'D'] }");
   1477   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1478             "A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);"
   1479             "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|"
   1480             "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
   1481             "A:control->DMT/_2:control;A:control->DMT/_3:control;"
   1482             "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
   1483             "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
   1484 }
   1485 
   1486 // Test MaxPool handling for batch-wise pooling (NCHW)
   1487 // No rewrite should take place in such case
   1488 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative3) {
   1489   InitGraph(
   1490       "node { name: 'A' op: 'Input'}"
   1491       "node { name: 'B' op: 'MaxPool'"
   1492       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1493       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1494       " attr { key: 'ksize'        value { list: {i: 2, i:1, i:1, i:1} } }"
   1495       " attr { key: 'padding'      value { s: 'VALID' } }"
   1496       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }"
   1497       " input: ['A'] }"
   1498       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1499       " input: ['A', 'B'] }");
   1500   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1501             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   1502 }
   1503 
   1504 // Test MaxPool handling for batch-wise pooling (NCHW)
   1505 // No rewrite should take place in such case
   1506 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative4) {
   1507   InitGraph(
   1508       "node { name: 'A' op: 'Input'}"
   1509       "node { name: 'B' op: 'MaxPool'"
   1510       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1511       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1512       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:1, i:1} } }"
   1513       " attr { key: 'padding'      value { s: 'VALID' } }"
   1514       " attr { key: 'strides'      value { list: {i: 2, i:1, i:1, i:1} } }"
   1515       " input: ['A'] }"
   1516       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1517       " input: ['A', 'B'] }");
   1518   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1519             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   1520 }
   1521 
   1522 // Test MaxPool handling for depth-wise pooling (NHWC)
   1523 // No rewrite should take place in such case
   1524 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative5) {
   1525   InitGraph(
   1526       "node { name: 'A' op: 'Input'}"
   1527       "node { name: 'B' op: 'MaxPool'"
   1528       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1529       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1530       " attr { key: 'ksize'        value { list: {i: 1, i:2, i:1, i:1} } }"
   1531       " attr { key: 'padding'      value { s: 'VALID' } }"
   1532       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }"
   1533       " input: ['A'] }"
   1534       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1535       " input: ['A', 'B'] }");
   1536   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1537             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   1538 }
   1539 
   1540 // Test MaxPool handling for depth-wise pooling (NCHW)
   1541 // No rewrite should take place in such case
   1542 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative6) {
   1543   InitGraph(
   1544       "node { name: 'A' op: 'Input'}"
   1545       "node { name: 'B' op: 'MaxPool'"
   1546       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1547       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1548       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:1, i:1} } }"
   1549       " attr { key: 'padding'      value { s: 'VALID' } }"
   1550       " attr { key: 'strides'      value { list: {i: 1, i:2, i:1, i:1} } }"
   1551       " input: ['A'] }"
   1552       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1553       " input: ['A', 'B'] }");
   1554   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1555             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   1556 }
   1557 
   1558 // Test MaxPool handling for batch-wise pooling (NHWC)
   1559 // No rewrite should take place in such case
   1560 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative7) {
   1561   InitGraph(
   1562       "node { name: 'A' op: 'Input'}"
   1563       "node { name: 'B' op: 'MaxPool'"
   1564       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1565       " attr { key: 'data_format'  value { s: 'NHWC' } }"
   1566       " attr { key: 'ksize'        value { list: {i: 2, i:1, i:1, i:1} } }"
   1567       " attr { key: 'padding'      value { s: 'VALID' } }"
   1568       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }"
   1569       " input: ['A'] }"
   1570       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1571       " input: ['A', 'B'] }");
   1572   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1573             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   1574 }
   1575 
   1576 // Test MaxPool handling for batch-wise pooling (NHWC)
   1577 // No rewrite should take place in such case
   1578 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative8) {
   1579   InitGraph(
   1580       "node { name: 'A' op: 'Input'}"
   1581       "node { name: 'B' op: 'MaxPool'"
   1582       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1583       " attr { key: 'data_format'  value { s: 'NHWC' } }"
   1584       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:1, i:1} } }"
   1585       " attr { key: 'padding'      value { s: 'VALID' } }"
   1586       " attr { key: 'strides'      value { list: {i: 2, i:1, i:1, i:1} } }"
   1587       " input: ['A'] }"
   1588       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1589       " input: ['A', 'B'] }");
   1590   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1591             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   1592 }
   1593 
   1594 // Test MaxPool handling for depth-wise pooling (NHWC)
   1595 // No rewrite should take place in such case
   1596 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative9) {
   1597   InitGraph(
   1598       "node { name: 'A' op: 'Input'}"
   1599       "node { name: 'B' op: 'MaxPool'"
   1600       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1601       " attr { key: 'data_format'  value { s: 'NHWC' } }"
   1602       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:1, i:2} } }"
   1603       " attr { key: 'padding'      value { s: 'VALID' } }"
   1604       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }"
   1605       " input: ['A'] }"
   1606       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1607       " input: ['A', 'B'] }");
   1608   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1609             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   1610 }
   1611 
   1612 // Test MaxPool handling for depth-wise pooling (NHWC)
   1613 // No rewrite should take place in such case
   1614 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative10) {
   1615   InitGraph(
   1616       "node { name: 'A' op: 'Input'}"
   1617       "node { name: 'B' op: 'MaxPool'"
   1618       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1619       " attr { key: 'data_format'  value { s: 'NHWC' } }"
   1620       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:1, i:1} } }"
   1621       " attr { key: 'padding'      value { s: 'VALID' } }"
   1622       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:2} } }"
   1623       " input: ['A'] }"
   1624       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1625       " input: ['A', 'B'] }");
   1626   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1627             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   1628 }
   1629 
   1630 /////////////////////////////////////////////////////////////////////
   1631 
   1632 // Single Conv2D Op on GPU device
   1633 // No rewrite should happen
   1634 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) {
   1635   InitGraph(
   1636       "node { name: 'A' op: 'Input'}"
   1637       "node { name: 'B' op: 'Input'}"
   1638       "node { name: 'C' op: 'Conv2D'"
   1639       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1640       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   1641       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   1642       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   1643       " attr { key: 'padding'          value { s: 'SAME' } }"
   1644       " input: ['A', 'B']}"
   1645       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1646       " input: ['B', 'C'] }",
   1647       kGPUDevice);
   1648   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1649             "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1");
   1650 }
   1651 
   1652 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) {
   1653   InitGraph(
   1654       "node { name: 'A' op: 'Input'}"
   1655       "node { name: 'B' op: 'Input'}"
   1656       "node { name: 'C' op: 'Input'}"
   1657       "node { name: 'M' op: '_MklInput'}"
   1658       "node { name: 'N' op: '_MklInput'}"
   1659       "node { name: 'O' op: '_MklInput'}"
   1660       "node { name: 'D' op: '_MklConv2DWithBias'"
   1661       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1662       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   1663       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   1664       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   1665       " attr { key: 'padding'          value { s: 'SAME' } }"
   1666       " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
   1667       "node { name: 'E' op: 'Zeta'"
   1668       " attr {key: 'T'                 value { type: DT_FLOAT } }"
   1669       " input: ['D', 'A']}"
   1670       "node { name: 'F' op: 'BiasAddGrad'"
   1671       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1672       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   1673       " input: ['E'] }",
   1674       kGPUDevice);
   1675   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1676             "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
   1677             "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);"
   1678             "O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;E->F;"
   1679             "M->D:3;N->D:4;O->D:5");
   1680 }
   1681 
   1682 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) {
   1683   InitGraph(
   1684       "node { name: 'A' op: 'Input'}"
   1685       "node { name: 'B' op: 'Int32Input'}"
   1686       "node { name: 'C' op: 'Input'}"
   1687       "node { name: 'D' op: 'Conv2DBackpropFilter'"
   1688       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1689       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   1690       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   1691       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   1692       " attr { key: 'padding'          value { s: 'SAME' } }"
   1693       " input: ['A', 'B', 'C']}"
   1694       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1695       " input: ['A', 'D'] }",
   1696       kGPUDevice);
   1697   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1698             "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|"
   1699             "A->D;A->E;B->D:1;C->D:2;D->E:1");
   1700 }
   1701 
   1702 TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) {
   1703   InitGraph(
   1704       "node { name: 'A' op: 'Input'}"
   1705       "node { name: 'B' op: 'Relu'"
   1706       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1707       " input: ['A'] }"
   1708       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1709       " input: ['A', 'B'] }",
   1710       kGPUDevice);
   1711   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1712             "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1");
   1713 }
   1714 
   1715 TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) {
   1716   InitGraph(
   1717       "node { name: 'A' op: 'Input'}"
   1718       "node { name: 'B' op: 'Input'}"
   1719       "node { name: 'C' op: 'ReluGrad'"
   1720       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1721       " input: ['A', 'B'] }"
   1722       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1723       " input: ['A', 'C'] }",
   1724       kGPUDevice);
   1725   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1726             "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1");
   1727 }
   1728 
   1729 TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) {
   1730   InitGraph(
   1731       "node { name: 'A' op: 'Input'}"
   1732       "node { name: 'B' op: 'MaxPool'"
   1733       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1734       " attr { key: 'data_format'  value { s: 'NHWC' } }"
   1735       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:1, i:1} } }"
   1736       " attr { key: 'padding'      value { s: 'VALID' } }"
   1737       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }"
   1738       " input: ['A'] }"
   1739       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1740       " input: ['A', 'B'] }",
   1741       kGPUDevice);
   1742   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1743             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   1744 }
   1745 
   1746 TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) {
   1747   InitGraph(
   1748       "node { name: 'A' op: 'Input'}"
   1749       "node { name: 'B' op: 'AvgPool'"
   1750       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1751       " attr { key: 'data_format'  value { s: 'NHWC' } }"
   1752       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:1, i:1} } }"
   1753       " attr { key: 'padding'      value { s: 'VALID' } }"
   1754       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }"
   1755       " input: ['A'] }"
   1756       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1757       " input: ['A', 'B'] }",
   1758       kGPUDevice);
   1759   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1760             "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1");
   1761 }
   1762 
   1763 // Concat Op test: Concat with no Mkl layer feeding it
   1764 TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) {
   1765   InitGraph(
   1766       "node { name: 'A' op: 'Const' "
   1767       " attr { key: 'dtype' value { type: DT_INT32 } }"
   1768       " attr { key: 'value' value { "
   1769       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
   1770       "    int_val: 0 } } } }"
   1771       "node { name: 'B' op: 'InputList'"
   1772       " attr { key: 'N'                value { i: 2 } }}"
   1773       "node { name: 'C' op: 'Input'}"
   1774       "node { name: 'D' op: 'Concat'"
   1775       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1776       " attr { key: 'N'                value { i: 2 } }"
   1777       " input: ['A', 'B:0', 'B:1']}"
   1778       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1779       " input: ['C', 'D'] }",
   1780       kGPUDevice);
   1781   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1782             "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;"
   1783             "B->D:1;B:1->D:2;C->E;D->E:1");
   1784 }
   1785 
   1786 TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) {
   1787   InitGraph(
   1788       "node { name: 'A' op: 'Const' "
   1789       " attr { key: 'dtype' value { type: DT_INT32 } }"
   1790       " attr { key: 'value' value { "
   1791       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
   1792       "    int_val: 0 } } } }"
   1793       "node { name: 'B' op: 'InputList'"
   1794       " attr { key: 'N'                value { i: 2 } }}"
   1795       "node { name: 'C' op: 'Input'}"
   1796       "node { name: 'D' op: 'ConcatV2'"
   1797       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1798       " attr { key: 'Tidx'             value { type: DT_INT32 } }"
   1799       " attr { key: 'N'                value { i: 2 } }"
   1800       " input: ['B:0', 'B:1', 'A']}"
   1801       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1802       " input: ['C', 'D'] }",
   1803       kGPUDevice);
   1804   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1805             "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|"
   1806             "A->D:2;B->D;B:1->D:1;C->E;D->E:1");
   1807 }
   1808 
   1809 TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) {
   1810   InitGraph(
   1811       "node { name: 'A' op: 'Input'}"
   1812       "node { name: 'B' op: 'Input'}"
   1813       "node { name: 'C' op: 'Input'}"
   1814       "node { name: 'D' op: 'Input'}"
   1815       "node { name: 'E' op: 'Input'}"
   1816       "node { name: 'F' op: 'FusedBatchNorm'"
   1817       " attr { key: 'T'            value { type: DT_FLOAT } }"
   1818       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   1819       " attr { key: 'epsilon'      value { f: 0.0001 } }"
   1820       " attr { key: 'is_training'  value { b: true } }"
   1821       " input: ['A', 'B', 'C', 'D', 'E'] }"
   1822       "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   1823       " input: ['A', 'F'] }",
   1824       kGPUDevice);
   1825   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1826             "A(Input);B(Input);C(Input);D(Input);E(Input);"
   1827             "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
   1828             "E->F:4;F->G:1");
   1829 }
   1830 
   1831 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
   1832   CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
   1833   InitGraph(
   1834       "node { name: 'A' op: 'Input'}"
   1835       "node { name: 'B' op: 'Input'}"
   1836       "node { name: 'M' op: '_MklInput'}"
   1837       "node { name: 'N' op: '_MklInput'}"
   1838       "node { name: 'C' op: '_MklConv2D'"
   1839       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1840       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   1841       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   1842       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   1843       " attr { key: 'padding'          value { s: 'SAME' } }"
   1844       " input: ['A', 'B', 'M', 'N']}"
   1845       "node { name: 'D' op: 'Input'}"
   1846       "node { name: 'E' op: 'BiasAdd'"
   1847       " attr { key: 'T'                value { type: DT_FLOAT } }"
   1848       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   1849       " input: ['C', 'D'] }"
   1850       "node { name: 'Y' op: 'Input'}"
   1851       "node { name: 'Z' op: 'Zeta'"
   1852       " attr {key: 'T'                 value { type: DT_FLOAT } }"
   1853       " input: ['E', 'Y']}",
   1854       kGPUDevice);
   1855   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   1856             "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);"
   1857             "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;"
   1858             "B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1");
   1859 }
   1860 
   1861 /////////////////////////////////////////////////////////////////////
   1862 
   1863 static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
   1864   testing::StopTiming();
   1865   string s;
   1866   for (int in = 0; in < 10; in++) {
   1867     s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in);
   1868   }
   1869   random::PhiloxRandom philox(301, 17);
   1870   random::SimplePhilox rnd(&philox);
   1871   for (int op = 0; op < op_nodes; op++) {
   1872     s += strings::Printf(
   1873         "node { name: 'op%04d' op: 'Zeta' attr { key: 'T' value { "
   1874         "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
   1875         op, rnd.Uniform(10), rnd.Uniform(10));
   1876   }
   1877 
   1878   bool first = true;
   1879   while (iters > 0) {
   1880     Graph* graph = new Graph(OpRegistry::Global());
   1881     InitGraph(s, graph);
   1882     int N = graph->num_node_ids();
   1883     if (first) {
   1884       testing::SetLabel(strings::StrCat("Per graph node.  Nodes: ", N));
   1885       first = false;
   1886     }
   1887     {
   1888       testing::StartTiming();
   1889       std::unique_ptr<Graph> ug(graph);
   1890       RunMklLayoutRewritePass(&ug);
   1891       testing::StopTiming();
   1892     }
   1893     iters -= N;  // Our benchmark units are individual graph nodes,
   1894                  // not whole graphs
   1895     // delete graph;
   1896   }
   1897 }
   1898 BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
   1899 
   1900 }  // namespace
   1901 
   1902 #else  // INTEL_MKL_ML
   1903 
   1904 namespace {
   1905 
   1906 const char kCPUDevice[] = "/job:a/replica:0/task:0/device:CPU:0";
   1907 const char kGPUDevice[] = "/job:a/replica:0/task:0/device:GPU:0";
   1908 
   1909 static void InitGraph(const string& s, Graph* graph,
   1910                       const string& device = kCPUDevice) {
   1911   GraphDef graph_def;
   1912 
   1913   auto parser = protobuf::TextFormat::Parser();
   1914   //  parser.AllowRelaxedWhitespace(true);
   1915   CHECK(parser.MergeFromString(s, &graph_def)) << s;
   1916   GraphConstructorOptions opts;
   1917   TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
   1918 
   1919   for (Node* node : graph->nodes()) {
   1920     node->set_assigned_device_name(device);
   1921   }
   1922 }
   1923 
   1924 class MklLayoutPassTest : public ::testing::Test {
   1925  public:
   1926   MklLayoutPassTest() : graph_(OpRegistry::Global()) {}
   1927 
   1928   void InitGraph(const string& s, const string& device = kCPUDevice) {
   1929     ::tensorflow::InitGraph(s, &graph_, device);
   1930     original_ = CanonicalGraphString(&graph_);
   1931   }
   1932 
   1933   static bool IncludeNode(const Node* n) { return n->IsOp(); }
   1934 
   1935   static string EdgeId(const Node* n, int index) {
   1936     if (index == 0) {
   1937       return n->name();
   1938     } else if (index == Graph::kControlSlot) {
   1939       return strings::StrCat(n->name(), ":control");
   1940     } else {
   1941       return strings::StrCat(n->name(), ":", index);
   1942     }
   1943   }
   1944 
   1945   string CanonicalGraphString(Graph* g) {
   1946     std::vector<string> nodes;
   1947     std::vector<string> edges;
   1948     for (const Node* n : g->nodes()) {
   1949       if (IncludeNode(n)) {
   1950         nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")"));
   1951       }
   1952     }
   1953     for (const Edge* e : g->edges()) {
   1954       if (IncludeNode(e->src()) && IncludeNode(e->dst())) {
   1955         edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->",
   1956                                         EdgeId(e->dst(), e->dst_input())));
   1957       }
   1958     }
   1959     // Canonicalize
   1960     std::sort(nodes.begin(), nodes.end());
   1961     std::sort(edges.begin(), edges.end());
   1962     return strings::StrCat(str_util::Join(nodes, ";"), "|",
   1963                            str_util::Join(edges, ";"));
   1964   }
   1965 
   1966   string DoMklLayoutOptimizationPass() {
   1967     string before = CanonicalGraphString(&graph_);
   1968     LOG(ERROR) << "Before MKL layout rewrite pass: " << before;
   1969 
   1970     std::unique_ptr<Graph>* ug = new std::unique_ptr<Graph>(&graph_);
   1971     RunMklLayoutRewritePass(ug);
   1972 
   1973     string result = CanonicalGraphString(&graph_);
   1974     LOG(ERROR) << "After MKL layout rewrite pass:  " << result;
   1975     return result;
   1976   }
   1977 
   1978   const string& OriginalGraph() const { return original_; }
   1979 
   1980   Graph graph_;
   1981   string original_;
   1982 };
   1983 
   1984 REGISTER_OP("Input").Output("o: float").SetIsStateful();
   1985 REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
   1986 REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
   1987 REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
   1988 REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
   1989 REGISTER_OP("_MklInput2")
   1990     .Output("o: uint8")
   1991     .Output("o1: uint8")
   1992     .SetIsStateful();
   1993 
   1994 /////////////////////////////////////////////////////////////////////
   1995 //  Unit tests related to node merge optiimization
   1996 /////////////////////////////////////////////////////////////////////
   1997 
   1998 TEST_F(MklLayoutPassTest, Basic) {
   1999   InitGraph(
   2000       "node { name: 'A' op: 'Input'}"
   2001       "node { name: 'B' op: 'Input'}"
   2002       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2003       " input: ['A', 'B'] }"
   2004       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2005       " input: ['A', 'B'] }");
   2006   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2007             "A(Input);B(Input);C(Zeta);D(Zeta)|"
   2008             "A->C;A->D;B->C:1;B->D:1");
   2009 }
   2010 
   2011 // Test set 1: Conv2D + AddBias
   2012 
   2013 // C=Conv2D(A,B); E=BiasAdd(C,D); Z=Zeta(E,Y)
   2014 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
   2015   CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
   2016   InitGraph(
   2017       "node { name: 'A' op: 'Input'}"
   2018       "node { name: 'B' op: 'Input'}"
   2019       "node { name: 'C' op: 'Conv2D'"
   2020       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2021       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2022       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2023       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2024       " attr { key: 'padding'          value { s: 'SAME' } }"
   2025       " input: ['A', 'B']}"
   2026       "node { name: 'D' op: 'Input'}"
   2027       "node { name: 'E' op: 'BiasAdd'"
   2028       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2029       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2030       " input: ['C', 'D'] }"
   2031       "node { name: 'Y' op: 'Input'}"
   2032       "node { name: 'Z' op: 'Zeta'"
   2033       " attr {key: 'T'                 value { type: DT_FLOAT } }"
   2034       " input: ['E', 'Y']}");
   2035   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2036             "A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
   2037             "DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Zeta)|A->E;"
   2038             "A:control->DMT/_0:control;A:control->DMT/_1:control;"
   2039             "A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;"
   2040             "DMT/_2->E:5;E->Z;Y->Z:1");
   2041 }
   2042 
   2043 // Graph contains only Conv2D, no AddBias.
   2044 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) {
   2045   InitGraph(
   2046       "node { name: 'A' op: 'Input'}"
   2047       "node { name: 'B' op: 'Input'}"
   2048       "node { name: 'C' op: 'Conv2D'"
   2049       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2050       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2051       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2052       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2053       " attr { key: 'padding'          value { s: 'SAME' } }"
   2054       " input: ['A', 'B']}");
   2055   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2056             "A(Input);B(Input);C(_MklConv2D);DMT/_0(Const);DMT/_1(Const)|"
   2057             "A->C;A:control->DMT/_0:control;A:control->DMT/_1:control;B->C:1;"
   2058             "DMT/_0->C:2;DMT/_1->C:3");
   2059 }
   2060 
   2061 // Conv2D output does not go to BiasAdd.
   2062 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) {
   2063   InitGraph(
   2064       "node { name: 'A' op: 'Input'}"
   2065       "node { name: 'B' op: 'Input'}"
   2066       "node { name: 'C' op: 'Conv2D'"
   2067       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2068       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2069       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2070       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2071       " attr { key: 'padding'          value { s: 'SAME' } }"
   2072       " input: ['A', 'B']}"
   2073       "node { name: 'D' op: 'Input'}"
   2074       "node { name: 'E' op: 'Input'}"
   2075       "node { name: 'F' op: 'BiasAdd'"
   2076       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2077       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2078       " input: ['D', 'E'] }");  // Output of _MklConv2D does not go to BiasAdd.
   2079   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2080             "A(Input);B(Input);C(_MklConv2D);D(Input);DMT/_0(Const);"
   2081             "DMT/_1(Const);E(Input);F(BiasAdd)|A->C;A:control->DMT/_0:control;"
   2082             "A:control->DMT/_1:control;B->C:1;D->F;DMT/_0->C:2;DMT/_1->C:3;"
   2083             "E->F:1");
   2084 }
   2085 
   2086 // Conv2D has two outgoing edges: BiasAdd and some other dummy node (Zeta).
   2087 // Merge should not be done in such case.
   2088 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
   2089   InitGraph(
   2090       "node { name: 'A' op: 'Input'}"
   2091       "node { name: 'B' op: 'Input'}"
   2092       "node { name: 'C' op: 'Conv2D'"
   2093       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2094       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2095       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2096       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2097       " attr { key: 'padding'          value { s: 'SAME' } }"
   2098       " input: ['A', 'B']}"
   2099       "node { name: 'D' op: 'Input'}"
   2100       "node { name: 'E' op: 'Input'}"
   2101       "node { name: 'F' op: 'BiasAdd'"
   2102       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2103       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2104       " input: ['D', 'E'] }"  // Conv2D has two outputs.
   2105                               // No merge should happen.
   2106       "node { name: 'G' op: 'Zeta'"
   2107       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2108       " input: ['C', 'E'] }");
   2109   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2110             "A(Input);B(Input);C(_MklConv2D);D(Input);DMT/_0(Const);"
   2111             "DMT/_1(Const);E(Input);F(BiasAdd);G(Zeta)|A->C;"
   2112             "A:control->DMT/_0:control;A:control->DMT/_1:control;B->C:1;C->G;"
   2113             "D->F;DMT/_0->C:2;DMT/_1->C:3;E->F:1;E->G:1");
   2114 }
   2115 
   2116 // data_format attribute value mismatch. Merge should not be done
   2117 // in such case.
   2118 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
   2119   InitGraph(
   2120       "node { name: 'A' op: 'Input'}"
   2121       "node { name: 'B' op: 'Input'}"
   2122       "node { name: 'C' op: 'Conv2D'"
   2123       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2124       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2125       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2126       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2127       " attr { key: 'padding'          value { s: 'SAME' } }"
   2128       " input: ['A', 'B']}"
   2129       "node { name: 'D' op: 'Input'}"
   2130       "node { name: 'E' op: 'BiasAdd'"
   2131       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2132       " attr { key: 'data_format'      value { s: 'NHCW' } }"
   2133       " input: ['C', 'D'] }");
   2134   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2135             "A(Input);B(Input);C(_MklConv2D);D(Input);DMT/_0(Const);"
   2136             "DMT/_1(Const);E(BiasAdd)|A->C;A:control->DMT/_0:control;"
   2137             "A:control->DMT/_1:control;B->C:1;C->E;D->E:1;DMT/_0->C:2;"
   2138             "DMT/_1->C:3");
   2139 }
   2140 
   2141 // Test set 2: BiasAddGrad + Conv2DBackpropFilter fusion tests
   2142 
   2143 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackpropFilterFusion_Positive) {
   2144   InitGraph(
   2145       "node { name: 'A' op: 'Input'}"
   2146       "node { name: 'B' op: 'Int32Input'}"
   2147       "node { name: 'C' op: 'Input'}"
   2148       "node { name: 'D' op: 'Conv2DBackpropFilter'"
   2149       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2150       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2151       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2152       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2153       " attr { key: 'padding'          value { s: 'SAME' } }"
   2154       " input: ['A', 'B', 'C'] }"
   2155       "node { name: 'E' op: 'BiasAddGrad'"
   2156       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2157       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2158       " input: ['C'] }");
   2159   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2160             "A(Input);B(Int32Input);C(Input);"
   2161             "D(_MklConv2DBackpropFilterWithBias);DMT/_0(Const);DMT/_1(Const);"
   2162             "DMT/_2(Const)|A->D;A:control->DMT/_0:control;"
   2163             "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;C->D:2;"
   2164             "DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
   2165 }
   2166 
   2167 // BiasAddGrad fusion in the presence of BackpropFilter. But nodes do not match
   2168 // criteria for rewrite. So rewrite should not happen. 3rd input of
   2169 // Conv2DBackpropFilter is different than input to BiasAddGrad.
   2170 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackpropFilterFusion_Negative1) {
   2171   InitGraph(
   2172       "node { name: 'A' op: 'Input'}"
   2173       "node { name: 'B' op: 'Int32Input'}"
   2174       "node { name: 'C' op: 'Input'}"
   2175       "node { name: 'D' op: 'Conv2DBackpropFilter'"
   2176       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2177       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2178       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2179       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2180       " attr { key: 'padding'          value { s: 'SAME' } }"
   2181       " input: ['A', 'B', 'C'] }"
   2182       "node { name: 'E' op: 'BiasAddGrad'"
   2183       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2184       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2185       " input: ['A'] }");
   2186   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2187             "A(Input);B(Int32Input);C(Input);"
   2188             "D(_MklConv2DBackpropFilter);DMT/_0(Const);DMT/_1(Const);"
   2189             "DMT/_2(Const);E(BiasAddGrad)|A->D;A->E;A:control->DMT/_0:control;"
   2190             "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;C->D:2;"
   2191             "DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
   2192 }
   2193 
   2194 // BiasAddGrad fusion, but nodes do not match criteria for fusion.
   2195 // Different input formats.
   2196 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackpropFilterFusion_Negative2) {
   2197   InitGraph(
   2198       "node { name: 'A' op: 'Input'}"
   2199       "node { name: 'B' op: 'Int32Input'}"
   2200       "node { name: 'C' op: 'Input'}"
   2201       "node { name: 'D' op: 'Conv2DBackpropFilter'"
   2202       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2203       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2204       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2205       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2206       " attr { key: 'padding'          value { s: 'SAME' } }"
   2207       " input: ['A', 'B', 'C'] }"
   2208       "node { name: 'E' op: 'BiasAddGrad'"
   2209       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2210       " attr { key: 'data_format'      value { s: 'NHWC' } }"
   2211       " input: ['A'] }");
   2212   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2213             "A(Input);B(Int32Input);C(Input);"
   2214             "D(_MklConv2DBackpropFilter);DMT/_0(Const);DMT/_1(Const);"
   2215             "DMT/_2(Const);E(BiasAddGrad)|A->D;A->E;A:control->DMT/_0:control;"
   2216             "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;C->D:2;"
   2217             "DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
   2218 }
   2219 
   2220 // BiasAddGrad fusion in the presence of BackpropFilter only. Fusion is done
   2221 // before node rewrite. Check this ordering.
   2222 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackpropFilterFusion_Negative3) {
   2223   InitGraph(
   2224       "node { name: 'A' op: 'Input'}"
   2225       "node { name: 'B' op: 'Input'}"
   2226       "node { name: 'C' op: 'Input'}"
   2227       "node { name: 'M' op: '_MklInput'}"
   2228       "node { name: 'N' op: '_MklInput'}"
   2229       "node { name: 'O' op: '_MklInput'}"
   2230       "node { name: 'D' op: '_MklConv2DWithBias'"
   2231       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2232       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2233       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2234       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2235       " attr { key: 'padding'          value { s: 'SAME' } }"
   2236       " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
   2237       "node { name: 'E' op: 'Zeta'"
   2238       " attr {key: 'T'                 value { type: DT_FLOAT } }"
   2239       " input: ['D', 'A']}"
   2240       "node { name: 'F' op: 'Int32Input'}"
   2241       "node { name: 'G' op: '_MklConv2DBackpropFilter'"
   2242       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2243       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2244       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2245       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2246       " attr { key: 'padding'          value { s: 'SAME' } }"
   2247       " input: ['E', 'F', 'A', 'M', 'N', 'O'] }"
   2248       "node { name: 'H' op: 'BiasAddGrad'"
   2249       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2250       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2251       " input: ['E'] }");
   2252   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2253             "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
   2254             "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);"
   2255             "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;"
   2256             "C->D:2;D->E;E->G;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;"
   2257             "O->G:5");
   2258 }
   2259 
   2260 // C=Conv2D(A,B); E=BiasAdd(C,D); Y=Zeta(E,X);
   2261 // G=Conv2DBackpropInput(F,B,E)
   2262 // This is a case of node rewrite followed by node merge followed by connecting
   2263 // filter output of Conv2DWithBias to filter input of Conv2DBackpropInput.
   2264 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_ConvBpropInput_FilterFwd) {
   2265   CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
   2266   InitGraph(
   2267       "node { name: 'A' op: 'Input'}"
   2268       "node { name: 'B' op: 'Input'}"
   2269       "node { name: 'C' op: 'Conv2D'"
   2270       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2271       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2272       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2273       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2274       " attr { key: 'padding'          value { s: 'SAME' } }"
   2275       " input: ['A', 'B']}"
   2276       "node { name: 'D' op: 'Input'}"
   2277       "node { name: 'E' op: 'BiasAdd'"
   2278       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2279       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2280       " input: ['C', 'D'] }"
   2281       "node { name: 'X' op: 'Input'}"
   2282       "node { name: 'Y' op: 'Zeta'"
   2283       " attr {key: 'T'                 value { type: DT_FLOAT } }"
   2284       " input: ['E', 'X']}"
   2285       "node { name: 'F' op: 'Int32Input'}"
   2286       "node { name: 'G' op: 'Conv2DBackpropInput'"
   2287       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2288       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2289       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2290       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2291       " attr { key: 'padding'          value { s: 'SAME' } }"
   2292       " input: ['F', 'B', 'E']}"
   2293       "node { name: 'Z' op: 'Zeta'"
   2294       " attr {key: 'T'                 value { type: DT_FLOAT } }"
   2295       " input: ['G', 'X']}");
   2296   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2297             "A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
   2298             "DMT/_2(Const);DMT/_3(Const);E(_MklConv2DWithBias);F(Int32Input);"
   2299             "G(_MklConv2DBackpropInput);X(Input);Y(Zeta);Z(Zeta)|"
   2300             "A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
   2301             "A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;"
   2302             "DMT/_1->E:4;DMT/_2->E:5;DMT/_3->G:3;E->G:2;E->Y;E:1->G:1;E:2->G:5;"
   2303             "E:3->G:4;F->G;F:control->DMT/_3:control;G->Z;X->Y:1;X->Z:1");
   2304 }
   2305 
   2306 /////////////////////////////////////////////////////////////////////
   2307 //  Unit tests related to rewriting node to Mkl node
   2308 /////////////////////////////////////////////////////////////////////
   2309 
   2310 // Single Conv2D Op; No Mkl layer on the input and on the output.
   2311 // We will generate dummy Mkl tensor as 2nd input of Conv2D.
   2312 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) {
   2313   InitGraph(
   2314       "node { name: 'A' op: 'Input'}"
   2315       "node { name: 'B' op: 'Input'}"
   2316       "node { name: 'C' op: 'Conv2D'"
   2317       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2318       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2319       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2320       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2321       " attr { key: 'padding'          value { s: 'SAME' } }"
   2322       " input: ['A', 'B']}"
   2323       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2324       " input: ['B', 'C'] }");
   2325   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2326             "A(Input);B(Input);C(_MklConv2D);D(Zeta);DMT/_0(Const);"
   2327             "DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
   2328             "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
   2329             "DMT/_1->C:3");
   2330 }
   2331 
   2332 // 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will
   2333 // have 2 outputs, both of which will be inputs to next Conv2D.
   2334 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) {
   2335   InitGraph(
   2336       "node { name: 'A' op: 'Input'}"
   2337       "node { name: 'B' op: 'Input'}"
   2338       "node { name: 'C' op: 'Conv2D'"
   2339       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2340       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2341       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2342       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2343       " attr { key: 'padding'          value { s: 'SAME' } }"
   2344       " input: ['A', 'B']}"
   2345       "node { name: 'D' op: 'Conv2D'"
   2346       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2347       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2348       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2349       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2350       " attr { key: 'padding'          value { s: 'SAME' } }"
   2351       " input: ['A', 'C']}"
   2352       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2353       " input: ['C', 'D'] }");
   2354   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2355             "A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);"
   2356             "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->C;A->D;"
   2357             "A:control->DMT/_0:control;A:control->DMT/_1:control;"
   2358             "A:control->DMT/_2:control;B->C:1;C->D:1;C->E;"
   2359             "C:2->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2");
   2360 }
   2361 
   2362 // Conv2D with INT32 which is not supported by Mkl
   2363 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) {
   2364   InitGraph(
   2365       "node { name: 'A' op: 'HalfInput'}"
   2366       "node { name: 'B' op: 'HalfInput'}"
   2367       "node { name: 'C' op: 'Conv2D'"
   2368       " attr { key: 'T'                value { type: DT_HALF } }"
   2369       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2370       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2371       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2372       " attr { key: 'padding'          value { s: 'SAME' } }"
   2373       " input: ['A', 'B']}"
   2374       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_HALF } }"
   2375       " input: ['B', 'C'] }");
   2376   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2377             "A(HalfInput);B(HalfInput);C(Conv2D);D(Zeta)|"
   2378             "A->C;B->C:1;B->D;C->D:1");
   2379 }
   2380 
   2381 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_Positive) {
   2382   InitGraph(
   2383       "node { name: 'A' op: 'Input'}"
   2384       "node { name: 'B' op: 'Int32Input'}"
   2385       "node { name: 'C' op: 'Input'}"
   2386       "node { name: 'D' op: 'Conv2DBackpropFilter'"
   2387       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2388       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2389       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2390       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2391       " attr { key: 'padding'          value { s: 'SAME' } }"
   2392       " input: ['A', 'B', 'C']}"
   2393       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2394       " input: ['A', 'D'] }");
   2395   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2396             "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropFilter);"
   2397             "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|"
   2398             "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
   2399             "A:control->DMT/_2:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
   2400             "DMT/_1->D:4;DMT/_2->D:5");
   2401 }
   2402 
   2403 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradInput_Positive) {
   2404   InitGraph(
   2405       "node { name: 'A' op: 'Input'}"
   2406       "node { name: 'B' op: 'Int32Input'}"
   2407       "node { name: 'C' op: 'Input'}"
   2408       "node { name: 'D' op: 'Conv2DBackpropInput'"
   2409       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2410       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2411       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2412       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2413       " attr { key: 'padding'          value { s: 'SAME' } }"
   2414       " input: ['B', 'A', 'C']}"
   2415       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2416       " input: ['A', 'D'] }");
   2417   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2418             "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropInput);"
   2419             "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|"
   2420             "A->D:1;A->E;B->D;B:control->DMT/_0:control;"
   2421             "B:control->DMT/_1:control;B:control->DMT/_2:control;C->D:2;"
   2422             "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
   2423 }
   2424 
   2425 // Check that we never rewrite BiasAddGrad.
   2426 TEST_F(MklLayoutPassTest, NodeRewrite_BiasAddGrad_Positive) {
   2427   InitGraph(
   2428       "node { name: 'A' op: 'Input'}"
   2429       "node { name: 'B' op: 'Input'}"
   2430       "node { name: 'C' op: 'Polygamma'"
   2431       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2432       " input: ['A', 'B']}"
   2433       "node { name: 'D' op: 'Zeta'"
   2434       " attr {key: 'T'                 value { type: DT_FLOAT } }"
   2435       " input: ['C', 'A']}"
   2436       "node { name: 'E' op: 'BiasAddGrad'"
   2437       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2438       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2439       " input: ['D'] }");
   2440   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2441             "A(Input);B(Input);C(Polygamma);D(Zeta);E(BiasAddGrad)|"
   2442             "A->C;A->D:1;B->C:1;C->D;D->E");
   2443 }
   2444 
   2445 // Check that we never rewrite BiasAddGrad.
   2446 TEST_F(MklLayoutPassTest, NodeRewrite_BiasAddGrad_Positive1) {
   2447   InitGraph(
   2448       "node { name: 'A' op: 'Input'}"
   2449       "node { name: 'B' op: 'Input'}"
   2450       "node { name: 'C' op: 'MatMul'"
   2451       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2452       " attr { key: 'transpose_a'      value { b: false } }"
   2453       " attr { key: 'transpose_b'      value { b: false } }"
   2454       " input: ['A', 'B']}"
   2455       "node { name: 'D' op: 'Zeta'"
   2456       " attr {key: 'T'                 value { type: DT_FLOAT } }"
   2457       " input: ['C', 'A']}"
   2458       "node { name: 'E' op: 'BiasAddGrad'"
   2459       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2460       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2461       " input: ['D'] }");
   2462   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2463             "A(Input);B(Input);C(MatMul);D(Zeta);E(BiasAddGrad)|"
   2464             "A->C;A->D:1;B->C:1;C->D;D->E");
   2465 }
   2466 
   2467 // Check that we never rewrite BiasAddGrad.
   2468 TEST_F(MklLayoutPassTest, NodeRewrite_BiasAddGrad_Positive2) {
   2469   InitGraph(
   2470       "node { name: 'A' op: 'Input'}"
   2471       "node { name: 'B' op: 'Input'}"
   2472       "node { name: 'M' op: '_MklInput'}"
   2473       "node { name: 'N' op: '_MklInput'}"
   2474       "node { name: 'C' op: '_MklConv2D'"
   2475       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2476       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2477       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2478       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2479       " attr { key: 'padding'          value { s: 'SAME' } }"
   2480       " input: ['A', 'B', 'M', 'N']}"
   2481       "node { name: 'D' op: 'Zeta'"
   2482       " attr {key: 'T'                 value { type: DT_FLOAT } }"
   2483       " input: ['C', 'A']}"
   2484       "node { name: 'E' op: 'BiasAddGrad'"
   2485       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2486       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2487       " input: ['D'] }");
   2488   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2489             "A(Input);B(Input);C(_MklConv2D);D(Zeta);E(BiasAddGrad);"
   2490             "M(_MklInput);N(_MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;"
   2491             "M->C:2;N->C:3");
   2492 }
   2493 
   2494 // Concat Op test: Concat with no Mkl layer feeding it
   2495 TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
   2496   InitGraph(
   2497       "node { name: 'A' op: 'Const' "
   2498       " attr { key: 'dtype' value { type: DT_INT32 } }"
   2499       " attr { key: 'value' value { "
   2500       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
   2501       "    int_val: 0 } } } }"
   2502       "node { name: 'B' op: 'InputList'"
   2503       " attr { key: 'N'                value { i: 2 } }}"
   2504       "node { name: 'C' op: 'Input'}"
   2505       "node { name: 'D' op: 'Concat'"
   2506       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2507       " attr { key: 'N'                value { i: 2 } }"
   2508       " input: ['A', 'B:0', 'B:1']}"
   2509       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2510       " input: ['C', 'D'] }");
   2511   EXPECT_EQ(
   2512       DoMklLayoutOptimizationPass(),
   2513       "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
   2514       "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;"
   2515       "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;"
   2516       "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
   2517 }
   2518 
   2519 // Concat with 2 Mkl layers feeding it
   2520 TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
   2521   InitGraph(
   2522       "node { name: 'A' op: 'Input'}"
   2523       "node { name: 'B' op: 'Input'}"
   2524       "node { name: 'C' op: 'Input'}"
   2525       "node { name: 'D' op: 'Input'}"
   2526       "node { name: 'E' op: 'Conv2D'"
   2527       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2528       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2529       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2530       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2531       " attr { key: 'padding'          value { s: 'SAME' } }"
   2532       " input: ['A', 'B']}"
   2533       "node { name: 'F' op: 'Conv2D'"
   2534       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2535       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2536       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2537       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2538       " attr { key: 'padding'          value { s: 'SAME' } }"
   2539       " input: ['C', 'D']}"
   2540       "node { name: 'G' op: 'Const' "
   2541       " attr { key: 'dtype' value { type: DT_INT32 } }"
   2542       " attr { key: 'value' value { "
   2543       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
   2544       "    int_val: 0 } } } }"
   2545       "node { name: 'H' op: 'Concat'"
   2546       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2547       " attr { key: 'N'                value { i: 2 } }"
   2548       " input: ['G', 'E', 'F']}"
   2549       "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2550       " input: ['A', 'H'] }");
   2551   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2552             "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
   2553             "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
   2554             "F(_MklConv2D);G(Const);H(_MklConcat);I(Zeta)|A->E;A->I;"
   2555             "A:control->DMT/_2:control;A:control->DMT/_3:control;"
   2556             "B->E:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;"
   2557             "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
   2558             "DMT/_4->H:3;E->H:1;E:2->H:4;F->H:2;F:2->H:5;G->H;"
   2559             "G:control->DMT/_4:control;H->I:1");
   2560 }
   2561 
   2562 // Concat with 1 Mkl and 1 non-Mkl layer feeding it
   2563 TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) {
   2564   InitGraph(
   2565       "node { name: 'A' op: 'Input'}"
   2566       "node { name: 'B' op: 'Input'}"
   2567       "node { name: 'C' op: 'Input'}"
   2568       "node { name: 'D' op: 'Input'}"
   2569       "node { name: 'E' op: 'Conv2D'"
   2570       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2571       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2572       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2573       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2574       " attr { key: 'padding'          value { s: 'SAME' } }"
   2575       " input: ['A', 'B']}"
   2576       "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2577       " input: ['C', 'D']}"
   2578       "node { name: 'G' op: 'Const' "
   2579       " attr { key: 'dtype' value { type: DT_INT32 } }"
   2580       " attr { key: 'value' value { "
   2581       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
   2582       "    int_val: 0 } } } }"
   2583       "node { name: 'H' op: 'Concat'"
   2584       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2585       " attr { key: 'N'                value { i: 2 } }"
   2586       " input: ['G', 'E', 'F']}"
   2587       "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2588       " input: ['A', 'H'] }");
   2589   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2590             "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
   2591             "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);"
   2592             "H(_MklConcat);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;"
   2593             "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
   2594             "DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:2->H:4;F->H:2;"
   2595             "G->H;G:control->DMT/_2:control;G:control->DMT/_3:control;H->I:1");
   2596 }
   2597 
   2598 // ConcatV2 Op test: ConcatV2 with no Mkl layer feeding it
   2599 TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
   2600   InitGraph(
   2601       "node { name: 'A' op: 'Const' "
   2602       " attr { key: 'dtype' value { type: DT_INT32 } }"
   2603       " attr { key: 'value' value { "
   2604       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
   2605       "    int_val: 0 } } } }"
   2606       "node { name: 'B' op: 'InputList'"
   2607       " attr { key: 'N'                value { i: 2 } }}"
   2608       "node { name: 'C' op: 'Input'}"
   2609       "node { name: 'D' op: 'ConcatV2'"
   2610       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2611       " attr { key: 'Tidx'             value { type: DT_INT32 } }"
   2612       " attr { key: 'N'                value { i: 2 } }"
   2613       " input: ['B:0', 'B:1', 'A']}"
   2614       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2615       " input: ['C', 'D'] }");
   2616   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2617             "A(Const);B(InputList);C(Input);D(_MklConcatV2);DMT/_0(Const);"
   2618             "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D:2;B->D;B:1->D:1;"
   2619             "B:control->DMT/_0:control;B:control->DMT/_1:control;"
   2620             "B:control->DMT/_2:control;C->E;D->E:1;DMT/_0->D:3;"
   2621             "DMT/_1->D:4;DMT/_2->D:5");
   2622 }
   2623 
   2624 // ConcatV2 with 2 Mkl layers feeding it
   2625 TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
   2626   InitGraph(
   2627       "node { name: 'A' op: 'Input'}"
   2628       "node { name: 'B' op: 'Input'}"
   2629       "node { name: 'C' op: 'Input'}"
   2630       "node { name: 'D' op: 'Input'}"
   2631       "node { name: 'E' op: 'Conv2D'"
   2632       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2633       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2634       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2635       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2636       " attr { key: 'padding'          value { s: 'SAME' } }"
   2637       " input: ['A', 'B']}"
   2638       "node { name: 'F' op: 'Conv2D'"
   2639       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2640       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2641       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2642       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2643       " attr { key: 'padding'          value { s: 'SAME' } }"
   2644       " input: ['C', 'D']}"
   2645       "node { name: 'G' op: 'Const' "
   2646       " attr { key: 'dtype' value { type: DT_INT32 } }"
   2647       " attr { key: 'value' value { "
   2648       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
   2649       "    int_val: 0 } } } }"
   2650       "node { name: 'H' op: 'ConcatV2'"
   2651       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2652       " attr { key: 'Tidx'             value { type: DT_INT32 } }"
   2653       " attr { key: 'N'                value { i: 2 } }"
   2654       " input: ['E', 'F', 'G']}"
   2655       "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2656       " input: ['A', 'H'] }");
   2657   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2658             "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
   2659             "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
   2660             "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Zeta)|A->E;A->I;"
   2661             "A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;C->F;"
   2662             "C:control->DMT/_0:control;C:control->DMT/_1:control;"
   2663             "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
   2664             "DMT/_4->H:5;E->H;E:2->H:3;E:control->DMT/_4:control;F->H:1;"
   2665             "F:2->H:4;G->H:2;H->I:1");
   2666 }
   2667 
   2668 // ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it
   2669 TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) {
   2670   InitGraph(
   2671       "node { name: 'A' op: 'Input'}"
   2672       "node { name: 'B' op: 'Input'}"
   2673       "node { name: 'C' op: 'Input'}"
   2674       "node { name: 'D' op: 'Input'}"
   2675       "node { name: 'E' op: 'Conv2D'"
   2676       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2677       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   2678       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   2679       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   2680       " attr { key: 'padding'          value { s: 'SAME' } }"
   2681       " input: ['A', 'B']}"
   2682       "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2683       " input: ['C', 'D']}"
   2684       "node { name: 'G' op: 'Const' "
   2685       " attr { key: 'dtype' value { type: DT_INT32 } }"
   2686       " attr { key: 'value' value { "
   2687       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
   2688       "    int_val: 0 } } } }"
   2689       "node { name: 'H' op: 'ConcatV2'"
   2690       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2691       " attr { key: 'Tidx'             value { type: DT_INT32 } }"
   2692       " attr { key: 'N'                value { i: 2 } }"
   2693       " input: ['E', 'F', 'G']}"
   2694       "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2695       " input: ['A', 'H'] }");
   2696   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2697             "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
   2698             "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);"
   2699             "H(_MklConcatV2);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;"
   2700             "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
   2701             "DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:2->H:3;"
   2702             "E:control->DMT/_2:control;E:control->DMT/_3:control;F->H:1;"
   2703             "G->H:2;H->I:1");
   2704 }
   2705 
   2706 TEST_F(MklLayoutPassTest, NodeRewrite_Relu_Positive) {
   2707   InitGraph(
   2708       "node { name: 'A' op: 'Input'}"
   2709       "node { name: 'B' op: 'Relu'"
   2710       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2711       " input: ['A'] }"
   2712       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2713       " input: ['A', 'B'] }");
   2714   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2715             "A(Input);B(_MklRelu);C(Zeta);DMT/_0(Const)|A->B;A->C;"
   2716             "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
   2717 }
   2718 
   2719 TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_Positive) {
   2720   InitGraph(
   2721       "node { name: 'A' op: 'Input'}"
   2722       "node { name: 'B' op: 'Input'}"
   2723       "node { name: 'C' op: 'ReluGrad'"
   2724       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2725       " input: ['A', 'B'] }"
   2726       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2727       " input: ['A', 'C'] }");
   2728   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2729             "A(Input);B(Input);C(_MklReluGrad);D(Zeta);DMT/_0(Const);"
   2730             "DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;"
   2731             "A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
   2732 }
   2733 
   2734 TEST_F(MklLayoutPassTest, NodeRewrite_ReluReluGrad_Positive) {
   2735   InitGraph(
   2736       "node { name: 'A' op: 'Input'}"
   2737       "node { name: 'B' op: 'Relu'"
   2738       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2739       " input: ['A'] }"
   2740       "node { name: 'C' op: 'ReluGrad'"
   2741       " attr { key: 'T'                value { type: DT_FLOAT } }"
   2742       " input: ['A', 'B'] }"
   2743       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2744       " input: ['A', 'C'] }");
   2745   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2746             "A(Input);B(_MklRelu);C(_MklReluGrad);D(Zeta);DMT/_0(Const);"
   2747             "DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;"
   2748             "A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;"
   2749             "DMT/_1->C:2");
   2750 }
   2751 
   2752 TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) {
   2753   InitGraph(
   2754       "node { name: 'A' op: 'Input'}"
   2755       "node { name: 'B' op: 'AvgPool'"
   2756       " attr { key: 'T'            value { type: DT_FLOAT } }"
   2757       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   2758       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   2759       " attr { key: 'padding'      value { s: 'VALID' } }"
   2760       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   2761       " input: ['A'] }"
   2762       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2763       " input: ['A', 'B'] }");
   2764   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2765             "A(Input);B(_MklAvgPool);C(Zeta);DMT/_0(Const)|A->B;A->C;"
   2766             "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
   2767 }
   2768 
   2769 TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolGrad_Positive) {
   2770   InitGraph(
   2771       "node { name: 'A' op: 'Int32Input'}"
   2772       "node { name: 'B' op: 'Input'}"
   2773       "node { name: 'C' op: 'AvgPoolGrad' "
   2774       " attr { key: 'T'            value { type: DT_FLOAT } }"
   2775       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   2776       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   2777       " attr { key: 'padding'      value { s: 'VALID' } }"
   2778       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   2779       " input: ['A', 'B'] }"
   2780       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2781       " input: ['B', 'C'] }");
   2782   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2783             "A(Int32Input);B(Input);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);"
   2784             "DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
   2785             "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
   2786             "DMT/_1->C:3");
   2787 }
   2788 
   2789 TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolAvgPoolGrad_Positive) {
   2790   InitGraph(
   2791       "node { name: 'A' op: 'Input'}"
   2792       "node { name: 'I' op: 'Int32Input'}"
   2793       "node { name: 'B' op: 'AvgPool'"
   2794       " attr { key: 'T'            value { type: DT_FLOAT } }"
   2795       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   2796       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   2797       " attr { key: 'padding'      value { s: 'VALID' } }"
   2798       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   2799       " input: ['A'] }"
   2800       "node { name: 'C' op: 'AvgPoolGrad' "
   2801       " attr { key: 'T'            value { type: DT_FLOAT } }"
   2802       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   2803       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   2804       " attr { key: 'padding'      value { s: 'VALID' } }"
   2805       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   2806       " input: ['I', 'B'] }"
   2807       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2808       " input: ['A', 'C'] }");
   2809   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2810             "A(Input);B(_MklAvgPool);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);"
   2811             "DMT/_1(Const);I(Int32Input)|A->B;A->D;A:control->DMT/_0:control;"
   2812             "B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;DMT/_1->C:2;I->C;"
   2813             "I:control->DMT/_1:control");
   2814 }
   2815 
   2816 TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormGrad_Positive) {
   2817   InitGraph(
   2818       "node { name: 'A' op: 'Input'}"
   2819       "node { name: 'B' op: 'Input'}"
   2820       "node { name: 'C' op: 'Input'}"
   2821       "node { name: 'D' op: 'Input'}"
   2822       "node { name: 'E' op: 'Input'}"
   2823       "node { name: 'F' op: 'FusedBatchNormGrad'"
   2824       " attr { key: 'T'            value { type: DT_FLOAT } }"
   2825       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   2826       " attr { key: 'epsilon'      value { f: 0.0001 } }"
   2827       " attr { key: 'is_training'  value { b: true } }"
   2828       " input: ['A', 'B', 'C', 'D', 'E'] }"
   2829       "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2830       " input: ['A', 'F'] }");
   2831   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2832             "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
   2833             "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
   2834             "F(_MklFusedBatchNormGrad);G(Zeta)|A->F;A->G;"
   2835             "A:control->DMT/_0:control;A:control->DMT/_1:control;"
   2836             "A:control->DMT/_2:control;A:control->DMT/_3:control;"
   2837             "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
   2838             "DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
   2839             "E->F:4;F->G:1");
   2840 }
   2841 
   2842 TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_Positive) {
   2843   InitGraph(
   2844       "node { name: 'A' op: 'Input'}"
   2845       "node { name: 'B' op: 'Input'}"
   2846       "node { name: 'C' op: 'Input'}"
   2847       "node { name: 'D' op: 'Input'}"
   2848       "node { name: 'E' op: 'Input'}"
   2849       "node { name: 'F' op: 'FusedBatchNorm'"
   2850       " attr { key: 'T'            value { type: DT_FLOAT } }"
   2851       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   2852       " attr { key: 'epsilon'      value { f: 0.0001 } }"
   2853       " attr { key: 'is_training'  value { b: true } }"
   2854       " input: ['A', 'B', 'C', 'D', 'E'] }"
   2855       "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2856       " input: ['A', 'F'] }");
   2857   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2858             "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
   2859             "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
   2860             "F(_MklFusedBatchNorm);G(Zeta)|A->F;A->G;"
   2861             "A:control->DMT/_0:control;A:control->DMT/_1:control;"
   2862             "A:control->DMT/_2:control;A:control->DMT/_3:control;"
   2863             "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
   2864             "DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
   2865             "E->F:4;F->G:1");
   2866 }
   2867 
   2868 /////////////////////////////////////////////////////////////////////
   2869 //  Unit tests related to rewriting node for workspace edges
   2870 /////////////////////////////////////////////////////////////////////
   2871 
   2872 /* Test LRN->MaxPool->MaxPoolGrad->LRNGrad replacement by workspace nodes. */
   2873 TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
   2874   InitGraph(
   2875       "node { name: 'A' op: 'Input'}"
   2876       "node { name: 'B' op: 'LRN'"
   2877       " attr { key: 'T'            value { type: DT_FLOAT } }"
   2878       " attr { key: 'alpha'        value { f: 0.001 } }"
   2879       " attr { key: 'beta'         value { f: 0.75 } }"
   2880       " attr { key: 'bias'         value { f: 1.0 } }"
   2881       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   2882       " attr { key: 'depth_radius' value { i: 2 } }"
   2883       " input: ['A'] }"
   2884       "node { name: 'C' op: 'MaxPool'"
   2885       " attr { key: 'T'            value { type: DT_FLOAT } }"
   2886       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   2887       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   2888       " attr { key: 'padding'      value { s: 'VALID' } }"
   2889       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   2890       " input: ['B'] }"
   2891       "node { name: 'D' op: 'Input'}"
   2892       "node { name: 'E' op: 'MaxPoolGrad'"
   2893       " attr { key: 'T'            value { type: DT_FLOAT } }"
   2894       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   2895       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   2896       " attr { key: 'padding'      value { s: 'VALID' } }"
   2897       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   2898       " input: ['B', 'C', 'D'] }"
   2899       "node { name: 'F' op: 'Input'}"
   2900       "node { name: 'G' op: 'LRNGrad'"
   2901       " attr { key: 'T'            value { type: DT_FLOAT } }"
   2902       " attr { key: 'alpha'        value { f: 0.001 } }"
   2903       " attr { key: 'beta'         value { f: 0.75 } }"
   2904       " attr { key: 'bias'         value { f: 1.0 } }"
   2905       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   2906       " attr { key: 'depth_radius' value { i: 2 } }"
   2907       " input: ['E', 'F', 'B'] }"
   2908       "node { name: 'H' op: 'Input'}"
   2909       "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2910       " input: ['H', 'G'] }");
   2911   EXPECT_EQ(
   2912       DoMklLayoutOptimizationPass(),
   2913       "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
   2914       "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);"
   2915       "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;"
   2916       "B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;B:control->DMT/_1:control;C->E:1;"
   2917       "C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;DMT/_2->G:5;"
   2918       "E->G;E:1->G:4;E:control->DMT/_2:control;F->G:1;G->I:1;H->I");
   2919 }
   2920 
   2921 /* Test LRN->LRNGrad replacement by workspace nodes. */
   2922 TEST_F(MklLayoutPassTest, LRN_Positive) {
   2923   InitGraph(
   2924       "node { name: 'A' op: 'Input'}"
   2925       "node { name: 'B' op: 'LRN'"
   2926       " attr { key: 'T'            value { type: DT_FLOAT } }"
   2927       " attr { key: 'alpha'        value { f: 0.001 } }"
   2928       " attr { key: 'beta'         value { f: 0.75 } }"
   2929       " attr { key: 'bias'         value { f: 1.0 } }"
   2930       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   2931       " attr { key: 'depth_radius' value { i: 2 } }"
   2932       " input: ['A'] }"
   2933       "node { name: 'C' op: 'Input'}"
   2934       "node { name: 'D' op: 'Input'}"
   2935       "node { name: 'E' op: 'LRNGrad'"
   2936       " attr { key: 'T'            value { type: DT_FLOAT } }"
   2937       " attr { key: 'alpha'        value { f: 0.001 } }"
   2938       " attr { key: 'beta'         value { f: 0.75 } }"
   2939       " attr { key: 'bias'         value { f: 1.0 } }"
   2940       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   2941       " attr { key: 'depth_radius' value { i: 2 } }"
   2942       " input: ['C', 'D', 'B'] }"
   2943       "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2944       " input: ['C', 'E'] }");
   2945   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2946             "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
   2947             "DMT/_2(Const);E(_MklLRNGrad);F(Zeta)|"
   2948             "A->B;A:control->DMT/_0:control;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;"
   2949             "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
   2950             "D->E:1;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
   2951 }
   2952 
   2953 /* Test LRN->LRNGrad replacement when only one of them is present. */
   2954 TEST_F(MklLayoutPassTest, LRN_Negative1) {
   2955   InitGraph(
   2956       "node { name: 'A' op: 'Input'}"
   2957       "node { name: 'B' op: 'LRN'"
   2958       " attr { key: 'T'            value { type: DT_FLOAT } }"
   2959       " attr { key: 'alpha'        value { f: 0.001 } }"
   2960       " attr { key: 'beta'         value { f: 0.75 } }"
   2961       " attr { key: 'bias'         value { f: 1.0 } }"
   2962       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   2963       " attr { key: 'depth_radius' value { i: 2 } }"
   2964       " input: ['A'] }"
   2965       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2966       " input: ['A', 'B'] }");
   2967   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2968             "A(Input);B(_MklLRN);C(Zeta);DMT/_0(Const)|"
   2969             "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
   2970 }
   2971 
   2972 /* Test LRN->LRNGrad replacement when only one of them is present. */
   2973 TEST_F(MklLayoutPassTest, LRN_Negative2) {
   2974   InitGraph(
   2975       "node { name: 'A' op: 'Input'}"
   2976       "node { name: 'B' op: 'Input'}"
   2977       "node { name: 'C' op: 'Input'}"
   2978       "node { name: 'D' op: 'LRNGrad'"
   2979       " attr { key: 'T'            value { type: DT_FLOAT } }"
   2980       " attr { key: 'alpha'        value { f: 0.001 } }"
   2981       " attr { key: 'beta'         value { f: 0.75 } }"
   2982       " attr { key: 'bias'         value { f: 1.0 } }"
   2983       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   2984       " attr { key: 'depth_radius' value { i: 2 } }"
   2985       " input: ['A', 'B', 'C'] }"
   2986       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   2987       " input: ['A', 'D'] }");
   2988   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   2989             "A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);"
   2990             "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|"
   2991             "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
   2992             "A:control->DMT/_2:control;A:control->DMT/_3:control;"
   2993             "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
   2994             "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
   2995 }
   2996 
   2997 /* Test LRN->LRNGrad negative case, where single LRN feeds
   2998    2 LRNGrad nodes at different slots. */
   2999 TEST_F(MklLayoutPassTest, LRN_Negative3) {
   3000   InitGraph(
   3001       "node { name: 'A' op: 'Input'}"
   3002       "node { name: 'B' op: 'LRN'"
   3003       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3004       " attr { key: 'alpha'        value { f: 0.001 } }"
   3005       " attr { key: 'beta'         value { f: 0.75 } }"
   3006       " attr { key: 'bias'         value { f: 1.0 } }"
   3007       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   3008       " attr { key: 'depth_radius' value { i: 2 } }"
   3009       " input: ['A'] }"
   3010       "node { name: 'C' op: 'Input'}"
   3011       "node { name: 'D' op: 'Input'}"
   3012       "node { name: 'E' op: 'LRNGrad'"
   3013       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3014       " attr { key: 'alpha'        value { f: 0.001 } }"
   3015       " attr { key: 'beta'         value { f: 0.75 } }"
   3016       " attr { key: 'bias'         value { f: 1.0 } }"
   3017       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   3018       " attr { key: 'depth_radius' value { i: 2 } }"
   3019       " input: ['C', 'D', 'B'] }"
   3020       "node { name: 'F' op: 'LRNGrad'"
   3021       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3022       " attr { key: 'alpha'        value { f: 0.001 } }"
   3023       " attr { key: 'beta'         value { f: 0.75 } }"
   3024       " attr { key: 'bias'         value { f: 1.0 } }"
   3025       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   3026       " attr { key: 'depth_radius' value { i: 2 } }"
   3027       " input: ['C', 'B', 'D'] }"
   3028       "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3029       " input: ['E', 'F'] }");
   3030   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3031             "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
   3032             "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);"
   3033             "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Zeta)|A->B;"
   3034             "A:control->DMT/_0:control;B->E:2;"
   3035             "B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;"
   3036             "C:control->DMT/_1:control;C:control->DMT/_2:control;"
   3037             "C:control->DMT/_3:control;C:control->DMT/_4:control;"
   3038             "C:control->DMT/_5:control;C:control->DMT/_6:control;"
   3039             "D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
   3040             "DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1");
   3041 }
   3042 
   3043 /* Test MaxPool->MaxPoolGrad replacement by workspace+rewrite nodes. */
   3044 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) {
   3045   InitGraph(
   3046       "node { name: 'A' op: 'Input'}"
   3047       "node { name: 'B' op: 'MaxPool'"
   3048       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3049       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   3050       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   3051       " attr { key: 'padding'      value { s: 'VALID' } }"
   3052       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   3053       " input: ['A'] }"
   3054       "node { name: 'C' op: 'Input'}"
   3055       "node { name: 'D' op: 'Input'}"
   3056       "node { name: 'E' op: 'MaxPoolGrad'"
   3057       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3058       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   3059       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   3060       " attr { key: 'padding'      value { s: 'VALID' } }"
   3061       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   3062       " input: ['C', 'B', 'D'] }"
   3063       "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3064       " input: ['C', 'E'] }");
   3065   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3066             "A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
   3067             "DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Zeta)|"
   3068             "A->B;A:control->DMT/_0:control;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;"
   3069             "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
   3070             "D->E:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
   3071 }
   3072 
   3073 // Test MaxPool>MaxPoolGrad replacement when only one of them is present.
   3074 // In this case, we will rewrite MaxPool node but workspace edges will not
   3075 // be present.
   3076 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) {
   3077   InitGraph(
   3078       "node { name: 'A' op: 'Input'}"
   3079       "node { name: 'B' op: 'MaxPool'"
   3080       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3081       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   3082       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   3083       " attr { key: 'padding'      value { s: 'VALID' } }"
   3084       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   3085       " input: ['A'] }"
   3086       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3087       " input: ['A', 'B'] }");
   3088   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3089             "A(Input);B(_MklMaxPool);C(Zeta);DMT/_0(Const)|"
   3090             "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
   3091 }
   3092 
   3093 // Test MaxPoolGrad replacement when only one of them is present.
   3094 // In this case, we will rewrite MaxPoolGrad and for workspace tensor and
   3095 // its Mkl part, we will generate dummy tensor.
   3096 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
   3097   InitGraph(
   3098       "node { name: 'A' op: 'Input'}"
   3099       "node { name: 'B' op: 'Input'}"
   3100       "node { name: 'C' op: 'Input'}"
   3101       "node { name: 'D' op: 'MaxPoolGrad'"
   3102       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3103       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   3104       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
   3105       " attr { key: 'padding'      value { s: 'VALID' } }"
   3106       " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
   3107       " input: ['A', 'B', 'C'] }"
   3108       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3109       " input: ['A', 'D'] }");
   3110   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3111             "A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);"
   3112             "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|"
   3113             "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
   3114             "A:control->DMT/_2:control;A:control->DMT/_3:control;"
   3115             "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
   3116             "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
   3117 }
   3118 
   3119 // Test MaxPool handling for batch-wise pooling (NCHW)
   3120 // No rewrite should take place in such case
   3121 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative3) {
   3122   InitGraph(
   3123       "node { name: 'A' op: 'Input'}"
   3124       "node { name: 'B' op: 'MaxPool'"
   3125       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3126       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   3127       " attr { key: 'ksize'        value { list: {i: 2, i:1, i:1, i:1} } }"
   3128       " attr { key: 'padding'      value { s: 'VALID' } }"
   3129       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }"
   3130       " input: ['A'] }"
   3131       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3132       " input: ['A', 'B'] }");
   3133   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3134             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   3135 }
   3136 
   3137 // Test MaxPool handling for batch-wise pooling (NCHW)
   3138 // No rewrite should take place in such case
   3139 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative4) {
   3140   InitGraph(
   3141       "node { name: 'A' op: 'Input'}"
   3142       "node { name: 'B' op: 'MaxPool'"
   3143       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3144       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   3145       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:1, i:1} } }"
   3146       " attr { key: 'padding'      value { s: 'VALID' } }"
   3147       " attr { key: 'strides'      value { list: {i: 2, i:1, i:1, i:1} } }"
   3148       " input: ['A'] }"
   3149       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3150       " input: ['A', 'B'] }");
   3151   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3152             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   3153 }
   3154 
   3155 // Test MaxPool handling for depth-wise pooling (NHWC)
   3156 // No rewrite should take place in such case
   3157 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative5) {
   3158   InitGraph(
   3159       "node { name: 'A' op: 'Input'}"
   3160       "node { name: 'B' op: 'MaxPool'"
   3161       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3162       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   3163       " attr { key: 'ksize'        value { list: {i: 1, i:2, i:1, i:1} } }"
   3164       " attr { key: 'padding'      value { s: 'VALID' } }"
   3165       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }"
   3166       " input: ['A'] }"
   3167       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3168       " input: ['A', 'B'] }");
   3169   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3170             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   3171 }
   3172 
   3173 // Test MaxPool handling for depth-wise pooling (NCHW)
   3174 // No rewrite should take place in such case
   3175 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative6) {
   3176   InitGraph(
   3177       "node { name: 'A' op: 'Input'}"
   3178       "node { name: 'B' op: 'MaxPool'"
   3179       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3180       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   3181       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:1, i:1} } }"
   3182       " attr { key: 'padding'      value { s: 'VALID' } }"
   3183       " attr { key: 'strides'      value { list: {i: 1, i:2, i:1, i:1} } }"
   3184       " input: ['A'] }"
   3185       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3186       " input: ['A', 'B'] }");
   3187   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3188             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   3189 }
   3190 
   3191 // Test MaxPool handling for batch-wise pooling (NHWC)
   3192 // No rewrite should take place in such case
   3193 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative7) {
   3194   InitGraph(
   3195       "node { name: 'A' op: 'Input'}"
   3196       "node { name: 'B' op: 'MaxPool'"
   3197       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3198       " attr { key: 'data_format'  value { s: 'NHWC' } }"
   3199       " attr { key: 'ksize'        value { list: {i: 2, i:1, i:1, i:1} } }"
   3200       " attr { key: 'padding'      value { s: 'VALID' } }"
   3201       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }"
   3202       " input: ['A'] }"
   3203       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3204       " input: ['A', 'B'] }");
   3205   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3206             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   3207 }
   3208 
   3209 // Test MaxPool handling for batch-wise pooling (NHWC)
   3210 // No rewrite should take place in such case
   3211 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative8) {
   3212   InitGraph(
   3213       "node { name: 'A' op: 'Input'}"
   3214       "node { name: 'B' op: 'MaxPool'"
   3215       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3216       " attr { key: 'data_format'  value { s: 'NHWC' } }"
   3217       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:1, i:1} } }"
   3218       " attr { key: 'padding'      value { s: 'VALID' } }"
   3219       " attr { key: 'strides'      value { list: {i: 2, i:1, i:1, i:1} } }"
   3220       " input: ['A'] }"
   3221       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3222       " input: ['A', 'B'] }");
   3223   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3224             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   3225 }
   3226 
   3227 // Test MaxPool handling for depth-wise pooling (NHWC)
   3228 // No rewrite should take place in such case
   3229 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative9) {
   3230   InitGraph(
   3231       "node { name: 'A' op: 'Input'}"
   3232       "node { name: 'B' op: 'MaxPool'"
   3233       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3234       " attr { key: 'data_format'  value { s: 'NHWC' } }"
   3235       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:1, i:2} } }"
   3236       " attr { key: 'padding'      value { s: 'VALID' } }"
   3237       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }"
   3238       " input: ['A'] }"
   3239       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3240       " input: ['A', 'B'] }");
   3241   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3242             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   3243 }
   3244 
   3245 // Test MaxPool handling for depth-wise pooling (NHWC)
   3246 // No rewrite should take place in such case
   3247 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative10) {
   3248   InitGraph(
   3249       "node { name: 'A' op: 'Input'}"
   3250       "node { name: 'B' op: 'MaxPool'"
   3251       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3252       " attr { key: 'data_format'  value { s: 'NHWC' } }"
   3253       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:1, i:1} } }"
   3254       " attr { key: 'padding'      value { s: 'VALID' } }"
   3255       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:2} } }"
   3256       " input: ['A'] }"
   3257       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3258       " input: ['A', 'B'] }");
   3259   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3260             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   3261 }
   3262 
   3263 /////////////////////////////////////////////////////////////////////
   3264 
   3265 // Single Conv2D Op on GPU device
   3266 // No rewrite should happen
   3267 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) {
   3268   InitGraph(
   3269       "node { name: 'A' op: 'Input'}"
   3270       "node { name: 'B' op: 'Input'}"
   3271       "node { name: 'C' op: 'Conv2D'"
   3272       " attr { key: 'T'                value { type: DT_FLOAT } }"
   3273       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   3274       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   3275       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   3276       " attr { key: 'padding'          value { s: 'SAME' } }"
   3277       " input: ['A', 'B']}"
   3278       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3279       " input: ['B', 'C'] }",
   3280       kGPUDevice);
   3281   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3282             "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1");
   3283 }
   3284 
   3285 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) {
   3286   InitGraph(
   3287       "node { name: 'A' op: 'Input'}"
   3288       "node { name: 'B' op: 'Input'}"
   3289       "node { name: 'C' op: 'Input'}"
   3290       "node { name: 'M' op: '_MklInput'}"
   3291       "node { name: 'N' op: '_MklInput'}"
   3292       "node { name: 'O' op: '_MklInput'}"
   3293       "node { name: 'D' op: '_MklConv2DWithBias'"
   3294       " attr { key: 'T'                value { type: DT_FLOAT } }"
   3295       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   3296       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   3297       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   3298       " attr { key: 'padding'          value { s: 'SAME' } }"
   3299       " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
   3300       "node { name: 'E' op: 'Zeta'"
   3301       " attr {key: 'T'                 value { type: DT_FLOAT } }"
   3302       " input: ['D', 'A']}"
   3303       "node { name: 'F' op: 'BiasAddGrad'"
   3304       " attr { key: 'T'                value { type: DT_FLOAT } }"
   3305       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   3306       " input: ['E'] }",
   3307       kGPUDevice);
   3308   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3309             "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
   3310             "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);"
   3311             "O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;E->F;"
   3312             "M->D:3;N->D:4;O->D:5");
   3313 }
   3314 
   3315 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) {
   3316   InitGraph(
   3317       "node { name: 'A' op: 'Input'}"
   3318       "node { name: 'B' op: 'Int32Input'}"
   3319       "node { name: 'C' op: 'Input'}"
   3320       "node { name: 'D' op: 'Conv2DBackpropFilter'"
   3321       " attr { key: 'T'                value { type: DT_FLOAT } }"
   3322       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   3323       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   3324       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   3325       " attr { key: 'padding'          value { s: 'SAME' } }"
   3326       " input: ['A', 'B', 'C']}"
   3327       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3328       " input: ['A', 'D'] }",
   3329       kGPUDevice);
   3330   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3331             "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|"
   3332             "A->D;A->E;B->D:1;C->D:2;D->E:1");
   3333 }
   3334 
   3335 TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) {
   3336   InitGraph(
   3337       "node { name: 'A' op: 'Input'}"
   3338       "node { name: 'B' op: 'Relu'"
   3339       " attr { key: 'T'                value { type: DT_FLOAT } }"
   3340       " input: ['A'] }"
   3341       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3342       " input: ['A', 'B'] }",
   3343       kGPUDevice);
   3344   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3345             "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1");
   3346 }
   3347 
   3348 TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) {
   3349   InitGraph(
   3350       "node { name: 'A' op: 'Input'}"
   3351       "node { name: 'B' op: 'Input'}"
   3352       "node { name: 'C' op: 'ReluGrad'"
   3353       " attr { key: 'T'                value { type: DT_FLOAT } }"
   3354       " input: ['A', 'B'] }"
   3355       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3356       " input: ['A', 'C'] }",
   3357       kGPUDevice);
   3358   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3359             "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1");
   3360 }
   3361 
   3362 TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) {
   3363   InitGraph(
   3364       "node { name: 'A' op: 'Input'}"
   3365       "node { name: 'B' op: 'MaxPool'"
   3366       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3367       " attr { key: 'data_format'  value { s: 'NHWC' } }"
   3368       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:1, i:1} } }"
   3369       " attr { key: 'padding'      value { s: 'VALID' } }"
   3370       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }"
   3371       " input: ['A'] }"
   3372       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3373       " input: ['A', 'B'] }",
   3374       kGPUDevice);
   3375   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3376             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
   3377 }
   3378 
   3379 TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) {
   3380   InitGraph(
   3381       "node { name: 'A' op: 'Input'}"
   3382       "node { name: 'B' op: 'AvgPool'"
   3383       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3384       " attr { key: 'data_format'  value { s: 'NHWC' } }"
   3385       " attr { key: 'ksize'        value { list: {i: 1, i:1, i:1, i:1} } }"
   3386       " attr { key: 'padding'      value { s: 'VALID' } }"
   3387       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }"
   3388       " input: ['A'] }"
   3389       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3390       " input: ['A', 'B'] }",
   3391       kGPUDevice);
   3392   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3393             "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1");
   3394 }
   3395 
   3396 // Concat Op test: Concat with no Mkl layer feeding it
   3397 TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) {
   3398   InitGraph(
   3399       "node { name: 'A' op: 'Const' "
   3400       " attr { key: 'dtype' value { type: DT_INT32 } }"
   3401       " attr { key: 'value' value { "
   3402       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
   3403       "    int_val: 0 } } } }"
   3404       "node { name: 'B' op: 'InputList'"
   3405       " attr { key: 'N'                value { i: 2 } }}"
   3406       "node { name: 'C' op: 'Input'}"
   3407       "node { name: 'D' op: 'Concat'"
   3408       " attr { key: 'T'                value { type: DT_FLOAT } }"
   3409       " attr { key: 'N'                value { i: 2 } }"
   3410       " input: ['A', 'B:0', 'B:1']}"
   3411       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3412       " input: ['C', 'D'] }",
   3413       kGPUDevice);
   3414   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3415             "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;"
   3416             "B->D:1;B:1->D:2;C->E;D->E:1");
   3417 }
   3418 
   3419 TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) {
   3420   InitGraph(
   3421       "node { name: 'A' op: 'Const' "
   3422       " attr { key: 'dtype' value { type: DT_INT32 } }"
   3423       " attr { key: 'value' value { "
   3424       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
   3425       "    int_val: 0 } } } }"
   3426       "node { name: 'B' op: 'InputList'"
   3427       " attr { key: 'N'                value { i: 2 } }}"
   3428       "node { name: 'C' op: 'Input'}"
   3429       "node { name: 'D' op: 'ConcatV2'"
   3430       " attr { key: 'T'                value { type: DT_FLOAT } }"
   3431       " attr { key: 'Tidx'             value { type: DT_INT32 } }"
   3432       " attr { key: 'N'                value { i: 2 } }"
   3433       " input: ['B:0', 'B:1', 'A']}"
   3434       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3435       " input: ['C', 'D'] }",
   3436       kGPUDevice);
   3437   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3438             "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|"
   3439             "A->D:2;B->D;B:1->D:1;C->E;D->E:1");
   3440 }
   3441 
   3442 TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) {
   3443   InitGraph(
   3444       "node { name: 'A' op: 'Input'}"
   3445       "node { name: 'B' op: 'Input'}"
   3446       "node { name: 'C' op: 'Input'}"
   3447       "node { name: 'D' op: 'Input'}"
   3448       "node { name: 'E' op: 'Input'}"
   3449       "node { name: 'F' op: 'FusedBatchNorm'"
   3450       " attr { key: 'T'            value { type: DT_FLOAT } }"
   3451       " attr { key: 'data_format'  value { s: 'NCHW' } }"
   3452       " attr { key: 'epsilon'      value { f: 0.0001 } }"
   3453       " attr { key: 'is_training'  value { b: true } }"
   3454       " input: ['A', 'B', 'C', 'D', 'E'] }"
   3455       "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
   3456       " input: ['A', 'F'] }",
   3457       kGPUDevice);
   3458   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3459             "A(Input);B(Input);C(Input);D(Input);E(Input);"
   3460             "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
   3461             "E->F:4;F->G:1");
   3462 }
   3463 
   3464 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
   3465   CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
   3466   InitGraph(
   3467       "node { name: 'A' op: 'Input'}"
   3468       "node { name: 'B' op: 'Input'}"
   3469       "node { name: 'M' op: '_MklInput'}"
   3470       "node { name: 'N' op: '_MklInput'}"
   3471       "node { name: 'C' op: '_MklConv2D'"
   3472       " attr { key: 'T'                value { type: DT_FLOAT } }"
   3473       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   3474       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
   3475       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
   3476       " attr { key: 'padding'          value { s: 'SAME' } }"
   3477       " input: ['A', 'B', 'M', 'N']}"
   3478       "node { name: 'D' op: 'Input'}"
   3479       "node { name: 'E' op: 'BiasAdd'"
   3480       " attr { key: 'T'                value { type: DT_FLOAT } }"
   3481       " attr { key: 'data_format'      value { s: 'NCHW' } }"
   3482       " input: ['C', 'D'] }"
   3483       "node { name: 'Y' op: 'Input'}"
   3484       "node { name: 'Z' op: 'Zeta'"
   3485       " attr {key: 'T'                 value { type: DT_FLOAT } }"
   3486       " input: ['E', 'Y']}",
   3487       kGPUDevice);
   3488   EXPECT_EQ(DoMklLayoutOptimizationPass(),
   3489             "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);"
   3490             "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;"
   3491             "B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1");
   3492 }
   3493 
   3494 /////////////////////////////////////////////////////////////////////
   3495 
   3496 static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
   3497   testing::StopTiming();
   3498   string s;
   3499   for (int in = 0; in < 10; in++) {
   3500     s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in);
   3501   }
   3502   random::PhiloxRandom philox(301, 17);
   3503   random::SimplePhilox rnd(&philox);
   3504   for (int op = 0; op < op_nodes; op++) {
   3505     s += strings::Printf(
   3506         "node { name: 'op%04d' op: 'Zeta' attr { key: 'T' value { "
   3507         "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
   3508         op, rnd.Uniform(10), rnd.Uniform(10));
   3509   }
   3510 
   3511   bool first = true;
   3512   while (iters > 0) {
   3513     Graph* graph = new Graph(OpRegistry::Global());
   3514     InitGraph(s, graph);
   3515     int N = graph->num_node_ids();
   3516     if (first) {
   3517       testing::SetLabel(strings::StrCat("Per graph node.  Nodes: ", N));
   3518       first = false;
   3519     }
   3520     {
   3521       testing::StartTiming();
   3522       std::unique_ptr<Graph> ug(graph);
   3523       RunMklLayoutRewritePass(&ug);
   3524       testing::StopTiming();
   3525     }
   3526     iters -= N;  // Our benchmark units are individual graph nodes,
   3527                  // not whole graphs
   3528     // delete graph;
   3529   }
   3530 }
   3531 BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
   3532 
   3533 }  // namespace
   3534 
   3535 #endif  // INTEL_MKL_ML
   3536 
   3537 }  // namespace tensorflow
   3538 
   3539 #endif /* INTEL_MKL */
   3540