1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/graph/subgraph.h" 17 18 #include <string> 19 #include <vector> 20 21 #include "tensorflow/core/framework/graph.pb.h" 22 #include "tensorflow/core/framework/partial_tensor_shape.h" 23 #include "tensorflow/core/graph/graph.h" 24 #include "tensorflow/core/graph/graph_constructor.h" 25 #include "tensorflow/core/graph/graph_def_builder.h" 26 #include "tensorflow/core/graph/graph_def_builder_util.h" 27 #include "tensorflow/core/kernels/ops_util.h" 28 #include "tensorflow/core/lib/core/status.h" 29 #include "tensorflow/core/lib/core/status_test_util.h" 30 #include "tensorflow/core/lib/strings/str_util.h" 31 #include "tensorflow/core/platform/logging.h" 32 #include "tensorflow/core/platform/protobuf.h" 33 #include "tensorflow/core/platform/test.h" 34 #include "tensorflow/core/platform/test_benchmark.h" 35 36 // TODO(josh11b): Test setting the "device" field of a NodeDef. 37 // TODO(josh11b): Test that feeding won't prune targets. 38 39 namespace tensorflow { 40 namespace { 41 42 class SubgraphTest : public ::testing::Test { 43 protected: 44 SubgraphTest() : g_(new Graph(OpRegistry::Global())) { 45 device_info_.set_name("/job:a/replica:0/task:0/cpu:0"); 46 device_info_.set_device_type(DeviceType(DEVICE_CPU).type()); 47 device_info_.set_incarnation(0); 48 } 49 50 ~SubgraphTest() override {} 51 52 void ExpectOK(const string& gdef_ascii) { 53 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef_)); 54 GraphConstructorOptions opts; 55 TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef_, g_.get())); 56 } 57 58 Node* FindNode(const string& name) { 59 for (Node* n : g_->nodes()) { 60 if (n->name() == name) return n; 61 } 62 return nullptr; 63 } 64 65 bool HasNode(const string& name) { return FindNode(name) != nullptr; } 66 67 void ExpectNodes(const string& nodes) { 68 int count = 0; 69 std::vector<string> actual_nodes; 70 for (Node* n : g_->nodes()) { 71 if (n->IsOp()) { 72 count++; 73 actual_nodes.push_back(n->name()); 74 } 75 } 76 std::sort(actual_nodes.begin(), actual_nodes.end()); 77 78 LOG(INFO) << "Nodes present: " << str_util::Join(actual_nodes, " "); 79 80 std::vector<string> expected_nodes = str_util::Split(nodes, ','); 81 std::sort(expected_nodes.begin(), expected_nodes.end()); 82 for (const string& s : expected_nodes) { 83 Node* n = FindNode(s); 84 EXPECT_TRUE(n != nullptr) << s; 85 if (n->type_string() == "_Send" || n->type_string() == "_Recv") { 86 EXPECT_EQ(device_info_.name(), n->assigned_device_name()) << s; 87 } 88 } 89 90 EXPECT_TRUE(actual_nodes.size() == expected_nodes.size()) 91 << "\nActual: " << str_util::Join(actual_nodes, ",") 92 << "\nExpected: " << str_util::Join(expected_nodes, ","); 93 } 94 95 bool HasEdge(const string& src, int src_out, const string& dst, int dst_in) { 96 for (const Edge* e : g_->edges()) { 97 if (e->src()->name() == src && e->src_output() == src_out && 98 e->dst()->name() == dst && e->dst_input() == dst_in) 99 return true; 100 } 101 return false; 102 } 103 bool HasControlEdge(const string& src, const string& dst) { 104 return HasEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot); 105 } 106 107 string Subgraph(const string& fed_str, const string& fetch_str, 108 const string& targets_str, 109 bool use_function_convention = false) { 110 Graph* subgraph = new Graph(OpRegistry::Global()); 111 CopyGraph(*g_, subgraph); 112 std::vector<string> fed = 113 str_util::Split(fed_str, ',', str_util::SkipEmpty()); 114 std::vector<string> fetch = 115 str_util::Split(fetch_str, ',', str_util::SkipEmpty()); 116 std::vector<string> targets = 117 str_util::Split(targets_str, ',', str_util::SkipEmpty()); 118 119 subgraph::RewriteGraphMetadata metadata; 120 Status s = subgraph::RewriteGraphForExecution( 121 subgraph, fed, fetch, targets, device_info_, use_function_convention, 122 &metadata); 123 if (!s.ok()) { 124 delete subgraph; 125 return s.ToString(); 126 } 127 128 EXPECT_EQ(fed.size(), metadata.feed_types.size()); 129 EXPECT_EQ(fetch.size(), metadata.fetch_types.size()); 130 131 // Replace the graph with the subgraph for the rest of the display program 132 g_.reset(subgraph); 133 return "OK"; 134 } 135 136 Graph* graph() { return g_.get(); } 137 138 private: 139 GraphDef gdef_; 140 std::unique_ptr<Graph> g_; 141 DeviceAttributes device_info_; 142 }; 143 144 REGISTER_OP("TestParams").Output("o: float"); 145 REGISTER_OP("TestInput").Output("a: float").Output("b: float"); 146 REGISTER_OP("TestRelu").Input("i: float").Output("o: float"); 147 REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); 148 149 TEST_F(SubgraphTest, Targets1) { 150 ExpectOK( 151 "node { name: 'W1' op: 'TestParams' }" 152 "node { name: 'W2' op: 'TestParams' }" 153 "node { name: 'input' op: 'TestInput' }" 154 "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" 155 "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" 156 "node { name: 't3_a' op: 'TestRelu' input: 't2' }" 157 "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); 158 EXPECT_EQ("OK", Subgraph("", "", "t1")); 159 ExpectNodes("W1,input,t1"); 160 } 161 162 TEST_F(SubgraphTest, Targets2) { 163 ExpectOK( 164 "node { name: 'W1' op: 'TestParams' }" 165 "node { name: 'W2' op: 'TestParams' }" 166 "node { name: 'input' op: 'TestInput' }" 167 "node { name: 't1' op: 'TestMul' input: 'W1' input: 'input:1' }" 168 "node { name: 't2' op: 'TestMul' input: 'W2' input: 't1' }" 169 "node { name: 't3_a' op: 'TestRelu' input: 't2' }" 170 "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); 171 EXPECT_EQ("OK", Subgraph("", "", "t2,t3_a")); 172 ExpectNodes("W1,W2,input,t1,t2,t3_a"); 173 } 174 175 TEST_F(SubgraphTest, FedOutputs1) { 176 ExpectOK( 177 "node { name: 'W1' op: 'TestParams' }" 178 "node { name: 'W2' op: 'TestParams' }" 179 "node { name: 'input' op: 'TestInput' }" 180 "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" 181 "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" 182 "node { name: 't3_a' op: 'TestRelu' input: 't2' }" 183 "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); 184 EXPECT_EQ("OK", Subgraph("input:1", "", "t2")); 185 ExpectNodes("W1,W2,_recv_input_1,t1,t2"); 186 } 187 188 TEST_F(SubgraphTest, FedOutputs1_FunctionConvention) { 189 ExpectOK( 190 "node { name: 'W1' op: 'TestParams' }" 191 "node { name: 'W2' op: 'TestParams' }" 192 "node { name: 'input' op: 'TestInput' }" 193 "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" 194 "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" 195 "node { name: 't3_a' op: 'TestRelu' input: 't2' }" 196 "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); 197 EXPECT_EQ("OK", 198 Subgraph("input:1", "", "t2", true /* use_function_convention */)); 199 ExpectNodes("W1,W2,_arg_input_1_0,t1,t2"); 200 } 201 202 TEST_F(SubgraphTest, FedRefNode) { 203 ExpectOK( 204 "node { name: 'W1' op: 'TestParams' }" 205 "node { name: 'W2' op: 'TestParams' }" 206 "node { name: 't1' op: 'TestMul' input: [ 'W2', 'W1' ] }"); 207 EXPECT_EQ("OK", Subgraph("W1:0", "", "t1")); 208 ExpectNodes("_recv_W1_0,W2,t1"); 209 Node* n = FindNode("_recv_W1_0"); 210 EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0))); 211 } 212 213 TEST_F(SubgraphTest, FedRefNode_FunctionConvention) { 214 ExpectOK( 215 "node { name: 'W1' op: 'TestParams' }" 216 "node { name: 'W2' op: 'TestParams' }" 217 "node { name: 't1' op: 'TestMul' input: [ 'W2', 'W1' ] }"); 218 EXPECT_EQ("OK", 219 Subgraph("W1:0", "", "t1", true /* use_function_convention */)); 220 ExpectNodes("_arg_W1_0_0,W2,t1"); 221 Node* n = FindNode("_arg_W1_0_0"); 222 EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0))); 223 } 224 225 TEST_F(SubgraphTest, FedOutputs2_FunctionConvention) { 226 ExpectOK( 227 "node { name: 'W1' op: 'TestParams' }" 228 "node { name: 'W2' op: 'TestParams' }" 229 "node { name: 'input' op: 'TestInput' }" 230 "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" 231 "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" 232 "node { name: 't3_a' op: 'TestRelu' input: 't2' }" 233 "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); 234 // We feed input:1, but nothing connects to it, so the _recv(input:1) 235 // node also disappears. 236 EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2", 237 true /* use_function_convention */)); 238 ExpectNodes("_arg_t1_0_1,_arg_W2_0_2,t2"); 239 } 240 241 TEST_F(SubgraphTest, FetchOutputs1) { 242 ExpectOK( 243 "node { name: 'W1' op: 'TestParams' }" 244 "node { name: 'W2' op: 'TestParams' }" 245 "node { name: 'input' op: 'TestInput' }" 246 "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" 247 "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" 248 "node { name: 't3_a' op: 'TestRelu' input: 't2' }" 249 "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); 250 EXPECT_EQ("OK", Subgraph("", "W2,input:1,t1,t2", "t2")); 251 ExpectNodes( 252 "W1,W2,input,t1,t2,_send_W2_0,_send_input_1,_send_t1_0,_send_t2_0"); 253 } 254 255 TEST_F(SubgraphTest, FetchOutputs1_FunctionConvention) { 256 ExpectOK( 257 "node { name: 'W1' op: 'TestParams' }" 258 "node { name: 'W2' op: 'TestParams' }" 259 "node { name: 'input' op: 'TestInput' }" 260 "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" 261 "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" 262 "node { name: 't3_a' op: 'TestRelu' input: 't2' }" 263 "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); 264 EXPECT_EQ("OK", Subgraph("", "W2,input:1,t1,t2", "t2", 265 true /* use_function_convention */)); 266 ExpectNodes( 267 "W1,W2,input,t1,t2,_retval_W2_0_0,_retval_input_1_1,_retval_t1_0_2,_" 268 "retval_t2_0_3"); 269 } 270 271 TEST_F(SubgraphTest, FetchOutputs2) { 272 ExpectOK( 273 "node { name: 'W1' op: 'TestParams' }" 274 "node { name: 'W2' op: 'TestParams' }" 275 "node { name: 'input' op: 'TestInput' }" 276 "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" 277 "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" 278 "node { name: 't3_a' op: 'TestRelu' input: 't2' }" 279 "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); 280 EXPECT_EQ("OK", Subgraph("", "t3_a", "t2")); 281 ExpectNodes("W1,W2,input,t1,t2,t3_a,_send_t3_a_0"); 282 } 283 284 TEST_F(SubgraphTest, FetchOutputs2_FunctionConvention) { 285 ExpectOK( 286 "node { name: 'W1' op: 'TestParams' }" 287 "node { name: 'W2' op: 'TestParams' }" 288 "node { name: 'input' op: 'TestInput' }" 289 "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" 290 "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" 291 "node { name: 't3_a' op: 'TestRelu' input: 't2' }" 292 "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); 293 EXPECT_EQ("OK", 294 Subgraph("", "t3_a", "t2", true /* use_function_convention */)); 295 ExpectNodes("W1,W2,input,t1,t2,t3_a,_retval_t3_a_0_0"); 296 } 297 298 TEST_F(SubgraphTest, ChainOfFools) { 299 ExpectOK( 300 "node { name: 'a' op: 'TestParams' }" 301 "node { name: 'b' op: 'TestRelu' input: 'a'}" 302 "node { name: 'c' op: 'TestRelu' input: 'b'}" 303 "node { name: 'd' op: 'TestRelu' input: 'c'}" 304 "node { name: 'e' op: 'TestRelu' input: 'd'}" 305 "node { name: 'f' op: 'TestRelu' input: 'e'}"); 306 EXPECT_EQ("OK", Subgraph("c:0", "b:0,e:0", "")); 307 ExpectNodes("a,b,_send_b_0,_recv_c_0,d,e,_send_e_0"); 308 EXPECT_TRUE(HasEdge("a", 0, "b", 0)); 309 EXPECT_TRUE(HasEdge("b", 0, "_send_b_0", 0)); 310 EXPECT_TRUE(HasEdge("_recv_c_0", 0, "d", 0)); 311 EXPECT_TRUE(HasEdge("d", 0, "e", 0)); 312 EXPECT_TRUE(HasEdge("e", 0, "_send_e_0", 0)); 313 } 314 315 static bool HasSubstr(const string& base, const string& substr) { 316 bool ok = StringPiece(base).contains(substr); 317 EXPECT_TRUE(ok) << base << ", expected substring " << substr; 318 return ok; 319 } 320 321 TEST_F(SubgraphTest, Errors) { 322 ExpectOK( 323 "node { name: 'a' op: 'TestParams' }" 324 "node { name: 'b' op: 'TestRelu' input: 'a'}" 325 "node { name: 'c' op: 'TestRelu' input: 'b'}" 326 "node { name: 'd' op: 'TestRelu' input: 'c'}" 327 "node { name: 'e' op: 'TestRelu' input: 'd'}" 328 "node { name: 'f' op: 'TestRelu' input: 'e'}"); 329 // Duplicated feed and fetch 330 EXPECT_TRUE( 331 HasSubstr(Subgraph("c:0", "b:0,c:0", ""), "both fed and fetched")); 332 // Feed not found. 333 EXPECT_TRUE(HasSubstr(Subgraph("foo:0", "c:0", ""), "unable to find")); 334 // Fetch not found. 335 EXPECT_TRUE(HasSubstr(Subgraph("", "foo:0", ""), "not found")); 336 // Target not found. 337 EXPECT_TRUE(HasSubstr(Subgraph("", "", "foo"), "not found")); 338 // No targets specified. 339 EXPECT_TRUE(HasSubstr(Subgraph("", "", ""), "at least one target")); 340 } 341 342 REGISTER_OP("In").Output("o: float"); 343 REGISTER_OP("Op").Input("i: float").Output("o: float"); 344 345 static void BM_SubgraphHelper(int iters, int num_nodes, 346 bool use_function_convention) { 347 DeviceAttributes device_info; 348 device_info.set_name("/job:a/replica:0/task:0/cpu:0"); 349 device_info.set_device_type(DeviceType(DEVICE_CPU).type()); 350 device_info.set_incarnation(0); 351 352 testing::StopTiming(); 353 Graph g(OpRegistry::Global()); 354 { // Scope for temporary variables used to construct g. 355 GraphDefBuilder b(GraphDefBuilder::kFailImmediately); 356 Node* last_node = nullptr; 357 for (int i = 0; i < num_nodes; i++) { 358 string name = strings::StrCat("N", i); 359 if (i > 0) { 360 last_node = ops::UnaryOp("Op", last_node, b.opts().WithName(name)); 361 } else { 362 last_node = ops::SourceOp("In", b.opts().WithName(name)); 363 } 364 } 365 TF_CHECK_OK(GraphDefBuilderToGraph(b, &g)); 366 } 367 368 std::vector<string> fed; 369 if (num_nodes > 1000) { 370 fed.push_back(strings::StrCat("N", num_nodes - 1000)); 371 } 372 std::vector<string> fetch; 373 std::vector<string> targets = {strings::StrCat("N", num_nodes - 1)}; 374 testing::StartTiming(); 375 while (--iters > 0) { 376 Graph* subgraph = new Graph(OpRegistry::Global()); 377 CopyGraph(g, subgraph); 378 subgraph::RewriteGraphMetadata metadata; 379 TF_CHECK_OK(subgraph::RewriteGraphForExecution( 380 subgraph, fed, fetch, targets, device_info, use_function_convention, 381 &metadata)); 382 delete subgraph; 383 } 384 } 385 386 static void BM_Subgraph(int iters, int num_nodes) { 387 BM_SubgraphHelper(iters, num_nodes, false /* use_function_convention */); 388 } 389 static void BM_SubgraphFunctionConvention(int iters, int num_nodes) { 390 BM_SubgraphHelper(iters, num_nodes, true /* use_function_convention */); 391 } 392 BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000); 393 BENCHMARK(BM_SubgraphFunctionConvention) 394 ->Arg(100) 395 ->Arg(1000) 396 ->Arg(10000) 397 ->Arg(100000); 398 399 } // namespace 400 } // namespace tensorflow 401