1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/cc/ops/const_op.h" 17 #include "tensorflow/cc/ops/image_ops.h" 18 #include "tensorflow/cc/ops/nn_ops.h" 19 #include "tensorflow/cc/ops/sendrecv_ops.h" 20 #include "tensorflow/cc/ops/standard_ops.h" 21 #include "tensorflow/core/framework/tensor_testutil.h" 22 #include "tensorflow/core/lib/core/status_test_util.h" 23 #include "tensorflow/core/platform/test.h" 24 #include "tensorflow/core/platform/test_benchmark.h" 25 #include "tensorflow/core/public/session.h" 26 #include "tensorflow/tools/graph_transforms/transform_utils.h" 27 28 namespace tensorflow { 29 namespace graph_transforms { 30 31 class SortByExecutionOrderTest : public ::testing::Test { 32 protected: 33 void GetOrder(const GraphDef& graph_def, std::map<string, int>* order) { 34 for (int i = 0; i < graph_def.node_size(); ++i) { 35 const NodeDef& node = graph_def.node(i); 36 (*order)[node.name()] = i; 37 } 38 } 39 40 void TestSimpleAdd() { 41 GraphDef graph_def; 42 NodeDef* add_node = graph_def.add_node(); 43 add_node->set_name("add_node"); 44 add_node->set_op("Add"); 45 add_node->add_input("a_node"); 46 add_node->add_input("b_node"); 47 48 NodeDef* b_node = graph_def.add_node(); 49 b_node->set_name("b_node"); 50 b_node->set_op("Const"); 51 52 NodeDef* a_node = graph_def.add_node(); 53 a_node->set_name("a_node"); 54 a_node->set_op("Const"); 55 56 GraphDef result; 57 TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result)); 58 59 std::map<string, int> order; 60 GetOrder(result, &order); 61 EXPECT_EQ(2, order["add_node"]); 62 EXPECT_GT(2, order["a_node"]); 63 EXPECT_GT(2, order["b_node"]); 64 } 65 66 void TestSimpleLinear() { 67 GraphDef graph_def; 68 69 NodeDef* negative_node = graph_def.add_node(); 70 negative_node->set_name("negative_node"); 71 negative_node->set_op("Negative"); 72 negative_node->add_input("sqrt_node"); 73 74 NodeDef* relu_node = graph_def.add_node(); 75 relu_node->set_name("relu_node"); 76 relu_node->set_op("Relu"); 77 relu_node->add_input("const_node"); 78 79 NodeDef* sqrt_node = graph_def.add_node(); 80 sqrt_node->set_name("sqrt_node"); 81 sqrt_node->set_op("Sqrt"); 82 sqrt_node->add_input("relu_node"); 83 84 NodeDef* const_node = graph_def.add_node(); 85 const_node->set_name("const_node"); 86 const_node->set_op("Const"); 87 88 GraphDef result; 89 TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result)); 90 91 std::map<string, int> order; 92 GetOrder(result, &order); 93 EXPECT_EQ(3, order["negative_node"]); 94 EXPECT_EQ(2, order["sqrt_node"]); 95 EXPECT_EQ(1, order["relu_node"]); 96 EXPECT_EQ(0, order["const_node"]); 97 } 98 99 void TestSimpleTree() { 100 GraphDef graph_def; 101 102 NodeDef* add_node1 = graph_def.add_node(); 103 add_node1->set_name("add_node1"); 104 add_node1->set_op("Add"); 105 add_node1->add_input("add_node2"); 106 add_node1->add_input("add_node3"); 107 108 NodeDef* add_node2 = graph_def.add_node(); 109 add_node2->set_name("add_node2"); 110 add_node2->set_op("Add"); 111 add_node2->add_input("const_node1"); 112 add_node2->add_input("const_node2"); 113 114 NodeDef* add_node3 = graph_def.add_node(); 115 add_node3->set_name("add_node3"); 116 add_node3->set_op("Add"); 117 add_node3->add_input("const_node3"); 118 add_node3->add_input("const_node4"); 119 120 NodeDef* const_node1 = graph_def.add_node(); 121 const_node1->set_name("const_node1"); 122 const_node1->set_op("Const"); 123 124 NodeDef* const_node2 = graph_def.add_node(); 125 const_node2->set_name("const_node2"); 126 const_node2->set_op("Const"); 127 128 NodeDef* const_node3 = graph_def.add_node(); 129 const_node3->set_name("const_node3"); 130 const_node3->set_op("Const"); 131 132 NodeDef* const_node4 = graph_def.add_node(); 133 const_node4->set_name("const_node4"); 134 const_node4->set_op("Const"); 135 136 GraphDef result; 137 TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result)); 138 139 std::map<string, int> order; 140 GetOrder(result, &order); 141 EXPECT_EQ(6, order["add_node1"]); 142 EXPECT_GT(6, order["add_node2"]); 143 EXPECT_GT(6, order["add_node3"]); 144 EXPECT_GT(5, order["const_node1"]); 145 EXPECT_GT(5, order["const_node2"]); 146 EXPECT_GT(5, order["const_node3"]); 147 EXPECT_GT(5, order["const_node4"]); 148 } 149 150 void TestCommonAncestor() { 151 GraphDef graph_def; 152 153 NodeDef* add_node1 = graph_def.add_node(); 154 add_node1->set_name("add_node1"); 155 add_node1->set_op("Add"); 156 add_node1->add_input("add_node2"); 157 add_node1->add_input("add_node3"); 158 159 NodeDef* add_node2 = graph_def.add_node(); 160 add_node2->set_name("add_node2"); 161 add_node2->set_op("Add"); 162 add_node2->add_input("const_node1"); 163 add_node2->add_input("const_node2"); 164 165 NodeDef* add_node3 = graph_def.add_node(); 166 add_node3->set_name("add_node3"); 167 add_node3->set_op("Add"); 168 add_node3->add_input("const_node1"); 169 add_node3->add_input("const_node3"); 170 171 NodeDef* const_node1 = graph_def.add_node(); 172 const_node1->set_name("const_node1"); 173 const_node1->set_op("Const"); 174 175 NodeDef* const_node2 = graph_def.add_node(); 176 const_node2->set_name("const_node2"); 177 const_node2->set_op("Const"); 178 179 NodeDef* const_node3 = graph_def.add_node(); 180 const_node3->set_name("const_node3"); 181 const_node3->set_op("Const"); 182 183 GraphDef result; 184 TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result)); 185 186 std::map<string, int> order; 187 GetOrder(result, &order); 188 EXPECT_EQ(5, order["add_node1"]); 189 EXPECT_GT(5, order["add_node2"]); 190 EXPECT_GT(5, order["add_node3"]); 191 EXPECT_GT(4, order["const_node2"]); 192 EXPECT_GT(4, order["const_node3"]); 193 EXPECT_GT(3, order["const_node1"]); 194 } 195 }; 196 197 TEST_F(SortByExecutionOrderTest, TestSimpleAdd) { TestSimpleAdd(); } 198 199 TEST_F(SortByExecutionOrderTest, TestSimpleLinear) { TestSimpleLinear(); } 200 201 TEST_F(SortByExecutionOrderTest, TestSimpleTree) { TestSimpleTree(); } 202 203 TEST_F(SortByExecutionOrderTest, TestCommonAncestor) { TestCommonAncestor(); } 204 205 } // namespace graph_transforms 206 } // namespace tensorflow 207