1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/cc/ops/const_op.h" 17 #include "tensorflow/cc/ops/image_ops.h" 18 #include "tensorflow/cc/ops/nn_ops.h" 19 #include "tensorflow/cc/ops/sendrecv_ops.h" 20 #include "tensorflow/cc/ops/standard_ops.h" 21 #include "tensorflow/core/framework/tensor_testutil.h" 22 #include "tensorflow/core/lib/core/status_test_util.h" 23 #include "tensorflow/core/platform/test.h" 24 #include "tensorflow/core/platform/test_benchmark.h" 25 #include "tensorflow/core/public/session.h" 26 #include "tensorflow/tools/graph_transforms/transform_utils.h" 27 28 namespace tensorflow { 29 namespace graph_transforms { 30 31 // Declare here, so we don't need a public header. 32 Status RemoveNodes(const GraphDef& input_graph_def, 33 const TransformFuncContext& context, 34 GraphDef* output_graph_def); 35 36 class RemoveNodesTest : public ::testing::Test { 37 protected: 38 void TestRemoveNodes() { 39 GraphDef graph_def; 40 41 NodeDef* add_node1 = graph_def.add_node(); 42 add_node1->set_name("add_node1"); 43 add_node1->set_op("Add"); 44 add_node1->add_input("add_node2"); 45 add_node1->add_input("add_node3"); 46 47 NodeDef* add_node2 = graph_def.add_node(); 48 add_node2->set_name("add_node2"); 49 add_node2->set_op("Add"); 50 add_node2->add_input("identity_node1"); 51 add_node2->add_input("identity_node2"); 52 53 NodeDef* add_node3 = graph_def.add_node(); 54 add_node3->set_name("add_node3"); 55 add_node3->set_op("Add"); 56 add_node3->add_input("identity_node1"); 57 add_node3->add_input("const_node3"); 58 59 NodeDef* identity_node1 = graph_def.add_node(); 60 identity_node1->set_name("identity_node1"); 61 identity_node1->set_op("Identity"); 62 identity_node1->add_input("const_node1"); 63 64 NodeDef* identity_node2 = graph_def.add_node(); 65 identity_node2->set_name("identity_node2"); 66 identity_node2->set_op("Identity"); 67 identity_node2->add_input("const_node2"); 68 69 NodeDef* identity_node3 = graph_def.add_node(); 70 identity_node3->set_name("identity_node3"); 71 identity_node3->set_op("Identity"); 72 identity_node3->add_input("const_node3"); 73 74 NodeDef* const_node1 = graph_def.add_node(); 75 const_node1->set_name("const_node1"); 76 const_node1->set_op("Const"); 77 78 NodeDef* const_node2 = graph_def.add_node(); 79 const_node2->set_name("const_node2"); 80 const_node2->set_op("Const"); 81 82 NodeDef* const_node3 = graph_def.add_node(); 83 const_node3->set_name("const_node3"); 84 const_node3->set_op("Const"); 85 86 NodeDef* add_node4 = graph_def.add_node(); 87 add_node4->set_name("add_node4"); 88 add_node4->set_op("Add"); 89 add_node4->add_input("add_node2"); 90 add_node4->add_input("add_node3"); 91 92 GraphDef result; 93 TransformFuncContext context; 94 context.input_names = {}; 95 context.output_names = {"add_node1"}; 96 context.params.insert( 97 std::pair<string, std::vector<string>>({"op", {string("Identity")}})); 98 TF_ASSERT_OK(RemoveNodes(graph_def, context, &result)); 99 100 std::map<string, const NodeDef*> node_lookup; 101 MapNamesToNodes(result, &node_lookup); 102 EXPECT_EQ(1, node_lookup.count("add_node1")); 103 EXPECT_EQ("add_node2", node_lookup.at("add_node1")->input(0)); 104 EXPECT_EQ("add_node3", node_lookup.at("add_node1")->input(1)); 105 EXPECT_EQ(1, node_lookup.count("add_node2")); 106 EXPECT_EQ("const_node1", node_lookup.at("add_node2")->input(0)); 107 EXPECT_EQ("const_node2", node_lookup.at("add_node2")->input(1)); 108 EXPECT_EQ(1, node_lookup.count("add_node3")); 109 EXPECT_EQ("const_node1", node_lookup.at("add_node3")->input(0)); 110 EXPECT_EQ("const_node3", node_lookup.at("add_node3")->input(1)); 111 EXPECT_EQ(1, node_lookup.count("add_node4")); 112 EXPECT_EQ("add_node2", node_lookup.at("add_node4")->input(0)); 113 EXPECT_EQ("add_node3", node_lookup.at("add_node4")->input(1)); 114 EXPECT_EQ(0, node_lookup.count("identity_node1")); 115 EXPECT_EQ(0, node_lookup.count("identity_node2")); 116 EXPECT_EQ(0, node_lookup.count("identity_node3")); 117 EXPECT_EQ(1, node_lookup.count("const_node1")); 118 EXPECT_EQ("Const", node_lookup.at("const_node1")->op()); 119 EXPECT_EQ(1, node_lookup.count("const_node2")); 120 EXPECT_EQ("Const", node_lookup.at("const_node2")->op()); 121 EXPECT_EQ(1, node_lookup.count("const_node3")); 122 EXPECT_EQ("Const", node_lookup.at("const_node3")->op()); 123 } 124 125 void TestRemoveOutputNodes() { 126 GraphDef graph_def; 127 128 NodeDef* const_node1 = graph_def.add_node(); 129 const_node1->set_name("const_node1"); 130 const_node1->set_op("Const"); 131 132 NodeDef* const_node2 = graph_def.add_node(); 133 const_node2->set_name("const_node2"); 134 const_node2->set_op("Const"); 135 136 NodeDef* add_node = graph_def.add_node(); 137 add_node->set_name("add_node"); 138 add_node->set_op("Add"); 139 add_node->add_input("const_node1"); 140 add_node->add_input("const_node2"); 141 142 NodeDef* identity_node = graph_def.add_node(); 143 identity_node->set_name("identity_node"); 144 identity_node->set_op("Identity"); 145 identity_node->add_input("add_node"); 146 147 GraphDef result; 148 TransformFuncContext context; 149 context.input_names = {}; 150 context.output_names = {"identity_node"}; 151 context.params.insert( 152 std::pair<string, std::vector<string>>({"op", {string("Identity")}})); 153 TF_ASSERT_OK(RemoveNodes(graph_def, context, &result)); 154 155 std::map<string, const NodeDef*> node_lookup; 156 MapNamesToNodes(result, &node_lookup); 157 EXPECT_EQ(1, node_lookup.count("add_node")); 158 EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0)); 159 EXPECT_EQ("const_node2", node_lookup.at("add_node")->input(1)); 160 EXPECT_EQ(1, node_lookup.count("identity_node")); 161 EXPECT_EQ("add_node", node_lookup.at("identity_node")->input(0)); 162 } 163 164 void TestRemoveChainedNodes() { 165 GraphDef graph_def; 166 167 NodeDef* const_node1 = graph_def.add_node(); 168 const_node1->set_name("const_node1"); 169 const_node1->set_op("Const"); 170 171 NodeDef* identity_node1 = graph_def.add_node(); 172 identity_node1->set_name("identity_node1"); 173 identity_node1->set_op("Identity"); 174 identity_node1->add_input("const_node1"); 175 176 NodeDef* identity_node2 = graph_def.add_node(); 177 identity_node2->set_name("identity_node2"); 178 identity_node2->set_op("Identity"); 179 identity_node2->add_input("identity_node1"); 180 181 NodeDef* identity_node3 = graph_def.add_node(); 182 identity_node3->set_name("identity_node3"); 183 identity_node3->set_op("Identity"); 184 identity_node3->add_input("identity_node2"); 185 186 NodeDef* const_node2 = graph_def.add_node(); 187 const_node2->set_name("const_node2"); 188 const_node2->set_op("Const"); 189 190 NodeDef* add_node = graph_def.add_node(); 191 add_node->set_name("add_node"); 192 add_node->set_op("Add"); 193 add_node->add_input("identity_node3"); 194 add_node->add_input("const_node2"); 195 196 GraphDef result; 197 TransformFuncContext context; 198 context.input_names = {}; 199 context.output_names = {"identity_node"}; 200 context.params.insert( 201 std::pair<string, std::vector<string>>({"op", {string("Identity")}})); 202 TF_ASSERT_OK(RemoveNodes(graph_def, context, &result)); 203 204 std::map<string, const NodeDef*> node_lookup; 205 MapNamesToNodes(result, &node_lookup); 206 EXPECT_EQ(1, node_lookup.count("add_node")); 207 EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0)); 208 EXPECT_EQ("const_node2", node_lookup.at("add_node")->input(1)); 209 EXPECT_EQ(0, node_lookup.count("identity_node1")); 210 EXPECT_EQ(0, node_lookup.count("identity_node2")); 211 EXPECT_EQ(0, node_lookup.count("identity_node3")); 212 } 213 214 void TestRemoveMultipleInputs() { 215 GraphDef graph_def; 216 217 NodeDef* const_node1 = graph_def.add_node(); 218 const_node1->set_name("const_node1"); 219 const_node1->set_op("Const"); 220 221 NodeDef* const_node2 = graph_def.add_node(); 222 const_node2->set_name("const_node2"); 223 const_node2->set_op("Const"); 224 225 NodeDef* const_node3 = graph_def.add_node(); 226 const_node3->set_name("const_node3"); 227 const_node3->set_op("Const"); 228 229 NodeDef* const_node4 = graph_def.add_node(); 230 const_node4->set_name("const_node4"); 231 const_node4->set_op("Const"); 232 233 NodeDef* fake_quant_node = graph_def.add_node(); 234 fake_quant_node->set_name("fake_quant_node"); 235 fake_quant_node->set_op("FakeQuantWithMinMaxVars"); 236 fake_quant_node->add_input("const_node1"); 237 fake_quant_node->add_input("const_node2"); 238 fake_quant_node->add_input("const_node3"); 239 240 NodeDef* add_node = graph_def.add_node(); 241 add_node->set_name("add_node"); 242 add_node->set_op("Add"); 243 add_node->add_input("fake_quant_node"); 244 add_node->add_input("const_node4"); 245 246 GraphDef result; 247 TransformFuncContext context; 248 context.input_names = {}; 249 context.output_names = {"add_node"}; 250 context.params.insert(std::pair<string, std::vector<string>>( 251 {"op", {string("FakeQuantWithMinMaxVars")}})); 252 context.params.insert( 253 std::pair<string, std::vector<string>>({"max_inputs", {string("3")}})); 254 TF_ASSERT_OK(RemoveNodes(graph_def, context, &result)); 255 256 std::map<string, const NodeDef*> node_lookup; 257 MapNamesToNodes(result, &node_lookup); 258 ASSERT_EQ(1, node_lookup.count("const_node1")); 259 ASSERT_EQ(1, node_lookup.count("const_node4")); 260 ASSERT_EQ(0, node_lookup.count("fake_quant_node")); 261 ASSERT_EQ(1, node_lookup.count("add_node")); 262 EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0)); 263 EXPECT_EQ("const_node4", node_lookup.at("add_node")->input(1)); 264 } 265 }; 266 267 TEST_F(RemoveNodesTest, TestRemoveNodes) { TestRemoveNodes(); } 268 269 TEST_F(RemoveNodesTest, TestRemoveOutputNodes) { TestRemoveOutputNodes(); } 270 271 TEST_F(RemoveNodesTest, TestRemoveChainedNodes) { TestRemoveChainedNodes(); } 272 273 TEST_F(RemoveNodesTest, TestRemoveMultipleInputs) { 274 TestRemoveMultipleInputs(); 275 } 276 277 } // namespace graph_transforms 278 } // namespace tensorflow 279