Home | History | Annotate | Download | only in graph_transforms
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/cc/ops/const_op.h"
     17 #include "tensorflow/cc/ops/image_ops.h"
     18 #include "tensorflow/cc/ops/nn_ops.h"
     19 #include "tensorflow/cc/ops/sendrecv_ops.h"
     20 #include "tensorflow/cc/ops/standard_ops.h"
     21 #include "tensorflow/core/framework/tensor_testutil.h"
     22 #include "tensorflow/core/lib/core/status_test_util.h"
     23 #include "tensorflow/core/platform/test.h"
     24 #include "tensorflow/core/platform/test_benchmark.h"
     25 #include "tensorflow/core/public/session.h"
     26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
     27 
     28 namespace tensorflow {
     29 namespace graph_transforms {
     30 
     31 class SortByExecutionOrderTest : public ::testing::Test {
     32  protected:
     33   void GetOrder(const GraphDef& graph_def, std::map<string, int>* order) {
     34     for (int i = 0; i < graph_def.node_size(); ++i) {
     35       const NodeDef& node = graph_def.node(i);
     36       (*order)[node.name()] = i;
     37     }
     38   }
     39 
     40   void TestSimpleAdd() {
     41     GraphDef graph_def;
     42     NodeDef* add_node = graph_def.add_node();
     43     add_node->set_name("add_node");
     44     add_node->set_op("Add");
     45     add_node->add_input("a_node");
     46     add_node->add_input("b_node");
     47 
     48     NodeDef* b_node = graph_def.add_node();
     49     b_node->set_name("b_node");
     50     b_node->set_op("Const");
     51 
     52     NodeDef* a_node = graph_def.add_node();
     53     a_node->set_name("a_node");
     54     a_node->set_op("Const");
     55 
     56     GraphDef result;
     57     TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
     58 
     59     std::map<string, int> order;
     60     GetOrder(result, &order);
     61     EXPECT_EQ(2, order["add_node"]);
     62     EXPECT_GT(2, order["a_node"]);
     63     EXPECT_GT(2, order["b_node"]);
     64   }
     65 
     66   void TestSimpleLinear() {
     67     GraphDef graph_def;
     68 
     69     NodeDef* negative_node = graph_def.add_node();
     70     negative_node->set_name("negative_node");
     71     negative_node->set_op("Negative");
     72     negative_node->add_input("sqrt_node");
     73 
     74     NodeDef* relu_node = graph_def.add_node();
     75     relu_node->set_name("relu_node");
     76     relu_node->set_op("Relu");
     77     relu_node->add_input("const_node");
     78 
     79     NodeDef* sqrt_node = graph_def.add_node();
     80     sqrt_node->set_name("sqrt_node");
     81     sqrt_node->set_op("Sqrt");
     82     sqrt_node->add_input("relu_node");
     83 
     84     NodeDef* const_node = graph_def.add_node();
     85     const_node->set_name("const_node");
     86     const_node->set_op("Const");
     87 
     88     GraphDef result;
     89     TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
     90 
     91     std::map<string, int> order;
     92     GetOrder(result, &order);
     93     EXPECT_EQ(3, order["negative_node"]);
     94     EXPECT_EQ(2, order["sqrt_node"]);
     95     EXPECT_EQ(1, order["relu_node"]);
     96     EXPECT_EQ(0, order["const_node"]);
     97   }
     98 
     99   void TestSimpleTree() {
    100     GraphDef graph_def;
    101 
    102     NodeDef* add_node1 = graph_def.add_node();
    103     add_node1->set_name("add_node1");
    104     add_node1->set_op("Add");
    105     add_node1->add_input("add_node2");
    106     add_node1->add_input("add_node3");
    107 
    108     NodeDef* add_node2 = graph_def.add_node();
    109     add_node2->set_name("add_node2");
    110     add_node2->set_op("Add");
    111     add_node2->add_input("const_node1");
    112     add_node2->add_input("const_node2");
    113 
    114     NodeDef* add_node3 = graph_def.add_node();
    115     add_node3->set_name("add_node3");
    116     add_node3->set_op("Add");
    117     add_node3->add_input("const_node3");
    118     add_node3->add_input("const_node4");
    119 
    120     NodeDef* const_node1 = graph_def.add_node();
    121     const_node1->set_name("const_node1");
    122     const_node1->set_op("Const");
    123 
    124     NodeDef* const_node2 = graph_def.add_node();
    125     const_node2->set_name("const_node2");
    126     const_node2->set_op("Const");
    127 
    128     NodeDef* const_node3 = graph_def.add_node();
    129     const_node3->set_name("const_node3");
    130     const_node3->set_op("Const");
    131 
    132     NodeDef* const_node4 = graph_def.add_node();
    133     const_node4->set_name("const_node4");
    134     const_node4->set_op("Const");
    135 
    136     GraphDef result;
    137     TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
    138 
    139     std::map<string, int> order;
    140     GetOrder(result, &order);
    141     EXPECT_EQ(6, order["add_node1"]);
    142     EXPECT_GT(6, order["add_node2"]);
    143     EXPECT_GT(6, order["add_node3"]);
    144     EXPECT_GT(5, order["const_node1"]);
    145     EXPECT_GT(5, order["const_node2"]);
    146     EXPECT_GT(5, order["const_node3"]);
    147     EXPECT_GT(5, order["const_node4"]);
    148   }
    149 
    150   void TestCommonAncestor() {
    151     GraphDef graph_def;
    152 
    153     NodeDef* add_node1 = graph_def.add_node();
    154     add_node1->set_name("add_node1");
    155     add_node1->set_op("Add");
    156     add_node1->add_input("add_node2");
    157     add_node1->add_input("add_node3");
    158 
    159     NodeDef* add_node2 = graph_def.add_node();
    160     add_node2->set_name("add_node2");
    161     add_node2->set_op("Add");
    162     add_node2->add_input("const_node1");
    163     add_node2->add_input("const_node2");
    164 
    165     NodeDef* add_node3 = graph_def.add_node();
    166     add_node3->set_name("add_node3");
    167     add_node3->set_op("Add");
    168     add_node3->add_input("const_node1");
    169     add_node3->add_input("const_node3");
    170 
    171     NodeDef* const_node1 = graph_def.add_node();
    172     const_node1->set_name("const_node1");
    173     const_node1->set_op("Const");
    174 
    175     NodeDef* const_node2 = graph_def.add_node();
    176     const_node2->set_name("const_node2");
    177     const_node2->set_op("Const");
    178 
    179     NodeDef* const_node3 = graph_def.add_node();
    180     const_node3->set_name("const_node3");
    181     const_node3->set_op("Const");
    182 
    183     GraphDef result;
    184     TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
    185 
    186     std::map<string, int> order;
    187     GetOrder(result, &order);
    188     EXPECT_EQ(5, order["add_node1"]);
    189     EXPECT_GT(5, order["add_node2"]);
    190     EXPECT_GT(5, order["add_node3"]);
    191     EXPECT_GT(4, order["const_node2"]);
    192     EXPECT_GT(4, order["const_node3"]);
    193     EXPECT_GT(3, order["const_node1"]);
    194   }
    195 };
    196 
    197 TEST_F(SortByExecutionOrderTest, TestSimpleAdd) { TestSimpleAdd(); }
    198 
    199 TEST_F(SortByExecutionOrderTest, TestSimpleLinear) { TestSimpleLinear(); }
    200 
    201 TEST_F(SortByExecutionOrderTest, TestSimpleTree) { TestSimpleTree(); }
    202 
    203 TEST_F(SortByExecutionOrderTest, TestCommonAncestor) { TestCommonAncestor(); }
    204 
    205 }  // namespace graph_transforms
    206 }  // namespace tensorflow
    207