Home | History | Annotate | Download | only in graph
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifdef INTEL_MKL
     17 
     18 #include "tensorflow/core/graph/mkl_tfconversion_pass.h"
     19 #include "tensorflow/core/graph/mkl_graph_util.h"
     20 
     21 #include <algorithm>
     22 #include <string>
     23 #include <vector>
     24 
     25 #include "tensorflow/core/framework/op.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/graph/graph.h"
     28 #include "tensorflow/core/graph/graph_constructor.h"
     29 #include "tensorflow/core/graph/testlib.h"
     30 #include "tensorflow/core/kernels/ops_util.h"
     31 #include "tensorflow/core/lib/random/simple_philox.h"
     32 #include "tensorflow/core/lib/strings/str_util.h"
     33 #include "tensorflow/core/lib/strings/stringprintf.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/protobuf.h"
     36 #include "tensorflow/core/platform/test.h"
     37 #include "tensorflow/core/platform/test_benchmark.h"
     38 
     39 namespace tensorflow {
     40 namespace {
     41 
     42 class MklToTfConversionPass : public ::testing::Test {
     43  public:
     44   MklToTfConversionPass() : graph_(OpRegistry::Global()) {}
     45 
     46   static void InitGraph(const string& s, Graph* graph) {
     47     GraphDef graph_def;
     48 
     49     auto parser = protobuf::TextFormat::Parser();
     50     CHECK(parser.MergeFromString(s, &graph_def)) << s;
     51     GraphConstructorOptions opts;
     52     TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
     53   }
     54 
     55   void InitGraph(const string& s) {
     56     InitGraph(s, &graph_);
     57     original_ = CanonicalGraphString(&graph_);
     58   }
     59 
     60   static bool IncludeNode(const Node* n) { return n->IsOp(); }
     61 
     62   static string EdgeId(const Node* n, int index) {
     63     if (index == 0) {
     64       return n->name();
     65     } else if (index == Graph::kControlSlot) {
     66       return strings::StrCat(n->name(), ":control");
     67     } else {
     68       return strings::StrCat(n->name(), ":", index);
     69     }
     70   }
     71 
     72   string CanonicalGraphString(Graph* g) {
     73     std::vector<string> nodes;
     74     std::vector<string> edges;
     75     for (const Node* n : g->nodes()) {
     76       if (IncludeNode(n)) {
     77         nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")"));
     78       }
     79     }
     80     for (const Edge* e : g->edges()) {
     81       if (IncludeNode(e->src()) && IncludeNode(e->dst())) {
     82         edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->",
     83                                         EdgeId(e->dst(), e->dst_input())));
     84       }
     85     }
     86     // Canonicalize
     87     std::sort(nodes.begin(), nodes.end());
     88     std::sort(edges.begin(), edges.end());
     89     return strings::StrCat(str_util::Join(nodes, ";"), "|",
     90                            str_util::Join(edges, ";"));
     91   }
     92 
     93   string DoRunMklToTfConversionPass() {
     94     string before = CanonicalGraphString(&graph_);
     95     LOG(ERROR) << "Before MklToTf conversion pass: " << before;
     96 
     97     std::unique_ptr<Graph>* ug = new std::unique_ptr<Graph>(&graph_);
     98     InsertMklToTfConversionNodes(ug);
     99 
    100     string result = CanonicalGraphString(&graph_);
    101     LOG(ERROR) << "After MklToTf conversion pass:  " << result;
    102     return result;
    103   }
    104 
    105   const string& OriginalGraph() const { return original_; }
    106 
    107   Graph graph_;
    108   string original_;
    109 };
    110 
    111 REGISTER_OP("Input").Output("o: float").SetIsStateful();
    112 REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
    113 REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
    114 
    115 TEST_F(MklToTfConversionPass, Basic) {
    116   InitGraph(
    117       "node { name: 'A' op: 'Input'}"
    118       "node { name: 'B' op: 'Input'}"
    119       "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
    120       " input: ['A', 'B'] }"
    121       "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
    122       " input: ['A', 'B'] }");
    123   EXPECT_EQ(DoRunMklToTfConversionPass(),
    124             "A(Input);B(Input);C(Mul);D(Mul)|"
    125             "A->C;A->D;B->C:1;B->D:1");
    126 }
    127 
    128 // MklConv2D followed by Non-Mkl layer
    129 // C=MklConv2D(A,M,B,N); E=Sub(C,D) (for interleaved ordering)
    130 // C=MklConv2D(A,B,M,N); E=Sub(C,D) (for contiguous ordering)
    131 TEST_F(MklToTfConversionPass, Positive) {
    132   if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
    133     InitGraph(
    134         "node { name: 'A' op: 'Input'}"
    135         "node { name: 'M' op: '_MklInput'}"
    136         "node { name: 'B' op: 'Input'}"
    137         "node { name: 'N' op: '_MklInput'}"
    138         "node { name: 'C' op: '_MklConv2D'"
    139         " attr { key: 'T'                value { type: DT_FLOAT } }"
    140         " attr { key: 'data_format'      value { s: 'NCHW' } }"
    141         " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    142         " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } "
    143         "}"
    144         " attr { key: 'padding'          value { s: 'SAME' } }"
    145         " input: ['A', 'M', 'B', 'N']}"
    146         "node { name: 'D' op: 'Input'}"
    147         "node { name: 'E' op: 'Sub'"
    148         " attr {key: 'T'                 value { type: DT_FLOAT } }"
    149         " input: ['C', 'D']}");
    150     EXPECT_EQ(DoRunMklToTfConversionPass(),
    151               "A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
    152               "Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;"
    153               "C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
    154   } else {
    155     CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
    156     InitGraph(
    157         "node { name: 'A' op: 'Input'}"
    158         "node { name: 'B' op: 'Input'}"
    159         "node { name: 'M' op: '_MklInput'}"
    160         "node { name: 'N' op: '_MklInput'}"
    161         "node { name: 'C' op: '_MklConv2D'"
    162         " attr { key: 'T'                value { type: DT_FLOAT } }"
    163         " attr { key: 'data_format'      value { s: 'NCHW' } }"
    164         " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    165         " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } "
    166         "}"
    167         " attr { key: 'padding'          value { s: 'SAME' } }"
    168         " input: ['A', 'B', 'M', 'N']}"
    169         "node { name: 'D' op: 'Input'}"
    170         "node { name: 'E' op: 'Sub'"
    171         " attr {key: 'T'                 value { type: DT_FLOAT } }"
    172         " input: ['C', 'D']}");
    173     EXPECT_EQ(DoRunMklToTfConversionPass(),
    174               "A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
    175               "Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;"
    176               "C:2->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3");
    177   }
    178 }
    179 
    180 // MklConv2D followed by MklToTf op followed by Non-Mkl layer.
    181 // C=MklConv2D(A,M,B,N); D=MklToTf(C:0, C:1) F=Sub(D,E) (for interleaved)
    182 // C=MklConv2D(A,B,M,N); D=MklToTf(C:0, C:2) F=Sub(D,E) (for contiguous)
    183 // MklToTf node should not be inserted again.
    184 TEST_F(MklToTfConversionPass, Negative_DoubleInsert) {
    185   if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
    186     InitGraph(
    187         "node { name: 'A' op: 'Input'}"
    188         "node { name: 'M' op: '_MklInput'}"
    189         "node { name: 'B' op: 'Input'}"
    190         "node { name: 'N' op: '_MklInput'}"
    191         "node { name: 'C' op: '_MklConv2D'"
    192         " attr { key: 'T'                value { type: DT_FLOAT } }"
    193         " attr { key: 'data_format'      value { s: 'NCHW' } }"
    194         " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    195         " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } "
    196         "}"
    197         " attr { key: 'padding'          value { s: 'SAME' } }"
    198         " input: ['A', 'M', 'B', 'N']}"
    199         "node { name: 'D' op: '_MklToTf'"
    200         " attr { key: 'T'                value { type: DT_FLOAT } }"
    201         " attr { key: 'data_format'      value { s: 'NCHW' } }"
    202         " input: ['C:0', 'C:1']}"
    203         "node { name: 'E' op: 'Input'}"
    204         "node { name: 'F' op: 'Sub'"
    205         " attr {key: 'T'                 value { type: DT_FLOAT } }"
    206         " input: ['D', 'E']}");
    207     EXPECT_EQ(DoRunMklToTfConversionPass(),
    208               "A(Input);B(Input);C(_MklConv2D);D(_MklToTf);E(Input);"
    209               "F(Sub);M(_MklInput);N(_MklInput)|"
    210               "A->C;B->C:2;C->D;C:1->D:1;D->F;E->F:1;M->C:1;N->C:3");
    211   } else {
    212     CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
    213     InitGraph(
    214         "node { name: 'A' op: 'Input'}"
    215         "node { name: 'B' op: 'Input'}"
    216         "node { name: 'M' op: '_MklInput'}"
    217         "node { name: 'N' op: '_MklInput'}"
    218         "node { name: 'C' op: '_MklConv2D'"
    219         " attr { key: 'T'                value { type: DT_FLOAT } }"
    220         " attr { key: 'data_format'      value { s: 'NCHW' } }"
    221         " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    222         " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } "
    223         "}"
    224         " attr { key: 'padding'          value { s: 'SAME' } }"
    225         " input: ['A', 'B', 'M', 'N']}"
    226         "node { name: 'D' op: '_MklToTf'"
    227         " attr { key: 'T'                value { type: DT_FLOAT } }"
    228         " attr { key: 'data_format'      value { s: 'NCHW' } }"
    229         " input: ['C:0', 'C:2']}"
    230         "node { name: 'E' op: 'Input'}"
    231         "node { name: 'F' op: 'Sub'"
    232         " attr {key: 'T'                 value { type: DT_FLOAT } }"
    233         " input: ['D', 'E']}");
    234     EXPECT_EQ(DoRunMklToTfConversionPass(),
    235               "A(Input);B(Input);C(_MklConv2D);D(_MklToTf);E(Input);"
    236               "F(Sub);M(_MklInput);N(_MklInput)|"
    237               "A->C;B->C:1;C->D;C:2->D:1;D->F;E->F:1;M->C:2;N->C:3");
    238   }
    239 }
    240 
    241 // C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
    242 // There is no Mkl layer so no conversion op should be inserted.
    243 TEST_F(MklToTfConversionPass, Negative_NoMklLayer) {
    244   InitGraph(
    245       "node { name: 'A' op: 'Input'}"
    246       "node { name: 'B' op: 'Input'}"
    247       "node { name: 'C' op: 'Conv2D'"
    248       " attr { key: 'T'                value { type: DT_FLOAT } }"
    249       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    250       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
    251       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
    252       " attr { key: 'padding'          value { s: 'SAME' } }"
    253       " input: ['A', 'B']}"
    254       "node { name: 'D' op: 'Input'}"
    255       "node { name: 'E' op: 'BiasAdd'"
    256       " attr { key: 'T'                value { type: DT_FLOAT } }"
    257       " attr { key: 'data_format'      value { s: 'NCHW' } }"
    258       " input: ['C', 'D'] }"
    259       "node { name: 'Y' op: 'Input'}"
    260       "node { name: 'Z' op: 'Sub'"
    261       " attr {key: 'T'                 value { type: DT_FLOAT } }"
    262       " input: ['E', 'Y']}");
    263   EXPECT_EQ(DoRunMklToTfConversionPass(),
    264             "A(Input);B(Input);C(Conv2D);D(Input);E(BiasAdd);Y(Input);Z(Sub)|"
    265             "A->C;B->C:1;C->E;D->E:1;E->Z;Y->Z:1");
    266 }
    267 
    268 static void BM_RunMklToTfConversionPass(int iters, int op_nodes) {
    269   testing::StopTiming();
    270   string s;
    271   for (int in = 0; in < 10; in++) {
    272     s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in);
    273   }
    274   random::PhiloxRandom philox(301, 17);
    275   random::SimplePhilox rnd(&philox);
    276   for (int op = 0; op < op_nodes; op++) {
    277     s += strings::Printf(
    278         "node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { "
    279         "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
    280         op, rnd.Uniform(10), rnd.Uniform(10));
    281   }
    282 
    283   bool first = true;
    284   while (iters > 0) {
    285     Graph* graph = new Graph(OpRegistry::Global());
    286     MklToTfConversionPass::InitGraph(s, graph);
    287     int N = graph->num_node_ids();
    288     if (first) {
    289       testing::SetLabel(strings::StrCat("Per graph node.  Nodes: ", N));
    290       first = false;
    291     }
    292     {
    293       testing::StartTiming();
    294       std::unique_ptr<Graph> ug(graph);
    295       InsertMklToTfConversionNodes(&ug);
    296       testing::StopTiming();
    297     }
    298     iters -= N;  // Our benchmark units are individual graph nodes,
    299                  // not whole graphs
    300     // delete graph;
    301   }
    302 }
    303 BENCHMARK(BM_RunMklToTfConversionPass)->Arg(1000)->Arg(10000);
    304 
    305 }  // namespace
    306 }  // namespace tensorflow
    307 
    308 #endif /* INTEL_MKL */
    309