1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/graph/graph_partition.h" 17 18 #include <unordered_map> 19 #include <utility> 20 21 #include "tensorflow/cc/ops/array_ops.h" 22 #include "tensorflow/cc/ops/const_op.h" 23 #include "tensorflow/cc/ops/control_flow_ops.h" 24 #include "tensorflow/cc/ops/control_flow_ops_internal.h" 25 #include "tensorflow/cc/ops/math_ops.h" 26 #include "tensorflow/cc/ops/random_ops.h" 27 #include "tensorflow/cc/ops/sendrecv_ops.h" 28 #include "tensorflow/cc/ops/while_loop.h" 29 #include "tensorflow/core/framework/common_shape_fns.h" 30 #include "tensorflow/core/framework/function_testlib.h" 31 #include "tensorflow/core/framework/op.h" 32 #include "tensorflow/core/framework/versions.pb.h" 33 #include "tensorflow/core/graph/graph.h" 34 #include "tensorflow/core/graph/graph_constructor.h" 35 #include "tensorflow/core/graph/graph_def_builder.h" 36 #include "tensorflow/core/kernels/ops_util.h" 37 #include "tensorflow/core/lib/core/status_test_util.h" 38 #include "tensorflow/core/platform/logging.h" 39 #include "tensorflow/core/platform/protobuf.h" 40 #include "tensorflow/core/platform/test.h" 41 #include "tensorflow/core/public/version.h" 42 #include "tensorflow/core/util/equal_graph_def.h" 43 44 namespace tensorflow { 45 46 // from graph_partition.cc 47 extern Status TopologicalSortNodesWithTimePriority( 48 const GraphDef* gdef, std::vector<std::pair<const NodeDef*, int64>>* nodes, 49 std::unordered_map<const NodeDef*, int64>* node_to_start_time_out); 50 51 namespace { 52 53 using ops::_Recv; 54 using ops::_Send; 55 using ops::Const; 56 using ops::Identity; 57 using ops::LoopCond; 58 using ops::NextIteration; 59 60 const char gpu_device[] = "/job:a/replica:0/task:0/device:GPU:0"; 61 62 string SplitByDevice(const Node* node) { return node->assigned_device_name(); } 63 64 string DeviceName(const Node* node) { 65 char first = node->name()[0]; 66 if (first == 'G') { 67 return gpu_device; 68 } else { 69 const string cpu_prefix = "/job:a/replica:0/task:0/cpu:"; 70 int index = first - 'A'; 71 return strings::StrCat(cpu_prefix, index); 72 } 73 } 74 75 void Partition(const GraphDef& graph_def, 76 std::unordered_map<string, GraphDef>* partitions) { 77 Graph g(OpRegistry::Global()); 78 GraphConstructorOptions opts; 79 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &g)); 80 81 // Assigns devices to each node. Uses 1st letter of the node name as the 82 // device index if no device is specified. 83 for (Node* node : g.nodes()) { 84 string device_name = !node->requested_device().empty() 85 ? node->requested_device() 86 : DeviceName(node); 87 node->set_assigned_device_name(device_name); 88 } 89 90 PartitionOptions popts; 91 popts.node_to_loc = SplitByDevice; 92 popts.new_name = [&g](const string& prefix) { return g.NewName(prefix); }; 93 popts.get_incarnation = [](const string& name) { 94 return (name[0] - 'A') + 100; 95 }; 96 Status s = Partition(popts, &g, partitions); 97 CHECK(s.ok()) << s; 98 99 // Check versions. 100 EXPECT_EQ(graph_def.versions().producer(), TF_GRAPH_DEF_VERSION); 101 // Partitions must inherit the versions of the original graph. 102 for (auto& it : *partitions) { 103 EXPECT_EQ(graph_def.versions().producer(), it.second.versions().producer()); 104 EXPECT_EQ(graph_def.versions().min_consumer(), 105 it.second.versions().min_consumer()); 106 } 107 } 108 109 void CheckLoopConstruction(const GraphDef& graph_def) { 110 std::unordered_map<string, GraphDef> partitions; 111 Partition(graph_def, &partitions); 112 for (const auto& kv : partitions) { 113 const GraphDef& gdef = kv.second; 114 bool has_control_enter = false; 115 bool has_control_merge = false; 116 bool has_control_switch = false; 117 bool has_control_next = false; 118 for (const NodeDef& ndef : gdef.node()) { 119 // _recvs must have a control input 120 if (ndef.op() == "_Recv") { 121 bool has_control = false; 122 for (const string& input_name : ndef.input()) { 123 if (StringPiece(input_name).starts_with("^")) { 124 has_control = true; 125 break; 126 } 127 } 128 EXPECT_TRUE(has_control); 129 } 130 // Must have a control loop 131 if (StringPiece(ndef.name()).starts_with("_cloop")) { 132 if (ndef.op() == "Enter") { 133 has_control_enter = true; 134 } 135 if (ndef.op() == "Merge") { 136 has_control_merge = true; 137 } 138 if (ndef.op() == "Switch") { 139 has_control_switch = true; 140 } 141 if (ndef.op() == "NextIteration") { 142 has_control_next = true; 143 } 144 } 145 } 146 EXPECT_TRUE(has_control_enter); 147 EXPECT_TRUE(has_control_merge); 148 EXPECT_TRUE(has_control_switch); 149 EXPECT_TRUE(has_control_next); 150 } 151 } 152 153 REGISTER_OP("FloatInput") 154 .Output("o: float") 155 .SetShapeFn(shape_inference::UnknownShape); 156 REGISTER_OP("BoolInput") 157 .Output("o: bool") 158 .SetShapeFn(shape_inference::UnknownShape); 159 REGISTER_OP("Combine") 160 .Input("a: float") 161 .Input("b: float") 162 .Output("o: float") 163 .SetShapeFn(shape_inference::UnknownShape); 164 165 Output ConstructOp(const Scope& scope, const string& op_type, 166 const gtl::ArraySlice<Input>& inputs) { 167 if (!scope.ok()) return Output(); 168 const string unique_name = scope.GetUniqueNameForOp(op_type); 169 auto builder = 170 NodeBuilder(unique_name, op_type, scope.graph()->op_registry()); 171 for (auto const& input : inputs) { 172 builder.Input(ops::NodeOut(input.node(), input.index())); 173 } 174 scope.UpdateBuilder(&builder); 175 Node* ret; 176 scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); 177 if (!scope.ok()) return Output(); 178 scope.UpdateStatus(scope.DoShapeInference(ret)); 179 if (!scope.ok()) return Output(); 180 return Output(ret); 181 } 182 183 Output FloatInput(const Scope& scope) { 184 return ConstructOp(scope, "FloatInput", {}); 185 } 186 187 Output BoolInput(const Scope& scope) { 188 return ConstructOp(scope, "BoolInput", {}); 189 } 190 191 Output Combine(const Scope& scope, Input a, Input b) { 192 return ConstructOp(scope, "Combine", {std::move(a), std::move(b)}); 193 } 194 195 class GraphPartitionTest : public ::testing::Test { 196 protected: 197 GraphPartitionTest() 198 : in_(Scope::NewRootScope().ExitOnError()), 199 scope_a_(Scope::NewRootScope().ExitOnError().WithDevice( 200 "/job:a/replica:0/task:0/cpu:0")), 201 scope_b_(Scope::NewRootScope().ExitOnError().WithDevice( 202 "/job:a/replica:0/task:0/cpu:1")) {} 203 204 const GraphDef& ToGraphDef() { 205 TF_EXPECT_OK(in_.ToGraphDef(&in_graph_def_)); 206 return in_graph_def_; 207 } 208 209 void ExpectMatchA() { 210 GraphDef graph_def; 211 TF_EXPECT_OK(scope_a_.ToGraphDef(&graph_def)); 212 string a = "/job:a/replica:0/task:0/cpu:0"; 213 TF_EXPECT_GRAPH_EQ(graph_def, partitions_[a]); 214 } 215 216 void ExpectMatchB() { 217 GraphDef graph_def; 218 TF_EXPECT_OK(scope_b_.ToGraphDef(&graph_def)); 219 string b = "/job:a/replica:0/task:0/cpu:1"; 220 TF_EXPECT_GRAPH_EQ(graph_def, partitions_[b]); 221 } 222 223 void ExpectFunctions(const FunctionDefLibrary& library, 224 const std::set<string>& expected_names) { 225 std::set<string> actual_names; 226 for (const FunctionDef& fdef : library.function()) { 227 actual_names.insert(fdef.signature().name()); 228 } 229 EXPECT_EQ(actual_names, expected_names); 230 } 231 232 Scope in_; 233 GraphDef in_graph_def_; 234 Scope scope_a_; 235 Scope scope_b_; 236 std::unordered_map<string, GraphDef> partitions_; 237 }; 238 239 TEST_F(GraphPartitionTest, SingleDevice) { 240 auto a1 = FloatInput(in_.WithOpName("A1")); 241 Combine(in_.WithOpName("A2"), a1, a1); 242 243 Partition(ToGraphDef(), &partitions_); 244 EXPECT_EQ(1, partitions_.size()); 245 246 a1 = FloatInput(scope_a_.WithOpName("A1")); 247 Combine(scope_a_.WithOpName("A2"), a1, a1); 248 ExpectMatchA(); 249 } 250 251 TEST_F(GraphPartitionTest, CrossDeviceData) { 252 auto a1 = FloatInput(in_.WithOpName("A1")); 253 auto b1 = FloatInput(in_.WithOpName("B1")); 254 Combine(in_.WithOpName("B2"), a1, b1); 255 256 Partition(ToGraphDef(), &partitions_); 257 EXPECT_EQ(2, partitions_.size()); 258 259 string a = "/job:a/replica:0/task:0/cpu:0"; 260 string b = "/job:a/replica:0/task:0/cpu:1"; 261 a1 = FloatInput(scope_a_.WithOpName("A1")); 262 _Send(scope_a_.WithOpName("A1/_0"), a1, "edge_1_A1", a, 82, b); 263 ExpectMatchA(); 264 265 b1 = FloatInput(scope_b_.WithOpName("B1")); 266 auto recv = 267 _Recv(scope_b_.WithOpName("A1/_1"), DT_FLOAT, "edge_1_A1", a, 82, b); 268 Combine(scope_b_.WithOpName("B2"), recv, b1); 269 ExpectMatchB(); 270 } 271 272 TEST_F(GraphPartitionTest, CrossDeviceControl) { 273 auto a1 = FloatInput(in_.WithOpName("A1")); 274 auto b1 = FloatInput(in_.WithOpName("B1")); 275 Combine(in_.WithOpName("B2").WithControlDependencies(a1), b1, b1); 276 277 Partition(ToGraphDef(), &partitions_); 278 EXPECT_EQ(2, partitions_.size()); 279 280 string a = "/job:a/replica:0/task:0/cpu:0"; 281 string b = "/job:a/replica:0/task:0/cpu:1"; 282 a1 = FloatInput(scope_a_.WithOpName("A1")); 283 auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {}); 284 _Send(scope_a_.WithOpName("A1/_1"), c, "edge_3_A1", a, 82, b); 285 ExpectMatchA(); 286 287 auto recv = 288 _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_3_A1", a, 82, b); 289 auto id = Identity(scope_b_.WithOpName("A1/_3"), recv); 290 b1 = FloatInput(scope_b_.WithOpName("B1")); 291 Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1); 292 ExpectMatchB(); 293 } 294 295 TEST_F(GraphPartitionTest, CrossDeviceData_MultiUse) { 296 auto a1 = FloatInput(in_.WithOpName("A1")); 297 auto b1 = FloatInput(in_.WithOpName("B1")); 298 Combine(in_.WithOpName("B2"), a1, b1); 299 Combine(in_.WithOpName("B3"), a1, a1); 300 301 Partition(ToGraphDef(), &partitions_); 302 EXPECT_EQ(2, partitions_.size()); 303 304 string a = "/job:a/replica:0/task:0/cpu:0"; 305 string b = "/job:a/replica:0/task:0/cpu:1"; 306 a1 = FloatInput(scope_a_.WithOpName("A1")); 307 _Send(scope_a_.WithOpName("A1/_0"), a1, "edge_1_A1", a, 82, b); 308 ExpectMatchA(); 309 310 auto recv = 311 _Recv(scope_b_.WithOpName("A1/_1"), DT_FLOAT, "edge_1_A1", a, 82, b); 312 b1 = FloatInput(scope_b_.WithOpName("B1")); 313 Combine(scope_b_.WithOpName("B2"), recv, b1); 314 Combine(scope_b_.WithOpName("B3"), recv, recv); 315 ExpectMatchB(); 316 } 317 318 TEST_F(GraphPartitionTest, CrossDeviceControl_MultiUse) { 319 auto a1 = FloatInput(in_.WithOpName("A1")); 320 auto b1 = FloatInput(in_.WithOpName("B1")); 321 Combine(in_.WithOpName("B2").WithControlDependencies(a1), b1, b1); 322 FloatInput(in_.WithOpName("B3").WithControlDependencies(a1)); 323 324 Partition(ToGraphDef(), &partitions_); 325 EXPECT_EQ(2, partitions_.size()); 326 327 string a = "/job:a/replica:0/task:0/cpu:0"; 328 string b = "/job:a/replica:0/task:0/cpu:1"; 329 a1 = FloatInput(scope_a_.WithOpName("A1")); 330 auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {}); 331 _Send(scope_a_.WithOpName("A1/_1"), c, "edge_1_A1", a, 82, b); 332 ExpectMatchA(); 333 334 auto recv = 335 _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_1_A1", a, 82, b); 336 auto id = Identity(scope_b_.WithOpName("A1/_3"), recv); 337 b1 = FloatInput(scope_b_.WithOpName("B1")); 338 Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1); 339 FloatInput(scope_b_.WithOpName("B3").WithControlDependencies(id)); 340 ExpectMatchB(); 341 } 342 343 TEST_F(GraphPartitionTest, CrossDevice_DataControl) { 344 auto a1 = FloatInput(in_.WithOpName("A1")); 345 auto b1 = FloatInput(in_.WithOpName("B1")); 346 Combine(in_.WithOpName("B2"), a1, b1); 347 FloatInput(in_.WithOpName("B3").WithControlDependencies(a1)); 348 349 Partition(ToGraphDef(), &partitions_); 350 EXPECT_EQ(2, partitions_.size()); 351 352 string a = "/job:a/replica:0/task:0/cpu:0"; 353 string b = "/job:a/replica:0/task:0/cpu:1"; 354 a1 = FloatInput(scope_a_.WithOpName("A1")); 355 auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {}); 356 // NOTE: Send 0 A1/_1 -> A1/_2 is not necessarily needed. We could 357 // use A1/_0 -> A1/_4 as the control as a minor optimization. 358 _Send(scope_a_.WithOpName("A1/_1"), c, "edge_1_A1", a, 82, b); 359 _Send(scope_a_.WithOpName("A1/_4"), a1, "edge_2_A1", a, 82, b); 360 ExpectMatchA(); 361 362 auto recv1 = 363 _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_1_A1", a, 82, b); 364 auto id1 = Identity(scope_b_.WithOpName("A1/_3"), recv1); 365 auto recv2 = 366 _Recv(scope_b_.WithOpName("A1/_5"), DT_FLOAT, "edge_2_A1", a, 82, b); 367 b1 = FloatInput(scope_b_.WithOpName("B1")); 368 Combine(scope_b_.WithOpName("B2"), recv2, b1); 369 FloatInput(scope_b_.WithOpName("B3").WithControlDependencies(id1)); 370 ExpectMatchB(); 371 } 372 373 TEST_F(GraphPartitionTest, CrossDeviceLoopSimple) { 374 auto a1 = BoolInput(in_.WithOpName("A1")); 375 auto a2 = ::tensorflow::ops::internal::Enter(in_.WithOpName("A2"), a1, "foo"); 376 auto a3 = ::tensorflow::ops::Merge(in_.WithOpName("A3"), 377 {a2, Input("A5", 0, DT_BOOL)}) 378 .output; 379 LoopCond(in_.WithOpName("A4"), a3); 380 auto b1 = Identity(in_.WithOpName("B1"), a3); 381 NextIteration(in_.WithOpName("A5"), b1); 382 383 CheckLoopConstruction(ToGraphDef()); 384 } 385 386 TEST_F(GraphPartitionTest, CrossDeviceLoopSimple1) { 387 auto a1 = BoolInput(in_.WithOpName("A1")); 388 auto a2 = ::tensorflow::ops::internal::Enter(in_.WithOpName("B2"), a1, "foo"); 389 auto a3 = ::tensorflow::ops::Merge(in_.WithOpName("A3"), 390 {a2, Input("B5", 0, DT_BOOL)}) 391 .output; 392 LoopCond(in_.WithOpName("A4"), a3); 393 auto b1 = Identity(in_.WithOpName("B1"), a3); 394 NextIteration(in_.WithOpName("B5"), b1); 395 396 std::unordered_map<string, GraphDef> partitions; 397 Partition(ToGraphDef(), &partitions); 398 for (const auto& kv : partitions) { 399 const GraphDef& gdef = kv.second; 400 for (const NodeDef& ndef : gdef.node()) { 401 if (ndef.name() == "A3") { 402 // A3, B2, and B5 are on the same device. 403 EXPECT_EQ(ndef.input(0), "B2"); 404 EXPECT_EQ(ndef.input(1), "B5"); 405 } 406 } 407 } 408 } 409 410 TEST_F(GraphPartitionTest, CrossDeviceLoopFull) { 411 Scope cpu0 = in_.WithDevice("/job:a/replica:0/task:0/cpu:0"); 412 auto p1 = ops::Placeholder(cpu0, DT_INT32); 413 auto p2 = ops::Placeholder(cpu0, DT_INT32); 414 OutputList outputs; 415 // while i1 < 10: i1 += i2 416 TF_ASSERT_OK(ops::BuildWhileLoop( 417 cpu0, {p1, p2}, 418 [](const Scope& s, const std::vector<Output>& inputs, Output* output) { 419 *output = ops::Less(s, inputs[0], 10); 420 return s.status(); 421 }, 422 [](const Scope& s, const std::vector<Output>& inputs, 423 std::vector<Output>* outputs) { 424 Scope cpu1 = s.WithDevice("/job:a/replica:0/task:0/cpu:1"); 425 outputs->push_back(ops::AddN(cpu1, {inputs[0], inputs[1]})); 426 outputs->push_back(inputs[1]); 427 return s.status(); 428 }, 429 "test_loop", &outputs)); 430 CheckLoopConstruction(ToGraphDef()); 431 } 432 433 TEST_F(GraphPartitionTest, PartitionIncompleteGraph) { 434 NodeDef ndef; 435 Graph g(OpRegistry::Global()); 436 // Invalid graph since the Combine node requires an input. 437 bool parsed = protobuf::TextFormat::ParseFromString( 438 R"EOF( 439 name: "N" 440 op: "Combine" 441 )EOF", 442 &ndef); 443 ASSERT_TRUE(parsed); 444 Status status; 445 g.AddNode(ndef, &status); 446 TF_ASSERT_OK(status); 447 448 PartitionOptions popts; 449 popts.node_to_loc = SplitByDevice; 450 popts.new_name = [&g](const string& prefix) { return g.NewName(prefix); }; 451 popts.get_incarnation = [](const string&) { return 1; }; 452 453 std::unordered_map<string, GraphDef> partitions; 454 status = Partition(popts, &g, &partitions); 455 // Partitioning should fail, but not crash like it did before the 456 // changes that accompanied the addition of this test. 457 EXPECT_EQ(error::INVALID_ARGUMENT, status.code()) << status; 458 } 459 460 TEST_F(GraphPartitionTest, Functions) { 461 FunctionDefLibrary fdef_lib; 462 *fdef_lib.add_function() = test::function::XTimesTwo(); 463 *fdef_lib.add_function() = test::function::XTimesFour(); 464 TF_ASSERT_OK(in_.graph()->AddFunctionLibrary(fdef_lib)); 465 466 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 467 auto a1 = FloatInput(in_.WithOpName("A1")); 468 auto b1 = FloatInput(in_.WithOpName("B1")); 469 ConstructOp(in_.WithOpName("A2"), "XTimesTwo", {a1}); 470 ConstructOp(in_.WithOpName("B2"), "XTimesFour", {b1}); 471 472 Partition(ToGraphDef(), &partitions_); 473 EXPECT_EQ(2, partitions_.size()); 474 475 // Test that partition graphs inherit function library from original graph 476 string a = "/job:a/replica:0/task:0/cpu:0"; 477 string b = "/job:a/replica:0/task:0/cpu:1"; 478 ExpectFunctions(partitions_[a].library(), {"XTimesTwo", "XTimesFour"}); 479 ExpectFunctions(partitions_[b].library(), {"XTimesTwo", "XTimesFour"}); 480 } 481 482 TEST_F(GraphPartitionTest, SetIncarnation) { 483 GraphDef gdef; 484 const char* const kSendRecvAttrs = R"proto( 485 attr { key: 'T' value { type: DT_FLOAT } } 486 attr { key: 'client_terminated' value { b: false } } 487 attr { key: 'recv_device' value { s: 'B' } } 488 attr { key: 'send_device' value { s: 'A' } } 489 attr { key: 'send_device_incarnation' value { i: 0 } } 490 attr { key: 'tensor_name' value { s: 'test' } } 491 )proto"; 492 CHECK(protobuf::TextFormat::ParseFromString( 493 strings::StrCat( 494 "node { name: 'A/Pi' op: 'Const' ", 495 " attr { key: 'dtype' value { type: DT_FLOAT } } ", 496 " attr { key: 'value' value { tensor { ", 497 " dtype: DT_FLOAT tensor_shape {} float_val: 3.14 } } } }", 498 "node { name: 'A' op: '_Send' input: 'A/Pi' ", kSendRecvAttrs, "}", 499 "node { name: 'B' op: '_Recv' ", kSendRecvAttrs, 500 " attr { key: 'tensor_type' value { type:DT_FLOAT}}}"), 501 &gdef)); 502 gdef.mutable_versions()->set_producer(TF_GRAPH_DEF_VERSION); 503 Partition(gdef, &partitions_); 504 EXPECT_EQ(2, partitions_.size()); 505 506 for (const auto& kv : partitions_) { 507 const GraphDef& gdef = kv.second; 508 for (const NodeDef& ndef : gdef.node()) { 509 if (ndef.name() == "A" || ndef.name() == "B") { 510 int64 val; 511 TF_CHECK_OK(GetNodeAttr(ndef, "send_device_incarnation", &val)); 512 EXPECT_EQ(val, 100); // Send device is "A". 513 } 514 } 515 } 516 } 517 518 TEST(TopologicalSortNodesWithTimePriorityTest, NoDependencies) { 519 // Create placeholders, shuffle them so the order in the graph is not strictly 520 // increasing. 521 Scope root = Scope::NewRootScope().ExitOnError(); 522 std::vector<int> indexes; 523 for (int i = 0; i < 20; ++i) { 524 indexes.push_back((i + 2001) % 20); 525 } 526 std::vector<ops::Placeholder> placeholders; 527 for (int i : indexes) { 528 placeholders.emplace_back(root.WithOpName(strings::StrCat("p", i)), 529 DT_FLOAT); 530 placeholders.back().node()->AddAttr("_start_time", i + 1); 531 } 532 533 GraphDef gdef; 534 TF_EXPECT_OK(root.ToGraphDef(&gdef)); 535 536 std::vector<std::pair<const NodeDef*, int64>> nodes; 537 std::unordered_map<const NodeDef*, int64> node_to_start_time; 538 TF_CHECK_OK( 539 TopologicalSortNodesWithTimePriority(&gdef, &nodes, &node_to_start_time)); 540 ASSERT_EQ(nodes.size(), 20); 541 for (int i = 0; i < nodes.size(); ++i) { 542 EXPECT_EQ(strings::StrCat("p", i), nodes[i].first->name()); 543 EXPECT_EQ(i + 1, nodes[i].second); 544 } 545 } 546 547 TEST(TopologicalSortNodesWithTimePriority, Dependencies) { 548 // Create placeholders, shuffle them so the order in the graph is not strictly 549 // increasing. 550 Scope root = Scope::NewRootScope().ExitOnError(); 551 std::vector<int> indexes; 552 std::vector<ops::Placeholder> placeholders_in_order; 553 const int num_leaves = 20; 554 for (int i = 0; i < num_leaves; ++i) { 555 indexes.push_back((i + 2001) % num_leaves); 556 placeholders_in_order.emplace_back(root.WithOpName(strings::StrCat("p", i)), 557 DT_FLOAT); 558 placeholders_in_order.back().node()->AddAttr("_start_time", i + 1); 559 } 560 std::vector<ops::Placeholder> placeholders; 561 for (int i : indexes) { 562 placeholders.push_back(placeholders_in_order[i]); 563 } 564 565 // Create ops that depend on the placeholders. We give start times to these 566 // that are in descending order (e.g., the op that depends on the first 567 // placeholder runs last). 568 std::vector<ops::Square> squares; 569 for (int i : indexes) { 570 squares.emplace_back(root.WithOpName(strings::StrCat("s", i)), 571 placeholders[i]); 572 squares.back().node()->AddAttr("_start_time", 50 - (i + 1)); 573 } 574 575 // Create addn to sum all squares. 576 std::vector<Input> inputs; 577 for (const auto& s : squares) inputs.push_back(s); 578 ops::AddN addn = ops::AddN(root.WithOpName("addn"), 579 tensorflow::gtl::ArraySlice<Input>(inputs)); 580 // Start times is actually listed earlier than the nodes it depends on. 581 // But because of dependency ordering, it is last in the list. 582 addn.node()->AddAttr("_start_time", 1); 583 584 GraphDef gdef; 585 TF_EXPECT_OK(root.ToGraphDef(&gdef)); 586 587 std::vector<std::pair<const NodeDef*, int64>> nodes; 588 std::unordered_map<const NodeDef*, int64> node_to_start_time; 589 TF_CHECK_OK( 590 TopologicalSortNodesWithTimePriority(&gdef, &nodes, &node_to_start_time)); 591 ASSERT_EQ(1 + squares.size() + placeholders.size(), nodes.size()); 592 for (int i = 0; i < placeholders.size(); ++i) { 593 const NodeDef* node = nodes[i].first; 594 EXPECT_EQ(strings::StrCat("p", i), node->name()); 595 EXPECT_EQ(i + 1, nodes[i].second); 596 EXPECT_EQ(i + 1, node_to_start_time[node]); 597 } 598 for (int i = 0; i < squares.size(); ++i) { 599 int node_index = placeholders.size() + i; 600 int square_index = num_leaves - 1 - i; 601 const NodeDef* node = nodes[node_index].first; 602 EXPECT_EQ(strings::StrCat("s", square_index), node->name()); 603 EXPECT_EQ(50 - (square_index + 1), nodes[node_index].second); 604 EXPECT_EQ(50 - (square_index + 1), node_to_start_time[node]); 605 } 606 EXPECT_EQ("addn", nodes.back().first->name()); 607 EXPECT_EQ(50, nodes.back().second); 608 EXPECT_EQ(50, node_to_start_time[nodes.back().first]); 609 } 610 611 TEST(TopologicalSortNodesWithTimePriority, WhileLoop) { 612 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 613 using namespace ::tensorflow::ops::internal; // NOLINT(build/namespaces) 614 615 // Create placeholders. 616 Scope root = Scope::NewRootScope().ExitOnError(); 617 std::vector<int> indexes; 618 std::vector<Placeholder> placeholders_in_order; 619 const int num_leaves = 20; 620 for (int i = 0; i < num_leaves; ++i) { 621 indexes.push_back((i + 2001) % num_leaves); 622 placeholders_in_order.emplace_back(root.WithOpName(strings::StrCat("p", i)), 623 DT_FLOAT); 624 placeholders_in_order.back().node()->AddAttr("_start_time", i + 1); 625 } 626 std::vector<Placeholder> placeholders; 627 placeholders.reserve(indexes.size()); 628 for (int i : indexes) { 629 placeholders.push_back(placeholders_in_order[i]); 630 } 631 632 // Add a while loop above each placeholder. 633 std::vector<Exit> while_exits; 634 const int nodes_per_loop = 8; 635 for (int i : indexes) { 636 Scope scope = root.NewSubScope(strings::StrCat("while", i)); 637 auto dummy = Placeholder(scope, DT_FLOAT); 638 639 Enter enter(scope, placeholders[i], strings::StrCat("frame", i)); 640 Merge merge(scope, std::initializer_list<Input>{enter, dummy}); 641 auto cv = Const(scope.WithControlDependencies({merge.output}), false); 642 LoopCond loop_cond(scope, cv); 643 Switch switch_node(scope, merge.output, loop_cond); 644 Identity identity(scope, switch_node.output_true); 645 NextIteration next_iteration(scope, identity); 646 while_exits.emplace_back(scope.WithOpName("exit"), 647 switch_node.output_false); 648 649 // Complete loop by removing dummy node and attaching NextIteration to 650 // that input of the merge node. 651 scope.graph()->RemoveNode(dummy.node()); 652 scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1); 653 654 int base_start_time = i * 10 + 100; 655 for (const auto& op : std::initializer_list<Output>{ 656 enter, merge.output, cv, loop_cond, switch_node.output_false, 657 identity, next_iteration, while_exits.back()}) { 658 op.node()->AddAttr("_start_time", base_start_time++); 659 } 660 } 661 662 // Create ops that depend on the loop exits. 663 std::vector<Square> squares; 664 squares.reserve(indexes.size()); 665 for (int i : indexes) { 666 squares.emplace_back(root.WithOpName(strings::StrCat("s", i)), 667 while_exits[i]); 668 squares.back().node()->AddAttr("_start_time", 500 - (i + 1)); 669 } 670 671 GraphDef gdef; 672 TF_EXPECT_OK(root.ToGraphDef(&gdef)); 673 674 // Run the sort. The while loop nodes do not appear in the output <nodes>. 675 std::vector<std::pair<const NodeDef*, int64>> nodes; 676 std::unordered_map<const NodeDef*, int64> node_to_start_time; 677 TF_CHECK_OK( 678 TopologicalSortNodesWithTimePriority(&gdef, &nodes, &node_to_start_time)); 679 ASSERT_LT(while_exits.size() + squares.size() + placeholders.size(), 680 nodes.size()); 681 int node_index = 0; 682 for (int i = 0; i < placeholders.size(); ++i, ++node_index) { 683 const NodeDef* node = nodes[i].first; 684 EXPECT_EQ(strings::StrCat("p", i), node->name()); 685 EXPECT_EQ(i + 1, nodes[i].second); 686 EXPECT_EQ(i + 1, node_to_start_time[node]); 687 } 688 for (int i = 0; i < while_exits.size(); ++i, node_index += nodes_per_loop) { 689 const NodeDef* node = nodes[node_index].first; 690 EXPECT_EQ(strings::StrCat("while", i, "/Enter"), node->name()); 691 EXPECT_EQ(100 + i * 10, nodes[node_index].second); 692 EXPECT_EQ(100 + i * 10, node_to_start_time[node]); 693 } 694 for (int i = 0; i < squares.size(); ++i, ++node_index) { 695 int square_index = num_leaves - 1 - i; 696 const NodeDef* node = nodes[node_index].first; 697 EXPECT_EQ(strings::StrCat("s", square_index), node->name()); 698 EXPECT_EQ(500 - (square_index + 1), nodes[node_index].second); 699 EXPECT_EQ(500 - (square_index + 1), node_to_start_time[node]); 700 } 701 } 702 703 } // namespace 704 } // namespace tensorflow 705