1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifdef INTEL_MKL 17 18 #include "tensorflow/core/graph/mkl_layout_pass.h" 19 #include "tensorflow/core/graph/mkl_graph_util.h" 20 21 #include <algorithm> 22 #include <string> 23 #include <vector> 24 25 #include "tensorflow/core/framework/op.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/graph/graph.h" 28 #include "tensorflow/core/graph/graph_constructor.h" 29 #include "tensorflow/core/graph/testlib.h" 30 #include "tensorflow/core/kernels/ops_util.h" 31 #include "tensorflow/core/lib/random/simple_philox.h" 32 #include "tensorflow/core/lib/strings/str_util.h" 33 #include "tensorflow/core/lib/strings/stringprintf.h" 34 #include "tensorflow/core/platform/logging.h" 35 #include "tensorflow/core/platform/protobuf.h" 36 #include "tensorflow/core/platform/test.h" 37 #include "tensorflow/core/platform/test_benchmark.h" 38 39 namespace tensorflow { 40 41 #ifdef INTEL_MKL_ML 42 43 namespace { 44 45 const char kCPUDevice[] = "/job:a/replica:0/task:0/device:CPU:0"; 46 const char kGPUDevice[] = "/job:a/replica:0/task:0/device:GPU:0"; 47 48 static void InitGraph(const string& s, Graph* graph, 49 const string& device = kCPUDevice) { 50 GraphDef graph_def; 51 52 auto parser = protobuf::TextFormat::Parser(); 53 // parser.AllowRelaxedWhitespace(true); 54 CHECK(parser.MergeFromString(s, &graph_def)) << s; 55 GraphConstructorOptions opts; 56 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph)); 57 58 for (Node* node : graph->nodes()) { 59 node->set_assigned_device_name(device); 60 } 61 } 62 63 class MklLayoutPassTest : public ::testing::Test { 64 public: 65 MklLayoutPassTest() : graph_(OpRegistry::Global()) {} 66 67 void InitGraph(const string& s, const string& device = kCPUDevice) { 68 ::tensorflow::InitGraph(s, &graph_, device); 69 original_ = CanonicalGraphString(&graph_); 70 } 71 72 static bool IncludeNode(const Node* n) { return n->IsOp(); } 73 74 static string EdgeId(const Node* n, int index) { 75 if (index == 0) { 76 return n->name(); 77 } else if (index == Graph::kControlSlot) { 78 return strings::StrCat(n->name(), ":control"); 79 } else { 80 return strings::StrCat(n->name(), ":", index); 81 } 82 } 83 84 string CanonicalGraphString(Graph* g) { 85 std::vector<string> nodes; 86 std::vector<string> edges; 87 for (const Node* n : g->nodes()) { 88 if (IncludeNode(n)) { 89 nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")")); 90 } 91 } 92 for (const Edge* e : g->edges()) { 93 if (IncludeNode(e->src()) && IncludeNode(e->dst())) { 94 edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->", 95 EdgeId(e->dst(), e->dst_input()))); 96 } 97 } 98 // Canonicalize 99 std::sort(nodes.begin(), nodes.end()); 100 std::sort(edges.begin(), edges.end()); 101 return strings::StrCat(str_util::Join(nodes, ";"), "|", 102 str_util::Join(edges, ";")); 103 } 104 105 string DoMklLayoutOptimizationPass() { 106 string before = CanonicalGraphString(&graph_); 107 LOG(ERROR) << "Before MKL layout rewrite pass: " << before; 108 109 std::unique_ptr<Graph>* ug = new std::unique_ptr<Graph>(&graph_); 110 RunMklLayoutRewritePass(ug); 111 112 string result = CanonicalGraphString(&graph_); 113 LOG(ERROR) << "After MKL layout rewrite pass: " << result; 114 return result; 115 } 116 117 const string& OriginalGraph() const { return original_; } 118 119 Graph graph_; 120 string original_; 121 }; 122 123 REGISTER_OP("Input").Output("o: float").SetIsStateful(); 124 REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful(); 125 REGISTER_OP("HalfInput").Output("o: half").SetIsStateful(); 126 REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful(); 127 REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful(); 128 REGISTER_OP("_MklInput2") 129 .Output("o: uint8") 130 .Output("o1: uint8") 131 .SetIsStateful(); 132 133 ///////////////////////////////////////////////////////////////////// 134 // Unit tests related to node merge optiimization 135 ///////////////////////////////////////////////////////////////////// 136 137 TEST_F(MklLayoutPassTest, Basic) { 138 InitGraph( 139 "node { name: 'A' op: 'Input'}" 140 "node { name: 'B' op: 'Input'}" 141 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 142 " input: ['A', 'B'] }" 143 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 144 " input: ['A', 'B'] }"); 145 EXPECT_EQ(DoMklLayoutOptimizationPass(), 146 "A(Input);B(Input);C(Zeta);D(Zeta)|" 147 "A->C;A->D;B->C:1;B->D:1"); 148 } 149 150 // Test set 1: Conv2D + AddBias 151 152 // C=_MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Zeta(E,Y) (for interleaved ordering) 153 // C=_MklConv2D(A,B,M,N); E=BiasAdd(C,D); Z=Zeta(E,Y) (for contiguous ordering) 154 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) { 155 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); 156 InitGraph( 157 "node { name: 'A' op: 'Input'}" 158 "node { name: 'B' op: 'Input'}" 159 "node { name: 'M' op: '_MklInput'}" 160 "node { name: 'N' op: '_MklInput'}" 161 "node { name: 'C' op: '_MklConv2D'" 162 " attr { key: 'T' value { type: DT_FLOAT } }" 163 " attr { key: 'data_format' value { s: 'NCHW' } }" 164 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 165 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 166 " attr { key: 'padding' value { s: 'SAME' } }" 167 " input: ['A', 'B', 'M', 'N']}" 168 "node { name: 'D' op: 'Input'}" 169 "node { name: 'E' op: 'BiasAdd'" 170 " attr { key: 'T' value { type: DT_FLOAT } }" 171 " attr { key: 'data_format' value { s: 'NCHW' } }" 172 " input: ['C', 'D'] }" 173 "node { name: 'Y' op: 'Input'}" 174 "node { name: 'Z' op: 'Zeta'" 175 " attr {key: 'T' value { type: DT_FLOAT } }" 176 " input: ['E', 'Y']}"); 177 EXPECT_EQ(DoMklLayoutOptimizationPass(), 178 "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);" 179 "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->E;" 180 "A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;M->E:3;" 181 "N->E:4;Y->Z:1"); 182 } 183 184 // C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Zeta(E,Y) (for interleaved) 185 // C=_MklConv2D(A,B,M:1,N:1); E=BiasAdd(C,D); Z=Zeta(E,Y) (for contiguous) 186 // Test for correct output slots selected 187 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) { 188 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); 189 InitGraph( 190 "node { name: 'A' op: 'Input'}" 191 "node { name: 'B' op: 'Input'}" 192 "node { name: 'M' op: '_MklInput2'}" 193 "node { name: 'N' op: '_MklInput2'}" 194 "node { name: 'C' op: '_MklConv2D'" 195 " attr { key: 'T' value { type: DT_FLOAT } }" 196 " attr { key: 'data_format' value { s: 'NCHW' } }" 197 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 198 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 199 " attr { key: 'padding' value { s: 'SAME' } }" 200 " input: ['A', 'B', 'M:1', 'N:1']}" 201 "node { name: 'D' op: 'Input'}" 202 "node { name: 'E' op: 'BiasAdd'" 203 " attr { key: 'T' value { type: DT_FLOAT } }" 204 " attr { key: 'data_format' value { s: 'NCHW' } }" 205 " input: ['C', 'D'] }" 206 "node { name: 'Y' op: 'Input'}" 207 "node { name: 'Z' op: 'Zeta'" 208 " attr {key: 'T' value { type: DT_FLOAT } }" 209 " input: ['E', 'Y']}"); 210 EXPECT_EQ(DoMklLayoutOptimizationPass(), 211 "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);" 212 "M(_MklInput2);N(_MklInput2);Y(Input);Z(Zeta)|A->E;" 213 "A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;" 214 "M:1->E:3;N:1->E:4;Y->Z:1"); 215 } 216 217 // C=Conv2D(A,B); E=BiasAdd(C,D); Z=Zeta(E,Y); 218 // This is a case of node rewrite followed by node merge. 219 // We will first rewrite Conv2D to _MklConv2D, and then merge _MklConv2D 220 // with BiasAdd to produce _MklConv2DWithBias. 221 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) { 222 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); 223 InitGraph( 224 "node { name: 'A' op: 'Input'}" 225 "node { name: 'B' op: 'Input'}" 226 "node { name: 'C' op: 'Conv2D'" 227 " attr { key: 'T' value { type: DT_FLOAT } }" 228 " attr { key: 'data_format' value { s: 'NCHW' } }" 229 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 230 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 231 " attr { key: 'padding' value { s: 'SAME' } }" 232 " input: ['A', 'B']}" 233 "node { name: 'D' op: 'Input'}" 234 "node { name: 'E' op: 'BiasAdd'" 235 " attr { key: 'T' value { type: DT_FLOAT } }" 236 " attr { key: 'data_format' value { s: 'NCHW' } }" 237 " input: ['C', 'D'] }" 238 "node { name: 'Y' op: 'Input'}" 239 "node { name: 'Z' op: 'Zeta'" 240 " attr {key: 'T' value { type: DT_FLOAT } }" 241 " input: ['E', 'Y']}"); 242 EXPECT_EQ(DoMklLayoutOptimizationPass(), 243 "A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 244 "DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Zeta)|" 245 "A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" 246 "A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;" 247 "DMT/_2->E:5;E->Z;Y->Z:1"); 248 } 249 250 // Graph contains only _MklConv2D, no AddBias. 251 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) { 252 InitGraph( 253 "node { name: 'A' op: 'Input'}" 254 "node { name: 'B' op: 'Input'}" 255 "node { name: 'M' op: '_MklInput'}" 256 "node { name: 'N' op: '_MklInput'}" 257 "node { name: 'C' op: '_MklConv2D'" 258 " attr { key: 'T' value { type: DT_FLOAT } }" 259 " attr { key: 'data_format' value { s: 'NCHW' } }" 260 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 261 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 262 " attr { key: 'padding' value { s: 'SAME' } }" 263 " input: ['A', 'B', 'M', 'N']}"); 264 EXPECT_EQ(DoMklLayoutOptimizationPass(), 265 "A(Input);B(Input);C(_MklConv2D);M(_MklInput);N(_MklInput)|" 266 "A->C;B->C:1;M->C:2;N->C:3"); 267 } 268 269 // _MklConv2D output does not go to BiasAdd. 270 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) { 271 InitGraph( 272 "node { name: 'A' op: 'Input'}" 273 "node { name: 'B' op: 'Input'}" 274 "node { name: 'M' op: '_MklInput'}" 275 "node { name: 'N' op: '_MklInput'}" 276 "node { name: 'C' op: '_MklConv2D'" 277 " attr { key: 'T' value { type: DT_FLOAT } }" 278 " attr { key: 'data_format' value { s: 'NCHW' } }" 279 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 280 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 281 " attr { key: 'padding' value { s: 'SAME' } }" 282 " input: ['A', 'B', 'M', 'N']}" 283 "node { name: 'D' op: 'Input'}" 284 "node { name: 'E' op: 'Input'}" 285 "node { name: 'F' op: 'BiasAdd'" 286 " attr { key: 'T' value { type: DT_FLOAT } }" 287 " attr { key: 'data_format' value { s: 'NCHW' } }" 288 " input: ['D', 'E'] }"); // Output of _MklConv2D does not go to BiasAdd. 289 EXPECT_EQ(DoMklLayoutOptimizationPass(), 290 "A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);" 291 "M(_MklInput);N(_MklInput)|A->C;B->C:1;D->F;E->F:1;M->C:2;N->C:3"); 292 } 293 294 // _MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Zeta). 295 // Merge should not be done in such case. 296 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) { 297 InitGraph( 298 "node { name: 'A' op: 'Input'}" 299 "node { name: 'B' op: 'Input'}" 300 "node { name: 'M' op: '_MklInput'}" 301 "node { name: 'N' op: '_MklInput'}" 302 "node { name: 'C' op: '_MklConv2D'" 303 " attr { key: 'T' value { type: DT_FLOAT } }" 304 " attr { key: 'data_format' value { s: 'NCHW' } }" 305 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 306 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 307 " attr { key: 'padding' value { s: 'SAME' } }" 308 " input: ['A', 'B', 'M', 'N']}" 309 "node { name: 'D' op: 'Input'}" 310 "node { name: 'E' op: 'Input'}" 311 "node { name: 'F' op: 'BiasAdd'" 312 " attr { key: 'T' value { type: DT_FLOAT } }" 313 " attr { key: 'data_format' value { s: 'NCHW' } }" 314 " input: ['D', 'E'] }" // Conv2D has two outputs. 315 // No merge should happen. 316 "node { name: 'G' op: 'Zeta'" 317 " attr { key: 'T' value { type: DT_FLOAT } }" 318 " input: ['C', 'E'] }"); 319 EXPECT_EQ(DoMklLayoutOptimizationPass(), 320 "A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);" 321 "G(Zeta);M(_MklInput);N(_MklInput)|A->C;B->C:1;C->G;D->F;" 322 "E->F:1;E->G:1;M->C:2;N->C:3"); 323 } 324 325 // data_format attribute value mismatch. Merge should not be done 326 // in such case. 327 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) { 328 InitGraph( 329 "node { name: 'A' op: 'Input'}" 330 "node { name: 'B' op: 'Input'}" 331 "node { name: 'M' op: '_MklInput'}" 332 "node { name: 'N' op: '_MklInput'}" 333 "node { name: 'C' op: '_MklConv2D'" 334 " attr { key: 'T' value { type: DT_FLOAT } }" 335 " attr { key: 'data_format' value { s: 'NCHW' } }" 336 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 337 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 338 " attr { key: 'padding' value { s: 'SAME' } }" 339 " input: ['A', 'B', 'M', 'N']}" 340 "node { name: 'D' op: 'Input'}" 341 "node { name: 'E' op: 'BiasAdd'" 342 " attr { key: 'T' value { type: DT_FLOAT } }" 343 " attr { key: 'data_format' value { s: 'NHCW' } }" 344 " input: ['C', 'D'] }"); 345 EXPECT_EQ(DoMklLayoutOptimizationPass(), 346 "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);M(_MklInput);" 347 "N(_MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3"); 348 } 349 350 // Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias 351 // rewrite tests 352 353 // BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter 354 // and BackpropInput 355 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) { 356 InitGraph( 357 "node { name: 'A' op: 'Input'}" 358 "node { name: 'B' op: 'Input'}" 359 "node { name: 'C' op: 'Input'}" 360 "node { name: 'M' op: '_MklInput'}" 361 "node { name: 'N' op: '_MklInput'}" 362 "node { name: 'O' op: '_MklInput'}" 363 "node { name: 'D' op: '_MklConv2DWithBias'" 364 " attr { key: 'T' value { type: DT_FLOAT } }" 365 " attr { key: 'data_format' value { s: 'NCHW' } }" 366 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 367 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 368 " attr { key: 'padding' value { s: 'SAME' } }" 369 " input: ['A', 'B', 'C', 'M', 'N', 'O']}" 370 "node { name: 'E' op: 'Zeta'" 371 " attr {key: 'T' value { type: DT_FLOAT } }" 372 " input: ['D', 'A']}" 373 "node { name: 'F' op: 'Int32Input'}" 374 "node { name: 'G' op: '_MklConv2DBackpropFilter'" 375 " attr { key: 'T' value { type: DT_FLOAT } }" 376 " attr { key: 'data_format' value { s: 'NCHW' } }" 377 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 378 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 379 " attr { key: 'padding' value { s: 'SAME' } }" 380 " input: ['A', 'F', 'E', 'M', 'N', 'O'] }" 381 "node { name: 'H' op: 'Int32Input'}" 382 "node { name: 'I' op: '_MklConv2DBackpropInput'" 383 " attr { key: 'T' value { type: DT_FLOAT } }" 384 " attr { key: 'data_format' value { s: 'NCHW' } }" 385 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 386 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 387 " attr { key: 'padding' value { s: 'SAME' } }" 388 " input: ['H', 'B', 'E', 'M', 'N', 'O']}" 389 "node { name: 'J' op: 'BiasAddGrad'" 390 " attr { key: 'T' value { type: DT_FLOAT } }" 391 " attr { key: 'data_format' value { s: 'NCHW' } }" 392 " input: ['E'] }"); 393 EXPECT_EQ(DoMklLayoutOptimizationPass(), 394 "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);" 395 "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);" 396 "I(_MklConv2DBackpropInput);J(_MklConv2DWithBiasBackpropBias);" 397 "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G;B->D:1;" 398 "B->I:1;C->D:2;D->E;DMT/_0->J:1;E->G:2;E->I:2;E->J;" 399 "E:control->DMT/_0:control;F->G:1;H->I;M->D:3;M->G:3;M->I:3;" 400 "N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5"); 401 } 402 403 // BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter 404 // and BackpropInput. But nodes do not match criteria for rewrite. So 405 // rewrite should not happen. 406 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative1) { 407 InitGraph( 408 "node { name: 'A' op: 'Input'}" 409 "node { name: 'B' op: 'Input'}" 410 "node { name: 'C' op: 'Input'}" 411 "node { name: 'M' op: '_MklInput'}" 412 "node { name: 'N' op: '_MklInput'}" 413 "node { name: 'O' op: '_MklInput'}" 414 "node { name: 'D' op: '_MklConv2DWithBias'" 415 " attr { key: 'T' value { type: DT_FLOAT } }" 416 " attr { key: 'data_format' value { s: 'NCHW' } }" 417 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 418 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 419 " attr { key: 'padding' value { s: 'SAME' } }" 420 " input: ['A', 'B', 'C', 'M', 'N', 'O']}" 421 "node { name: 'E' op: 'Zeta'" 422 " attr {key: 'T' value { type: DT_FLOAT } }" 423 " input: ['D', 'A']}" 424 "node { name: 'F' op: 'Int32Input'}" 425 "node { name: 'G' op: '_MklConv2DBackpropFilter'" 426 " attr { key: 'T' value { type: DT_FLOAT } }" 427 " attr { key: 'data_format' value { s: 'NCHW' } }" 428 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 429 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 430 " attr { key: 'padding' value { s: 'SAME' } }" 431 " input: ['E', 'F', 'A', 'M', 'N', 'O'] }" 432 "node { name: 'H' op: 'Int32Input'}" 433 "node { name: 'I' op: '_MklConv2DBackpropInput'" 434 " attr { key: 'T' value { type: DT_FLOAT } }" 435 " attr { key: 'data_format' value { s: 'NCHW' } }" 436 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 437 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 438 " attr { key: 'padding' value { s: 'SAME' } }" 439 " input: ['H', 'B', 'E', 'M', 'N', 'O']}" 440 "node { name: 'J' op: 'BiasAddGrad'" 441 " attr { key: 'T' value { type: DT_FLOAT } }" 442 " attr { key: 'data_format' value { s: 'NCHW' } }" 443 " input: ['E'] }"); 444 EXPECT_EQ(DoMklLayoutOptimizationPass(), 445 "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" 446 "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);" 447 "I(_MklConv2DBackpropInput);J(BiasAddGrad);" 448 "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;" 449 "B->I:1;C->D:2;D->E;E->G;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;" 450 "M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5"); 451 } 452 453 // BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter 454 // and BackpropInput. But nodes do not match criteria for rewrite. So 455 // rewrite should not happen. 456 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative2) { 457 InitGraph( 458 "node { name: 'A' op: 'Input'}" 459 "node { name: 'B' op: 'Input'}" 460 "node { name: 'C' op: 'Input'}" 461 "node { name: 'M' op: '_MklInput'}" 462 "node { name: 'N' op: '_MklInput'}" 463 "node { name: 'O' op: '_MklInput'}" 464 "node { name: 'D' op: '_MklConv2DWithBias'" 465 " attr { key: 'T' value { type: DT_FLOAT } }" 466 " attr { key: 'data_format' value { s: 'NCHW' } }" 467 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 468 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 469 " attr { key: 'padding' value { s: 'SAME' } }" 470 " input: ['B', 'A', 'C', 'M', 'N', 'O']}" 471 "node { name: 'E' op: 'Zeta'" 472 " attr {key: 'T' value { type: DT_FLOAT } }" 473 " input: ['D', 'A']}" 474 "node { name: 'F' op: 'Int32Input'}" 475 "node { name: 'G' op: '_MklConv2DBackpropFilter'" 476 " attr { key: 'T' value { type: DT_FLOAT } }" 477 " attr { key: 'data_format' value { s: 'NCHW' } }" 478 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 479 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 480 " attr { key: 'padding' value { s: 'SAME' } }" 481 " input: ['A', 'F', 'E', 'M', 'N', 'O'] }" 482 "node { name: 'H' op: 'Int32Input'}" 483 "node { name: 'I' op: '_MklConv2DBackpropInput'" 484 " attr { key: 'T' value { type: DT_FLOAT } }" 485 " attr { key: 'data_format' value { s: 'NCHW' } }" 486 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 487 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 488 " attr { key: 'padding' value { s: 'SAME' } }" 489 " input: ['H', 'B', 'E', 'M', 'N', 'O']}" 490 "node { name: 'J' op: 'BiasAddGrad'" 491 " attr { key: 'T' value { type: DT_FLOAT } }" 492 " attr { key: 'data_format' value { s: 'NCHW' } }" 493 " input: ['E'] }"); 494 EXPECT_EQ(DoMklLayoutOptimizationPass(), 495 "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" 496 "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);" 497 "I(_MklConv2DBackpropInput);J(BiasAddGrad);" 498 "M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;" 499 "B->I:1;C->D:2;D->E;E->G:2;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;" 500 "M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5"); 501 } 502 503 // BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only 504 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) { 505 InitGraph( 506 "node { name: 'A' op: 'Input'}" 507 "node { name: 'B' op: 'Input'}" 508 "node { name: 'C' op: 'Input'}" 509 "node { name: 'M' op: '_MklInput'}" 510 "node { name: 'N' op: '_MklInput'}" 511 "node { name: 'O' op: '_MklInput'}" 512 "node { name: 'D' op: '_MklConv2DWithBias'" 513 " attr { key: 'T' value { type: DT_FLOAT } }" 514 " attr { key: 'data_format' value { s: 'NCHW' } }" 515 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 516 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 517 " attr { key: 'padding' value { s: 'SAME' } }" 518 " input: ['A', 'B', 'C', 'M', 'N', 'O']}" 519 "node { name: 'E' op: 'Zeta'" 520 " attr {key: 'T' value { type: DT_FLOAT } }" 521 " input: ['D', 'A']}" 522 "node { name: 'F' op: 'Int32Input'}" 523 "node { name: 'G' op: '_MklConv2DBackpropFilter'" 524 " attr { key: 'T' value { type: DT_FLOAT } }" 525 " attr { key: 'data_format' value { s: 'NCHW' } }" 526 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 527 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 528 " attr { key: 'padding' value { s: 'SAME' } }" 529 " input: ['A', 'F', 'E', 'M', 'N', 'O'] }" 530 "node { name: 'H' op: 'BiasAddGrad'" 531 " attr { key: 'T' value { type: DT_FLOAT } }" 532 " attr { key: 'data_format' value { s: 'NCHW' } }" 533 " input: ['E'] }"); 534 EXPECT_EQ(DoMklLayoutOptimizationPass(), 535 "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);" 536 "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);" 537 "H(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);" 538 "O(_MklInput)|A->D;A->E:1;A->G;B->D:1;C->D:2;D->E;DMT/_0->H:1;" 539 "E->G:2;E->H;E:control->DMT/_0:control;F->G:1;M->D:3;M->G:3;" 540 "N->D:4;N->G:4;O->D:5;O->G:5"); 541 } 542 543 // BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only 544 // But BackpropFilter node inputs do not satisfy criteria for rewrite. 545 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative1) { 546 InitGraph( 547 "node { name: 'A' op: 'Input'}" 548 "node { name: 'B' op: 'Input'}" 549 "node { name: 'C' op: 'Input'}" 550 "node { name: 'M' op: '_MklInput'}" 551 "node { name: 'N' op: '_MklInput'}" 552 "node { name: 'O' op: '_MklInput'}" 553 "node { name: 'D' op: '_MklConv2DWithBias'" 554 " attr { key: 'T' value { type: DT_FLOAT } }" 555 " attr { key: 'data_format' value { s: 'NCHW' } }" 556 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 557 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 558 " attr { key: 'padding' value { s: 'SAME' } }" 559 " input: ['A', 'B', 'C', 'M', 'N', 'O']}" 560 "node { name: 'E' op: 'Zeta'" 561 " attr {key: 'T' value { type: DT_FLOAT } }" 562 " input: ['D', 'A']}" 563 "node { name: 'F' op: 'Int32Input'}" 564 "node { name: 'G' op: '_MklConv2DBackpropFilter'" 565 " attr { key: 'T' value { type: DT_FLOAT } }" 566 " attr { key: 'data_format' value { s: 'NCHW' } }" 567 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 568 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 569 " attr { key: 'padding' value { s: 'SAME' } }" 570 " input: ['E', 'F', 'A', 'M', 'N', 'O'] }" 571 "node { name: 'H' op: 'BiasAddGrad'" 572 " attr { key: 'T' value { type: DT_FLOAT } }" 573 " attr { key: 'data_format' value { s: 'NCHW' } }" 574 " input: ['E'] }"); 575 EXPECT_EQ(DoMklLayoutOptimizationPass(), 576 "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" 577 "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);" 578 "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;" 579 "C->D:2;D->E;E->G;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;" 580 "O->G:5"); 581 } 582 583 // BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only 584 // But BackpropFilter node inputs do not satisfy criteria for rewrite. 585 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative2) { 586 InitGraph( 587 "node { name: 'A' op: 'Input'}" 588 "node { name: 'B' op: 'Input'}" 589 "node { name: 'C' op: 'Input'}" 590 "node { name: 'M' op: '_MklInput'}" 591 "node { name: 'N' op: '_MklInput'}" 592 "node { name: 'O' op: '_MklInput'}" 593 "node { name: 'D' op: '_MklConv2DWithBias'" 594 " attr { key: 'T' value { type: DT_FLOAT } }" 595 " attr { key: 'data_format' value { s: 'NCHW' } }" 596 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 597 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 598 " attr { key: 'padding' value { s: 'SAME' } }" 599 " input: ['B', 'A', 'C', 'M', 'N', 'O']}" 600 "node { name: 'E' op: 'Zeta'" 601 " attr {key: 'T' value { type: DT_FLOAT } }" 602 " input: ['D', 'A']}" 603 "node { name: 'F' op: 'Int32Input'}" 604 "node { name: 'G' op: '_MklConv2DBackpropFilter'" 605 " attr { key: 'T' value { type: DT_FLOAT } }" 606 " attr { key: 'data_format' value { s: 'NCHW' } }" 607 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 608 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 609 " attr { key: 'padding' value { s: 'SAME' } }" 610 " input: ['A', 'F', 'E', 'M', 'N', 'O'] }" 611 "node { name: 'H' op: 'BiasAddGrad'" 612 " attr { key: 'T' value { type: DT_FLOAT } }" 613 " attr { key: 'data_format' value { s: 'NCHW' } }" 614 " input: ['E'] }"); 615 EXPECT_EQ(DoMklLayoutOptimizationPass(), 616 "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" 617 "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);" 618 "M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;" 619 "C->D:2;D->E;E->G:2;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;" 620 "O->G:5"); 621 } 622 623 // No _MklConv2DWithBias in context, but _MklConv2D in context. 624 // No rewrite for BiasAddGrad should happen. 625 // C=_MklConv2D(A,M,B,N); D=Zeta(C,A); E=BiasAddGrad(D) (for interleaved) 626 // C=_MklConv2D(A,B,M,N); D=Zeta(C,A); E=BiasAddGrad(D) (for contiguous) 627 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) { 628 InitGraph( 629 "node { name: 'A' op: 'Input'}" 630 "node { name: 'B' op: 'Input'}" 631 "node { name: 'M' op: '_MklInput'}" 632 "node { name: 'N' op: '_MklInput'}" 633 "node { name: 'C' op: '_MklConv2D'" 634 " attr { key: 'T' value { type: DT_FLOAT } }" 635 " attr { key: 'data_format' value { s: 'NCHW' } }" 636 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 637 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 638 " attr { key: 'padding' value { s: 'SAME' } }" 639 " input: ['A', 'B', 'M', 'N']}" 640 "node { name: 'D' op: 'Zeta'" 641 " attr {key: 'T' value { type: DT_FLOAT } }" 642 " input: ['C', 'A']}" 643 "node { name: 'E' op: 'BiasAddGrad'" 644 " attr { key: 'T' value { type: DT_FLOAT } }" 645 " attr { key: 'data_format' value { s: 'NCHW' } }" 646 " input: ['D'] }"); 647 EXPECT_EQ(DoMklLayoutOptimizationPass(), 648 "A(Input);B(Input);C(_MklConv2D);D(Zeta);E(BiasAddGrad);" 649 "M(_MklInput);N(_MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;" 650 "M->C:2;N->C:3"); 651 } 652 653 // No Conv2D in the context for BiasAddGrad. No rewrite should happen. 654 // C=Polygamma(A,B); D=Zeta(C,A); E=BiasAddGrad(D) 655 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D) { 656 InitGraph( 657 "node { name: 'A' op: 'Input'}" 658 "node { name: 'B' op: 'Input'}" 659 "node { name: 'C' op: 'Polygamma'" 660 " attr { key: 'T' value { type: DT_FLOAT } }" 661 " input: ['A', 'B']}" 662 "node { name: 'D' op: 'Zeta'" 663 " attr {key: 'T' value { type: DT_FLOAT } }" 664 " input: ['C', 'A']}" 665 "node { name: 'E' op: 'BiasAddGrad'" 666 " attr { key: 'T' value { type: DT_FLOAT } }" 667 " attr { key: 'data_format' value { s: 'NCHW' } }" 668 " input: ['D'] }"); 669 EXPECT_EQ(DoMklLayoutOptimizationPass(), 670 "A(Input);B(Input);C(Polygamma);D(Zeta);E(BiasAddGrad)|" 671 "A->C;A->D:1;B->C:1;C->D;D->E"); 672 } 673 674 // No Conv2D in the context for BiasAddGrad, but MatMul in context. 675 // Rewrite should happen, but name of BiasAddGrad does not change. 676 // C=MatMul(A,B); D=Zeta(C,A); E=BiasAddGrad(D) 677 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D_MatMul) { 678 InitGraph( 679 "node { name: 'A' op: 'Input'}" 680 "node { name: 'B' op: 'Input'}" 681 "node { name: 'C' op: 'MatMul'" 682 " attr { key: 'T' value { type: DT_FLOAT } }" 683 " attr { key: 'transpose_a' value { b: false } }" 684 " attr { key: 'transpose_b' value { b: false } }" 685 " input: ['A', 'B']}" 686 "node { name: 'D' op: 'Zeta'" 687 " attr {key: 'T' value { type: DT_FLOAT } }" 688 " input: ['C', 'A']}" 689 "node { name: 'E' op: 'BiasAddGrad'" 690 " attr { key: 'T' value { type: DT_FLOAT } }" 691 " attr { key: 'data_format' value { s: 'NCHW' } }" 692 " input: ['D'] }"); 693 EXPECT_EQ(DoMklLayoutOptimizationPass(), 694 "A(Input);B(Input);C(MatMul);D(Zeta);E(BiasAddGrad)|" 695 "A->C;A->D:1;B->C:1;C->D;D->E"); 696 } 697 698 // Test set 3: MatMul..BiasAddGrad -> BiasAddGrad rewrite tests 699 // C=MatMul(A,B); D=Zeta(C,A); E=BiasAddGrad(D) 700 TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Positive) { 701 InitGraph( 702 "node { name: 'A' op: 'Input'}" 703 "node { name: 'B' op: 'Input'}" 704 "node { name: 'C' op: 'MatMul'" 705 " attr { key: 'T' value { type: DT_FLOAT } }" 706 " attr { key: 'transpose_a' value { b: false } }" 707 " attr { key: 'transpose_b' value { b: false } }" 708 " input: ['A', 'B']}" 709 "node { name: 'D' op: 'Zeta'" 710 " attr {key: 'T' value { type: DT_FLOAT } }" 711 " input: ['C', 'A']}" 712 "node { name: 'E' op: 'BiasAddGrad'" 713 " attr { key: 'T' value { type: DT_FLOAT } }" 714 " attr { key: 'data_format' value { s: 'NCHW' } }" 715 " input: ['D'] }"); 716 EXPECT_EQ(DoMklLayoutOptimizationPass(), 717 "A(Input);B(Input);C(MatMul);D(Zeta);E(BiasAddGrad)|" 718 "A->C;A->D:1;B->C:1;C->D;D->E"); 719 } 720 721 // No MatMul in the context for BiasAddGrad. No rewrite should happen. 722 // C=Polygamma(A,B); D=Zeta(C,A); E=BiasAddGrad(D) 723 TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Negative_NoMatMul) { 724 InitGraph( 725 "node { name: 'A' op: 'Input'}" 726 "node { name: 'B' op: 'Input'}" 727 "node { name: 'C' op: 'Polygamma'" 728 " attr { key: 'T' value { type: DT_FLOAT } }" 729 " input: ['A', 'B']}" 730 "node { name: 'D' op: 'Zeta'" 731 " attr {key: 'T' value { type: DT_FLOAT } }" 732 " input: ['C', 'A']}" 733 "node { name: 'E' op: 'BiasAddGrad'" 734 " attr { key: 'T' value { type: DT_FLOAT } }" 735 " attr { key: 'data_format' value { s: 'NCHW' } }" 736 " input: ['D'] }"); 737 EXPECT_EQ(DoMklLayoutOptimizationPass(), 738 "A(Input);B(Input);C(Polygamma);D(Zeta);E(BiasAddGrad)|" 739 "A->C;A->D:1;B->C:1;C->D;D->E"); 740 } 741 742 ///////////////////////////////////////////////////////////////////// 743 // Unit tests related to rewriting node to Mkl node 744 ///////////////////////////////////////////////////////////////////// 745 746 // Single Conv2D Op; No Mkl layer on the input and on the output. 747 // We will generate dummy Mkl tensor as 2nd input of Conv2D. 748 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) { 749 InitGraph( 750 "node { name: 'A' op: 'Input'}" 751 "node { name: 'B' op: 'Input'}" 752 "node { name: 'C' op: 'Conv2D'" 753 " attr { key: 'T' value { type: DT_FLOAT } }" 754 " attr { key: 'data_format' value { s: 'NCHW' } }" 755 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 756 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 757 " attr { key: 'padding' value { s: 'SAME' } }" 758 " input: ['A', 'B']}" 759 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 760 " input: ['B', 'C'] }"); 761 EXPECT_EQ(DoMklLayoutOptimizationPass(), 762 "A(Input);B(Input);C(_MklConv2D);D(Zeta);DMT/_0(Const);" 763 "DMT/_1(Const)|A->C;A:control->DMT/_0:control;" 764 "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;" 765 "DMT/_1->C:3"); 766 } 767 768 // 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will 769 // have 2 outputs, both of which will be inputs to next Conv2D. 770 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) { 771 InitGraph( 772 "node { name: 'A' op: 'Input'}" 773 "node { name: 'B' op: 'Input'}" 774 "node { name: 'C' op: 'Conv2D'" 775 " attr { key: 'T' value { type: DT_FLOAT } }" 776 " attr { key: 'data_format' value { s: 'NCHW' } }" 777 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 778 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 779 " attr { key: 'padding' value { s: 'SAME' } }" 780 " input: ['A', 'B']}" 781 "node { name: 'D' op: 'Conv2D'" 782 " attr { key: 'T' value { type: DT_FLOAT } }" 783 " attr { key: 'data_format' value { s: 'NCHW' } }" 784 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 785 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 786 " attr { key: 'padding' value { s: 'SAME' } }" 787 " input: ['A', 'C']}" 788 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 789 " input: ['C', 'D'] }"); 790 EXPECT_EQ(DoMklLayoutOptimizationPass(), 791 "A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);" 792 "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->C;A->D;" 793 "A:control->DMT/_0:control;A:control->DMT/_1:control;" 794 "A:control->DMT/_2:control;B->C:1;C->D:1;C->E;" 795 "C:2->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2"); 796 } 797 798 // Conv2D with INT32 which is not supported by Mkl 799 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) { 800 InitGraph( 801 "node { name: 'A' op: 'HalfInput'}" 802 "node { name: 'B' op: 'HalfInput'}" 803 "node { name: 'C' op: 'Conv2D'" 804 " attr { key: 'T' value { type: DT_HALF } }" 805 " attr { key: 'data_format' value { s: 'NCHW' } }" 806 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 807 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 808 " attr { key: 'padding' value { s: 'SAME' } }" 809 " input: ['A', 'B']}" 810 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_HALF } }" 811 " input: ['B', 'C'] }"); 812 EXPECT_EQ(DoMklLayoutOptimizationPass(), 813 "A(HalfInput);B(HalfInput);C(Conv2D);D(Zeta)|" 814 "A->C;B->C:1;B->D;C->D:1"); 815 } 816 817 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_Positive) { 818 InitGraph( 819 "node { name: 'A' op: 'Input'}" 820 "node { name: 'B' op: 'Int32Input'}" 821 "node { name: 'C' op: 'Input'}" 822 "node { name: 'D' op: 'Conv2DBackpropFilter'" 823 " attr { key: 'T' value { type: DT_FLOAT } }" 824 " attr { key: 'data_format' value { s: 'NCHW' } }" 825 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 826 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 827 " attr { key: 'padding' value { s: 'SAME' } }" 828 " input: ['A', 'B', 'C']}" 829 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 830 " input: ['A', 'D'] }"); 831 EXPECT_EQ(DoMklLayoutOptimizationPass(), 832 "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropFilter);" 833 "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|" 834 "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" 835 "A:control->DMT/_2:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;" 836 "DMT/_1->D:4;DMT/_2->D:5"); 837 } 838 839 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradInput_Positive) { 840 InitGraph( 841 "node { name: 'A' op: 'Input'}" 842 "node { name: 'B' op: 'Int32Input'}" 843 "node { name: 'C' op: 'Input'}" 844 "node { name: 'D' op: 'Conv2DBackpropInput'" 845 " attr { key: 'T' value { type: DT_FLOAT } }" 846 " attr { key: 'data_format' value { s: 'NCHW' } }" 847 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 848 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 849 " attr { key: 'padding' value { s: 'SAME' } }" 850 " input: ['B', 'A', 'C']}" 851 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 852 " input: ['A', 'D'] }"); 853 EXPECT_EQ(DoMklLayoutOptimizationPass(), 854 "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropInput);" 855 "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|" 856 "A->D:1;A->E;B->D;B:control->DMT/_0:control;" 857 "B:control->DMT/_1:control;B:control->DMT/_2:control;C->D:2;" 858 "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); 859 } 860 861 // Concat Op test: Concat with no Mkl layer feeding it 862 TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) { 863 InitGraph( 864 "node { name: 'A' op: 'Const' " 865 " attr { key: 'dtype' value { type: DT_INT32 } }" 866 " attr { key: 'value' value { " 867 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " 868 " int_val: 0 } } } }" 869 "node { name: 'B' op: 'InputList'" 870 " attr { key: 'N' value { i: 2 } }}" 871 "node { name: 'C' op: 'Input'}" 872 "node { name: 'D' op: 'Concat'" 873 " attr { key: 'T' value { type: DT_FLOAT } }" 874 " attr { key: 'N' value { i: 2 } }" 875 " input: ['A', 'B:0', 'B:1']}" 876 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 877 " input: ['C', 'D'] }"); 878 EXPECT_EQ( 879 DoMklLayoutOptimizationPass(), 880 "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" 881 "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" 882 "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" 883 "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); 884 } 885 886 // Concat with 2 Mkl layers feeding it 887 TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) { 888 InitGraph( 889 "node { name: 'A' op: 'Input'}" 890 "node { name: 'B' op: 'Input'}" 891 "node { name: 'C' op: 'Input'}" 892 "node { name: 'D' op: 'Input'}" 893 "node { name: 'E' op: 'Conv2D'" 894 " attr { key: 'T' value { type: DT_FLOAT } }" 895 " attr { key: 'data_format' value { s: 'NCHW' } }" 896 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 897 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 898 " attr { key: 'padding' value { s: 'SAME' } }" 899 " input: ['A', 'B']}" 900 "node { name: 'F' op: 'Conv2D'" 901 " attr { key: 'T' value { type: DT_FLOAT } }" 902 " attr { key: 'data_format' value { s: 'NCHW' } }" 903 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 904 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 905 " attr { key: 'padding' value { s: 'SAME' } }" 906 " input: ['C', 'D']}" 907 "node { name: 'G' op: 'Const' " 908 " attr { key: 'dtype' value { type: DT_INT32 } }" 909 " attr { key: 'value' value { " 910 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " 911 " int_val: 0 } } } }" 912 "node { name: 'H' op: 'Concat'" 913 " attr { key: 'T' value { type: DT_FLOAT } }" 914 " attr { key: 'N' value { i: 2 } }" 915 " input: ['G', 'E', 'F']}" 916 "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 917 " input: ['A', 'H'] }"); 918 EXPECT_EQ(DoMklLayoutOptimizationPass(), 919 "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 920 "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);" 921 "F(_MklConv2D);G(Const);H(_MklConcat);I(Zeta)|A->E;A->I;" 922 "A:control->DMT/_2:control;A:control->DMT/_3:control;" 923 "B->E:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;" 924 "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;" 925 "DMT/_4->H:3;E->H:1;E:2->H:4;F->H:2;F:2->H:5;G->H;" 926 "G:control->DMT/_4:control;H->I:1"); 927 } 928 929 // Concat with 1 Mkl and 1 non-Mkl layer feeding it 930 TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) { 931 InitGraph( 932 "node { name: 'A' op: 'Input'}" 933 "node { name: 'B' op: 'Input'}" 934 "node { name: 'C' op: 'Input'}" 935 "node { name: 'D' op: 'Input'}" 936 "node { name: 'E' op: 'Conv2D'" 937 " attr { key: 'T' value { type: DT_FLOAT } }" 938 " attr { key: 'data_format' value { s: 'NCHW' } }" 939 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 940 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 941 " attr { key: 'padding' value { s: 'SAME' } }" 942 " input: ['A', 'B']}" 943 "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 944 " input: ['C', 'D']}" 945 "node { name: 'G' op: 'Const' " 946 " attr { key: 'dtype' value { type: DT_INT32 } }" 947 " attr { key: 'value' value { " 948 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " 949 " int_val: 0 } } } }" 950 "node { name: 'H' op: 'Concat'" 951 " attr { key: 'T' value { type: DT_FLOAT } }" 952 " attr { key: 'N' value { i: 2 } }" 953 " input: ['G', 'E', 'F']}" 954 "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 955 " input: ['A', 'H'] }"); 956 EXPECT_EQ(DoMklLayoutOptimizationPass(), 957 "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 958 "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);" 959 "H(_MklConcat);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;" 960 "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;" 961 "DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:2->H:4;F->H:2;" 962 "G->H;G:control->DMT/_2:control;G:control->DMT/_3:control;H->I:1"); 963 } 964 965 // ConcatV2 Op test: ConcatV2 with no Mkl layer feeding it 966 TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) { 967 InitGraph( 968 "node { name: 'A' op: 'Const' " 969 " attr { key: 'dtype' value { type: DT_INT32 } }" 970 " attr { key: 'value' value { " 971 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " 972 " int_val: 0 } } } }" 973 "node { name: 'B' op: 'InputList'" 974 " attr { key: 'N' value { i: 2 } }}" 975 "node { name: 'C' op: 'Input'}" 976 "node { name: 'D' op: 'ConcatV2'" 977 " attr { key: 'T' value { type: DT_FLOAT } }" 978 " attr { key: 'Tidx' value { type: DT_INT32 } }" 979 " attr { key: 'N' value { i: 2 } }" 980 " input: ['B:0', 'B:1', 'A']}" 981 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 982 " input: ['C', 'D'] }"); 983 EXPECT_EQ(DoMklLayoutOptimizationPass(), 984 "A(Const);B(InputList);C(Input);D(_MklConcatV2);DMT/_0(Const);" 985 "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D:2;B->D;B:1->D:1;" 986 "B:control->DMT/_0:control;B:control->DMT/_1:control;" 987 "B:control->DMT/_2:control;C->E;D->E:1;DMT/_0->D:3;" 988 "DMT/_1->D:4;DMT/_2->D:5"); 989 } 990 991 // ConcatV2 with 2 Mkl layers feeding it 992 TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) { 993 InitGraph( 994 "node { name: 'A' op: 'Input'}" 995 "node { name: 'B' op: 'Input'}" 996 "node { name: 'C' op: 'Input'}" 997 "node { name: 'D' op: 'Input'}" 998 "node { name: 'E' op: 'Conv2D'" 999 " attr { key: 'T' value { type: DT_FLOAT } }" 1000 " attr { key: 'data_format' value { s: 'NCHW' } }" 1001 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 1002 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 1003 " attr { key: 'padding' value { s: 'SAME' } }" 1004 " input: ['A', 'B']}" 1005 "node { name: 'F' op: 'Conv2D'" 1006 " attr { key: 'T' value { type: DT_FLOAT } }" 1007 " attr { key: 'data_format' value { s: 'NCHW' } }" 1008 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 1009 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 1010 " attr { key: 'padding' value { s: 'SAME' } }" 1011 " input: ['C', 'D']}" 1012 "node { name: 'G' op: 'Const' " 1013 " attr { key: 'dtype' value { type: DT_INT32 } }" 1014 " attr { key: 'value' value { " 1015 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " 1016 " int_val: 0 } } } }" 1017 "node { name: 'H' op: 'ConcatV2'" 1018 " attr { key: 'T' value { type: DT_FLOAT } }" 1019 " attr { key: 'Tidx' value { type: DT_INT32 } }" 1020 " attr { key: 'N' value { i: 2 } }" 1021 " input: ['E', 'F', 'G']}" 1022 "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1023 " input: ['A', 'H'] }"); 1024 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1025 "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 1026 "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);" 1027 "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Zeta)|A->E;A->I;" 1028 "A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;C->F;" 1029 "C:control->DMT/_0:control;C:control->DMT/_1:control;" 1030 "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;" 1031 "DMT/_4->H:5;E->H;E:2->H:3;E:control->DMT/_4:control;F->H:1;" 1032 "F:2->H:4;G->H:2;H->I:1"); 1033 } 1034 1035 // ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it 1036 TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) { 1037 InitGraph( 1038 "node { name: 'A' op: 'Input'}" 1039 "node { name: 'B' op: 'Input'}" 1040 "node { name: 'C' op: 'Input'}" 1041 "node { name: 'D' op: 'Input'}" 1042 "node { name: 'E' op: 'Conv2D'" 1043 " attr { key: 'T' value { type: DT_FLOAT } }" 1044 " attr { key: 'data_format' value { s: 'NCHW' } }" 1045 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 1046 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 1047 " attr { key: 'padding' value { s: 'SAME' } }" 1048 " input: ['A', 'B']}" 1049 "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1050 " input: ['C', 'D']}" 1051 "node { name: 'G' op: 'Const' " 1052 " attr { key: 'dtype' value { type: DT_INT32 } }" 1053 " attr { key: 'value' value { " 1054 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " 1055 " int_val: 0 } } } }" 1056 "node { name: 'H' op: 'ConcatV2'" 1057 " attr { key: 'T' value { type: DT_FLOAT } }" 1058 " attr { key: 'Tidx' value { type: DT_INT32 } }" 1059 " attr { key: 'N' value { i: 2 } }" 1060 " input: ['E', 'F', 'G']}" 1061 "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1062 " input: ['A', 'H'] }"); 1063 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1064 "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 1065 "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);" 1066 "H(_MklConcatV2);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;" 1067 "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;" 1068 "DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:2->H:3;" 1069 "E:control->DMT/_2:control;E:control->DMT/_3:control;F->H:1;" 1070 "G->H:2;H->I:1"); 1071 } 1072 1073 TEST_F(MklLayoutPassTest, NodeRewrite_Relu_Positive) { 1074 InitGraph( 1075 "node { name: 'A' op: 'Input'}" 1076 "node { name: 'B' op: 'Relu'" 1077 " attr { key: 'T' value { type: DT_FLOAT } }" 1078 " input: ['A'] }" 1079 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1080 " input: ['A', 'B'] }"); 1081 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1082 "A(Input);B(_MklRelu);C(Zeta);DMT/_0(Const)|A->B;A->C;" 1083 "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); 1084 } 1085 1086 TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_Positive) { 1087 InitGraph( 1088 "node { name: 'A' op: 'Input'}" 1089 "node { name: 'B' op: 'Input'}" 1090 "node { name: 'C' op: 'ReluGrad'" 1091 " attr { key: 'T' value { type: DT_FLOAT } }" 1092 " input: ['A', 'B'] }" 1093 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1094 " input: ['A', 'C'] }"); 1095 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1096 "A(Input);B(Input);C(_MklReluGrad);D(Zeta);DMT/_0(Const);" 1097 "DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;" 1098 "A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3"); 1099 } 1100 1101 TEST_F(MklLayoutPassTest, NodeRewrite_ReluReluGrad_Positive) { 1102 InitGraph( 1103 "node { name: 'A' op: 'Input'}" 1104 "node { name: 'B' op: 'Relu'" 1105 " attr { key: 'T' value { type: DT_FLOAT } }" 1106 " input: ['A'] }" 1107 "node { name: 'C' op: 'ReluGrad'" 1108 " attr { key: 'T' value { type: DT_FLOAT } }" 1109 " input: ['A', 'B'] }" 1110 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1111 " input: ['A', 'C'] }"); 1112 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1113 "A(Input);B(_MklRelu);C(_MklReluGrad);D(Zeta);DMT/_0(Const);" 1114 "DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;" 1115 "A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;" 1116 "DMT/_1->C:2"); 1117 } 1118 1119 TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) { 1120 InitGraph( 1121 "node { name: 'A' op: 'Input'}" 1122 "node { name: 'B' op: 'AvgPool'" 1123 " attr { key: 'T' value { type: DT_FLOAT } }" 1124 " attr { key: 'data_format' value { s: 'NCHW' } }" 1125 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 1126 " attr { key: 'padding' value { s: 'VALID' } }" 1127 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 1128 " input: ['A'] }" 1129 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1130 " input: ['A', 'B'] }"); 1131 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1132 "A(Input);B(_MklAvgPool);C(Zeta);DMT/_0(Const)|A->B;A->C;" 1133 "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); 1134 } 1135 1136 TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolGrad_Positive) { 1137 InitGraph( 1138 "node { name: 'A' op: 'Int32Input'}" 1139 "node { name: 'B' op: 'Input'}" 1140 "node { name: 'C' op: 'AvgPoolGrad' " 1141 " attr { key: 'T' value { type: DT_FLOAT } }" 1142 " attr { key: 'data_format' value { s: 'NCHW' } }" 1143 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 1144 " attr { key: 'padding' value { s: 'VALID' } }" 1145 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 1146 " input: ['A', 'B'] }" 1147 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1148 " input: ['B', 'C'] }"); 1149 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1150 "A(Int32Input);B(Input);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);" 1151 "DMT/_1(Const)|A->C;A:control->DMT/_0:control;" 1152 "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;" 1153 "DMT/_1->C:3"); 1154 } 1155 1156 TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolAvgPoolGrad_Positive) { 1157 InitGraph( 1158 "node { name: 'A' op: 'Input'}" 1159 "node { name: 'I' op: 'Int32Input'}" 1160 "node { name: 'B' op: 'AvgPool'" 1161 " attr { key: 'T' value { type: DT_FLOAT } }" 1162 " attr { key: 'data_format' value { s: 'NCHW' } }" 1163 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 1164 " attr { key: 'padding' value { s: 'VALID' } }" 1165 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 1166 " input: ['A'] }" 1167 "node { name: 'C' op: 'AvgPoolGrad' " 1168 " attr { key: 'T' value { type: DT_FLOAT } }" 1169 " attr { key: 'data_format' value { s: 'NCHW' } }" 1170 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 1171 " attr { key: 'padding' value { s: 'VALID' } }" 1172 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 1173 " input: ['I', 'B'] }" 1174 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1175 " input: ['A', 'C'] }"); 1176 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1177 "A(Input);B(_MklAvgPool);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);" 1178 "DMT/_1(Const);I(Int32Input)|A->B;A->D;A:control->DMT/_0:control;" 1179 "B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;DMT/_1->C:2;I->C;" 1180 "I:control->DMT/_1:control"); 1181 } 1182 1183 TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormGrad_Positive) { 1184 InitGraph( 1185 "node { name: 'A' op: 'Input'}" 1186 "node { name: 'B' op: 'Input'}" 1187 "node { name: 'C' op: 'Input'}" 1188 "node { name: 'D' op: 'Input'}" 1189 "node { name: 'E' op: 'Input'}" 1190 "node { name: 'F' op: 'FusedBatchNormGrad'" 1191 " attr { key: 'T' value { type: DT_FLOAT } }" 1192 " attr { key: 'data_format' value { s: 'NCHW' } }" 1193 " attr { key: 'epsilon' value { f: 0.0001 } }" 1194 " attr { key: 'is_training' value { b: true } }" 1195 " input: ['A', 'B', 'C', 'D', 'E'] }" 1196 "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1197 " input: ['A', 'F'] }"); 1198 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1199 "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 1200 "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);" 1201 "F(_MklFusedBatchNormGrad);G(Zeta)|A->F;A->G;" 1202 "A:control->DMT/_0:control;A:control->DMT/_1:control;" 1203 "A:control->DMT/_2:control;A:control->DMT/_3:control;" 1204 "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" 1205 "DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" 1206 "E->F:4;F->G:1"); 1207 } 1208 1209 TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_Positive) { 1210 InitGraph( 1211 "node { name: 'A' op: 'Input'}" 1212 "node { name: 'B' op: 'Input'}" 1213 "node { name: 'C' op: 'Input'}" 1214 "node { name: 'D' op: 'Input'}" 1215 "node { name: 'E' op: 'Input'}" 1216 "node { name: 'F' op: 'FusedBatchNorm'" 1217 " attr { key: 'T' value { type: DT_FLOAT } }" 1218 " attr { key: 'data_format' value { s: 'NCHW' } }" 1219 " attr { key: 'epsilon' value { f: 0.0001 } }" 1220 " attr { key: 'is_training' value { b: true } }" 1221 " input: ['A', 'B', 'C', 'D', 'E'] }" 1222 "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1223 " input: ['A', 'F'] }"); 1224 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1225 "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 1226 "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);" 1227 "F(_MklFusedBatchNorm);G(Zeta)|A->F;A->G;" 1228 "A:control->DMT/_0:control;A:control->DMT/_1:control;" 1229 "A:control->DMT/_2:control;A:control->DMT/_3:control;" 1230 "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" 1231 "DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" 1232 "E->F:4;F->G:1"); 1233 } 1234 1235 ///////////////////////////////////////////////////////////////////// 1236 // Unit tests related to rewriting node for workspace edges 1237 ///////////////////////////////////////////////////////////////////// 1238 1239 /* Test LRN->MaxPool->MaxPoolGrad->LRNGrad replacement by workspace nodes. */ 1240 TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) { 1241 InitGraph( 1242 "node { name: 'A' op: 'Input'}" 1243 "node { name: 'B' op: 'LRN'" 1244 " attr { key: 'T' value { type: DT_FLOAT } }" 1245 " attr { key: 'alpha' value { f: 0.001 } }" 1246 " attr { key: 'beta' value { f: 0.75 } }" 1247 " attr { key: 'bias' value { f: 1.0 } }" 1248 " attr { key: 'data_format' value { s: 'NCHW' } }" 1249 " attr { key: 'depth_radius' value { i: 2 } }" 1250 " input: ['A'] }" 1251 "node { name: 'C' op: 'MaxPool'" 1252 " attr { key: 'T' value { type: DT_FLOAT } }" 1253 " attr { key: 'data_format' value { s: 'NCHW' } }" 1254 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 1255 " attr { key: 'padding' value { s: 'VALID' } }" 1256 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 1257 " input: ['B'] }" 1258 "node { name: 'D' op: 'Input'}" 1259 "node { name: 'E' op: 'MaxPoolGrad'" 1260 " attr { key: 'T' value { type: DT_FLOAT } }" 1261 " attr { key: 'data_format' value { s: 'NCHW' } }" 1262 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 1263 " attr { key: 'padding' value { s: 'VALID' } }" 1264 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 1265 " input: ['B', 'C', 'D'] }" 1266 "node { name: 'F' op: 'Input'}" 1267 "node { name: 'G' op: 'LRNGrad'" 1268 " attr { key: 'T' value { type: DT_FLOAT } }" 1269 " attr { key: 'alpha' value { f: 0.001 } }" 1270 " attr { key: 'beta' value { f: 0.75 } }" 1271 " attr { key: 'bias' value { f: 1.0 } }" 1272 " attr { key: 'data_format' value { s: 'NCHW' } }" 1273 " attr { key: 'depth_radius' value { i: 2 } }" 1274 " input: ['E', 'F', 'B'] }" 1275 "node { name: 'H' op: 'Input'}" 1276 "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1277 " input: ['H', 'G'] }"); 1278 EXPECT_EQ( 1279 DoMklLayoutOptimizationPass(), 1280 "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);" 1281 "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);" 1282 "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;" 1283 "B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;B:control->DMT/_1:control;C->E:1;" 1284 "C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;DMT/_2->G:5;" 1285 "E->G;E:1->G:4;E:control->DMT/_2:control;F->G:1;G->I:1;H->I"); 1286 } 1287 1288 /* Test LRN->LRNGrad replacement by workspace nodes. */ 1289 TEST_F(MklLayoutPassTest, LRN_Positive) { 1290 InitGraph( 1291 "node { name: 'A' op: 'Input'}" 1292 "node { name: 'B' op: 'LRN'" 1293 " attr { key: 'T' value { type: DT_FLOAT } }" 1294 " attr { key: 'alpha' value { f: 0.001 } }" 1295 " attr { key: 'beta' value { f: 0.75 } }" 1296 " attr { key: 'bias' value { f: 1.0 } }" 1297 " attr { key: 'data_format' value { s: 'NCHW' } }" 1298 " attr { key: 'depth_radius' value { i: 2 } }" 1299 " input: ['A'] }" 1300 "node { name: 'C' op: 'Input'}" 1301 "node { name: 'D' op: 'Input'}" 1302 "node { name: 'E' op: 'LRNGrad'" 1303 " attr { key: 'T' value { type: DT_FLOAT } }" 1304 " attr { key: 'alpha' value { f: 0.001 } }" 1305 " attr { key: 'beta' value { f: 0.75 } }" 1306 " attr { key: 'bias' value { f: 1.0 } }" 1307 " attr { key: 'data_format' value { s: 'NCHW' } }" 1308 " attr { key: 'depth_radius' value { i: 2 } }" 1309 " input: ['C', 'D', 'B'] }" 1310 "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1311 " input: ['C', 'E'] }"); 1312 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1313 "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 1314 "DMT/_2(Const);E(_MklLRNGrad);F(Zeta)|" 1315 "A->B;A:control->DMT/_0:control;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;" 1316 "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;" 1317 "D->E:1;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1"); 1318 } 1319 1320 /* Test LRN->LRNGrad replacement when only one of them is present. */ 1321 TEST_F(MklLayoutPassTest, LRN_Negative1) { 1322 InitGraph( 1323 "node { name: 'A' op: 'Input'}" 1324 "node { name: 'B' op: 'LRN'" 1325 " attr { key: 'T' value { type: DT_FLOAT } }" 1326 " attr { key: 'alpha' value { f: 0.001 } }" 1327 " attr { key: 'beta' value { f: 0.75 } }" 1328 " attr { key: 'bias' value { f: 1.0 } }" 1329 " attr { key: 'data_format' value { s: 'NCHW' } }" 1330 " attr { key: 'depth_radius' value { i: 2 } }" 1331 " input: ['A'] }" 1332 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1333 " input: ['A', 'B'] }"); 1334 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1335 "A(Input);B(_MklLRN);C(Zeta);DMT/_0(Const)|" 1336 "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); 1337 } 1338 1339 /* Test LRN->LRNGrad replacement when only one of them is present. */ 1340 TEST_F(MklLayoutPassTest, LRN_Negative2) { 1341 InitGraph( 1342 "node { name: 'A' op: 'Input'}" 1343 "node { name: 'B' op: 'Input'}" 1344 "node { name: 'C' op: 'Input'}" 1345 "node { name: 'D' op: 'LRNGrad'" 1346 " attr { key: 'T' value { type: DT_FLOAT } }" 1347 " attr { key: 'alpha' value { f: 0.001 } }" 1348 " attr { key: 'beta' value { f: 0.75 } }" 1349 " attr { key: 'bias' value { f: 1.0 } }" 1350 " attr { key: 'data_format' value { s: 'NCHW' } }" 1351 " attr { key: 'depth_radius' value { i: 2 } }" 1352 " input: ['A', 'B', 'C'] }" 1353 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1354 " input: ['A', 'D'] }"); 1355 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1356 "A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);" 1357 "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|" 1358 "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" 1359 "A:control->DMT/_2:control;A:control->DMT/_3:control;" 1360 "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;" 1361 "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6"); 1362 } 1363 1364 /* Test LRN->LRNGrad negative case, where single LRN feeds 1365 2 LRNGrad nodes at different slots. */ 1366 TEST_F(MklLayoutPassTest, LRN_Negative3) { 1367 InitGraph( 1368 "node { name: 'A' op: 'Input'}" 1369 "node { name: 'B' op: 'LRN'" 1370 " attr { key: 'T' value { type: DT_FLOAT } }" 1371 " attr { key: 'alpha' value { f: 0.001 } }" 1372 " attr { key: 'beta' value { f: 0.75 } }" 1373 " attr { key: 'bias' value { f: 1.0 } }" 1374 " attr { key: 'data_format' value { s: 'NCHW' } }" 1375 " attr { key: 'depth_radius' value { i: 2 } }" 1376 " input: ['A'] }" 1377 "node { name: 'C' op: 'Input'}" 1378 "node { name: 'D' op: 'Input'}" 1379 "node { name: 'E' op: 'LRNGrad'" 1380 " attr { key: 'T' value { type: DT_FLOAT } }" 1381 " attr { key: 'alpha' value { f: 0.001 } }" 1382 " attr { key: 'beta' value { f: 0.75 } }" 1383 " attr { key: 'bias' value { f: 1.0 } }" 1384 " attr { key: 'data_format' value { s: 'NCHW' } }" 1385 " attr { key: 'depth_radius' value { i: 2 } }" 1386 " input: ['C', 'D', 'B'] }" 1387 "node { name: 'F' op: 'LRNGrad'" 1388 " attr { key: 'T' value { type: DT_FLOAT } }" 1389 " attr { key: 'alpha' value { f: 0.001 } }" 1390 " attr { key: 'beta' value { f: 0.75 } }" 1391 " attr { key: 'bias' value { f: 1.0 } }" 1392 " attr { key: 'data_format' value { s: 'NCHW' } }" 1393 " attr { key: 'depth_radius' value { i: 2 } }" 1394 " input: ['C', 'B', 'D'] }" 1395 "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1396 " input: ['E', 'F'] }"); 1397 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1398 "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 1399 "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);" 1400 "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Zeta)|A->B;" 1401 "A:control->DMT/_0:control;B->E:2;" 1402 "B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;" 1403 "C:control->DMT/_1:control;C:control->DMT/_2:control;" 1404 "C:control->DMT/_3:control;C:control->DMT/_4:control;" 1405 "C:control->DMT/_5:control;C:control->DMT/_6:control;" 1406 "D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;" 1407 "DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1"); 1408 } 1409 1410 /* Test MaxPool->MaxPoolGrad replacement by workspace+rewrite nodes. */ 1411 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) { 1412 InitGraph( 1413 "node { name: 'A' op: 'Input'}" 1414 "node { name: 'B' op: 'MaxPool'" 1415 " attr { key: 'T' value { type: DT_FLOAT } }" 1416 " attr { key: 'data_format' value { s: 'NCHW' } }" 1417 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 1418 " attr { key: 'padding' value { s: 'VALID' } }" 1419 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 1420 " input: ['A'] }" 1421 "node { name: 'C' op: 'Input'}" 1422 "node { name: 'D' op: 'Input'}" 1423 "node { name: 'E' op: 'MaxPoolGrad'" 1424 " attr { key: 'T' value { type: DT_FLOAT } }" 1425 " attr { key: 'data_format' value { s: 'NCHW' } }" 1426 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 1427 " attr { key: 'padding' value { s: 'VALID' } }" 1428 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 1429 " input: ['C', 'B', 'D'] }" 1430 "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1431 " input: ['C', 'E'] }"); 1432 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1433 "A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);" 1434 "DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Zeta)|" 1435 "A->B;A:control->DMT/_0:control;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;" 1436 "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;" 1437 "D->E:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1"); 1438 } 1439 1440 // Test MaxPool>MaxPoolGrad replacement when only one of them is present. 1441 // In this case, we will rewrite MaxPool node but workspace edges will not 1442 // be present. 1443 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) { 1444 InitGraph( 1445 "node { name: 'A' op: 'Input'}" 1446 "node { name: 'B' op: 'MaxPool'" 1447 " attr { key: 'T' value { type: DT_FLOAT } }" 1448 " attr { key: 'data_format' value { s: 'NCHW' } }" 1449 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 1450 " attr { key: 'padding' value { s: 'VALID' } }" 1451 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 1452 " input: ['A'] }" 1453 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1454 " input: ['A', 'B'] }"); 1455 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1456 "A(Input);B(_MklMaxPool);C(Zeta);DMT/_0(Const)|" 1457 "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); 1458 } 1459 1460 // Test MaxPoolGrad replacement when only one of them is present. 1461 // In this case, we will rewrite MaxPoolGrad and for workspace tensor and 1462 // its Mkl part, we will generate dummy tensor. 1463 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) { 1464 InitGraph( 1465 "node { name: 'A' op: 'Input'}" 1466 "node { name: 'B' op: 'Input'}" 1467 "node { name: 'C' op: 'Input'}" 1468 "node { name: 'D' op: 'MaxPoolGrad'" 1469 " attr { key: 'T' value { type: DT_FLOAT } }" 1470 " attr { key: 'data_format' value { s: 'NCHW' } }" 1471 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 1472 " attr { key: 'padding' value { s: 'VALID' } }" 1473 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 1474 " input: ['A', 'B', 'C'] }" 1475 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1476 " input: ['A', 'D'] }"); 1477 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1478 "A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);" 1479 "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|" 1480 "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" 1481 "A:control->DMT/_2:control;A:control->DMT/_3:control;" 1482 "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;" 1483 "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6"); 1484 } 1485 1486 // Test MaxPool handling for batch-wise pooling (NCHW) 1487 // No rewrite should take place in such case 1488 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative3) { 1489 InitGraph( 1490 "node { name: 'A' op: 'Input'}" 1491 "node { name: 'B' op: 'MaxPool'" 1492 " attr { key: 'T' value { type: DT_FLOAT } }" 1493 " attr { key: 'data_format' value { s: 'NCHW' } }" 1494 " attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }" 1495 " attr { key: 'padding' value { s: 'VALID' } }" 1496 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 1497 " input: ['A'] }" 1498 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1499 " input: ['A', 'B'] }"); 1500 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1501 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 1502 } 1503 1504 // Test MaxPool handling for batch-wise pooling (NCHW) 1505 // No rewrite should take place in such case 1506 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative4) { 1507 InitGraph( 1508 "node { name: 'A' op: 'Input'}" 1509 "node { name: 'B' op: 'MaxPool'" 1510 " attr { key: 'T' value { type: DT_FLOAT } }" 1511 " attr { key: 'data_format' value { s: 'NCHW' } }" 1512 " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" 1513 " attr { key: 'padding' value { s: 'VALID' } }" 1514 " attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }" 1515 " input: ['A'] }" 1516 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1517 " input: ['A', 'B'] }"); 1518 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1519 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 1520 } 1521 1522 // Test MaxPool handling for depth-wise pooling (NHWC) 1523 // No rewrite should take place in such case 1524 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative5) { 1525 InitGraph( 1526 "node { name: 'A' op: 'Input'}" 1527 "node { name: 'B' op: 'MaxPool'" 1528 " attr { key: 'T' value { type: DT_FLOAT } }" 1529 " attr { key: 'data_format' value { s: 'NCHW' } }" 1530 " attr { key: 'ksize' value { list: {i: 1, i:2, i:1, i:1} } }" 1531 " attr { key: 'padding' value { s: 'VALID' } }" 1532 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 1533 " input: ['A'] }" 1534 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1535 " input: ['A', 'B'] }"); 1536 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1537 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 1538 } 1539 1540 // Test MaxPool handling for depth-wise pooling (NCHW) 1541 // No rewrite should take place in such case 1542 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative6) { 1543 InitGraph( 1544 "node { name: 'A' op: 'Input'}" 1545 "node { name: 'B' op: 'MaxPool'" 1546 " attr { key: 'T' value { type: DT_FLOAT } }" 1547 " attr { key: 'data_format' value { s: 'NCHW' } }" 1548 " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" 1549 " attr { key: 'padding' value { s: 'VALID' } }" 1550 " attr { key: 'strides' value { list: {i: 1, i:2, i:1, i:1} } }" 1551 " input: ['A'] }" 1552 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1553 " input: ['A', 'B'] }"); 1554 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1555 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 1556 } 1557 1558 // Test MaxPool handling for batch-wise pooling (NHWC) 1559 // No rewrite should take place in such case 1560 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative7) { 1561 InitGraph( 1562 "node { name: 'A' op: 'Input'}" 1563 "node { name: 'B' op: 'MaxPool'" 1564 " attr { key: 'T' value { type: DT_FLOAT } }" 1565 " attr { key: 'data_format' value { s: 'NHWC' } }" 1566 " attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }" 1567 " attr { key: 'padding' value { s: 'VALID' } }" 1568 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 1569 " input: ['A'] }" 1570 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1571 " input: ['A', 'B'] }"); 1572 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1573 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 1574 } 1575 1576 // Test MaxPool handling for batch-wise pooling (NHWC) 1577 // No rewrite should take place in such case 1578 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative8) { 1579 InitGraph( 1580 "node { name: 'A' op: 'Input'}" 1581 "node { name: 'B' op: 'MaxPool'" 1582 " attr { key: 'T' value { type: DT_FLOAT } }" 1583 " attr { key: 'data_format' value { s: 'NHWC' } }" 1584 " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" 1585 " attr { key: 'padding' value { s: 'VALID' } }" 1586 " attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }" 1587 " input: ['A'] }" 1588 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1589 " input: ['A', 'B'] }"); 1590 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1591 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 1592 } 1593 1594 // Test MaxPool handling for depth-wise pooling (NHWC) 1595 // No rewrite should take place in such case 1596 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative9) { 1597 InitGraph( 1598 "node { name: 'A' op: 'Input'}" 1599 "node { name: 'B' op: 'MaxPool'" 1600 " attr { key: 'T' value { type: DT_FLOAT } }" 1601 " attr { key: 'data_format' value { s: 'NHWC' } }" 1602 " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:2} } }" 1603 " attr { key: 'padding' value { s: 'VALID' } }" 1604 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 1605 " input: ['A'] }" 1606 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1607 " input: ['A', 'B'] }"); 1608 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1609 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 1610 } 1611 1612 // Test MaxPool handling for depth-wise pooling (NHWC) 1613 // No rewrite should take place in such case 1614 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative10) { 1615 InitGraph( 1616 "node { name: 'A' op: 'Input'}" 1617 "node { name: 'B' op: 'MaxPool'" 1618 " attr { key: 'T' value { type: DT_FLOAT } }" 1619 " attr { key: 'data_format' value { s: 'NHWC' } }" 1620 " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" 1621 " attr { key: 'padding' value { s: 'VALID' } }" 1622 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:2} } }" 1623 " input: ['A'] }" 1624 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1625 " input: ['A', 'B'] }"); 1626 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1627 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 1628 } 1629 1630 ///////////////////////////////////////////////////////////////////// 1631 1632 // Single Conv2D Op on GPU device 1633 // No rewrite should happen 1634 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) { 1635 InitGraph( 1636 "node { name: 'A' op: 'Input'}" 1637 "node { name: 'B' op: 'Input'}" 1638 "node { name: 'C' op: 'Conv2D'" 1639 " attr { key: 'T' value { type: DT_FLOAT } }" 1640 " attr { key: 'data_format' value { s: 'NCHW' } }" 1641 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 1642 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 1643 " attr { key: 'padding' value { s: 'SAME' } }" 1644 " input: ['A', 'B']}" 1645 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1646 " input: ['B', 'C'] }", 1647 kGPUDevice); 1648 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1649 "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1"); 1650 } 1651 1652 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) { 1653 InitGraph( 1654 "node { name: 'A' op: 'Input'}" 1655 "node { name: 'B' op: 'Input'}" 1656 "node { name: 'C' op: 'Input'}" 1657 "node { name: 'M' op: '_MklInput'}" 1658 "node { name: 'N' op: '_MklInput'}" 1659 "node { name: 'O' op: '_MklInput'}" 1660 "node { name: 'D' op: '_MklConv2DWithBias'" 1661 " attr { key: 'T' value { type: DT_FLOAT } }" 1662 " attr { key: 'data_format' value { s: 'NCHW' } }" 1663 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 1664 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 1665 " attr { key: 'padding' value { s: 'SAME' } }" 1666 " input: ['A', 'B', 'C', 'M', 'N', 'O']}" 1667 "node { name: 'E' op: 'Zeta'" 1668 " attr {key: 'T' value { type: DT_FLOAT } }" 1669 " input: ['D', 'A']}" 1670 "node { name: 'F' op: 'BiasAddGrad'" 1671 " attr { key: 'T' value { type: DT_FLOAT } }" 1672 " attr { key: 'data_format' value { s: 'NCHW' } }" 1673 " input: ['E'] }", 1674 kGPUDevice); 1675 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1676 "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" 1677 "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);" 1678 "O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;E->F;" 1679 "M->D:3;N->D:4;O->D:5"); 1680 } 1681 1682 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) { 1683 InitGraph( 1684 "node { name: 'A' op: 'Input'}" 1685 "node { name: 'B' op: 'Int32Input'}" 1686 "node { name: 'C' op: 'Input'}" 1687 "node { name: 'D' op: 'Conv2DBackpropFilter'" 1688 " attr { key: 'T' value { type: DT_FLOAT } }" 1689 " attr { key: 'data_format' value { s: 'NCHW' } }" 1690 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 1691 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 1692 " attr { key: 'padding' value { s: 'SAME' } }" 1693 " input: ['A', 'B', 'C']}" 1694 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1695 " input: ['A', 'D'] }", 1696 kGPUDevice); 1697 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1698 "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|" 1699 "A->D;A->E;B->D:1;C->D:2;D->E:1"); 1700 } 1701 1702 TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) { 1703 InitGraph( 1704 "node { name: 'A' op: 'Input'}" 1705 "node { name: 'B' op: 'Relu'" 1706 " attr { key: 'T' value { type: DT_FLOAT } }" 1707 " input: ['A'] }" 1708 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1709 " input: ['A', 'B'] }", 1710 kGPUDevice); 1711 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1712 "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1"); 1713 } 1714 1715 TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) { 1716 InitGraph( 1717 "node { name: 'A' op: 'Input'}" 1718 "node { name: 'B' op: 'Input'}" 1719 "node { name: 'C' op: 'ReluGrad'" 1720 " attr { key: 'T' value { type: DT_FLOAT } }" 1721 " input: ['A', 'B'] }" 1722 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1723 " input: ['A', 'C'] }", 1724 kGPUDevice); 1725 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1726 "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1"); 1727 } 1728 1729 TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) { 1730 InitGraph( 1731 "node { name: 'A' op: 'Input'}" 1732 "node { name: 'B' op: 'MaxPool'" 1733 " attr { key: 'T' value { type: DT_FLOAT } }" 1734 " attr { key: 'data_format' value { s: 'NHWC' } }" 1735 " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" 1736 " attr { key: 'padding' value { s: 'VALID' } }" 1737 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 1738 " input: ['A'] }" 1739 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1740 " input: ['A', 'B'] }", 1741 kGPUDevice); 1742 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1743 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 1744 } 1745 1746 TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) { 1747 InitGraph( 1748 "node { name: 'A' op: 'Input'}" 1749 "node { name: 'B' op: 'AvgPool'" 1750 " attr { key: 'T' value { type: DT_FLOAT } }" 1751 " attr { key: 'data_format' value { s: 'NHWC' } }" 1752 " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" 1753 " attr { key: 'padding' value { s: 'VALID' } }" 1754 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 1755 " input: ['A'] }" 1756 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1757 " input: ['A', 'B'] }", 1758 kGPUDevice); 1759 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1760 "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1"); 1761 } 1762 1763 // Concat Op test: Concat with no Mkl layer feeding it 1764 TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) { 1765 InitGraph( 1766 "node { name: 'A' op: 'Const' " 1767 " attr { key: 'dtype' value { type: DT_INT32 } }" 1768 " attr { key: 'value' value { " 1769 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " 1770 " int_val: 0 } } } }" 1771 "node { name: 'B' op: 'InputList'" 1772 " attr { key: 'N' value { i: 2 } }}" 1773 "node { name: 'C' op: 'Input'}" 1774 "node { name: 'D' op: 'Concat'" 1775 " attr { key: 'T' value { type: DT_FLOAT } }" 1776 " attr { key: 'N' value { i: 2 } }" 1777 " input: ['A', 'B:0', 'B:1']}" 1778 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1779 " input: ['C', 'D'] }", 1780 kGPUDevice); 1781 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1782 "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;" 1783 "B->D:1;B:1->D:2;C->E;D->E:1"); 1784 } 1785 1786 TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) { 1787 InitGraph( 1788 "node { name: 'A' op: 'Const' " 1789 " attr { key: 'dtype' value { type: DT_INT32 } }" 1790 " attr { key: 'value' value { " 1791 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " 1792 " int_val: 0 } } } }" 1793 "node { name: 'B' op: 'InputList'" 1794 " attr { key: 'N' value { i: 2 } }}" 1795 "node { name: 'C' op: 'Input'}" 1796 "node { name: 'D' op: 'ConcatV2'" 1797 " attr { key: 'T' value { type: DT_FLOAT } }" 1798 " attr { key: 'Tidx' value { type: DT_INT32 } }" 1799 " attr { key: 'N' value { i: 2 } }" 1800 " input: ['B:0', 'B:1', 'A']}" 1801 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1802 " input: ['C', 'D'] }", 1803 kGPUDevice); 1804 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1805 "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|" 1806 "A->D:2;B->D;B:1->D:1;C->E;D->E:1"); 1807 } 1808 1809 TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) { 1810 InitGraph( 1811 "node { name: 'A' op: 'Input'}" 1812 "node { name: 'B' op: 'Input'}" 1813 "node { name: 'C' op: 'Input'}" 1814 "node { name: 'D' op: 'Input'}" 1815 "node { name: 'E' op: 'Input'}" 1816 "node { name: 'F' op: 'FusedBatchNorm'" 1817 " attr { key: 'T' value { type: DT_FLOAT } }" 1818 " attr { key: 'data_format' value { s: 'NCHW' } }" 1819 " attr { key: 'epsilon' value { f: 0.0001 } }" 1820 " attr { key: 'is_training' value { b: true } }" 1821 " input: ['A', 'B', 'C', 'D', 'E'] }" 1822 "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 1823 " input: ['A', 'F'] }", 1824 kGPUDevice); 1825 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1826 "A(Input);B(Input);C(Input);D(Input);E(Input);" 1827 "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;" 1828 "E->F:4;F->G:1"); 1829 } 1830 1831 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) { 1832 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); 1833 InitGraph( 1834 "node { name: 'A' op: 'Input'}" 1835 "node { name: 'B' op: 'Input'}" 1836 "node { name: 'M' op: '_MklInput'}" 1837 "node { name: 'N' op: '_MklInput'}" 1838 "node { name: 'C' op: '_MklConv2D'" 1839 " attr { key: 'T' value { type: DT_FLOAT } }" 1840 " attr { key: 'data_format' value { s: 'NCHW' } }" 1841 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 1842 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 1843 " attr { key: 'padding' value { s: 'SAME' } }" 1844 " input: ['A', 'B', 'M', 'N']}" 1845 "node { name: 'D' op: 'Input'}" 1846 "node { name: 'E' op: 'BiasAdd'" 1847 " attr { key: 'T' value { type: DT_FLOAT } }" 1848 " attr { key: 'data_format' value { s: 'NCHW' } }" 1849 " input: ['C', 'D'] }" 1850 "node { name: 'Y' op: 'Input'}" 1851 "node { name: 'Z' op: 'Zeta'" 1852 " attr {key: 'T' value { type: DT_FLOAT } }" 1853 " input: ['E', 'Y']}", 1854 kGPUDevice); 1855 EXPECT_EQ(DoMklLayoutOptimizationPass(), 1856 "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);" 1857 "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;" 1858 "B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1"); 1859 } 1860 1861 ///////////////////////////////////////////////////////////////////// 1862 1863 static void BM_MklLayoutRewritePass(int iters, int op_nodes) { 1864 testing::StopTiming(); 1865 string s; 1866 for (int in = 0; in < 10; in++) { 1867 s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in); 1868 } 1869 random::PhiloxRandom philox(301, 17); 1870 random::SimplePhilox rnd(&philox); 1871 for (int op = 0; op < op_nodes; op++) { 1872 s += strings::Printf( 1873 "node { name: 'op%04d' op: 'Zeta' attr { key: 'T' value { " 1874 "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }", 1875 op, rnd.Uniform(10), rnd.Uniform(10)); 1876 } 1877 1878 bool first = true; 1879 while (iters > 0) { 1880 Graph* graph = new Graph(OpRegistry::Global()); 1881 InitGraph(s, graph); 1882 int N = graph->num_node_ids(); 1883 if (first) { 1884 testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N)); 1885 first = false; 1886 } 1887 { 1888 testing::StartTiming(); 1889 std::unique_ptr<Graph> ug(graph); 1890 RunMklLayoutRewritePass(&ug); 1891 testing::StopTiming(); 1892 } 1893 iters -= N; // Our benchmark units are individual graph nodes, 1894 // not whole graphs 1895 // delete graph; 1896 } 1897 } 1898 BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000); 1899 1900 } // namespace 1901 1902 #else // INTEL_MKL_ML 1903 1904 namespace { 1905 1906 const char kCPUDevice[] = "/job:a/replica:0/task:0/device:CPU:0"; 1907 const char kGPUDevice[] = "/job:a/replica:0/task:0/device:GPU:0"; 1908 1909 static void InitGraph(const string& s, Graph* graph, 1910 const string& device = kCPUDevice) { 1911 GraphDef graph_def; 1912 1913 auto parser = protobuf::TextFormat::Parser(); 1914 // parser.AllowRelaxedWhitespace(true); 1915 CHECK(parser.MergeFromString(s, &graph_def)) << s; 1916 GraphConstructorOptions opts; 1917 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph)); 1918 1919 for (Node* node : graph->nodes()) { 1920 node->set_assigned_device_name(device); 1921 } 1922 } 1923 1924 class MklLayoutPassTest : public ::testing::Test { 1925 public: 1926 MklLayoutPassTest() : graph_(OpRegistry::Global()) {} 1927 1928 void InitGraph(const string& s, const string& device = kCPUDevice) { 1929 ::tensorflow::InitGraph(s, &graph_, device); 1930 original_ = CanonicalGraphString(&graph_); 1931 } 1932 1933 static bool IncludeNode(const Node* n) { return n->IsOp(); } 1934 1935 static string EdgeId(const Node* n, int index) { 1936 if (index == 0) { 1937 return n->name(); 1938 } else if (index == Graph::kControlSlot) { 1939 return strings::StrCat(n->name(), ":control"); 1940 } else { 1941 return strings::StrCat(n->name(), ":", index); 1942 } 1943 } 1944 1945 string CanonicalGraphString(Graph* g) { 1946 std::vector<string> nodes; 1947 std::vector<string> edges; 1948 for (const Node* n : g->nodes()) { 1949 if (IncludeNode(n)) { 1950 nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")")); 1951 } 1952 } 1953 for (const Edge* e : g->edges()) { 1954 if (IncludeNode(e->src()) && IncludeNode(e->dst())) { 1955 edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->", 1956 EdgeId(e->dst(), e->dst_input()))); 1957 } 1958 } 1959 // Canonicalize 1960 std::sort(nodes.begin(), nodes.end()); 1961 std::sort(edges.begin(), edges.end()); 1962 return strings::StrCat(str_util::Join(nodes, ";"), "|", 1963 str_util::Join(edges, ";")); 1964 } 1965 1966 string DoMklLayoutOptimizationPass() { 1967 string before = CanonicalGraphString(&graph_); 1968 LOG(ERROR) << "Before MKL layout rewrite pass: " << before; 1969 1970 std::unique_ptr<Graph>* ug = new std::unique_ptr<Graph>(&graph_); 1971 RunMklLayoutRewritePass(ug); 1972 1973 string result = CanonicalGraphString(&graph_); 1974 LOG(ERROR) << "After MKL layout rewrite pass: " << result; 1975 return result; 1976 } 1977 1978 const string& OriginalGraph() const { return original_; } 1979 1980 Graph graph_; 1981 string original_; 1982 }; 1983 1984 REGISTER_OP("Input").Output("o: float").SetIsStateful(); 1985 REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful(); 1986 REGISTER_OP("HalfInput").Output("o: half").SetIsStateful(); 1987 REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful(); 1988 REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful(); 1989 REGISTER_OP("_MklInput2") 1990 .Output("o: uint8") 1991 .Output("o1: uint8") 1992 .SetIsStateful(); 1993 1994 ///////////////////////////////////////////////////////////////////// 1995 // Unit tests related to node merge optiimization 1996 ///////////////////////////////////////////////////////////////////// 1997 1998 TEST_F(MklLayoutPassTest, Basic) { 1999 InitGraph( 2000 "node { name: 'A' op: 'Input'}" 2001 "node { name: 'B' op: 'Input'}" 2002 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2003 " input: ['A', 'B'] }" 2004 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2005 " input: ['A', 'B'] }"); 2006 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2007 "A(Input);B(Input);C(Zeta);D(Zeta)|" 2008 "A->C;A->D;B->C:1;B->D:1"); 2009 } 2010 2011 // Test set 1: Conv2D + AddBias 2012 2013 // C=Conv2D(A,B); E=BiasAdd(C,D); Z=Zeta(E,Y) 2014 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) { 2015 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); 2016 InitGraph( 2017 "node { name: 'A' op: 'Input'}" 2018 "node { name: 'B' op: 'Input'}" 2019 "node { name: 'C' op: 'Conv2D'" 2020 " attr { key: 'T' value { type: DT_FLOAT } }" 2021 " attr { key: 'data_format' value { s: 'NCHW' } }" 2022 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2023 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2024 " attr { key: 'padding' value { s: 'SAME' } }" 2025 " input: ['A', 'B']}" 2026 "node { name: 'D' op: 'Input'}" 2027 "node { name: 'E' op: 'BiasAdd'" 2028 " attr { key: 'T' value { type: DT_FLOAT } }" 2029 " attr { key: 'data_format' value { s: 'NCHW' } }" 2030 " input: ['C', 'D'] }" 2031 "node { name: 'Y' op: 'Input'}" 2032 "node { name: 'Z' op: 'Zeta'" 2033 " attr {key: 'T' value { type: DT_FLOAT } }" 2034 " input: ['E', 'Y']}"); 2035 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2036 "A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 2037 "DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Zeta)|A->E;" 2038 "A:control->DMT/_0:control;A:control->DMT/_1:control;" 2039 "A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;" 2040 "DMT/_2->E:5;E->Z;Y->Z:1"); 2041 } 2042 2043 // Graph contains only Conv2D, no AddBias. 2044 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) { 2045 InitGraph( 2046 "node { name: 'A' op: 'Input'}" 2047 "node { name: 'B' op: 'Input'}" 2048 "node { name: 'C' op: 'Conv2D'" 2049 " attr { key: 'T' value { type: DT_FLOAT } }" 2050 " attr { key: 'data_format' value { s: 'NCHW' } }" 2051 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2052 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2053 " attr { key: 'padding' value { s: 'SAME' } }" 2054 " input: ['A', 'B']}"); 2055 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2056 "A(Input);B(Input);C(_MklConv2D);DMT/_0(Const);DMT/_1(Const)|" 2057 "A->C;A:control->DMT/_0:control;A:control->DMT/_1:control;B->C:1;" 2058 "DMT/_0->C:2;DMT/_1->C:3"); 2059 } 2060 2061 // Conv2D output does not go to BiasAdd. 2062 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) { 2063 InitGraph( 2064 "node { name: 'A' op: 'Input'}" 2065 "node { name: 'B' op: 'Input'}" 2066 "node { name: 'C' op: 'Conv2D'" 2067 " attr { key: 'T' value { type: DT_FLOAT } }" 2068 " attr { key: 'data_format' value { s: 'NCHW' } }" 2069 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2070 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2071 " attr { key: 'padding' value { s: 'SAME' } }" 2072 " input: ['A', 'B']}" 2073 "node { name: 'D' op: 'Input'}" 2074 "node { name: 'E' op: 'Input'}" 2075 "node { name: 'F' op: 'BiasAdd'" 2076 " attr { key: 'T' value { type: DT_FLOAT } }" 2077 " attr { key: 'data_format' value { s: 'NCHW' } }" 2078 " input: ['D', 'E'] }"); // Output of _MklConv2D does not go to BiasAdd. 2079 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2080 "A(Input);B(Input);C(_MklConv2D);D(Input);DMT/_0(Const);" 2081 "DMT/_1(Const);E(Input);F(BiasAdd)|A->C;A:control->DMT/_0:control;" 2082 "A:control->DMT/_1:control;B->C:1;D->F;DMT/_0->C:2;DMT/_1->C:3;" 2083 "E->F:1"); 2084 } 2085 2086 // Conv2D has two outgoing edges: BiasAdd and some other dummy node (Zeta). 2087 // Merge should not be done in such case. 2088 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) { 2089 InitGraph( 2090 "node { name: 'A' op: 'Input'}" 2091 "node { name: 'B' op: 'Input'}" 2092 "node { name: 'C' op: 'Conv2D'" 2093 " attr { key: 'T' value { type: DT_FLOAT } }" 2094 " attr { key: 'data_format' value { s: 'NCHW' } }" 2095 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2096 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2097 " attr { key: 'padding' value { s: 'SAME' } }" 2098 " input: ['A', 'B']}" 2099 "node { name: 'D' op: 'Input'}" 2100 "node { name: 'E' op: 'Input'}" 2101 "node { name: 'F' op: 'BiasAdd'" 2102 " attr { key: 'T' value { type: DT_FLOAT } }" 2103 " attr { key: 'data_format' value { s: 'NCHW' } }" 2104 " input: ['D', 'E'] }" // Conv2D has two outputs. 2105 // No merge should happen. 2106 "node { name: 'G' op: 'Zeta'" 2107 " attr { key: 'T' value { type: DT_FLOAT } }" 2108 " input: ['C', 'E'] }"); 2109 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2110 "A(Input);B(Input);C(_MklConv2D);D(Input);DMT/_0(Const);" 2111 "DMT/_1(Const);E(Input);F(BiasAdd);G(Zeta)|A->C;" 2112 "A:control->DMT/_0:control;A:control->DMT/_1:control;B->C:1;C->G;" 2113 "D->F;DMT/_0->C:2;DMT/_1->C:3;E->F:1;E->G:1"); 2114 } 2115 2116 // data_format attribute value mismatch. Merge should not be done 2117 // in such case. 2118 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) { 2119 InitGraph( 2120 "node { name: 'A' op: 'Input'}" 2121 "node { name: 'B' op: 'Input'}" 2122 "node { name: 'C' op: 'Conv2D'" 2123 " attr { key: 'T' value { type: DT_FLOAT } }" 2124 " attr { key: 'data_format' value { s: 'NCHW' } }" 2125 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2126 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2127 " attr { key: 'padding' value { s: 'SAME' } }" 2128 " input: ['A', 'B']}" 2129 "node { name: 'D' op: 'Input'}" 2130 "node { name: 'E' op: 'BiasAdd'" 2131 " attr { key: 'T' value { type: DT_FLOAT } }" 2132 " attr { key: 'data_format' value { s: 'NHCW' } }" 2133 " input: ['C', 'D'] }"); 2134 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2135 "A(Input);B(Input);C(_MklConv2D);D(Input);DMT/_0(Const);" 2136 "DMT/_1(Const);E(BiasAdd)|A->C;A:control->DMT/_0:control;" 2137 "A:control->DMT/_1:control;B->C:1;C->E;D->E:1;DMT/_0->C:2;" 2138 "DMT/_1->C:3"); 2139 } 2140 2141 // Test set 2: BiasAddGrad + Conv2DBackpropFilter fusion tests 2142 2143 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackpropFilterFusion_Positive) { 2144 InitGraph( 2145 "node { name: 'A' op: 'Input'}" 2146 "node { name: 'B' op: 'Int32Input'}" 2147 "node { name: 'C' op: 'Input'}" 2148 "node { name: 'D' op: 'Conv2DBackpropFilter'" 2149 " attr { key: 'T' value { type: DT_FLOAT } }" 2150 " attr { key: 'data_format' value { s: 'NCHW' } }" 2151 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2152 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2153 " attr { key: 'padding' value { s: 'SAME' } }" 2154 " input: ['A', 'B', 'C'] }" 2155 "node { name: 'E' op: 'BiasAddGrad'" 2156 " attr { key: 'T' value { type: DT_FLOAT } }" 2157 " attr { key: 'data_format' value { s: 'NCHW' } }" 2158 " input: ['C'] }"); 2159 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2160 "A(Input);B(Int32Input);C(Input);" 2161 "D(_MklConv2DBackpropFilterWithBias);DMT/_0(Const);DMT/_1(Const);" 2162 "DMT/_2(Const)|A->D;A:control->DMT/_0:control;" 2163 "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;C->D:2;" 2164 "DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); 2165 } 2166 2167 // BiasAddGrad fusion in the presence of BackpropFilter. But nodes do not match 2168 // criteria for rewrite. So rewrite should not happen. 3rd input of 2169 // Conv2DBackpropFilter is different than input to BiasAddGrad. 2170 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackpropFilterFusion_Negative1) { 2171 InitGraph( 2172 "node { name: 'A' op: 'Input'}" 2173 "node { name: 'B' op: 'Int32Input'}" 2174 "node { name: 'C' op: 'Input'}" 2175 "node { name: 'D' op: 'Conv2DBackpropFilter'" 2176 " attr { key: 'T' value { type: DT_FLOAT } }" 2177 " attr { key: 'data_format' value { s: 'NCHW' } }" 2178 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2179 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2180 " attr { key: 'padding' value { s: 'SAME' } }" 2181 " input: ['A', 'B', 'C'] }" 2182 "node { name: 'E' op: 'BiasAddGrad'" 2183 " attr { key: 'T' value { type: DT_FLOAT } }" 2184 " attr { key: 'data_format' value { s: 'NCHW' } }" 2185 " input: ['A'] }"); 2186 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2187 "A(Input);B(Int32Input);C(Input);" 2188 "D(_MklConv2DBackpropFilter);DMT/_0(Const);DMT/_1(Const);" 2189 "DMT/_2(Const);E(BiasAddGrad)|A->D;A->E;A:control->DMT/_0:control;" 2190 "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;C->D:2;" 2191 "DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); 2192 } 2193 2194 // BiasAddGrad fusion, but nodes do not match criteria for fusion. 2195 // Different input formats. 2196 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackpropFilterFusion_Negative2) { 2197 InitGraph( 2198 "node { name: 'A' op: 'Input'}" 2199 "node { name: 'B' op: 'Int32Input'}" 2200 "node { name: 'C' op: 'Input'}" 2201 "node { name: 'D' op: 'Conv2DBackpropFilter'" 2202 " attr { key: 'T' value { type: DT_FLOAT } }" 2203 " attr { key: 'data_format' value { s: 'NCHW' } }" 2204 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2205 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2206 " attr { key: 'padding' value { s: 'SAME' } }" 2207 " input: ['A', 'B', 'C'] }" 2208 "node { name: 'E' op: 'BiasAddGrad'" 2209 " attr { key: 'T' value { type: DT_FLOAT } }" 2210 " attr { key: 'data_format' value { s: 'NHWC' } }" 2211 " input: ['A'] }"); 2212 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2213 "A(Input);B(Int32Input);C(Input);" 2214 "D(_MklConv2DBackpropFilter);DMT/_0(Const);DMT/_1(Const);" 2215 "DMT/_2(Const);E(BiasAddGrad)|A->D;A->E;A:control->DMT/_0:control;" 2216 "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;C->D:2;" 2217 "DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); 2218 } 2219 2220 // BiasAddGrad fusion in the presence of BackpropFilter only. Fusion is done 2221 // before node rewrite. Check this ordering. 2222 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackpropFilterFusion_Negative3) { 2223 InitGraph( 2224 "node { name: 'A' op: 'Input'}" 2225 "node { name: 'B' op: 'Input'}" 2226 "node { name: 'C' op: 'Input'}" 2227 "node { name: 'M' op: '_MklInput'}" 2228 "node { name: 'N' op: '_MklInput'}" 2229 "node { name: 'O' op: '_MklInput'}" 2230 "node { name: 'D' op: '_MklConv2DWithBias'" 2231 " attr { key: 'T' value { type: DT_FLOAT } }" 2232 " attr { key: 'data_format' value { s: 'NCHW' } }" 2233 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2234 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2235 " attr { key: 'padding' value { s: 'SAME' } }" 2236 " input: ['A', 'B', 'C', 'M', 'N', 'O']}" 2237 "node { name: 'E' op: 'Zeta'" 2238 " attr {key: 'T' value { type: DT_FLOAT } }" 2239 " input: ['D', 'A']}" 2240 "node { name: 'F' op: 'Int32Input'}" 2241 "node { name: 'G' op: '_MklConv2DBackpropFilter'" 2242 " attr { key: 'T' value { type: DT_FLOAT } }" 2243 " attr { key: 'data_format' value { s: 'NCHW' } }" 2244 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2245 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2246 " attr { key: 'padding' value { s: 'SAME' } }" 2247 " input: ['E', 'F', 'A', 'M', 'N', 'O'] }" 2248 "node { name: 'H' op: 'BiasAddGrad'" 2249 " attr { key: 'T' value { type: DT_FLOAT } }" 2250 " attr { key: 'data_format' value { s: 'NCHW' } }" 2251 " input: ['E'] }"); 2252 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2253 "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" 2254 "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);" 2255 "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;" 2256 "C->D:2;D->E;E->G;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;" 2257 "O->G:5"); 2258 } 2259 2260 // C=Conv2D(A,B); E=BiasAdd(C,D); Y=Zeta(E,X); 2261 // G=Conv2DBackpropInput(F,B,E) 2262 // This is a case of node rewrite followed by node merge followed by connecting 2263 // filter output of Conv2DWithBias to filter input of Conv2DBackpropInput. 2264 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_ConvBpropInput_FilterFwd) { 2265 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); 2266 InitGraph( 2267 "node { name: 'A' op: 'Input'}" 2268 "node { name: 'B' op: 'Input'}" 2269 "node { name: 'C' op: 'Conv2D'" 2270 " attr { key: 'T' value { type: DT_FLOAT } }" 2271 " attr { key: 'data_format' value { s: 'NCHW' } }" 2272 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2273 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2274 " attr { key: 'padding' value { s: 'SAME' } }" 2275 " input: ['A', 'B']}" 2276 "node { name: 'D' op: 'Input'}" 2277 "node { name: 'E' op: 'BiasAdd'" 2278 " attr { key: 'T' value { type: DT_FLOAT } }" 2279 " attr { key: 'data_format' value { s: 'NCHW' } }" 2280 " input: ['C', 'D'] }" 2281 "node { name: 'X' op: 'Input'}" 2282 "node { name: 'Y' op: 'Zeta'" 2283 " attr {key: 'T' value { type: DT_FLOAT } }" 2284 " input: ['E', 'X']}" 2285 "node { name: 'F' op: 'Int32Input'}" 2286 "node { name: 'G' op: 'Conv2DBackpropInput'" 2287 " attr { key: 'T' value { type: DT_FLOAT } }" 2288 " attr { key: 'data_format' value { s: 'NCHW' } }" 2289 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2290 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2291 " attr { key: 'padding' value { s: 'SAME' } }" 2292 " input: ['F', 'B', 'E']}" 2293 "node { name: 'Z' op: 'Zeta'" 2294 " attr {key: 'T' value { type: DT_FLOAT } }" 2295 " input: ['G', 'X']}"); 2296 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2297 "A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 2298 "DMT/_2(Const);DMT/_3(Const);E(_MklConv2DWithBias);F(Int32Input);" 2299 "G(_MklConv2DBackpropInput);X(Input);Y(Zeta);Z(Zeta)|" 2300 "A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" 2301 "A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;" 2302 "DMT/_1->E:4;DMT/_2->E:5;DMT/_3->G:3;E->G:2;E->Y;E:1->G:1;E:2->G:5;" 2303 "E:3->G:4;F->G;F:control->DMT/_3:control;G->Z;X->Y:1;X->Z:1"); 2304 } 2305 2306 ///////////////////////////////////////////////////////////////////// 2307 // Unit tests related to rewriting node to Mkl node 2308 ///////////////////////////////////////////////////////////////////// 2309 2310 // Single Conv2D Op; No Mkl layer on the input and on the output. 2311 // We will generate dummy Mkl tensor as 2nd input of Conv2D. 2312 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) { 2313 InitGraph( 2314 "node { name: 'A' op: 'Input'}" 2315 "node { name: 'B' op: 'Input'}" 2316 "node { name: 'C' op: 'Conv2D'" 2317 " attr { key: 'T' value { type: DT_FLOAT } }" 2318 " attr { key: 'data_format' value { s: 'NCHW' } }" 2319 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2320 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2321 " attr { key: 'padding' value { s: 'SAME' } }" 2322 " input: ['A', 'B']}" 2323 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2324 " input: ['B', 'C'] }"); 2325 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2326 "A(Input);B(Input);C(_MklConv2D);D(Zeta);DMT/_0(Const);" 2327 "DMT/_1(Const)|A->C;A:control->DMT/_0:control;" 2328 "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;" 2329 "DMT/_1->C:3"); 2330 } 2331 2332 // 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will 2333 // have 2 outputs, both of which will be inputs to next Conv2D. 2334 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) { 2335 InitGraph( 2336 "node { name: 'A' op: 'Input'}" 2337 "node { name: 'B' op: 'Input'}" 2338 "node { name: 'C' op: 'Conv2D'" 2339 " attr { key: 'T' value { type: DT_FLOAT } }" 2340 " attr { key: 'data_format' value { s: 'NCHW' } }" 2341 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2342 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2343 " attr { key: 'padding' value { s: 'SAME' } }" 2344 " input: ['A', 'B']}" 2345 "node { name: 'D' op: 'Conv2D'" 2346 " attr { key: 'T' value { type: DT_FLOAT } }" 2347 " attr { key: 'data_format' value { s: 'NCHW' } }" 2348 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2349 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2350 " attr { key: 'padding' value { s: 'SAME' } }" 2351 " input: ['A', 'C']}" 2352 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2353 " input: ['C', 'D'] }"); 2354 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2355 "A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);" 2356 "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->C;A->D;" 2357 "A:control->DMT/_0:control;A:control->DMT/_1:control;" 2358 "A:control->DMT/_2:control;B->C:1;C->D:1;C->E;" 2359 "C:2->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2"); 2360 } 2361 2362 // Conv2D with INT32 which is not supported by Mkl 2363 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) { 2364 InitGraph( 2365 "node { name: 'A' op: 'HalfInput'}" 2366 "node { name: 'B' op: 'HalfInput'}" 2367 "node { name: 'C' op: 'Conv2D'" 2368 " attr { key: 'T' value { type: DT_HALF } }" 2369 " attr { key: 'data_format' value { s: 'NCHW' } }" 2370 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2371 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2372 " attr { key: 'padding' value { s: 'SAME' } }" 2373 " input: ['A', 'B']}" 2374 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_HALF } }" 2375 " input: ['B', 'C'] }"); 2376 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2377 "A(HalfInput);B(HalfInput);C(Conv2D);D(Zeta)|" 2378 "A->C;B->C:1;B->D;C->D:1"); 2379 } 2380 2381 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_Positive) { 2382 InitGraph( 2383 "node { name: 'A' op: 'Input'}" 2384 "node { name: 'B' op: 'Int32Input'}" 2385 "node { name: 'C' op: 'Input'}" 2386 "node { name: 'D' op: 'Conv2DBackpropFilter'" 2387 " attr { key: 'T' value { type: DT_FLOAT } }" 2388 " attr { key: 'data_format' value { s: 'NCHW' } }" 2389 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2390 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2391 " attr { key: 'padding' value { s: 'SAME' } }" 2392 " input: ['A', 'B', 'C']}" 2393 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2394 " input: ['A', 'D'] }"); 2395 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2396 "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropFilter);" 2397 "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|" 2398 "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" 2399 "A:control->DMT/_2:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;" 2400 "DMT/_1->D:4;DMT/_2->D:5"); 2401 } 2402 2403 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradInput_Positive) { 2404 InitGraph( 2405 "node { name: 'A' op: 'Input'}" 2406 "node { name: 'B' op: 'Int32Input'}" 2407 "node { name: 'C' op: 'Input'}" 2408 "node { name: 'D' op: 'Conv2DBackpropInput'" 2409 " attr { key: 'T' value { type: DT_FLOAT } }" 2410 " attr { key: 'data_format' value { s: 'NCHW' } }" 2411 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2412 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2413 " attr { key: 'padding' value { s: 'SAME' } }" 2414 " input: ['B', 'A', 'C']}" 2415 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2416 " input: ['A', 'D'] }"); 2417 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2418 "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropInput);" 2419 "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|" 2420 "A->D:1;A->E;B->D;B:control->DMT/_0:control;" 2421 "B:control->DMT/_1:control;B:control->DMT/_2:control;C->D:2;" 2422 "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); 2423 } 2424 2425 // Check that we never rewrite BiasAddGrad. 2426 TEST_F(MklLayoutPassTest, NodeRewrite_BiasAddGrad_Positive) { 2427 InitGraph( 2428 "node { name: 'A' op: 'Input'}" 2429 "node { name: 'B' op: 'Input'}" 2430 "node { name: 'C' op: 'Polygamma'" 2431 " attr { key: 'T' value { type: DT_FLOAT } }" 2432 " input: ['A', 'B']}" 2433 "node { name: 'D' op: 'Zeta'" 2434 " attr {key: 'T' value { type: DT_FLOAT } }" 2435 " input: ['C', 'A']}" 2436 "node { name: 'E' op: 'BiasAddGrad'" 2437 " attr { key: 'T' value { type: DT_FLOAT } }" 2438 " attr { key: 'data_format' value { s: 'NCHW' } }" 2439 " input: ['D'] }"); 2440 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2441 "A(Input);B(Input);C(Polygamma);D(Zeta);E(BiasAddGrad)|" 2442 "A->C;A->D:1;B->C:1;C->D;D->E"); 2443 } 2444 2445 // Check that we never rewrite BiasAddGrad. 2446 TEST_F(MklLayoutPassTest, NodeRewrite_BiasAddGrad_Positive1) { 2447 InitGraph( 2448 "node { name: 'A' op: 'Input'}" 2449 "node { name: 'B' op: 'Input'}" 2450 "node { name: 'C' op: 'MatMul'" 2451 " attr { key: 'T' value { type: DT_FLOAT } }" 2452 " attr { key: 'transpose_a' value { b: false } }" 2453 " attr { key: 'transpose_b' value { b: false } }" 2454 " input: ['A', 'B']}" 2455 "node { name: 'D' op: 'Zeta'" 2456 " attr {key: 'T' value { type: DT_FLOAT } }" 2457 " input: ['C', 'A']}" 2458 "node { name: 'E' op: 'BiasAddGrad'" 2459 " attr { key: 'T' value { type: DT_FLOAT } }" 2460 " attr { key: 'data_format' value { s: 'NCHW' } }" 2461 " input: ['D'] }"); 2462 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2463 "A(Input);B(Input);C(MatMul);D(Zeta);E(BiasAddGrad)|" 2464 "A->C;A->D:1;B->C:1;C->D;D->E"); 2465 } 2466 2467 // Check that we never rewrite BiasAddGrad. 2468 TEST_F(MklLayoutPassTest, NodeRewrite_BiasAddGrad_Positive2) { 2469 InitGraph( 2470 "node { name: 'A' op: 'Input'}" 2471 "node { name: 'B' op: 'Input'}" 2472 "node { name: 'M' op: '_MklInput'}" 2473 "node { name: 'N' op: '_MklInput'}" 2474 "node { name: 'C' op: '_MklConv2D'" 2475 " attr { key: 'T' value { type: DT_FLOAT } }" 2476 " attr { key: 'data_format' value { s: 'NCHW' } }" 2477 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2478 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2479 " attr { key: 'padding' value { s: 'SAME' } }" 2480 " input: ['A', 'B', 'M', 'N']}" 2481 "node { name: 'D' op: 'Zeta'" 2482 " attr {key: 'T' value { type: DT_FLOAT } }" 2483 " input: ['C', 'A']}" 2484 "node { name: 'E' op: 'BiasAddGrad'" 2485 " attr { key: 'T' value { type: DT_FLOAT } }" 2486 " attr { key: 'data_format' value { s: 'NCHW' } }" 2487 " input: ['D'] }"); 2488 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2489 "A(Input);B(Input);C(_MklConv2D);D(Zeta);E(BiasAddGrad);" 2490 "M(_MklInput);N(_MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;" 2491 "M->C:2;N->C:3"); 2492 } 2493 2494 // Concat Op test: Concat with no Mkl layer feeding it 2495 TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) { 2496 InitGraph( 2497 "node { name: 'A' op: 'Const' " 2498 " attr { key: 'dtype' value { type: DT_INT32 } }" 2499 " attr { key: 'value' value { " 2500 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " 2501 " int_val: 0 } } } }" 2502 "node { name: 'B' op: 'InputList'" 2503 " attr { key: 'N' value { i: 2 } }}" 2504 "node { name: 'C' op: 'Input'}" 2505 "node { name: 'D' op: 'Concat'" 2506 " attr { key: 'T' value { type: DT_FLOAT } }" 2507 " attr { key: 'N' value { i: 2 } }" 2508 " input: ['A', 'B:0', 'B:1']}" 2509 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2510 " input: ['C', 'D'] }"); 2511 EXPECT_EQ( 2512 DoMklLayoutOptimizationPass(), 2513 "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" 2514 "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" 2515 "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" 2516 "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); 2517 } 2518 2519 // Concat with 2 Mkl layers feeding it 2520 TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) { 2521 InitGraph( 2522 "node { name: 'A' op: 'Input'}" 2523 "node { name: 'B' op: 'Input'}" 2524 "node { name: 'C' op: 'Input'}" 2525 "node { name: 'D' op: 'Input'}" 2526 "node { name: 'E' op: 'Conv2D'" 2527 " attr { key: 'T' value { type: DT_FLOAT } }" 2528 " attr { key: 'data_format' value { s: 'NCHW' } }" 2529 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2530 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2531 " attr { key: 'padding' value { s: 'SAME' } }" 2532 " input: ['A', 'B']}" 2533 "node { name: 'F' op: 'Conv2D'" 2534 " attr { key: 'T' value { type: DT_FLOAT } }" 2535 " attr { key: 'data_format' value { s: 'NCHW' } }" 2536 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2537 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2538 " attr { key: 'padding' value { s: 'SAME' } }" 2539 " input: ['C', 'D']}" 2540 "node { name: 'G' op: 'Const' " 2541 " attr { key: 'dtype' value { type: DT_INT32 } }" 2542 " attr { key: 'value' value { " 2543 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " 2544 " int_val: 0 } } } }" 2545 "node { name: 'H' op: 'Concat'" 2546 " attr { key: 'T' value { type: DT_FLOAT } }" 2547 " attr { key: 'N' value { i: 2 } }" 2548 " input: ['G', 'E', 'F']}" 2549 "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2550 " input: ['A', 'H'] }"); 2551 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2552 "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 2553 "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);" 2554 "F(_MklConv2D);G(Const);H(_MklConcat);I(Zeta)|A->E;A->I;" 2555 "A:control->DMT/_2:control;A:control->DMT/_3:control;" 2556 "B->E:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;" 2557 "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;" 2558 "DMT/_4->H:3;E->H:1;E:2->H:4;F->H:2;F:2->H:5;G->H;" 2559 "G:control->DMT/_4:control;H->I:1"); 2560 } 2561 2562 // Concat with 1 Mkl and 1 non-Mkl layer feeding it 2563 TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) { 2564 InitGraph( 2565 "node { name: 'A' op: 'Input'}" 2566 "node { name: 'B' op: 'Input'}" 2567 "node { name: 'C' op: 'Input'}" 2568 "node { name: 'D' op: 'Input'}" 2569 "node { name: 'E' op: 'Conv2D'" 2570 " attr { key: 'T' value { type: DT_FLOAT } }" 2571 " attr { key: 'data_format' value { s: 'NCHW' } }" 2572 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2573 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2574 " attr { key: 'padding' value { s: 'SAME' } }" 2575 " input: ['A', 'B']}" 2576 "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2577 " input: ['C', 'D']}" 2578 "node { name: 'G' op: 'Const' " 2579 " attr { key: 'dtype' value { type: DT_INT32 } }" 2580 " attr { key: 'value' value { " 2581 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " 2582 " int_val: 0 } } } }" 2583 "node { name: 'H' op: 'Concat'" 2584 " attr { key: 'T' value { type: DT_FLOAT } }" 2585 " attr { key: 'N' value { i: 2 } }" 2586 " input: ['G', 'E', 'F']}" 2587 "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2588 " input: ['A', 'H'] }"); 2589 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2590 "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 2591 "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);" 2592 "H(_MklConcat);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;" 2593 "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;" 2594 "DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:2->H:4;F->H:2;" 2595 "G->H;G:control->DMT/_2:control;G:control->DMT/_3:control;H->I:1"); 2596 } 2597 2598 // ConcatV2 Op test: ConcatV2 with no Mkl layer feeding it 2599 TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) { 2600 InitGraph( 2601 "node { name: 'A' op: 'Const' " 2602 " attr { key: 'dtype' value { type: DT_INT32 } }" 2603 " attr { key: 'value' value { " 2604 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " 2605 " int_val: 0 } } } }" 2606 "node { name: 'B' op: 'InputList'" 2607 " attr { key: 'N' value { i: 2 } }}" 2608 "node { name: 'C' op: 'Input'}" 2609 "node { name: 'D' op: 'ConcatV2'" 2610 " attr { key: 'T' value { type: DT_FLOAT } }" 2611 " attr { key: 'Tidx' value { type: DT_INT32 } }" 2612 " attr { key: 'N' value { i: 2 } }" 2613 " input: ['B:0', 'B:1', 'A']}" 2614 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2615 " input: ['C', 'D'] }"); 2616 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2617 "A(Const);B(InputList);C(Input);D(_MklConcatV2);DMT/_0(Const);" 2618 "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D:2;B->D;B:1->D:1;" 2619 "B:control->DMT/_0:control;B:control->DMT/_1:control;" 2620 "B:control->DMT/_2:control;C->E;D->E:1;DMT/_0->D:3;" 2621 "DMT/_1->D:4;DMT/_2->D:5"); 2622 } 2623 2624 // ConcatV2 with 2 Mkl layers feeding it 2625 TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) { 2626 InitGraph( 2627 "node { name: 'A' op: 'Input'}" 2628 "node { name: 'B' op: 'Input'}" 2629 "node { name: 'C' op: 'Input'}" 2630 "node { name: 'D' op: 'Input'}" 2631 "node { name: 'E' op: 'Conv2D'" 2632 " attr { key: 'T' value { type: DT_FLOAT } }" 2633 " attr { key: 'data_format' value { s: 'NCHW' } }" 2634 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2635 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2636 " attr { key: 'padding' value { s: 'SAME' } }" 2637 " input: ['A', 'B']}" 2638 "node { name: 'F' op: 'Conv2D'" 2639 " attr { key: 'T' value { type: DT_FLOAT } }" 2640 " attr { key: 'data_format' value { s: 'NCHW' } }" 2641 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2642 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2643 " attr { key: 'padding' value { s: 'SAME' } }" 2644 " input: ['C', 'D']}" 2645 "node { name: 'G' op: 'Const' " 2646 " attr { key: 'dtype' value { type: DT_INT32 } }" 2647 " attr { key: 'value' value { " 2648 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " 2649 " int_val: 0 } } } }" 2650 "node { name: 'H' op: 'ConcatV2'" 2651 " attr { key: 'T' value { type: DT_FLOAT } }" 2652 " attr { key: 'Tidx' value { type: DT_INT32 } }" 2653 " attr { key: 'N' value { i: 2 } }" 2654 " input: ['E', 'F', 'G']}" 2655 "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2656 " input: ['A', 'H'] }"); 2657 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2658 "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 2659 "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);" 2660 "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Zeta)|A->E;A->I;" 2661 "A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;C->F;" 2662 "C:control->DMT/_0:control;C:control->DMT/_1:control;" 2663 "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;" 2664 "DMT/_4->H:5;E->H;E:2->H:3;E:control->DMT/_4:control;F->H:1;" 2665 "F:2->H:4;G->H:2;H->I:1"); 2666 } 2667 2668 // ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it 2669 TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) { 2670 InitGraph( 2671 "node { name: 'A' op: 'Input'}" 2672 "node { name: 'B' op: 'Input'}" 2673 "node { name: 'C' op: 'Input'}" 2674 "node { name: 'D' op: 'Input'}" 2675 "node { name: 'E' op: 'Conv2D'" 2676 " attr { key: 'T' value { type: DT_FLOAT } }" 2677 " attr { key: 'data_format' value { s: 'NCHW' } }" 2678 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 2679 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 2680 " attr { key: 'padding' value { s: 'SAME' } }" 2681 " input: ['A', 'B']}" 2682 "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2683 " input: ['C', 'D']}" 2684 "node { name: 'G' op: 'Const' " 2685 " attr { key: 'dtype' value { type: DT_INT32 } }" 2686 " attr { key: 'value' value { " 2687 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " 2688 " int_val: 0 } } } }" 2689 "node { name: 'H' op: 'ConcatV2'" 2690 " attr { key: 'T' value { type: DT_FLOAT } }" 2691 " attr { key: 'Tidx' value { type: DT_INT32 } }" 2692 " attr { key: 'N' value { i: 2 } }" 2693 " input: ['E', 'F', 'G']}" 2694 "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2695 " input: ['A', 'H'] }"); 2696 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2697 "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 2698 "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);" 2699 "H(_MklConcatV2);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;" 2700 "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;" 2701 "DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:2->H:3;" 2702 "E:control->DMT/_2:control;E:control->DMT/_3:control;F->H:1;" 2703 "G->H:2;H->I:1"); 2704 } 2705 2706 TEST_F(MklLayoutPassTest, NodeRewrite_Relu_Positive) { 2707 InitGraph( 2708 "node { name: 'A' op: 'Input'}" 2709 "node { name: 'B' op: 'Relu'" 2710 " attr { key: 'T' value { type: DT_FLOAT } }" 2711 " input: ['A'] }" 2712 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2713 " input: ['A', 'B'] }"); 2714 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2715 "A(Input);B(_MklRelu);C(Zeta);DMT/_0(Const)|A->B;A->C;" 2716 "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); 2717 } 2718 2719 TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_Positive) { 2720 InitGraph( 2721 "node { name: 'A' op: 'Input'}" 2722 "node { name: 'B' op: 'Input'}" 2723 "node { name: 'C' op: 'ReluGrad'" 2724 " attr { key: 'T' value { type: DT_FLOAT } }" 2725 " input: ['A', 'B'] }" 2726 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2727 " input: ['A', 'C'] }"); 2728 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2729 "A(Input);B(Input);C(_MklReluGrad);D(Zeta);DMT/_0(Const);" 2730 "DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;" 2731 "A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3"); 2732 } 2733 2734 TEST_F(MklLayoutPassTest, NodeRewrite_ReluReluGrad_Positive) { 2735 InitGraph( 2736 "node { name: 'A' op: 'Input'}" 2737 "node { name: 'B' op: 'Relu'" 2738 " attr { key: 'T' value { type: DT_FLOAT } }" 2739 " input: ['A'] }" 2740 "node { name: 'C' op: 'ReluGrad'" 2741 " attr { key: 'T' value { type: DT_FLOAT } }" 2742 " input: ['A', 'B'] }" 2743 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2744 " input: ['A', 'C'] }"); 2745 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2746 "A(Input);B(_MklRelu);C(_MklReluGrad);D(Zeta);DMT/_0(Const);" 2747 "DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;" 2748 "A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;" 2749 "DMT/_1->C:2"); 2750 } 2751 2752 TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) { 2753 InitGraph( 2754 "node { name: 'A' op: 'Input'}" 2755 "node { name: 'B' op: 'AvgPool'" 2756 " attr { key: 'T' value { type: DT_FLOAT } }" 2757 " attr { key: 'data_format' value { s: 'NCHW' } }" 2758 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 2759 " attr { key: 'padding' value { s: 'VALID' } }" 2760 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 2761 " input: ['A'] }" 2762 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2763 " input: ['A', 'B'] }"); 2764 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2765 "A(Input);B(_MklAvgPool);C(Zeta);DMT/_0(Const)|A->B;A->C;" 2766 "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); 2767 } 2768 2769 TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolGrad_Positive) { 2770 InitGraph( 2771 "node { name: 'A' op: 'Int32Input'}" 2772 "node { name: 'B' op: 'Input'}" 2773 "node { name: 'C' op: 'AvgPoolGrad' " 2774 " attr { key: 'T' value { type: DT_FLOAT } }" 2775 " attr { key: 'data_format' value { s: 'NCHW' } }" 2776 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 2777 " attr { key: 'padding' value { s: 'VALID' } }" 2778 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 2779 " input: ['A', 'B'] }" 2780 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2781 " input: ['B', 'C'] }"); 2782 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2783 "A(Int32Input);B(Input);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);" 2784 "DMT/_1(Const)|A->C;A:control->DMT/_0:control;" 2785 "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;" 2786 "DMT/_1->C:3"); 2787 } 2788 2789 TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolAvgPoolGrad_Positive) { 2790 InitGraph( 2791 "node { name: 'A' op: 'Input'}" 2792 "node { name: 'I' op: 'Int32Input'}" 2793 "node { name: 'B' op: 'AvgPool'" 2794 " attr { key: 'T' value { type: DT_FLOAT } }" 2795 " attr { key: 'data_format' value { s: 'NCHW' } }" 2796 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 2797 " attr { key: 'padding' value { s: 'VALID' } }" 2798 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 2799 " input: ['A'] }" 2800 "node { name: 'C' op: 'AvgPoolGrad' " 2801 " attr { key: 'T' value { type: DT_FLOAT } }" 2802 " attr { key: 'data_format' value { s: 'NCHW' } }" 2803 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 2804 " attr { key: 'padding' value { s: 'VALID' } }" 2805 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 2806 " input: ['I', 'B'] }" 2807 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2808 " input: ['A', 'C'] }"); 2809 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2810 "A(Input);B(_MklAvgPool);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);" 2811 "DMT/_1(Const);I(Int32Input)|A->B;A->D;A:control->DMT/_0:control;" 2812 "B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;DMT/_1->C:2;I->C;" 2813 "I:control->DMT/_1:control"); 2814 } 2815 2816 TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormGrad_Positive) { 2817 InitGraph( 2818 "node { name: 'A' op: 'Input'}" 2819 "node { name: 'B' op: 'Input'}" 2820 "node { name: 'C' op: 'Input'}" 2821 "node { name: 'D' op: 'Input'}" 2822 "node { name: 'E' op: 'Input'}" 2823 "node { name: 'F' op: 'FusedBatchNormGrad'" 2824 " attr { key: 'T' value { type: DT_FLOAT } }" 2825 " attr { key: 'data_format' value { s: 'NCHW' } }" 2826 " attr { key: 'epsilon' value { f: 0.0001 } }" 2827 " attr { key: 'is_training' value { b: true } }" 2828 " input: ['A', 'B', 'C', 'D', 'E'] }" 2829 "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2830 " input: ['A', 'F'] }"); 2831 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2832 "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 2833 "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);" 2834 "F(_MklFusedBatchNormGrad);G(Zeta)|A->F;A->G;" 2835 "A:control->DMT/_0:control;A:control->DMT/_1:control;" 2836 "A:control->DMT/_2:control;A:control->DMT/_3:control;" 2837 "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" 2838 "DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" 2839 "E->F:4;F->G:1"); 2840 } 2841 2842 TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_Positive) { 2843 InitGraph( 2844 "node { name: 'A' op: 'Input'}" 2845 "node { name: 'B' op: 'Input'}" 2846 "node { name: 'C' op: 'Input'}" 2847 "node { name: 'D' op: 'Input'}" 2848 "node { name: 'E' op: 'Input'}" 2849 "node { name: 'F' op: 'FusedBatchNorm'" 2850 " attr { key: 'T' value { type: DT_FLOAT } }" 2851 " attr { key: 'data_format' value { s: 'NCHW' } }" 2852 " attr { key: 'epsilon' value { f: 0.0001 } }" 2853 " attr { key: 'is_training' value { b: true } }" 2854 " input: ['A', 'B', 'C', 'D', 'E'] }" 2855 "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2856 " input: ['A', 'F'] }"); 2857 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2858 "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 2859 "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);" 2860 "F(_MklFusedBatchNorm);G(Zeta)|A->F;A->G;" 2861 "A:control->DMT/_0:control;A:control->DMT/_1:control;" 2862 "A:control->DMT/_2:control;A:control->DMT/_3:control;" 2863 "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" 2864 "DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" 2865 "E->F:4;F->G:1"); 2866 } 2867 2868 ///////////////////////////////////////////////////////////////////// 2869 // Unit tests related to rewriting node for workspace edges 2870 ///////////////////////////////////////////////////////////////////// 2871 2872 /* Test LRN->MaxPool->MaxPoolGrad->LRNGrad replacement by workspace nodes. */ 2873 TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) { 2874 InitGraph( 2875 "node { name: 'A' op: 'Input'}" 2876 "node { name: 'B' op: 'LRN'" 2877 " attr { key: 'T' value { type: DT_FLOAT } }" 2878 " attr { key: 'alpha' value { f: 0.001 } }" 2879 " attr { key: 'beta' value { f: 0.75 } }" 2880 " attr { key: 'bias' value { f: 1.0 } }" 2881 " attr { key: 'data_format' value { s: 'NCHW' } }" 2882 " attr { key: 'depth_radius' value { i: 2 } }" 2883 " input: ['A'] }" 2884 "node { name: 'C' op: 'MaxPool'" 2885 " attr { key: 'T' value { type: DT_FLOAT } }" 2886 " attr { key: 'data_format' value { s: 'NCHW' } }" 2887 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 2888 " attr { key: 'padding' value { s: 'VALID' } }" 2889 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 2890 " input: ['B'] }" 2891 "node { name: 'D' op: 'Input'}" 2892 "node { name: 'E' op: 'MaxPoolGrad'" 2893 " attr { key: 'T' value { type: DT_FLOAT } }" 2894 " attr { key: 'data_format' value { s: 'NCHW' } }" 2895 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 2896 " attr { key: 'padding' value { s: 'VALID' } }" 2897 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 2898 " input: ['B', 'C', 'D'] }" 2899 "node { name: 'F' op: 'Input'}" 2900 "node { name: 'G' op: 'LRNGrad'" 2901 " attr { key: 'T' value { type: DT_FLOAT } }" 2902 " attr { key: 'alpha' value { f: 0.001 } }" 2903 " attr { key: 'beta' value { f: 0.75 } }" 2904 " attr { key: 'bias' value { f: 1.0 } }" 2905 " attr { key: 'data_format' value { s: 'NCHW' } }" 2906 " attr { key: 'depth_radius' value { i: 2 } }" 2907 " input: ['E', 'F', 'B'] }" 2908 "node { name: 'H' op: 'Input'}" 2909 "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2910 " input: ['H', 'G'] }"); 2911 EXPECT_EQ( 2912 DoMklLayoutOptimizationPass(), 2913 "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);" 2914 "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);" 2915 "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;" 2916 "B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;B:control->DMT/_1:control;C->E:1;" 2917 "C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;DMT/_2->G:5;" 2918 "E->G;E:1->G:4;E:control->DMT/_2:control;F->G:1;G->I:1;H->I"); 2919 } 2920 2921 /* Test LRN->LRNGrad replacement by workspace nodes. */ 2922 TEST_F(MklLayoutPassTest, LRN_Positive) { 2923 InitGraph( 2924 "node { name: 'A' op: 'Input'}" 2925 "node { name: 'B' op: 'LRN'" 2926 " attr { key: 'T' value { type: DT_FLOAT } }" 2927 " attr { key: 'alpha' value { f: 0.001 } }" 2928 " attr { key: 'beta' value { f: 0.75 } }" 2929 " attr { key: 'bias' value { f: 1.0 } }" 2930 " attr { key: 'data_format' value { s: 'NCHW' } }" 2931 " attr { key: 'depth_radius' value { i: 2 } }" 2932 " input: ['A'] }" 2933 "node { name: 'C' op: 'Input'}" 2934 "node { name: 'D' op: 'Input'}" 2935 "node { name: 'E' op: 'LRNGrad'" 2936 " attr { key: 'T' value { type: DT_FLOAT } }" 2937 " attr { key: 'alpha' value { f: 0.001 } }" 2938 " attr { key: 'beta' value { f: 0.75 } }" 2939 " attr { key: 'bias' value { f: 1.0 } }" 2940 " attr { key: 'data_format' value { s: 'NCHW' } }" 2941 " attr { key: 'depth_radius' value { i: 2 } }" 2942 " input: ['C', 'D', 'B'] }" 2943 "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2944 " input: ['C', 'E'] }"); 2945 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2946 "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 2947 "DMT/_2(Const);E(_MklLRNGrad);F(Zeta)|" 2948 "A->B;A:control->DMT/_0:control;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;" 2949 "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;" 2950 "D->E:1;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1"); 2951 } 2952 2953 /* Test LRN->LRNGrad replacement when only one of them is present. */ 2954 TEST_F(MklLayoutPassTest, LRN_Negative1) { 2955 InitGraph( 2956 "node { name: 'A' op: 'Input'}" 2957 "node { name: 'B' op: 'LRN'" 2958 " attr { key: 'T' value { type: DT_FLOAT } }" 2959 " attr { key: 'alpha' value { f: 0.001 } }" 2960 " attr { key: 'beta' value { f: 0.75 } }" 2961 " attr { key: 'bias' value { f: 1.0 } }" 2962 " attr { key: 'data_format' value { s: 'NCHW' } }" 2963 " attr { key: 'depth_radius' value { i: 2 } }" 2964 " input: ['A'] }" 2965 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2966 " input: ['A', 'B'] }"); 2967 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2968 "A(Input);B(_MklLRN);C(Zeta);DMT/_0(Const)|" 2969 "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); 2970 } 2971 2972 /* Test LRN->LRNGrad replacement when only one of them is present. */ 2973 TEST_F(MklLayoutPassTest, LRN_Negative2) { 2974 InitGraph( 2975 "node { name: 'A' op: 'Input'}" 2976 "node { name: 'B' op: 'Input'}" 2977 "node { name: 'C' op: 'Input'}" 2978 "node { name: 'D' op: 'LRNGrad'" 2979 " attr { key: 'T' value { type: DT_FLOAT } }" 2980 " attr { key: 'alpha' value { f: 0.001 } }" 2981 " attr { key: 'beta' value { f: 0.75 } }" 2982 " attr { key: 'bias' value { f: 1.0 } }" 2983 " attr { key: 'data_format' value { s: 'NCHW' } }" 2984 " attr { key: 'depth_radius' value { i: 2 } }" 2985 " input: ['A', 'B', 'C'] }" 2986 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 2987 " input: ['A', 'D'] }"); 2988 EXPECT_EQ(DoMklLayoutOptimizationPass(), 2989 "A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);" 2990 "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|" 2991 "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" 2992 "A:control->DMT/_2:control;A:control->DMT/_3:control;" 2993 "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;" 2994 "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6"); 2995 } 2996 2997 /* Test LRN->LRNGrad negative case, where single LRN feeds 2998 2 LRNGrad nodes at different slots. */ 2999 TEST_F(MklLayoutPassTest, LRN_Negative3) { 3000 InitGraph( 3001 "node { name: 'A' op: 'Input'}" 3002 "node { name: 'B' op: 'LRN'" 3003 " attr { key: 'T' value { type: DT_FLOAT } }" 3004 " attr { key: 'alpha' value { f: 0.001 } }" 3005 " attr { key: 'beta' value { f: 0.75 } }" 3006 " attr { key: 'bias' value { f: 1.0 } }" 3007 " attr { key: 'data_format' value { s: 'NCHW' } }" 3008 " attr { key: 'depth_radius' value { i: 2 } }" 3009 " input: ['A'] }" 3010 "node { name: 'C' op: 'Input'}" 3011 "node { name: 'D' op: 'Input'}" 3012 "node { name: 'E' op: 'LRNGrad'" 3013 " attr { key: 'T' value { type: DT_FLOAT } }" 3014 " attr { key: 'alpha' value { f: 0.001 } }" 3015 " attr { key: 'beta' value { f: 0.75 } }" 3016 " attr { key: 'bias' value { f: 1.0 } }" 3017 " attr { key: 'data_format' value { s: 'NCHW' } }" 3018 " attr { key: 'depth_radius' value { i: 2 } }" 3019 " input: ['C', 'D', 'B'] }" 3020 "node { name: 'F' op: 'LRNGrad'" 3021 " attr { key: 'T' value { type: DT_FLOAT } }" 3022 " attr { key: 'alpha' value { f: 0.001 } }" 3023 " attr { key: 'beta' value { f: 0.75 } }" 3024 " attr { key: 'bias' value { f: 1.0 } }" 3025 " attr { key: 'data_format' value { s: 'NCHW' } }" 3026 " attr { key: 'depth_radius' value { i: 2 } }" 3027 " input: ['C', 'B', 'D'] }" 3028 "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3029 " input: ['E', 'F'] }"); 3030 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3031 "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" 3032 "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);" 3033 "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Zeta)|A->B;" 3034 "A:control->DMT/_0:control;B->E:2;" 3035 "B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;" 3036 "C:control->DMT/_1:control;C:control->DMT/_2:control;" 3037 "C:control->DMT/_3:control;C:control->DMT/_4:control;" 3038 "C:control->DMT/_5:control;C:control->DMT/_6:control;" 3039 "D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;" 3040 "DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1"); 3041 } 3042 3043 /* Test MaxPool->MaxPoolGrad replacement by workspace+rewrite nodes. */ 3044 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) { 3045 InitGraph( 3046 "node { name: 'A' op: 'Input'}" 3047 "node { name: 'B' op: 'MaxPool'" 3048 " attr { key: 'T' value { type: DT_FLOAT } }" 3049 " attr { key: 'data_format' value { s: 'NCHW' } }" 3050 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 3051 " attr { key: 'padding' value { s: 'VALID' } }" 3052 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 3053 " input: ['A'] }" 3054 "node { name: 'C' op: 'Input'}" 3055 "node { name: 'D' op: 'Input'}" 3056 "node { name: 'E' op: 'MaxPoolGrad'" 3057 " attr { key: 'T' value { type: DT_FLOAT } }" 3058 " attr { key: 'data_format' value { s: 'NCHW' } }" 3059 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 3060 " attr { key: 'padding' value { s: 'VALID' } }" 3061 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 3062 " input: ['C', 'B', 'D'] }" 3063 "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3064 " input: ['C', 'E'] }"); 3065 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3066 "A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);" 3067 "DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Zeta)|" 3068 "A->B;A:control->DMT/_0:control;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;" 3069 "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;" 3070 "D->E:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1"); 3071 } 3072 3073 // Test MaxPool>MaxPoolGrad replacement when only one of them is present. 3074 // In this case, we will rewrite MaxPool node but workspace edges will not 3075 // be present. 3076 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) { 3077 InitGraph( 3078 "node { name: 'A' op: 'Input'}" 3079 "node { name: 'B' op: 'MaxPool'" 3080 " attr { key: 'T' value { type: DT_FLOAT } }" 3081 " attr { key: 'data_format' value { s: 'NCHW' } }" 3082 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 3083 " attr { key: 'padding' value { s: 'VALID' } }" 3084 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 3085 " input: ['A'] }" 3086 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3087 " input: ['A', 'B'] }"); 3088 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3089 "A(Input);B(_MklMaxPool);C(Zeta);DMT/_0(Const)|" 3090 "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); 3091 } 3092 3093 // Test MaxPoolGrad replacement when only one of them is present. 3094 // In this case, we will rewrite MaxPoolGrad and for workspace tensor and 3095 // its Mkl part, we will generate dummy tensor. 3096 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) { 3097 InitGraph( 3098 "node { name: 'A' op: 'Input'}" 3099 "node { name: 'B' op: 'Input'}" 3100 "node { name: 'C' op: 'Input'}" 3101 "node { name: 'D' op: 'MaxPoolGrad'" 3102 " attr { key: 'T' value { type: DT_FLOAT } }" 3103 " attr { key: 'data_format' value { s: 'NCHW' } }" 3104 " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" 3105 " attr { key: 'padding' value { s: 'VALID' } }" 3106 " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" 3107 " input: ['A', 'B', 'C'] }" 3108 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3109 " input: ['A', 'D'] }"); 3110 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3111 "A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);" 3112 "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|" 3113 "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" 3114 "A:control->DMT/_2:control;A:control->DMT/_3:control;" 3115 "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;" 3116 "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6"); 3117 } 3118 3119 // Test MaxPool handling for batch-wise pooling (NCHW) 3120 // No rewrite should take place in such case 3121 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative3) { 3122 InitGraph( 3123 "node { name: 'A' op: 'Input'}" 3124 "node { name: 'B' op: 'MaxPool'" 3125 " attr { key: 'T' value { type: DT_FLOAT } }" 3126 " attr { key: 'data_format' value { s: 'NCHW' } }" 3127 " attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }" 3128 " attr { key: 'padding' value { s: 'VALID' } }" 3129 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 3130 " input: ['A'] }" 3131 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3132 " input: ['A', 'B'] }"); 3133 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3134 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 3135 } 3136 3137 // Test MaxPool handling for batch-wise pooling (NCHW) 3138 // No rewrite should take place in such case 3139 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative4) { 3140 InitGraph( 3141 "node { name: 'A' op: 'Input'}" 3142 "node { name: 'B' op: 'MaxPool'" 3143 " attr { key: 'T' value { type: DT_FLOAT } }" 3144 " attr { key: 'data_format' value { s: 'NCHW' } }" 3145 " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" 3146 " attr { key: 'padding' value { s: 'VALID' } }" 3147 " attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }" 3148 " input: ['A'] }" 3149 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3150 " input: ['A', 'B'] }"); 3151 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3152 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 3153 } 3154 3155 // Test MaxPool handling for depth-wise pooling (NHWC) 3156 // No rewrite should take place in such case 3157 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative5) { 3158 InitGraph( 3159 "node { name: 'A' op: 'Input'}" 3160 "node { name: 'B' op: 'MaxPool'" 3161 " attr { key: 'T' value { type: DT_FLOAT } }" 3162 " attr { key: 'data_format' value { s: 'NCHW' } }" 3163 " attr { key: 'ksize' value { list: {i: 1, i:2, i:1, i:1} } }" 3164 " attr { key: 'padding' value { s: 'VALID' } }" 3165 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 3166 " input: ['A'] }" 3167 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3168 " input: ['A', 'B'] }"); 3169 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3170 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 3171 } 3172 3173 // Test MaxPool handling for depth-wise pooling (NCHW) 3174 // No rewrite should take place in such case 3175 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative6) { 3176 InitGraph( 3177 "node { name: 'A' op: 'Input'}" 3178 "node { name: 'B' op: 'MaxPool'" 3179 " attr { key: 'T' value { type: DT_FLOAT } }" 3180 " attr { key: 'data_format' value { s: 'NCHW' } }" 3181 " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" 3182 " attr { key: 'padding' value { s: 'VALID' } }" 3183 " attr { key: 'strides' value { list: {i: 1, i:2, i:1, i:1} } }" 3184 " input: ['A'] }" 3185 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3186 " input: ['A', 'B'] }"); 3187 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3188 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 3189 } 3190 3191 // Test MaxPool handling for batch-wise pooling (NHWC) 3192 // No rewrite should take place in such case 3193 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative7) { 3194 InitGraph( 3195 "node { name: 'A' op: 'Input'}" 3196 "node { name: 'B' op: 'MaxPool'" 3197 " attr { key: 'T' value { type: DT_FLOAT } }" 3198 " attr { key: 'data_format' value { s: 'NHWC' } }" 3199 " attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }" 3200 " attr { key: 'padding' value { s: 'VALID' } }" 3201 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 3202 " input: ['A'] }" 3203 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3204 " input: ['A', 'B'] }"); 3205 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3206 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 3207 } 3208 3209 // Test MaxPool handling for batch-wise pooling (NHWC) 3210 // No rewrite should take place in such case 3211 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative8) { 3212 InitGraph( 3213 "node { name: 'A' op: 'Input'}" 3214 "node { name: 'B' op: 'MaxPool'" 3215 " attr { key: 'T' value { type: DT_FLOAT } }" 3216 " attr { key: 'data_format' value { s: 'NHWC' } }" 3217 " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" 3218 " attr { key: 'padding' value { s: 'VALID' } }" 3219 " attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }" 3220 " input: ['A'] }" 3221 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3222 " input: ['A', 'B'] }"); 3223 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3224 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 3225 } 3226 3227 // Test MaxPool handling for depth-wise pooling (NHWC) 3228 // No rewrite should take place in such case 3229 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative9) { 3230 InitGraph( 3231 "node { name: 'A' op: 'Input'}" 3232 "node { name: 'B' op: 'MaxPool'" 3233 " attr { key: 'T' value { type: DT_FLOAT } }" 3234 " attr { key: 'data_format' value { s: 'NHWC' } }" 3235 " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:2} } }" 3236 " attr { key: 'padding' value { s: 'VALID' } }" 3237 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 3238 " input: ['A'] }" 3239 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3240 " input: ['A', 'B'] }"); 3241 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3242 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 3243 } 3244 3245 // Test MaxPool handling for depth-wise pooling (NHWC) 3246 // No rewrite should take place in such case 3247 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative10) { 3248 InitGraph( 3249 "node { name: 'A' op: 'Input'}" 3250 "node { name: 'B' op: 'MaxPool'" 3251 " attr { key: 'T' value { type: DT_FLOAT } }" 3252 " attr { key: 'data_format' value { s: 'NHWC' } }" 3253 " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" 3254 " attr { key: 'padding' value { s: 'VALID' } }" 3255 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:2} } }" 3256 " input: ['A'] }" 3257 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3258 " input: ['A', 'B'] }"); 3259 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3260 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 3261 } 3262 3263 ///////////////////////////////////////////////////////////////////// 3264 3265 // Single Conv2D Op on GPU device 3266 // No rewrite should happen 3267 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) { 3268 InitGraph( 3269 "node { name: 'A' op: 'Input'}" 3270 "node { name: 'B' op: 'Input'}" 3271 "node { name: 'C' op: 'Conv2D'" 3272 " attr { key: 'T' value { type: DT_FLOAT } }" 3273 " attr { key: 'data_format' value { s: 'NCHW' } }" 3274 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 3275 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 3276 " attr { key: 'padding' value { s: 'SAME' } }" 3277 " input: ['A', 'B']}" 3278 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3279 " input: ['B', 'C'] }", 3280 kGPUDevice); 3281 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3282 "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1"); 3283 } 3284 3285 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) { 3286 InitGraph( 3287 "node { name: 'A' op: 'Input'}" 3288 "node { name: 'B' op: 'Input'}" 3289 "node { name: 'C' op: 'Input'}" 3290 "node { name: 'M' op: '_MklInput'}" 3291 "node { name: 'N' op: '_MklInput'}" 3292 "node { name: 'O' op: '_MklInput'}" 3293 "node { name: 'D' op: '_MklConv2DWithBias'" 3294 " attr { key: 'T' value { type: DT_FLOAT } }" 3295 " attr { key: 'data_format' value { s: 'NCHW' } }" 3296 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 3297 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 3298 " attr { key: 'padding' value { s: 'SAME' } }" 3299 " input: ['A', 'B', 'C', 'M', 'N', 'O']}" 3300 "node { name: 'E' op: 'Zeta'" 3301 " attr {key: 'T' value { type: DT_FLOAT } }" 3302 " input: ['D', 'A']}" 3303 "node { name: 'F' op: 'BiasAddGrad'" 3304 " attr { key: 'T' value { type: DT_FLOAT } }" 3305 " attr { key: 'data_format' value { s: 'NCHW' } }" 3306 " input: ['E'] }", 3307 kGPUDevice); 3308 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3309 "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" 3310 "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);" 3311 "O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;E->F;" 3312 "M->D:3;N->D:4;O->D:5"); 3313 } 3314 3315 TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) { 3316 InitGraph( 3317 "node { name: 'A' op: 'Input'}" 3318 "node { name: 'B' op: 'Int32Input'}" 3319 "node { name: 'C' op: 'Input'}" 3320 "node { name: 'D' op: 'Conv2DBackpropFilter'" 3321 " attr { key: 'T' value { type: DT_FLOAT } }" 3322 " attr { key: 'data_format' value { s: 'NCHW' } }" 3323 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 3324 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 3325 " attr { key: 'padding' value { s: 'SAME' } }" 3326 " input: ['A', 'B', 'C']}" 3327 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3328 " input: ['A', 'D'] }", 3329 kGPUDevice); 3330 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3331 "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|" 3332 "A->D;A->E;B->D:1;C->D:2;D->E:1"); 3333 } 3334 3335 TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) { 3336 InitGraph( 3337 "node { name: 'A' op: 'Input'}" 3338 "node { name: 'B' op: 'Relu'" 3339 " attr { key: 'T' value { type: DT_FLOAT } }" 3340 " input: ['A'] }" 3341 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3342 " input: ['A', 'B'] }", 3343 kGPUDevice); 3344 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3345 "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1"); 3346 } 3347 3348 TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) { 3349 InitGraph( 3350 "node { name: 'A' op: 'Input'}" 3351 "node { name: 'B' op: 'Input'}" 3352 "node { name: 'C' op: 'ReluGrad'" 3353 " attr { key: 'T' value { type: DT_FLOAT } }" 3354 " input: ['A', 'B'] }" 3355 "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3356 " input: ['A', 'C'] }", 3357 kGPUDevice); 3358 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3359 "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1"); 3360 } 3361 3362 TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) { 3363 InitGraph( 3364 "node { name: 'A' op: 'Input'}" 3365 "node { name: 'B' op: 'MaxPool'" 3366 " attr { key: 'T' value { type: DT_FLOAT } }" 3367 " attr { key: 'data_format' value { s: 'NHWC' } }" 3368 " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" 3369 " attr { key: 'padding' value { s: 'VALID' } }" 3370 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 3371 " input: ['A'] }" 3372 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3373 " input: ['A', 'B'] }", 3374 kGPUDevice); 3375 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3376 "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); 3377 } 3378 3379 TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) { 3380 InitGraph( 3381 "node { name: 'A' op: 'Input'}" 3382 "node { name: 'B' op: 'AvgPool'" 3383 " attr { key: 'T' value { type: DT_FLOAT } }" 3384 " attr { key: 'data_format' value { s: 'NHWC' } }" 3385 " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" 3386 " attr { key: 'padding' value { s: 'VALID' } }" 3387 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 3388 " input: ['A'] }" 3389 "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3390 " input: ['A', 'B'] }", 3391 kGPUDevice); 3392 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3393 "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1"); 3394 } 3395 3396 // Concat Op test: Concat with no Mkl layer feeding it 3397 TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) { 3398 InitGraph( 3399 "node { name: 'A' op: 'Const' " 3400 " attr { key: 'dtype' value { type: DT_INT32 } }" 3401 " attr { key: 'value' value { " 3402 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " 3403 " int_val: 0 } } } }" 3404 "node { name: 'B' op: 'InputList'" 3405 " attr { key: 'N' value { i: 2 } }}" 3406 "node { name: 'C' op: 'Input'}" 3407 "node { name: 'D' op: 'Concat'" 3408 " attr { key: 'T' value { type: DT_FLOAT } }" 3409 " attr { key: 'N' value { i: 2 } }" 3410 " input: ['A', 'B:0', 'B:1']}" 3411 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3412 " input: ['C', 'D'] }", 3413 kGPUDevice); 3414 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3415 "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;" 3416 "B->D:1;B:1->D:2;C->E;D->E:1"); 3417 } 3418 3419 TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) { 3420 InitGraph( 3421 "node { name: 'A' op: 'Const' " 3422 " attr { key: 'dtype' value { type: DT_INT32 } }" 3423 " attr { key: 'value' value { " 3424 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " 3425 " int_val: 0 } } } }" 3426 "node { name: 'B' op: 'InputList'" 3427 " attr { key: 'N' value { i: 2 } }}" 3428 "node { name: 'C' op: 'Input'}" 3429 "node { name: 'D' op: 'ConcatV2'" 3430 " attr { key: 'T' value { type: DT_FLOAT } }" 3431 " attr { key: 'Tidx' value { type: DT_INT32 } }" 3432 " attr { key: 'N' value { i: 2 } }" 3433 " input: ['B:0', 'B:1', 'A']}" 3434 "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3435 " input: ['C', 'D'] }", 3436 kGPUDevice); 3437 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3438 "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|" 3439 "A->D:2;B->D;B:1->D:1;C->E;D->E:1"); 3440 } 3441 3442 TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) { 3443 InitGraph( 3444 "node { name: 'A' op: 'Input'}" 3445 "node { name: 'B' op: 'Input'}" 3446 "node { name: 'C' op: 'Input'}" 3447 "node { name: 'D' op: 'Input'}" 3448 "node { name: 'E' op: 'Input'}" 3449 "node { name: 'F' op: 'FusedBatchNorm'" 3450 " attr { key: 'T' value { type: DT_FLOAT } }" 3451 " attr { key: 'data_format' value { s: 'NCHW' } }" 3452 " attr { key: 'epsilon' value { f: 0.0001 } }" 3453 " attr { key: 'is_training' value { b: true } }" 3454 " input: ['A', 'B', 'C', 'D', 'E'] }" 3455 "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" 3456 " input: ['A', 'F'] }", 3457 kGPUDevice); 3458 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3459 "A(Input);B(Input);C(Input);D(Input);E(Input);" 3460 "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;" 3461 "E->F:4;F->G:1"); 3462 } 3463 3464 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) { 3465 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); 3466 InitGraph( 3467 "node { name: 'A' op: 'Input'}" 3468 "node { name: 'B' op: 'Input'}" 3469 "node { name: 'M' op: '_MklInput'}" 3470 "node { name: 'N' op: '_MklInput'}" 3471 "node { name: 'C' op: '_MklConv2D'" 3472 " attr { key: 'T' value { type: DT_FLOAT } }" 3473 " attr { key: 'data_format' value { s: 'NCHW' } }" 3474 " attr { key: 'use_cudnn_on_gpu' value { b: false } }" 3475 " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" 3476 " attr { key: 'padding' value { s: 'SAME' } }" 3477 " input: ['A', 'B', 'M', 'N']}" 3478 "node { name: 'D' op: 'Input'}" 3479 "node { name: 'E' op: 'BiasAdd'" 3480 " attr { key: 'T' value { type: DT_FLOAT } }" 3481 " attr { key: 'data_format' value { s: 'NCHW' } }" 3482 " input: ['C', 'D'] }" 3483 "node { name: 'Y' op: 'Input'}" 3484 "node { name: 'Z' op: 'Zeta'" 3485 " attr {key: 'T' value { type: DT_FLOAT } }" 3486 " input: ['E', 'Y']}", 3487 kGPUDevice); 3488 EXPECT_EQ(DoMklLayoutOptimizationPass(), 3489 "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);" 3490 "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;" 3491 "B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1"); 3492 } 3493 3494 ///////////////////////////////////////////////////////////////////// 3495 3496 static void BM_MklLayoutRewritePass(int iters, int op_nodes) { 3497 testing::StopTiming(); 3498 string s; 3499 for (int in = 0; in < 10; in++) { 3500 s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in); 3501 } 3502 random::PhiloxRandom philox(301, 17); 3503 random::SimplePhilox rnd(&philox); 3504 for (int op = 0; op < op_nodes; op++) { 3505 s += strings::Printf( 3506 "node { name: 'op%04d' op: 'Zeta' attr { key: 'T' value { " 3507 "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }", 3508 op, rnd.Uniform(10), rnd.Uniform(10)); 3509 } 3510 3511 bool first = true; 3512 while (iters > 0) { 3513 Graph* graph = new Graph(OpRegistry::Global()); 3514 InitGraph(s, graph); 3515 int N = graph->num_node_ids(); 3516 if (first) { 3517 testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N)); 3518 first = false; 3519 } 3520 { 3521 testing::StartTiming(); 3522 std::unique_ptr<Graph> ug(graph); 3523 RunMklLayoutRewritePass(&ug); 3524 testing::StopTiming(); 3525 } 3526 iters -= N; // Our benchmark units are individual graph nodes, 3527 // not whole graphs 3528 // delete graph; 3529 } 3530 } 3531 BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000); 3532 3533 } // namespace 3534 3535 #endif // INTEL_MKL_ML 3536 3537 } // namespace tensorflow 3538 3539 #endif /* INTEL_MKL */ 3540