1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h" 17 #include "tensorflow/cc/ops/standard_ops.h" 18 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT 19 #include "tensorflow/core/framework/tensor_description.pb.h" 20 #include "tensorflow/core/framework/tensor_shape.pb.h" 21 #include "tensorflow/core/grappler/clusters/virtual_cluster.h" 22 #include "tensorflow/core/grappler/costs/utils.h" 23 #include "tensorflow/core/grappler/costs/virtual_placer.h" 24 #include "tensorflow/core/lib/core/status_test_util.h" 25 #include "tensorflow/core/platform/test.h" 26 27 namespace tensorflow { 28 namespace grappler { 29 30 // Class for testing virtual scheduler. 31 class TestVirtualScheduler : public VirtualScheduler { 32 public: 33 TestVirtualScheduler(const bool use_static_shapes, 34 const bool use_aggressive_shape_inference, 35 Cluster* cluster) 36 : VirtualScheduler(use_static_shapes, use_aggressive_shape_inference, 37 cluster, &ready_node_manager_) { 38 enable_mem_usage_tracking(); 39 } 40 41 FRIEND_TEST(VirtualSchedulerTest, MemoryUsage); 42 FRIEND_TEST(VirtualSchedulerTest, ControlDependency); 43 FRIEND_TEST(VirtualSchedulerTest, ComplexDependency); 44 FRIEND_TEST(VirtualSchedulerTest, Variable); 45 FRIEND_TEST(VirtualSchedulerTest, InterDeviceTransfer); 46 47 protected: 48 FirstReadyManager ready_node_manager_; 49 }; 50 51 class VirtualSchedulerTest : public ::testing::Test { 52 protected: 53 VirtualSchedulerTest() { 54 // node1_ to node6_ on kCPU0, with time_ready in reverse_order. 55 NodeSetUp("Node1", kConv2D, kCPU0, 6000, &node1_); 56 NodeSetUp("Node2", kConv2D, kCPU0, 5000, &node2_); 57 NodeSetUp("Node3", kConv2D, kCPU0, 4000, &node3_); 58 NodeSetUp("Node4", kConv2D, kCPU0, 3000, &node4_); 59 NodeSetUp("Node5", kConv2D, kCPU0, 2000, &node5_); 60 NodeSetUp("Node6", kConv2D, kCPU0, 1000, &node6_); 61 62 // Initializes cluster_ and scheduler_. 63 std::unordered_map<string, DeviceProperties> devices; 64 65 // Set some dummy CPU properties 66 DeviceProperties cpu_device = GetDummyCPUDevice(); 67 68 // IMPORTANT: Device is not actually ever used in the test case since 69 // force_cpu_type is defaulted to "Haswell" 70 devices[kCPU0] = cpu_device; 71 devices[kCPU1] = cpu_device; 72 cluster_ = absl::make_unique<VirtualCluster>(devices); 73 scheduler_ = absl::make_unique<TestVirtualScheduler>( 74 /*use_static_shapes=*/true, 75 /*use_aggressive_shape_inference=*/true, cluster_.get()); 76 } 77 78 NodeDef node1_, node2_, node3_, node4_, node5_, node6_; 79 std::unordered_map<const NodeDef*, NodeState> node_states_; 80 81 // Device names: 82 const string kCPU0 = "/job:localhost/replica:0/task:0/cpu:0"; 83 const string kCPU1 = "/job:localhost/replica:0/task:0/cpu:1"; 84 const string kChannelFrom0To1 = "Channel from CPU0 to CPU1"; 85 const string kChannelFrom1To0 = "Channel from CPU1 to CPU0"; 86 // Op names: 87 const string kSend = "_Send"; 88 const string kRecv = "_Recv"; 89 const string kConv2D = "Conv2D"; 90 91 DeviceProperties GetDummyCPUDevice() { 92 // Create CPU with 2 cores, 4 Ghz freq, 2 GB/s mem bandwidth. 93 // - 8 Gflops 94 // - 2 GB/s 95 DeviceProperties cpu_device; 96 cpu_device.set_type("CPU"); 97 cpu_device.set_frequency(4000); 98 cpu_device.set_num_cores(2); 99 cpu_device.set_bandwidth(2000000); 100 return cpu_device; 101 } 102 103 void NodeSetUp(const string& name, const string& op_name, 104 const string& device_name, const uint64 time_ready, 105 NodeDef* node) { 106 node->set_name(name); 107 node->set_op(op_name); 108 node->set_device(device_name); 109 110 node_states_[node] = NodeState(); 111 node_states_[node].time_ready = time_ready; 112 node_states_[node].device_name = device_name; 113 } 114 115 // Three Conv2Ds with only two in fetch nodes. 116 void CreateGrapplerItemWithConv2Ds() { 117 Scope s = Scope::NewRootScope().WithDevice(kCPU0); 118 auto x = ops::RandomUniform( 119 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); 120 auto y = ops::RandomUniform( 121 s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); 122 auto z = ops::RandomUniform( 123 s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); 124 auto f = ops::RandomUniform( 125 s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT); 126 std::vector<int> strides = {1, 1, 1, 1}; 127 auto c0 = ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME"); 128 auto c1 = ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME"); 129 auto c2 = ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME"); 130 GraphDef def; 131 TF_CHECK_OK(s.ToGraphDef(&def)); 132 133 grappler_item_.reset(new GrapplerItem); 134 grappler_item_->id = "test_conv2d_graph"; 135 grappler_item_->graph = def; 136 grappler_item_->fetch = {"c0", "c1"}; 137 138 dependency_["c0"] = {"x", "f"}; 139 dependency_["c1"] = {"y", "f"}; 140 } 141 142 // A Conv2D with a variable. 143 void CreateGrapplerItemWithConv2DAndVariable() { 144 Scope s = Scope::NewRootScope().WithDevice(kCPU0); 145 auto x = ops::RandomUniform( 146 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); 147 auto f = ops::Variable(s.WithOpName("f"), 148 {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT); 149 std::vector<int> strides = {1, 1, 1, 1}; 150 auto y = ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME"); 151 GraphDef def; 152 TF_CHECK_OK(s.ToGraphDef(&def)); 153 154 grappler_item_.reset(new GrapplerItem); 155 grappler_item_->id = "test_conv2d_var_graph"; 156 grappler_item_->graph = def; 157 grappler_item_->fetch = {"y"}; 158 159 dependency_["y"] = {"x", "f"}; 160 } 161 162 void CreateGrapplerItemWithMatmulChain() { 163 Scope s = Scope::NewRootScope().WithDevice(kCPU0); 164 // Add control dependencies to ensure tests do not rely on specific 165 // manager and the order remains consistent for the test. 166 auto a = ops::RandomUniform(s.WithOpName("a"), {3200, 3200}, DT_FLOAT); 167 auto b = ops::RandomUniform(s.WithOpName("b").WithControlDependencies(a), 168 {3200, 3200}, DT_FLOAT); 169 auto c = ops::RandomUniform(s.WithOpName("c").WithControlDependencies(b), 170 {3200, 3200}, DT_FLOAT); 171 auto d = ops::RandomUniform(s.WithOpName("d").WithControlDependencies(c), 172 {3200, 3200}, DT_FLOAT); 173 auto e = ops::RandomUniform(s.WithOpName("e").WithControlDependencies(d), 174 {3200, 3200}, DT_FLOAT); 175 176 auto ab = ops::MatMul(s.WithOpName("ab").WithControlDependencies(e), a, b); 177 auto abc = ops::MatMul(s.WithOpName("abc"), ab, c); 178 auto abcd = ops::MatMul(s.WithOpName("abcd"), abc, d); 179 auto abcde = ops::MatMul(s.WithOpName("abcde"), abcd, e); 180 181 GraphDef def; 182 TF_CHECK_OK(s.ToGraphDef(&def)); 183 184 grappler_item_.reset(new GrapplerItem); 185 grappler_item_->id = "test_matmul_sequence_graph"; 186 grappler_item_->graph = def; 187 grappler_item_->fetch = {"abcde"}; 188 189 dependency_["ab"] = {"a", "b"}; 190 dependency_["abc"] = {"ab", "c"}; 191 dependency_["abcd"] = {"abc", "d"}; 192 dependency_["abcde"] = {"abcd", "e"}; 193 } 194 195 // AddN that takes 4 tensors with 10x10x10x10. 196 void CreateGrapplerItemWithAddN() { 197 Scope s = Scope::NewRootScope().WithDevice(kCPU0); 198 auto x = ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10}, DT_FLOAT); 199 auto y = ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10}, DT_FLOAT); 200 auto z = ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10}, DT_FLOAT); 201 auto w = ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10}, DT_FLOAT); 202 OutputList input_tensors = {x, y, z, w}; 203 auto out = ops::AddN(s.WithOpName("out"), input_tensors); 204 GraphDef def; 205 TF_CHECK_OK(s.ToGraphDef(&def)); 206 207 grappler_item_.reset(new GrapplerItem); 208 grappler_item_->id = "test_addn_graph"; 209 grappler_item_->graph = def; 210 grappler_item_->fetch = {"out"}; 211 212 dependency_["out"] = {"x", "y", "z", "w"}; 213 } 214 215 // Graph with some placeholder feed nodes that are not in the fetch fan-in. 216 void CreateGrapplerItemWithUnnecessaryPlaceholderNodes() { 217 Scope s = Scope::NewRootScope().WithDevice(kCPU0); 218 auto unnecessary = ops::Placeholder(s.WithOpName("unnecessary"), DT_FLOAT); 219 auto x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT); 220 221 GraphDef def; 222 TF_CHECK_OK(s.ToGraphDef(&def)); 223 224 grappler_item_.reset(new GrapplerItem); 225 grappler_item_->id = "test_extra_placeholders"; 226 grappler_item_->graph = def; 227 grappler_item_->fetch = {"x"}; 228 229 // Grappler Item Builder puts all placeholder nodes into the feed 230 // list by default. 231 grappler_item_->feed = {{"x", Tensor()}, {"unnecessary", Tensor()}}; 232 } 233 234 // NoOp that takes 7 NoOps as control dependency. 235 void CreateGrapplerItemWithControlDependency() { 236 Scope s = Scope::NewRootScope().WithDevice(kCPU0); 237 std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"}; 238 std::vector<Operation> input_tensors; 239 for (const auto& input : input_noop_names) { 240 auto x = ops::NoOp(s.WithOpName(input)); 241 input_tensors.push_back(x.operation); 242 } 243 auto out = 244 ops::NoOp(s.WithControlDependencies(input_tensors).WithOpName("out")); 245 GraphDef def; 246 TF_CHECK_OK(s.ToGraphDef(&def)); 247 248 grappler_item_.reset(new GrapplerItem); 249 grappler_item_->id = "test_control_dependency_graph"; 250 grappler_item_->graph = def; 251 grappler_item_->fetch = {"out"}; 252 253 dependency_["out"] = input_noop_names; 254 } 255 256 // FusedBN [an op with multiple outputs] with multiple consumers (including 257 // control dependency). 258 void CreateGrapplerItemWithBatchNorm() { 259 Scope s = Scope::NewRootScope().WithDevice(kCPU0); 260 auto x = ops::RandomUniform( 261 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); 262 auto scale = 263 ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT); 264 auto offset = 265 ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT); 266 auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT); 267 auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT); 268 269 auto batch_norm = ops::FusedBatchNorm( 270 s.WithOpName("bn"), x, scale, offset, mean, var, 271 ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f)); 272 auto y = batch_norm.y; 273 auto batch_mean = batch_norm.batch_mean; 274 auto batch_var = batch_norm.batch_variance; 275 276 auto z1 = ops::Add(s.WithOpName("z1"), x, y); 277 auto z2 = ops::Add(s.WithOpName("z2"), batch_var, batch_var); 278 auto z3 = ops::Add(s.WithOpName("z3"), batch_var, batch_var); 279 std::vector<Operation> input_tensors = { 280 batch_mean.op(), 281 z1.z.op(), 282 z2.z.op(), 283 z3.z.op(), 284 }; 285 auto z4 = ops::NoOp(s.WithControlDependencies(batch_var).WithOpName("z4")); 286 287 GraphDef def; 288 TF_CHECK_OK(s.ToGraphDef(&def)); 289 290 grappler_item_.reset(new GrapplerItem); 291 grappler_item_->id = "test_complex_dependency_graph"; 292 grappler_item_->graph = def; 293 grappler_item_->fetch = {"z1", "z2", "z3", "z4"}; 294 295 dependency_["bn"] = {"x", "scale", "offset", "mean", "var"}; 296 dependency_["z1"] = {"x", "bn"}; 297 dependency_["z2"] = {"bn"}; 298 dependency_["z3"] = {"bn"}; 299 dependency_["z4"] = {"bn"}; 300 } 301 302 void CreateGrapplerItemWithSendRecv() { 303 const string gdef_ascii = R"EOF( 304 node { 305 name: "Const" 306 op: "Const" 307 device: "/job:localhost/replica:0/task:0/device:CPU:0" 308 attr { 309 key: "dtype" 310 value { 311 type: DT_FLOAT 312 } 313 } 314 attr { 315 key: "value" 316 value { 317 tensor { 318 dtype: DT_FLOAT 319 tensor_shape { 320 } 321 float_val: 3.1415 322 } 323 } 324 } 325 } 326 node { 327 name: "Send" 328 op: "_Send" 329 input: "Const" 330 device: "/job:localhost/replica:0/task:0/device:CPU:0" 331 attr { 332 key: "T" 333 value { 334 type: DT_FLOAT 335 } 336 } 337 attr { 338 key: "client_terminated" 339 value { 340 b: false 341 } 342 } 343 attr { 344 key: "recv_device" 345 value { 346 s: "/job:localhost/replica:0/task:0/device:CPU:0" 347 } 348 } 349 attr { 350 key: "send_device" 351 value { 352 s: "/job:localhost/replica:0/task:0/device:CPU:0" 353 } 354 } 355 attr { 356 key: "send_device_incarnation" 357 value { 358 i: 0 359 } 360 } 361 attr { 362 key: "tensor_name" 363 value { 364 s: "test" 365 } 366 } 367 } 368 node { 369 name: "Recv" 370 op: "_Recv" 371 device: "/job:localhost/replica:0/task:0/device:CPU:0" 372 attr { 373 key: "client_terminated" 374 value { 375 b: false 376 } 377 } 378 attr { 379 key: "recv_device" 380 value { 381 s: "/job:localhost/replica:0/task:0/device:CPU:0" 382 } 383 } 384 attr { 385 key: "send_device" 386 value { 387 s: "/job:localhost/replica:0/task:0/device:CPU:0" 388 } 389 } 390 attr { 391 key: "send_device_incarnation" 392 value { 393 i: 0 394 } 395 } 396 attr { 397 key: "tensor_name" 398 value { 399 s: "test" 400 } 401 } 402 attr { 403 key: "tensor_type" 404 value { 405 type: DT_FLOAT 406 } 407 } 408 } 409 library { 410 } 411 versions { 412 producer: 24 413 } 414 )EOF"; 415 416 grappler_item_.reset(new GrapplerItem); 417 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, 418 &grappler_item_->graph)); 419 grappler_item_->id = "test_graph"; 420 grappler_item_->fetch = {"Recv"}; 421 } 422 423 void CreateGrapplerItemWithRecvWithoutSend() { 424 const string gdef_ascii = R"EOF( 425 node { 426 name: "Recv" 427 op: "_Recv" 428 device: "/job:localhost/replica:0/task:0/device:CPU:0" 429 attr { 430 key: "client_terminated" 431 value { 432 b: false 433 } 434 } 435 attr { 436 key: "recv_device" 437 value { 438 s: "/job:localhost/replica:0/task:0/device:CPU:0" 439 } 440 } 441 attr { 442 key: "send_device" 443 value { 444 s: "/job:localhost/replica:0/task:0/device:CPU:0" 445 } 446 } 447 attr { 448 key: "send_device_incarnation" 449 value { 450 i: 0 451 } 452 } 453 attr { 454 key: "tensor_name" 455 value { 456 s: "test" 457 } 458 } 459 attr { 460 key: "tensor_type" 461 value { 462 type: DT_FLOAT 463 } 464 } 465 } 466 library { 467 } 468 versions { 469 producer: 24 470 } 471 )EOF"; 472 473 grappler_item_.reset(new GrapplerItem); 474 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, 475 &grappler_item_->graph)); 476 grappler_item_->id = "test_graph"; 477 grappler_item_->fetch = {"Recv"}; 478 } 479 480 // A simple while loop 481 void CreateGrapplerItemWithLoop() { 482 // Test graph produced in python using: 483 /* 484 with tf.Graph().as_default(): 485 i0 = tf.constant(0) 486 m0 = tf.ones([2, 2]) 487 c = lambda i, m: i < 10 488 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] 489 r = tf.while_loop( 490 c, b, loop_vars=[i0, m0], 491 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) 492 with open('/tmp/graph.pbtxt', 'w') as f: 493 f.write(str(tf.get_default_graph().as_graph_def())) 494 */ 495 const string gdef_ascii = R"EOF( 496 node { 497 name: "Const" 498 op: "Const" 499 attr { 500 key: "dtype" 501 value { 502 type: DT_INT32 503 } 504 } 505 attr { 506 key: "value" 507 value { 508 tensor { 509 dtype: DT_INT32 510 tensor_shape { 511 } 512 int_val: 0 513 } 514 } 515 } 516 } 517 node { 518 name: "ones" 519 op: "Const" 520 attr { 521 key: "dtype" 522 value { 523 type: DT_FLOAT 524 } 525 } 526 attr { 527 key: "value" 528 value { 529 tensor { 530 dtype: DT_FLOAT 531 tensor_shape { 532 dim { 533 size: 2 534 } 535 dim { 536 size: 2 537 } 538 } 539 float_val: 1.0 540 } 541 } 542 } 543 } 544 node { 545 name: "while/Enter" 546 op: "Enter" 547 input: "Const" 548 attr { 549 key: "T" 550 value { 551 type: DT_INT32 552 } 553 } 554 attr { 555 key: "frame_name" 556 value { 557 s: "while/while/" 558 } 559 } 560 attr { 561 key: "is_constant" 562 value { 563 b: false 564 } 565 } 566 attr { 567 key: "parallel_iterations" 568 value { 569 i: 10 570 } 571 } 572 } 573 node { 574 name: "while/Enter_1" 575 op: "Enter" 576 input: "ones" 577 attr { 578 key: "T" 579 value { 580 type: DT_FLOAT 581 } 582 } 583 attr { 584 key: "frame_name" 585 value { 586 s: "while/while/" 587 } 588 } 589 attr { 590 key: "is_constant" 591 value { 592 b: false 593 } 594 } 595 attr { 596 key: "parallel_iterations" 597 value { 598 i: 10 599 } 600 } 601 } 602 node { 603 name: "while/Merge" 604 op: "Merge" 605 input: "while/Enter" 606 input: "while/NextIteration" 607 attr { 608 key: "N" 609 value { 610 i: 2 611 } 612 } 613 attr { 614 key: "T" 615 value { 616 type: DT_INT32 617 } 618 } 619 } 620 node { 621 name: "while/Merge_1" 622 op: "Merge" 623 input: "while/Enter_1" 624 input: "while/NextIteration_1" 625 attr { 626 key: "N" 627 value { 628 i: 2 629 } 630 } 631 attr { 632 key: "T" 633 value { 634 type: DT_FLOAT 635 } 636 } 637 } 638 node { 639 name: "while/Less/y" 640 op: "Const" 641 input: "^while/Merge" 642 attr { 643 key: "dtype" 644 value { 645 type: DT_INT32 646 } 647 } 648 attr { 649 key: "value" 650 value { 651 tensor { 652 dtype: DT_INT32 653 tensor_shape { 654 } 655 int_val: 10 656 } 657 } 658 } 659 } 660 node { 661 name: "while/Less" 662 op: "Less" 663 input: "while/Merge" 664 input: "while/Less/y" 665 attr { 666 key: "T" 667 value { 668 type: DT_INT32 669 } 670 } 671 } 672 node { 673 name: "while/LoopCond" 674 op: "LoopCond" 675 input: "while/Less" 676 } 677 node { 678 name: "while/Switch" 679 op: "Switch" 680 input: "while/Merge" 681 input: "while/LoopCond" 682 attr { 683 key: "T" 684 value { 685 type: DT_INT32 686 } 687 } 688 attr { 689 key: "_class" 690 value { 691 list { 692 s: "loc:@while/Merge" 693 } 694 } 695 } 696 } 697 node { 698 name: "while/Switch_1" 699 op: "Switch" 700 input: "while/Merge_1" 701 input: "while/LoopCond" 702 attr { 703 key: "T" 704 value { 705 type: DT_FLOAT 706 } 707 } 708 attr { 709 key: "_class" 710 value { 711 list { 712 s: "loc:@while/Merge_1" 713 } 714 } 715 } 716 } 717 node { 718 name: "while/Identity" 719 op: "Identity" 720 input: "while/Switch:1" 721 attr { 722 key: "T" 723 value { 724 type: DT_INT32 725 } 726 } 727 } 728 node { 729 name: "while/Identity_1" 730 op: "Identity" 731 input: "while/Switch_1:1" 732 attr { 733 key: "T" 734 value { 735 type: DT_FLOAT 736 } 737 } 738 } 739 node { 740 name: "while/add/y" 741 op: "Const" 742 input: "^while/Identity" 743 attr { 744 key: "dtype" 745 value { 746 type: DT_INT32 747 } 748 } 749 attr { 750 key: "value" 751 value { 752 tensor { 753 dtype: DT_INT32 754 tensor_shape { 755 } 756 int_val: 1 757 } 758 } 759 } 760 } 761 node { 762 name: "while/add" 763 op: "Add" 764 input: "while/Identity" 765 input: "while/add/y" 766 attr { 767 key: "T" 768 value { 769 type: DT_INT32 770 } 771 } 772 } 773 node { 774 name: "while/concat/axis" 775 op: "Const" 776 input: "^while/Identity" 777 attr { 778 key: "dtype" 779 value { 780 type: DT_INT32 781 } 782 } 783 attr { 784 key: "value" 785 value { 786 tensor { 787 dtype: DT_INT32 788 tensor_shape { 789 } 790 int_val: 0 791 } 792 } 793 } 794 } 795 node { 796 name: "while/concat" 797 op: "ConcatV2" 798 input: "while/Identity_1" 799 input: "while/Identity_1" 800 input: "while/concat/axis" 801 attr { 802 key: "N" 803 value { 804 i: 2 805 } 806 } 807 attr { 808 key: "T" 809 value { 810 type: DT_FLOAT 811 } 812 } 813 attr { 814 key: "Tidx" 815 value { 816 type: DT_INT32 817 } 818 } 819 } 820 node { 821 name: "while/NextIteration" 822 op: "NextIteration" 823 input: "while/add" 824 attr { 825 key: "T" 826 value { 827 type: DT_INT32 828 } 829 } 830 } 831 node { 832 name: "while/NextIteration_1" 833 op: "NextIteration" 834 input: "while/concat" 835 attr { 836 key: "T" 837 value { 838 type: DT_FLOAT 839 } 840 } 841 } 842 node { 843 name: "while/Exit" 844 op: "Exit" 845 input: "while/Switch" 846 attr { 847 key: "T" 848 value { 849 type: DT_INT32 850 } 851 } 852 } 853 node { 854 name: "while/Exit_1" 855 op: "Exit" 856 input: "while/Switch_1" 857 attr { 858 key: "T" 859 value { 860 type: DT_FLOAT 861 } 862 } 863 } 864 versions { 865 producer: 21 866 } 867 )EOF"; 868 869 grappler_item_.reset(new GrapplerItem); 870 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, 871 &grappler_item_->graph)); 872 grappler_item_->id = "test_graph"; 873 grappler_item_->fetch = {"while/Exit", "while/Exit_1"}; 874 } 875 876 // A simple while loop strengthened with Switch outputs xxx. 877 void CreateGrapplerItemWithLoopAnnotated() { 878 // Test graph produced in python using: 879 /* 880 with tf.Graph().as_default(): 881 i0 = tf.constant(0) 882 m0 = tf.ones([2, 2]) 883 c = lambda i, m: i < 10 884 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] 885 r = tf.while_loop( 886 c, b, loop_vars=[i0, m0], 887 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) 888 with open('/tmp/graph.pbtxt', 'w') as f: 889 f.write(str(tf.get_default_graph().as_graph_def())) 890 */ 891 const string gdef_ascii = R"EOF( 892 node { 893 name: "Const" 894 op: "Const" 895 attr { 896 key: "dtype" 897 value { 898 type: DT_INT32 899 } 900 } 901 attr { 902 key: "value" 903 value { 904 tensor { 905 dtype: DT_INT32 906 tensor_shape { 907 } 908 int_val: 0 909 } 910 } 911 } 912 attr { 913 key: "_execution_count" 914 value { 915 i: 1 916 } 917 } 918 } 919 node { 920 name: "ones" 921 op: "Const" 922 attr { 923 key: "dtype" 924 value { 925 type: DT_FLOAT 926 } 927 } 928 attr { 929 key: "value" 930 value { 931 tensor { 932 dtype: DT_FLOAT 933 tensor_shape { 934 dim { 935 size: 2 936 } 937 dim { 938 size: 2 939 } 940 } 941 float_val: 1.0 942 } 943 } 944 } 945 attr { 946 key: "_execution_count" 947 value { 948 i: 1 949 } 950 } 951 } 952 node { 953 name: "while/Enter" 954 op: "Enter" 955 input: "Const" 956 attr { 957 key: "T" 958 value { 959 type: DT_INT32 960 } 961 } 962 attr { 963 key: "frame_name" 964 value { 965 s: "while/while/" 966 } 967 } 968 attr { 969 key: "is_constant" 970 value { 971 b: false 972 } 973 } 974 attr { 975 key: "parallel_iterations" 976 value { 977 i: 10 978 } 979 } 980 attr { 981 key: "_execution_count" 982 value { 983 i: 1 984 } 985 } 986 } 987 node { 988 name: "while/Enter_1" 989 op: "Enter" 990 input: "ones" 991 attr { 992 key: "T" 993 value { 994 type: DT_FLOAT 995 } 996 } 997 attr { 998 key: "frame_name" 999 value { 1000 s: "while/while/" 1001 } 1002 } 1003 attr { 1004 key: "is_constant" 1005 value { 1006 b: false 1007 } 1008 } 1009 attr { 1010 key: "parallel_iterations" 1011 value { 1012 i: 10 1013 } 1014 } 1015 attr { 1016 key: "_execution_count" 1017 value { 1018 i: 1 1019 } 1020 } 1021 } 1022 node { 1023 name: "while/Merge" 1024 op: "Merge" 1025 input: "while/Enter" 1026 input: "while/NextIteration" 1027 attr { 1028 key: "N" 1029 value { 1030 i: 2 1031 } 1032 } 1033 attr { 1034 key: "T" 1035 value { 1036 type: DT_INT32 1037 } 1038 } 1039 attr { 1040 key: "_execution_count" 1041 value { 1042 i: 10 1043 } 1044 } 1045 } 1046 node { 1047 name: "while/Merge_1" 1048 op: "Merge" 1049 input: "while/Enter_1" 1050 input: "while/NextIteration_1" 1051 attr { 1052 key: "N" 1053 value { 1054 i: 2 1055 } 1056 } 1057 attr { 1058 key: "T" 1059 value { 1060 type: DT_FLOAT 1061 } 1062 } 1063 attr { 1064 key: "_execution_count" 1065 value { 1066 i: 10 1067 } 1068 } 1069 } 1070 node { 1071 name: "while/Less/y" 1072 op: "Const" 1073 input: "^while/Merge" 1074 attr { 1075 key: "dtype" 1076 value { 1077 type: DT_INT32 1078 } 1079 } 1080 attr { 1081 key: "value" 1082 value { 1083 tensor { 1084 dtype: DT_INT32 1085 tensor_shape { 1086 } 1087 int_val: 10 1088 } 1089 } 1090 } 1091 attr { 1092 key: "_execution_count" 1093 value { 1094 i: 10 1095 } 1096 } 1097 } 1098 node { 1099 name: "while/Less" 1100 op: "Less" 1101 input: "while/Merge" 1102 input: "while/Less/y" 1103 attr { 1104 key: "T" 1105 value { 1106 type: DT_INT32 1107 } 1108 } 1109 attr { 1110 key: "_execution_count" 1111 value { 1112 i: 10 1113 } 1114 } 1115 } 1116 node { 1117 name: "while/LoopCond" 1118 op: "LoopCond" 1119 input: "while/Less" 1120 attr { 1121 key: "_execution_count" 1122 value { 1123 i: 10 1124 } 1125 } 1126 } 1127 node { 1128 name: "while/Switch" 1129 op: "Switch" 1130 input: "while/Merge" 1131 input: "while/LoopCond" 1132 attr { 1133 key: "T" 1134 value { 1135 type: DT_INT32 1136 } 1137 } 1138 attr { 1139 key: "_class" 1140 value { 1141 list { 1142 s: "loc:@while/Merge" 1143 } 1144 } 1145 } 1146 attr { 1147 key: "_execution_count" 1148 value { 1149 i: 11 1150 } 1151 } 1152 attr { 1153 key: "_output_slot_vector" 1154 value { 1155 list { 1156 i: 1 1157 i: 1 1158 i: 1 1159 i: 1 1160 i: 1 1161 i: 1 1162 i: 1 1163 i: 1 1164 i: 1 1165 i: 1 1166 i: 0 1167 } 1168 } 1169 } 1170 } 1171 node { 1172 name: "while/Switch_1" 1173 op: "Switch" 1174 input: "while/Merge_1" 1175 input: "while/LoopCond" 1176 attr { 1177 key: "T" 1178 value { 1179 type: DT_FLOAT 1180 } 1181 } 1182 attr { 1183 key: "_class" 1184 value { 1185 list { 1186 s: "loc:@while/Merge_1" 1187 } 1188 } 1189 } 1190 attr { 1191 key: "_execution_count" 1192 value { 1193 i: 11 1194 } 1195 } 1196 attr { 1197 key: "_output_slot_vector" 1198 value { 1199 list { 1200 i: 1 1201 i: 1 1202 i: 1 1203 i: 1 1204 i: 1 1205 i: 1 1206 i: 1 1207 i: 1 1208 i: 1 1209 i: 1 1210 i: 0 1211 } 1212 } 1213 } 1214 } 1215 node { 1216 name: "while/Identity" 1217 op: "Identity" 1218 input: "while/Switch:1" 1219 attr { 1220 key: "T" 1221 value { 1222 type: DT_INT32 1223 } 1224 } 1225 attr { 1226 key: "_execution_count" 1227 value { 1228 i: 10 1229 } 1230 } 1231 } 1232 node { 1233 name: "while/Identity_1" 1234 op: "Identity" 1235 input: "while/Switch_1:1" 1236 attr { 1237 key: "T" 1238 value { 1239 type: DT_FLOAT 1240 } 1241 } 1242 attr { 1243 key: "_execution_count" 1244 value { 1245 i: 10 1246 } 1247 } 1248 } 1249 node { 1250 name: "while/add/y" 1251 op: "Const" 1252 input: "^while/Identity" 1253 attr { 1254 key: "dtype" 1255 value { 1256 type: DT_INT32 1257 } 1258 } 1259 attr { 1260 key: "value" 1261 value { 1262 tensor { 1263 dtype: DT_INT32 1264 tensor_shape { 1265 } 1266 int_val: 1 1267 } 1268 } 1269 } 1270 attr { 1271 key: "_execution_count" 1272 value { 1273 i: 10 1274 } 1275 } 1276 } 1277 node { 1278 name: "while/add" 1279 op: "Add" 1280 input: "while/Identity" 1281 input: "while/add/y" 1282 attr { 1283 key: "T" 1284 value { 1285 type: DT_INT32 1286 } 1287 } 1288 attr { 1289 key: "_execution_count" 1290 value { 1291 i: 10 1292 } 1293 } 1294 } 1295 node { 1296 name: "while/concat/axis" 1297 op: "Const" 1298 input: "^while/Identity" 1299 attr { 1300 key: "dtype" 1301 value { 1302 type: DT_INT32 1303 } 1304 } 1305 attr { 1306 key: "value" 1307 value { 1308 tensor { 1309 dtype: DT_INT32 1310 tensor_shape { 1311 } 1312 int_val: 0 1313 } 1314 } 1315 } 1316 attr { 1317 key: "_execution_count" 1318 value { 1319 i: 10 1320 } 1321 } 1322 } 1323 node { 1324 name: "while/concat" 1325 op: "ConcatV2" 1326 input: "while/Identity_1" 1327 input: "while/Identity_1" 1328 input: "while/concat/axis" 1329 attr { 1330 key: "N" 1331 value { 1332 i: 2 1333 } 1334 } 1335 attr { 1336 key: "T" 1337 value { 1338 type: DT_FLOAT 1339 } 1340 } 1341 attr { 1342 key: "Tidx" 1343 value { 1344 type: DT_INT32 1345 } 1346 } 1347 attr { 1348 key: "_execution_count" 1349 value { 1350 i: 10 1351 } 1352 } 1353 } 1354 node { 1355 name: "while/NextIteration" 1356 op: "NextIteration" 1357 input: "while/add" 1358 attr { 1359 key: "T" 1360 value { 1361 type: DT_INT32 1362 } 1363 } 1364 attr { 1365 key: "_execution_count" 1366 value { 1367 i: 10 1368 } 1369 } 1370 } 1371 node { 1372 name: "while/NextIteration_1" 1373 op: "NextIteration" 1374 input: "while/concat" 1375 attr { 1376 key: "T" 1377 value { 1378 type: DT_FLOAT 1379 } 1380 } 1381 attr { 1382 key: "_execution_count" 1383 value { 1384 i: 10 1385 } 1386 } 1387 } 1388 node { 1389 name: "while/Exit" 1390 op: "Exit" 1391 input: "while/Switch" 1392 attr { 1393 key: "T" 1394 value { 1395 type: DT_INT32 1396 } 1397 } 1398 attr { 1399 key: "_execution_count" 1400 value { 1401 i: 1 1402 } 1403 } 1404 } 1405 node { 1406 name: "while/Exit_1" 1407 op: "Exit" 1408 input: "while/Switch_1" 1409 attr { 1410 key: "T" 1411 value { 1412 type: DT_FLOAT 1413 } 1414 } 1415 attr { 1416 key: "_execution_count" 1417 value { 1418 i: 1 1419 } 1420 } 1421 } 1422 versions { 1423 producer: 21 1424 } 1425 )EOF"; 1426 1427 grappler_item_.reset(new GrapplerItem); 1428 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, 1429 &grappler_item_->graph)); 1430 grappler_item_->id = "test_graph"; 1431 grappler_item_->fetch = {"while/Exit", "while/Exit_1"}; 1432 } 1433 1434 // A simple condition graph. 1435 void CreateGrapplerItemWithCondition() { 1436 // Handcrafted test graph: a/Less -> Switch -> First/Second -> Merge. 1437 const string gdef_ascii = R"EOF( 1438 node { 1439 name: "a" 1440 op: "Const" 1441 attr { 1442 key: "dtype" 1443 value { 1444 type: DT_FLOAT 1445 } 1446 } 1447 attr { 1448 key: "value" 1449 value { 1450 tensor { 1451 dtype: DT_FLOAT 1452 tensor_shape { 1453 } 1454 float_val: 2.0 1455 } 1456 } 1457 } 1458 } 1459 node { 1460 name: "Less" 1461 op: "Const" 1462 attr { 1463 key: "dtype" 1464 value { 1465 type: DT_BOOL 1466 } 1467 } 1468 attr { 1469 key: "value" 1470 value { 1471 tensor { 1472 dtype: DT_BOOL 1473 tensor_shape { 1474 } 1475 tensor_content: "\001" 1476 } 1477 } 1478 } 1479 } 1480 node { 1481 name: "Switch" 1482 op: "Switch" 1483 input: "a" 1484 input: "Less" 1485 attr { 1486 key: "T" 1487 value { 1488 type: DT_FLOAT 1489 } 1490 } 1491 } 1492 node { 1493 name: "First" 1494 op: "Identity" 1495 input: "Switch" 1496 attr { 1497 key: "T" 1498 value { 1499 type: DT_FLOAT 1500 } 1501 } 1502 } 1503 node { 1504 name: "Second" 1505 op: "Identity" 1506 input: "Switch:1" 1507 attr { 1508 key: "T" 1509 value { 1510 type: DT_FLOAT 1511 } 1512 } 1513 } 1514 node { 1515 name: "Merge" 1516 op: "Merge" 1517 input: "First" 1518 input: "Second" 1519 attr { 1520 key: "N" 1521 value { 1522 i: 2 1523 } 1524 } 1525 attr { 1526 key: "T" 1527 value { 1528 type: DT_FLOAT 1529 } 1530 } 1531 } 1532 versions { 1533 producer: 27 1534 })EOF"; 1535 1536 grappler_item_.reset(new GrapplerItem); 1537 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, 1538 &grappler_item_->graph)); 1539 grappler_item_->id = "test_graph"; 1540 grappler_item_->fetch = {"Merge"}; 1541 } 1542 1543 // Create a FusedBatchNorm op that has multiple output ports. 1544 void CreateGrapplerItemWithInterDeviceTransfers() { 1545 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0); 1546 1547 // Create a FusedBatchNorm op that has multiple output ports. 1548 auto x = ops::RandomUniform( 1549 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); 1550 auto scale = 1551 ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT); 1552 auto offset = 1553 ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT); 1554 auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT); 1555 auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT); 1556 1557 auto batch_norm = ops::FusedBatchNorm( 1558 s.WithOpName("bn"), x, scale, offset, mean, var, 1559 ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f)); 1560 auto y = batch_norm.y; 1561 auto batch_mean = batch_norm.batch_mean; 1562 auto batch_var = batch_norm.batch_variance; 1563 // y1 and y2 take the same tensor, so there should be only 1 Send and Recv. 1564 auto y1 = ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y); 1565 auto y2 = ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y); 1566 // batch_mean1 and batch_var1 take different output ports, so each will 1567 // initiate Send/Recv. 1568 auto batch_mean1 = ops::Identity( 1569 s.WithOpName("batch_mean1").WithDevice(kCPU1), batch_mean); 1570 auto batch_var1 = 1571 ops::Identity(s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var); 1572 // This is control dependency. 1573 auto control_dep = ops::NoOp(s.WithOpName("control_dep") 1574 .WithControlDependencies(y) 1575 .WithDevice(kCPU1)); 1576 1577 GraphDef def; 1578 TF_CHECK_OK(s.ToGraphDef(&def)); 1579 1580 grappler_item_.reset(new GrapplerItem); 1581 grappler_item_->id = "test_conv2d_graph"; 1582 grappler_item_->graph = def; 1583 grappler_item_->fetch = {"y1", "y2", "batch_mean1", "batch_var1", 1584 "control_dep"}; 1585 1586 dependency_["bn"] = {"x", "mean", "var"}; 1587 dependency_["y1"] = {"bn"}; 1588 dependency_["y2"] = {"bn"}; 1589 dependency_["batch_mean1"] = {"bn"}; 1590 dependency_["batch_var1"] = {"bn"}; 1591 dependency_["control_dep"] = {"bn"}; 1592 } 1593 1594 // Call this after creating grappler_item_ and setting up dependency_. 1595 void InitScheduler() { TF_ASSERT_OK(scheduler_->Init(grappler_item_.get())); } 1596 1597 // Returns cost based on op. 1598 Costs SimplePredictCosts(const OpContext& op_context) const { 1599 Costs c; 1600 int64 exec_cost = 0; 1601 if (op_context.op_info.op() == "MatMul") { 1602 exec_cost = 2000000000; 1603 } else if (op_context.op_info.op() == "RandomUniform") { 1604 exec_cost = 1000000000; 1605 } else { 1606 exec_cost = 1000; 1607 } 1608 c.execution_time = Costs::NanoSeconds(exec_cost); 1609 return c; 1610 } 1611 1612 // Call this after init scheduler_. Scheduler stops after executing 1613 // target_node. 1614 std::unordered_map<string, OpContext> RunScheduler( 1615 const string& target_node) { 1616 std::unordered_map<string, OpContext> ops_executed; 1617 bool more_nodes = true; 1618 do { 1619 OpContext op_context = scheduler_->GetCurrNode(); 1620 ops_executed[op_context.name] = op_context; 1621 std::cout << op_context.name << std::endl; 1622 1623 Costs node_costs = SimplePredictCosts(op_context); 1624 1625 // Check scheduling order. 1626 auto it = dependency_.find(op_context.name); 1627 if (it != dependency_.end()) { 1628 for (const auto& preceding_node : it->second) { 1629 EXPECT_GT(ops_executed.count(preceding_node), 0); 1630 } 1631 } 1632 more_nodes = scheduler_->MarkCurrNodeExecuted(node_costs); 1633 1634 if (op_context.name == target_node) { 1635 // Scheduler has the state after executing the target node. 1636 break; 1637 } 1638 } while (more_nodes); 1639 return ops_executed; 1640 } 1641 1642 // Helper method for validating a vector. 1643 template <typename T> 1644 void ExpectVectorEq(const std::vector<T>& expected, 1645 const std::vector<T>& test_elements) { 1646 // Set of expected elements for an easy comparison. 1647 std::set<T> expected_set(expected.begin(), expected.end()); 1648 for (const auto& element : test_elements) { 1649 EXPECT_GT(expected_set.count(element), 0); 1650 } 1651 EXPECT_EQ(expected.size(), test_elements.size()); 1652 } 1653 1654 // Helper method that checks the name of nodes. 1655 void ValidateNodeDefs(const std::vector<string>& expected, 1656 const std::vector<const NodeDef*>& node_defs) { 1657 std::vector<string> node_names; 1658 std::transform(node_defs.begin(), node_defs.end(), 1659 std::back_inserter(node_names), 1660 [](const NodeDef* node) { return node->name(); }); 1661 ExpectVectorEq(expected, node_names); 1662 } 1663 1664 // Helper method for validating a set. 1665 template <typename T> 1666 void ExpectSetEq(const std::set<T>& expected, 1667 const std::set<T>& test_elements) { 1668 for (const auto& element : test_elements) { 1669 EXPECT_GT(expected.count(element), 0); 1670 } 1671 EXPECT_EQ(expected.size(), test_elements.size()); 1672 } 1673 1674 // Helper method tthat checks name - port pairs. 1675 void ValidateMemoryUsageSnapshot( 1676 const std::vector<string>& expected_names, const int port_num_expected, 1677 const std::unordered_set<std::pair<const NodeDef*, int>, 1678 DeviceState::NodePairHash>& mem_usage_snapshot) { 1679 std::set<std::pair<string, int>> nodes_at_peak_mem_usage; 1680 std::transform( 1681 mem_usage_snapshot.begin(), mem_usage_snapshot.end(), 1682 std::inserter(nodes_at_peak_mem_usage, nodes_at_peak_mem_usage.begin()), 1683 [](const std::pair<const NodeDef*, int>& node_port) { 1684 return std::make_pair(node_port.first->name(), node_port.second); 1685 }); 1686 std::set<std::pair<string, int>> expected; 1687 std::transform(expected_names.begin(), expected_names.end(), 1688 std::inserter(expected, expected.begin()), 1689 [port_num_expected](const string& name) { 1690 return std::make_pair(name, port_num_expected); 1691 }); 1692 ExpectSetEq(expected, nodes_at_peak_mem_usage); 1693 } 1694 1695 // Helper method for checking nodes dependency. 1696 void ValidateDependencyChain( 1697 const std::unordered_map<string, int64>& start_times, 1698 const std::vector<string>& nodes_in_dependency_order) { 1699 int64 prev_node_time = -1; 1700 for (const auto& node : nodes_in_dependency_order) { 1701 int64 curr_node_time = start_times.at(node); 1702 EXPECT_GE(curr_node_time, prev_node_time); 1703 prev_node_time = curr_node_time; 1704 } 1705 } 1706 1707 // cluster_ and scheduler_ are initialized in the c'tor. 1708 std::unique_ptr<VirtualCluster> cluster_; 1709 std::unique_ptr<TestVirtualScheduler> scheduler_; 1710 1711 // grappler_item_ will be initialized differently for each test case. 1712 std::unique_ptr<GrapplerItem> grappler_item_; 1713 // Node name -> its preceding nodes map for testing scheduling order. 1714 std::unordered_map<string, std::vector<string>> dependency_; 1715 1716 // Shared params for Conv2D related graphs: 1717 const int batch_size_ = 4; 1718 const int width_ = 10; 1719 const int height_ = 10; 1720 const int depth_in_ = 8; 1721 const int kernel_ = 3; 1722 const int depth_out_ = 16; 1723 }; 1724 1725 // Test that FIFOManager correctly returns the current node with only 1 node. 1726 TEST_F(VirtualSchedulerTest, GetSingleNodeFIFOManager) { 1727 // Init. 1728 FIFOManager manager = FIFOManager(); 1729 1730 // Add the node to FIFOManager. 1731 manager.AddNode(&node1_); 1732 EXPECT_EQ("Node1", manager.GetCurrNode()->name()); 1733 } 1734 1735 // Test that FIFOManager removes the only node contained within. 1736 TEST_F(VirtualSchedulerTest, RemoveSingleNodeFIFOManager) { 1737 // Init. 1738 FIFOManager manager = FIFOManager(); 1739 1740 // Add the node to FIFOManager. 1741 manager.AddNode(&node1_); 1742 1743 // Remove the only node in FIFOManager. 1744 manager.RemoveCurrNode(); 1745 EXPECT_TRUE(manager.Empty()); 1746 } 1747 1748 // Test that FIFOManager can remove multiple nodes and returns the current node 1749 // in the right order 1750 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleFIFOManager) { 1751 // Init. 1752 FIFOManager manager = FIFOManager(); 1753 1754 // Add the nodes to FIFOManager. 1755 manager.AddNode(&node1_); 1756 manager.AddNode(&node2_); 1757 manager.AddNode(&node3_); 1758 manager.AddNode(&node4_); 1759 1760 // Keep checking current node while removing nodes from manager. 1761 EXPECT_EQ("Node1", manager.GetCurrNode()->name()); 1762 manager.RemoveCurrNode(); 1763 EXPECT_EQ("Node2", manager.GetCurrNode()->name()); 1764 manager.RemoveCurrNode(); 1765 EXPECT_EQ("Node3", manager.GetCurrNode()->name()); 1766 manager.RemoveCurrNode(); 1767 EXPECT_EQ("Node4", manager.GetCurrNode()->name()); 1768 manager.RemoveCurrNode(); 1769 EXPECT_TRUE(manager.Empty()); 1770 } 1771 1772 // Test that FIFOManager can remove multiple nodes and add more nodes, still 1773 // returning the current node in the right order 1774 TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleFIFOManager) { 1775 // Init. 1776 FIFOManager manager = FIFOManager(); 1777 1778 // Add the nodes to FIFOManager. 1779 manager.AddNode(&node1_); 1780 manager.AddNode(&node2_); 1781 manager.AddNode(&node3_); 1782 manager.AddNode(&node4_); 1783 1784 // Keep checking current node as nodes are removed and added. 1785 EXPECT_EQ("Node1", manager.GetCurrNode()->name()); 1786 manager.RemoveCurrNode(); 1787 EXPECT_EQ("Node2", manager.GetCurrNode()->name()); 1788 manager.AddNode(&node5_); 1789 // GetCurrNode() should return the same node even if some nodes are added, 1790 // until RemoveCurrNode() is called. 1791 EXPECT_EQ("Node2", manager.GetCurrNode()->name()); 1792 manager.RemoveCurrNode(); 1793 EXPECT_EQ("Node3", manager.GetCurrNode()->name()); 1794 manager.RemoveCurrNode(); 1795 EXPECT_EQ("Node4", manager.GetCurrNode()->name()); 1796 manager.AddNode(&node6_); 1797 EXPECT_EQ("Node4", manager.GetCurrNode()->name()); 1798 manager.RemoveCurrNode(); 1799 EXPECT_EQ("Node5", manager.GetCurrNode()->name()); 1800 manager.RemoveCurrNode(); 1801 EXPECT_EQ("Node6", manager.GetCurrNode()->name()); 1802 manager.RemoveCurrNode(); 1803 EXPECT_TRUE(manager.Empty()); 1804 } 1805 1806 // Test that LIFOManager correctly returns the current node with only 1 node. 1807 TEST_F(VirtualSchedulerTest, GetSingleNodeLIFOManager) { 1808 // Init. 1809 LIFOManager manager = LIFOManager(); 1810 1811 // Add the node to LIFOManager. 1812 manager.AddNode(&node1_); 1813 EXPECT_EQ("Node1", manager.GetCurrNode()->name()); 1814 } 1815 1816 // Test that LIFOManager removes the only node contained within. 1817 TEST_F(VirtualSchedulerTest, RemoveSingleNodeLIFOManager) { 1818 // Init. 1819 LIFOManager manager = LIFOManager(); 1820 1821 // Add the node to LIFOManager. 1822 manager.AddNode(&node1_); 1823 1824 // Remove the only node in LIFOManager. 1825 manager.RemoveCurrNode(); 1826 EXPECT_TRUE(manager.Empty()); 1827 } 1828 1829 // Test that LIFOManager can remove multiple nodes and returns the current node 1830 // in the right order 1831 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleLIFOManager) { 1832 // Init. 1833 LIFOManager manager = LIFOManager(); 1834 1835 // Add the nodes to LIFOManager. 1836 manager.AddNode(&node1_); 1837 manager.AddNode(&node2_); 1838 manager.AddNode(&node3_); 1839 manager.AddNode(&node4_); 1840 1841 // Keep checking current node while removing nodes from manager. 1842 EXPECT_EQ("Node4", manager.GetCurrNode()->name()); 1843 manager.RemoveCurrNode(); 1844 EXPECT_EQ("Node3", manager.GetCurrNode()->name()); 1845 manager.RemoveCurrNode(); 1846 EXPECT_EQ("Node2", manager.GetCurrNode()->name()); 1847 manager.RemoveCurrNode(); 1848 EXPECT_EQ("Node1", manager.GetCurrNode()->name()); 1849 manager.RemoveCurrNode(); 1850 EXPECT_TRUE(manager.Empty()); 1851 } 1852 1853 // Test that LIFOManager can remove multiple nodes (must be removing the current 1854 // node) and add more nodes, still returning the current node in the right order 1855 TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleLIFOManager) { 1856 // Init. 1857 LIFOManager manager = LIFOManager(); 1858 1859 // Add the nodes to LIFOManager. 1860 manager.AddNode(&node1_); 1861 manager.AddNode(&node2_); 1862 manager.AddNode(&node3_); 1863 manager.AddNode(&node4_); 1864 1865 // Keep checking current node as nodes are removed and added. 1866 EXPECT_EQ("Node4", manager.GetCurrNode()->name()); 1867 manager.RemoveCurrNode(); 1868 EXPECT_EQ("Node3", manager.GetCurrNode()->name()); 1869 manager.AddNode(&node5_); 1870 // GetCurrNode() should return the same node even if some nodes are added, 1871 // until RemoveCurrNode() is called. 1872 EXPECT_EQ("Node3", manager.GetCurrNode()->name()); 1873 manager.RemoveCurrNode(); 1874 EXPECT_EQ("Node5", manager.GetCurrNode()->name()); 1875 manager.RemoveCurrNode(); 1876 EXPECT_EQ("Node2", manager.GetCurrNode()->name()); 1877 manager.AddNode(&node6_); 1878 EXPECT_EQ("Node2", manager.GetCurrNode()->name()); 1879 manager.RemoveCurrNode(); 1880 EXPECT_EQ("Node6", manager.GetCurrNode()->name()); 1881 manager.RemoveCurrNode(); 1882 EXPECT_EQ("Node1", manager.GetCurrNode()->name()); 1883 manager.RemoveCurrNode(); 1884 EXPECT_TRUE(manager.Empty()); 1885 } 1886 1887 TEST_F(VirtualSchedulerTest, GetSingleNodeFirstReadyManager) { 1888 FirstReadyManager manager; 1889 manager.Init(&node_states_); 1890 1891 manager.AddNode(&node1_); 1892 EXPECT_EQ("Node1", manager.GetCurrNode()->name()); 1893 } 1894 1895 TEST_F(VirtualSchedulerTest, RemoveSingleNodeFirstReadyManager) { 1896 FirstReadyManager manager; 1897 manager.Init(&node_states_); 1898 manager.AddNode(&node1_); 1899 manager.RemoveCurrNode(); 1900 EXPECT_TRUE(manager.Empty()); 1901 } 1902 1903 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleFirstReadyManager) { 1904 FirstReadyManager manager; 1905 manager.Init(&node_states_); 1906 // Insert nodes in some random order. 1907 manager.AddNode(&node2_); 1908 manager.AddNode(&node1_); 1909 manager.AddNode(&node4_); 1910 manager.AddNode(&node5_); 1911 manager.AddNode(&node3_); 1912 manager.AddNode(&node6_); 1913 1914 // In whatever order we insert nodes, we get the same order based on nodes' 1915 // time_ready. 1916 EXPECT_EQ("Node6", manager.GetCurrNode()->name()); 1917 manager.RemoveCurrNode(); 1918 EXPECT_EQ("Node5", manager.GetCurrNode()->name()); 1919 manager.RemoveCurrNode(); 1920 EXPECT_EQ("Node4", manager.GetCurrNode()->name()); 1921 manager.RemoveCurrNode(); 1922 EXPECT_EQ("Node3", manager.GetCurrNode()->name()); 1923 manager.RemoveCurrNode(); 1924 EXPECT_EQ("Node2", manager.GetCurrNode()->name()); 1925 manager.RemoveCurrNode(); 1926 EXPECT_EQ("Node1", manager.GetCurrNode()->name()); 1927 manager.RemoveCurrNode(); 1928 EXPECT_TRUE(manager.Empty()); 1929 } 1930 1931 TEST_F(VirtualSchedulerTest, GetCurrNodeFirstReadyManager) { 1932 FirstReadyManager manager; 1933 manager.Init(&node_states_); 1934 // Insert nodes in some random order. 1935 manager.AddNode(&node2_); 1936 manager.AddNode(&node1_); 1937 manager.AddNode(&node4_); 1938 manager.AddNode(&node5_); 1939 manager.AddNode(&node3_); 1940 manager.AddNode(&node6_); 1941 1942 // Among these nodes, node6 has the smallest time_ready, hence, GetCurrNode() 1943 // should return it. 1944 EXPECT_EQ("Node6", manager.GetCurrNode()->name()); 1945 // Now insret a few other nodes, but their time_ready's are even smaller than 1946 // that of Node6. Before calling RemoveCurrNode(), GetCurrNode() should return 1947 // the same node, Node6, in this case. 1948 1949 NodeDef node7; 1950 NodeDef node8; 1951 NodeDef node9; 1952 NodeSetUp("Node7", kConv2D, kCPU0, 5, &node7); 1953 NodeSetUp("Node8", kConv2D, kCPU0, 4, &node8); 1954 NodeSetUp("Node9", kConv2D, kCPU0, 3, &node9); 1955 1956 manager.AddNode(&node7); 1957 EXPECT_EQ("Node6", manager.GetCurrNode()->name()); 1958 1959 manager.AddNode(&node8); 1960 EXPECT_EQ("Node6", manager.GetCurrNode()->name()); 1961 1962 manager.RemoveCurrNode(); 1963 // Now Node6 is removed, and GetCurrNode() will return Node8. 1964 EXPECT_EQ("Node8", manager.GetCurrNode()->name()); 1965 1966 // Again, AddNode shouldn't change GetCurrNode(). 1967 manager.AddNode(&node9); 1968 EXPECT_EQ("Node8", manager.GetCurrNode()->name()); 1969 1970 manager.RemoveCurrNode(); 1971 EXPECT_EQ("Node9", manager.GetCurrNode()->name()); 1972 manager.RemoveCurrNode(); 1973 EXPECT_EQ("Node7", manager.GetCurrNode()->name()); 1974 manager.RemoveCurrNode(); 1975 EXPECT_EQ("Node5", manager.GetCurrNode()->name()); 1976 manager.RemoveCurrNode(); 1977 EXPECT_EQ("Node4", manager.GetCurrNode()->name()); 1978 manager.RemoveCurrNode(); 1979 EXPECT_EQ("Node3", manager.GetCurrNode()->name()); 1980 manager.RemoveCurrNode(); 1981 EXPECT_EQ("Node2", manager.GetCurrNode()->name()); 1982 manager.RemoveCurrNode(); 1983 EXPECT_EQ("Node1", manager.GetCurrNode()->name()); 1984 manager.RemoveCurrNode(); 1985 EXPECT_TRUE(manager.Empty()); 1986 } 1987 1988 TEST_F(VirtualSchedulerTest, DeterminismInFirstReadyManager) { 1989 FirstReadyManager manager1; 1990 manager1.Init(&node_states_); 1991 FirstReadyManager manager2; 1992 manager2.Init(&node_states_); 1993 1994 // 6 nodes with same time_ready. 1995 NodeDef node7; 1996 NodeDef node8; 1997 NodeDef node9; 1998 NodeDef node10; 1999 NodeDef node11; 2000 NodeDef node12; 2001 NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7); 2002 NodeSetUp("Node8", kConv2D, kCPU0, 1000, &node8); 2003 NodeSetUp("Node9", kConv2D, kCPU0, 1000, &node9); 2004 NodeSetUp("Node10", kConv2D, kCPU0, 1000, &node10); 2005 NodeSetUp("Node11", kConv2D, kCPU0, 1000, &node11); 2006 NodeSetUp("Node12", kConv2D, kCPU0, 1000, &node12); 2007 2008 // Add the above 6 nodes to manager1. 2009 manager1.AddNode(&node7); 2010 manager1.AddNode(&node8); 2011 manager1.AddNode(&node9); 2012 manager1.AddNode(&node10); 2013 manager1.AddNode(&node11); 2014 manager1.AddNode(&node12); 2015 2016 // Add the above 6 nodes to manager2, but in a different order. 2017 manager2.AddNode(&node8); 2018 manager2.AddNode(&node11); 2019 manager2.AddNode(&node9); 2020 manager2.AddNode(&node10); 2021 manager2.AddNode(&node7); 2022 manager2.AddNode(&node12); 2023 2024 // Expect both managers return the same nodes for deterministic node 2025 // scheduling. 2026 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name()); 2027 manager1.RemoveCurrNode(); 2028 manager2.RemoveCurrNode(); 2029 2030 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name()); 2031 manager1.RemoveCurrNode(); 2032 manager2.RemoveCurrNode(); 2033 2034 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name()); 2035 manager1.RemoveCurrNode(); 2036 manager2.RemoveCurrNode(); 2037 2038 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name()); 2039 manager1.RemoveCurrNode(); 2040 manager2.RemoveCurrNode(); 2041 2042 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name()); 2043 manager1.RemoveCurrNode(); 2044 manager2.RemoveCurrNode(); 2045 2046 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name()); 2047 manager1.RemoveCurrNode(); 2048 manager2.RemoveCurrNode(); 2049 2050 EXPECT_TRUE(manager1.Empty()); 2051 EXPECT_TRUE(manager2.Empty()); 2052 } 2053 2054 TEST_F(VirtualSchedulerTest, RemoveSingleNodeCompositeNodeManager) { 2055 CompositeNodeManager manager; 2056 manager.Init(&node_states_); 2057 manager.AddNode(&node1_); 2058 manager.RemoveCurrNode(); 2059 EXPECT_TRUE(manager.Empty()); 2060 } 2061 2062 TEST_F(VirtualSchedulerTest, RemoveSingleNodeComopsiteNodeManager) { 2063 CompositeNodeManager manager; 2064 manager.Init(&node_states_); 2065 2066 manager.AddNode(&node1_); 2067 manager.RemoveCurrNode(); 2068 EXPECT_TRUE(manager.Empty()); 2069 } 2070 2071 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleComopsiteNodeManager) { 2072 CompositeNodeManager manager; 2073 manager.Init(&node_states_); 2074 2075 // Add the nodes to LIFOManager. 2076 manager.AddNode(&node1_); 2077 manager.AddNode(&node2_); 2078 manager.AddNode(&node3_); 2079 manager.AddNode(&node4_); 2080 2081 // Keep checking current node as nodes are removed and added. 2082 EXPECT_EQ("Node4", manager.GetCurrNode()->name()); 2083 manager.RemoveCurrNode(); 2084 EXPECT_EQ("Node3", manager.GetCurrNode()->name()); 2085 manager.AddNode(&node5_); 2086 // GetCurrNode() should return the same node even if some nodes are added, 2087 // until RemoveCurrNode() is called. 2088 EXPECT_EQ("Node3", manager.GetCurrNode()->name()); 2089 manager.RemoveCurrNode(); 2090 EXPECT_EQ("Node5", manager.GetCurrNode()->name()); 2091 manager.RemoveCurrNode(); 2092 EXPECT_EQ("Node2", manager.GetCurrNode()->name()); 2093 manager.AddNode(&node6_); 2094 EXPECT_EQ("Node2", manager.GetCurrNode()->name()); 2095 manager.RemoveCurrNode(); 2096 EXPECT_EQ("Node6", manager.GetCurrNode()->name()); 2097 manager.RemoveCurrNode(); 2098 EXPECT_EQ("Node1", manager.GetCurrNode()->name()); 2099 manager.RemoveCurrNode(); 2100 EXPECT_TRUE(manager.Empty()); 2101 } 2102 2103 TEST_F(VirtualSchedulerTest, MultiDeviceSendRecvComopsiteNodeManager) { 2104 CompositeNodeManager manager; 2105 manager.Init(&node_states_); 2106 // Additional nodes on kCPU1 2107 NodeDef node7; 2108 NodeDef node8; 2109 NodeDef node9; 2110 NodeSetUp("Node7", kConv2D, kCPU1, 1001, &node7); 2111 NodeSetUp("Node8", kConv2D, kCPU1, 2001, &node8); 2112 NodeSetUp("Node9", kConv2D, kCPU1, 3001, &node9); 2113 2114 // Send and Recv nodes. 2115 NodeDef send1; 2116 NodeDef send2; 2117 NodeDef recv1; 2118 NodeDef recv2; 2119 NodeSetUp("Send1", kSend, kChannelFrom0To1, 2002, &send1); 2120 NodeSetUp("Send2", kSend, kChannelFrom1To0, 2005, &send2); 2121 NodeSetUp("Recv1", kRecv, kCPU0, 2003, &recv1); 2122 NodeSetUp("Recv2", kRecv, kCPU1, 2004, &recv2); 2123 2124 // Insert nodes. 2125 manager.AddNode(&node1_); 2126 manager.AddNode(&node2_); 2127 manager.AddNode(&node3_); 2128 manager.AddNode(&node4_); 2129 manager.AddNode(&node5_); 2130 manager.AddNode(&node6_); 2131 manager.AddNode(&node7); 2132 manager.AddNode(&node8); 2133 manager.AddNode(&node9); 2134 manager.AddNode(&send1); 2135 manager.AddNode(&send2); 2136 manager.AddNode(&recv1); 2137 manager.AddNode(&recv2); 2138 2139 // on kCPU0; last one is node6_, on kCPU1: last one is node9; 2140 // so choose one that has earliest time_ready among node6_, node9, 2141 // Send1, Send2, Recv1, and Recv2. 2142 EXPECT_EQ("Node6", manager.GetCurrNode()->name()); 2143 manager.RemoveCurrNode(); 2144 // Then, the next one on kCPU0 is node5_; choose the earliest time_ready node 2145 // among node5_, node9, Send1, Send2, Recv1, and Recv2. 2146 EXPECT_EQ("Node5", manager.GetCurrNode()->name()); 2147 manager.RemoveCurrNode(); 2148 // Next, choose among node4_, node9, Send1, Send2, Recv1, and Recv2. 2149 EXPECT_EQ("Send1", manager.GetCurrNode()->name()); 2150 manager.RemoveCurrNode(); 2151 // Next, choose among node4_, node9, Sen2, Recv1, and Recv2. 2152 EXPECT_EQ("Recv1", manager.GetCurrNode()->name()); 2153 manager.RemoveCurrNode(); 2154 // Next, choose among node4_, node9, Send2, and Recv2. 2155 EXPECT_EQ("Recv2", manager.GetCurrNode()->name()); 2156 manager.RemoveCurrNode(); 2157 // Next, choose among node4_, node9, and Send2. 2158 EXPECT_EQ("Send2", manager.GetCurrNode()->name()); 2159 manager.RemoveCurrNode(); 2160 // Next, choose between node4_, node9. 2161 EXPECT_EQ("Node4", manager.GetCurrNode()->name()); 2162 manager.RemoveCurrNode(); 2163 // Next, choose between node3_, node9. 2164 EXPECT_EQ("Node9", manager.GetCurrNode()->name()); 2165 manager.RemoveCurrNode(); 2166 // Next, choose between node3_, node8. 2167 EXPECT_EQ("Node8", manager.GetCurrNode()->name()); 2168 manager.RemoveCurrNode(); 2169 // Next, choose between node3_, node7. 2170 EXPECT_EQ("Node7", manager.GetCurrNode()->name()); 2171 manager.RemoveCurrNode(); 2172 // Then, just the nodes on kCPU1 -- LIFO. 2173 EXPECT_EQ("Node3", manager.GetCurrNode()->name()); 2174 manager.RemoveCurrNode(); 2175 EXPECT_EQ("Node2", manager.GetCurrNode()->name()); 2176 manager.RemoveCurrNode(); 2177 EXPECT_EQ("Node1", manager.GetCurrNode()->name()); 2178 manager.RemoveCurrNode(); 2179 EXPECT_TRUE(manager.Empty()); 2180 } 2181 2182 TEST_F(VirtualSchedulerTest, DeterminismInCompositeNodeManager) { 2183 CompositeNodeManager manager; 2184 manager.Init(&node_states_); 2185 CompositeNodeManager manager2; 2186 manager2.Init(&node_states_); 2187 2188 // 6 nodes with same time_ready. 2189 NodeDef node7; 2190 NodeDef node8; 2191 NodeDef node9; 2192 NodeDef node10; 2193 NodeDef node11; 2194 NodeDef node12; 2195 NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7); 2196 NodeSetUp("Node8", kSend, kCPU0, 1000, &node8); 2197 NodeSetUp("Node9", kRecv, kCPU0, 1000, &node9); 2198 NodeSetUp("Node10", kConv2D, kCPU0, 999, &node10); 2199 NodeSetUp("Node11", kRecv, kCPU0, 999, &node11); 2200 NodeSetUp("Node12", kConv2D, kCPU1, 1000, &node12); 2201 2202 // Add Nodes 7 to 9 to manager. 2203 manager.AddNode(&node7); 2204 manager.AddNode(&node8); 2205 manager.AddNode(&node9); 2206 2207 // It should return _Send, Recv, and the other op order, when the candidate 2208 // nodes have same time_ready. 2209 EXPECT_EQ("Node8", manager.GetCurrNode()->name()); 2210 EXPECT_EQ(kSend, manager.GetCurrNode()->op()); 2211 manager.RemoveCurrNode(); 2212 EXPECT_EQ("Node9", manager.GetCurrNode()->name()); 2213 EXPECT_EQ(kRecv, manager.GetCurrNode()->op()); 2214 manager.RemoveCurrNode(); 2215 EXPECT_EQ("Node7", manager.GetCurrNode()->name()); 2216 EXPECT_EQ(kConv2D, manager.GetCurrNode()->op()); 2217 manager.RemoveCurrNode(); 2218 EXPECT_TRUE(manager.Empty()); 2219 2220 // Add Nodes 7 to 9 to manager, but in a different order. 2221 manager.AddNode(&node9); 2222 manager.AddNode(&node8); 2223 manager.AddNode(&node7); 2224 2225 // Expect same order (_Send, _Recv, and the other op), regardless of Add 2226 // order. 2227 EXPECT_EQ("Node8", manager.GetCurrNode()->name()); 2228 EXPECT_EQ(kSend, manager.GetCurrNode()->op()); 2229 manager.RemoveCurrNode(); 2230 EXPECT_EQ("Node9", manager.GetCurrNode()->name()); 2231 EXPECT_EQ(kRecv, manager.GetCurrNode()->op()); 2232 manager.RemoveCurrNode(); 2233 EXPECT_EQ("Node7", manager.GetCurrNode()->name()); 2234 EXPECT_EQ(kConv2D, manager.GetCurrNode()->op()); 2235 manager.RemoveCurrNode(); 2236 EXPECT_TRUE(manager.Empty()); 2237 2238 // Conv2D's time_ready < Send's time_ready; Expect Conv2D first. 2239 manager.AddNode(&node8); 2240 manager.AddNode(&node10); 2241 EXPECT_EQ("Node10", manager.GetCurrNode()->name()); 2242 EXPECT_EQ(kConv2D, manager.GetCurrNode()->op()); 2243 manager.RemoveCurrNode(); 2244 EXPECT_EQ("Node8", manager.GetCurrNode()->name()); 2245 EXPECT_EQ(kSend, manager.GetCurrNode()->op()); 2246 manager.RemoveCurrNode(); 2247 EXPECT_TRUE(manager.Empty()); 2248 2249 // Recv's time_ready < Send' time_ready; Expect Recv first. 2250 manager.AddNode(&node11); 2251 manager.AddNode(&node8); 2252 EXPECT_EQ("Node11", manager.GetCurrNode()->name()); 2253 EXPECT_EQ(kRecv, manager.GetCurrNode()->op()); 2254 manager.RemoveCurrNode(); 2255 EXPECT_EQ("Node8", manager.GetCurrNode()->name()); 2256 EXPECT_EQ(kSend, manager.GetCurrNode()->op()); 2257 manager.RemoveCurrNode(); 2258 EXPECT_TRUE(manager.Empty()); 2259 2260 // Node7 and 12 are normal ops with the same time_ready, placed on different 2261 // devices. These two nodes are added to manager and manager2, but in 2262 // different orders; Expect GetCurrNode() returns the nodes in the same order. 2263 manager.AddNode(&node7); 2264 manager.AddNode(&node12); 2265 2266 manager2.AddNode(&node12); 2267 manager2.AddNode(&node7); 2268 2269 EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name()); 2270 manager.RemoveCurrNode(); 2271 manager2.RemoveCurrNode(); 2272 EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name()); 2273 manager.RemoveCurrNode(); 2274 manager2.RemoveCurrNode(); 2275 EXPECT_TRUE(manager.Empty()); 2276 } 2277 2278 // Create small graph, run predict costs on it, make sure the costs from the 2279 // summary match the hand-calculated costs. 2280 TEST_F(VirtualSchedulerTest, SummaryCostTest) { 2281 // Run matmul test. 2282 CreateGrapplerItemWithMatmulChain(); 2283 InitScheduler(); 2284 auto ops_executed = RunScheduler(""); 2285 Costs c = scheduler_->Summary(); 2286 2287 // RandomUniform - 5 * 1s 2288 // Matmuls - 4 * 2s = 8 2289 // Misc - 5 * 1us 2290 // Total: 13000005 2291 EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count()); 2292 EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total); 2293 EXPECT_FALSE(c.inaccurate); 2294 EXPECT_EQ(0, c.num_ops_with_unknown_shapes); 2295 } 2296 2297 // Like the above SummaryCostTest, but makes sure the stepstats timeline is 2298 // correct. 2299 TEST_F(VirtualSchedulerTest, SummaryCostStepStatsTest) { 2300 // Run matmul test. 2301 CreateGrapplerItemWithMatmulChain(); 2302 InitScheduler(); 2303 auto ops_executed = RunScheduler(""); 2304 RunMetadata metadata; 2305 Costs c = scheduler_->Summary(&metadata); 2306 StepStats stepstats = metadata.step_stats(); 2307 EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count()); 2308 EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total); 2309 EXPECT_FALSE(c.inaccurate); 2310 EXPECT_EQ(0, c.num_ops_with_unknown_shapes); 2311 2312 // Should only be 1 device! 2313 EXPECT_EQ(1, stepstats.dev_stats().size()); 2314 2315 // Create a map of op name -> start and end times (micros). 2316 std::map<string, std::pair<int64, int64>> start_end_times; 2317 for (const auto& device_step_stats : stepstats.dev_stats()) { 2318 for (const auto& stats : device_step_stats.node_stats()) { 2319 int64 start = stats.all_start_micros(); 2320 int64 end = start + stats.all_end_rel_micros(); 2321 start_end_times[stats.node_name()] = std::pair<int64, int64>(start, end); 2322 2323 // Make sure that the output properties are correct for 2324 // MatMul and RandomUniform operations. 2325 // We only check for dtype, and shape (excluding alloc) 2326 // since alloc is not set by the virtual scheduler. 2327 if (stats.timeline_label() == "MatMul" || 2328 stats.timeline_label() == "RandomUniform") { 2329 EXPECT_EQ(1, stats.output().size()); 2330 for (const auto& output : stats.output()) { 2331 EXPECT_EQ(DT_FLOAT, output.tensor_description().dtype()); 2332 EXPECT_EQ(2, output.tensor_description().shape().dim().size()); 2333 for (const auto& dim : output.tensor_description().shape().dim()) { 2334 EXPECT_EQ(3200, dim.size()); 2335 } 2336 } 2337 } 2338 } 2339 } 2340 2341 // The base start_time is the time to compute RandomUniforms 2342 int64 cur_time = static_cast<int64>(5000005); 2343 // The increment is the execution time of one matmul. See 2344 // CreateGrapplerItemWithMatmulChain for details. 2345 int64 increment = static_cast<int64>(2000000); 2346 auto op_names = {"ab", "abc", "abcd", "abcde"}; 2347 for (const auto& op_name : op_names) { 2348 int64 actual_start = start_end_times[op_name].first; 2349 int64 actual_end = start_end_times[op_name].second; 2350 int64 expected_start = cur_time; 2351 int64 expected_end = cur_time + increment; 2352 EXPECT_EQ(expected_start, actual_start); 2353 EXPECT_EQ(expected_end, actual_end); 2354 cur_time += increment; 2355 } 2356 } 2357 2358 TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) { 2359 // Init. 2360 CreateGrapplerItemWithConv2Ds(); 2361 InitScheduler(); 2362 2363 // Run the scheduler. 2364 auto ops_executed = RunScheduler(""); // Run all the nodes. 2365 2366 // [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be 2367 // executed. 2368 EXPECT_EQ(8, ops_executed.size()); 2369 2370 // x, y, f, c0, and c1 should be in the ops executed. 2371 EXPECT_GT(ops_executed.count("x"), 0); 2372 EXPECT_GT(ops_executed.count("y"), 0); 2373 EXPECT_GT(ops_executed.count("f"), 0); 2374 EXPECT_GT(ops_executed.count("c0"), 0); 2375 EXPECT_GT(ops_executed.count("c1"), 0); 2376 2377 // z and c2 shouldn't be part of it. 2378 EXPECT_EQ(ops_executed.count("z"), 0); 2379 EXPECT_EQ(ops_executed.count("c2"), 0); 2380 2381 // Check input / output properties. 2382 EXPECT_EQ(1, ops_executed["x"].op_info.outputs_size()); 2383 EXPECT_EQ(1, ops_executed["y"].op_info.outputs_size()); 2384 EXPECT_EQ(1, ops_executed["f"].op_info.outputs_size()); 2385 EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size()); 2386 EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size()); 2387 } 2388 2389 TEST_F(VirtualSchedulerTest, MemoryUsage) { 2390 // Init. 2391 CreateGrapplerItemWithAddN(); 2392 InitScheduler(); 2393 2394 // Run the scheduler. 2395 RunScheduler(""); 2396 2397 const auto* device_states = scheduler_->GetDeviceStates(); 2398 const auto& cpu_state = device_states->at(kCPU0); 2399 2400 // out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage 2401 // is 4 x the input tensor size while executing the out node. 2402 int64 one_input_node_size = 4 * 10 * 10 * 10 * 10; 2403 const std::vector<string> expected_names = {"x", "y", "z", "w"}; 2404 EXPECT_EQ(expected_names.size() * one_input_node_size, 2405 cpu_state.max_memory_usage); 2406 ValidateMemoryUsageSnapshot(expected_names, 0 /* port_num_expected */, 2407 cpu_state.mem_usage_snapshot_at_peak); 2408 } 2409 2410 TEST_F(VirtualSchedulerTest, UnnecessaryFeedNodes) { 2411 CreateGrapplerItemWithUnnecessaryPlaceholderNodes(); 2412 InitScheduler(); 2413 2414 // Test that scheduler can run graphs with extra unnecessary feed nodes. 2415 auto ops_executed = RunScheduler(""); 2416 ASSERT_EQ(1, ops_executed.size()); 2417 ASSERT_EQ(ops_executed.count("x"), 1); 2418 } 2419 2420 TEST_F(VirtualSchedulerTest, ControlDependency) { 2421 // Init. 2422 CreateGrapplerItemWithControlDependency(); 2423 InitScheduler(); 2424 2425 // Run the scheduler. 2426 RunScheduler(""); 2427 2428 const auto* device_states = scheduler_->GetDeviceStates(); 2429 const auto& cpu_state = device_states->at(kCPU0); 2430 2431 // The graph has a NoOp that takes control dependency from 7 NoOps. The peak 2432 // memory usage is when executing the final NoOp. 2433 int64 one_input_node_size = 4; // control dependency 2434 const std::vector<string> expected_names = {"x", "y", "z", "w", 2435 "u", "v", "t"}; 2436 EXPECT_EQ(expected_names.size() * one_input_node_size, 2437 cpu_state.max_memory_usage); 2438 ValidateMemoryUsageSnapshot(expected_names, -1 /* port_num_expected */, 2439 cpu_state.mem_usage_snapshot_at_peak); 2440 } 2441 2442 TEST_F(VirtualSchedulerTest, ComplexDependency) { 2443 // Init. 2444 CreateGrapplerItemWithBatchNorm(); 2445 InitScheduler(); 2446 2447 // Run the scheduler. 2448 RunScheduler("bn"); 2449 2450 const auto& device_states = scheduler_->GetDeviceStates(); 2451 const auto& cpu_state = device_states->at(kCPU0); 2452 2453 // The graph is 2454 // bn = FusedBatchNorm(x, scale, offset, mean, var) 2455 // z1 = bn.y + x 2456 // z2 = bn.var + bn.var 2457 // z3 = bn.var + bn.var 2458 // z4 = control dependency from bn. 2459 // Note that bn.mean doesn't have any consumer. 2460 const int x_size = batch_size_ * width_ * height_ * depth_in_; 2461 int64 expected_size = 2462 4 * (2 * x_size /* x and bn.y */ + depth_in_ /* bn.var */ + 2463 1 /* control dependency */); 2464 EXPECT_EQ(expected_size, cpu_state.memory_usage); 2465 2466 // Nodes currently in memory: bn's port -1, 0, and 2, and x's port 0. 2467 std::set<std::pair<string, int>> nodes_in_memory; 2468 std::transform( 2469 cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(), 2470 std::inserter(nodes_in_memory, nodes_in_memory.begin()), 2471 [](const std::pair<const NodeDef*, int>& node_port) { 2472 return std::make_pair(node_port.first->name(), node_port.second); 2473 }); 2474 std::set<std::pair<string, int>> expected = { 2475 std::make_pair("bn", -1), 2476 std::make_pair("bn", 0), 2477 std::make_pair("bn", 2), 2478 std::make_pair("x", 0), 2479 }; 2480 ExpectSetEq(expected, nodes_in_memory); 2481 2482 const auto* node_states = scheduler_->GetNodeStates(); 2483 const NodeState* bn_node = nullptr; 2484 const NodeState* x_node = nullptr; 2485 for (const auto& nodedef_node_state : *node_states) { 2486 const NodeDef* node = nodedef_node_state.first; 2487 const NodeState& node_state = nodedef_node_state.second; 2488 if (node->name() == "bn") { 2489 bn_node = &node_state; 2490 } 2491 if (node->name() == "x") { 2492 x_node = &node_state; 2493 } 2494 } 2495 CHECK_NOTNULL(bn_node); 2496 CHECK_NOTNULL(x_node); 2497 2498 ValidateNodeDefs({"bn", "z1"}, x_node->outputs.at(0)); 2499 ValidateNodeDefs({"z4"}, bn_node->outputs.at(-1)); 2500 ValidateNodeDefs({"z1"}, bn_node->outputs.at(0)); 2501 // z2 and z3 are bn.var + bn.var, so they appear twice in bn's output port 2. 2502 ValidateNodeDefs({"z2", "z3", "z2", "z3"}, bn_node->outputs.at(2)); 2503 } 2504 2505 TEST_F(VirtualSchedulerTest, Variable) { 2506 // Init. 2507 CreateGrapplerItemWithConv2DAndVariable(); 2508 InitScheduler(); 2509 2510 // Run the scheduler. 2511 RunScheduler(""); 2512 2513 const auto* device_states = scheduler_->GetDeviceStates(); 2514 const auto& cpu_state = device_states->at(kCPU0); 2515 2516 // There is one Conv2D that takes x and f, but f is variable, so it should be 2517 // in persistent nodes. 2518 // f is variable. 2519 ValidateMemoryUsageSnapshot({"f"}, 0 /* port_num_expected */, 2520 cpu_state.persistent_nodes); 2521 // Only x in peak memory usage snapshot. 2522 ValidateMemoryUsageSnapshot({"x"}, 0 /* port_num_expected */, 2523 cpu_state.mem_usage_snapshot_at_peak); 2524 } 2525 2526 TEST_F(VirtualSchedulerTest, WhileLoop) { 2527 // Init. 2528 CreateGrapplerItemWithLoop(); 2529 InitScheduler(); 2530 2531 // Run the scheduler. 2532 RunScheduler(""); 2533 2534 // Check the timeline 2535 RunMetadata metadata; 2536 scheduler_->Summary(&metadata); 2537 2538 // Nodes in topological order: 2539 // * const, ones 2540 // * while/Enter, while/Enter_1 2541 // * while/Merge, while/Merge_1 2542 // * while/Less/y 2543 // * while/Less 2544 // * while/LoopCond 2545 // * while/Switch, while/Switch_1 2546 // * while/Identity, while/Identity_1, while/Exit, while/Exit_1 2547 // * while/add/y, while/concat/axis 2548 // * while/add, while/concat 2549 // * while/NextIteration, while/NextIteration_1 2550 2551 int num_next_iteration = 0; 2552 int num_next_iteration_1 = 0; 2553 int num_exit = 0; 2554 int num_exit_1 = 0; 2555 int64 next_iter_start_micro; 2556 int64 next_iter_1_start_micro; 2557 int64 exit_start_micro; 2558 int64 exit_1_start_micro; 2559 2560 std::unordered_map<string, int64> start_times; 2561 for (const auto& device_step_stats : metadata.step_stats().dev_stats()) { 2562 for (const auto& stats : device_step_stats.node_stats()) { 2563 start_times[stats.node_name()] = stats.all_start_micros(); 2564 if (stats.node_name() == "while/NextIteration") { 2565 ++num_next_iteration; 2566 next_iter_start_micro = stats.all_start_micros(); 2567 } else if (stats.node_name() == "while/NextIteration_1") { 2568 ++num_next_iteration_1; 2569 next_iter_1_start_micro = stats.all_start_micros(); 2570 } else if (stats.node_name() == "while/Exit") { 2571 ++num_exit; 2572 exit_start_micro = stats.all_start_micros(); 2573 } else if (stats.node_name() == "while/Exit_1") { 2574 ++num_exit_1; 2575 exit_1_start_micro = stats.all_start_micros(); 2576 } 2577 } 2578 } 2579 2580 // Make sure we went though the body of the loop once, and that the output of 2581 // the loop was scheduled as well. 2582 EXPECT_EQ(1, num_next_iteration); 2583 EXPECT_EQ(1, num_next_iteration_1); 2584 EXPECT_EQ(1, num_exit); 2585 EXPECT_EQ(1, num_exit_1); 2586 2587 // Start times of while/NextIteration and while/NextIteration_1 should be 2588 // different, so should be those of while/Exit and while/Exit_1. 2589 EXPECT_NE(next_iter_start_micro, next_iter_1_start_micro); 2590 EXPECT_NE(exit_start_micro, exit_1_start_micro); 2591 2592 // Check dependency among the nodes; no matter what scheduling mechanism we 2593 // use, the scheduled ops should follow these dependency chains. 2594 // Note that currently, VirtualScheduler executes while/Merge twice; hence, 2595 // we're not testing dependency chains related to while/Merge. 2596 // TODO(dyoon): after fixing while loop behavior correctly (run nodes in the 2597 // order of Enter, Merge, ...loop condition ..., ... loop body ..., 2598 // NextIteration, Merge, ... loop condition ..., Exit), re-enable dependency 2599 // chaining test w/ Merge nodes. 2600 ValidateDependencyChain( 2601 start_times, 2602 {"Const", "while/Enter", // "while/Merge", 2603 "while/Less/y", "while/Less", "while/LoopCond", "while/Switch", 2604 "while/Identity", "while/add/y", "while/add", "while/NextIteration"}); 2605 // ValidateDependencyChain(start_times, {"while/Merge", "while/Less"}); 2606 ValidateDependencyChain(start_times, 2607 {"ones", "while/Enter_1", // "while/Merge_1", 2608 "while/Switch_1", "while/Identity_1", "while/concat", 2609 "while/NextIteration_1"}); 2610 ValidateDependencyChain(start_times, {"while/Switch", "while/Exit"}); 2611 ValidateDependencyChain( 2612 start_times, {"while/Identity", "while/concat/axis", "while/concat"}); 2613 ValidateDependencyChain(start_times, {"while/Identity", "while/add"}); 2614 ValidateDependencyChain(start_times, {"while/Switch_1", "while/Exit_1"}); 2615 } 2616 2617 TEST_F(VirtualSchedulerTest, AnnotatedWhileLoop) { 2618 { 2619 // Init. 2620 CreateGrapplerItemWithLoop(); 2621 InitScheduler(); 2622 2623 // Runs the scheduler. 2624 RunScheduler(""); 2625 Costs c = scheduler_->Summary(); 2626 2627 EXPECT_EQ(23, c.execution_time.asMicroSeconds().count()); 2628 // Both while/Merge and while/Merge_1 are scheduled twice. 2629 EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total); 2630 EXPECT_FALSE(c.inaccurate); 2631 EXPECT_EQ(0, c.num_ops_with_unknown_shapes); 2632 } 2633 2634 { 2635 // Init. 2636 CreateGrapplerItemWithLoopAnnotated(); 2637 InitScheduler(); 2638 2639 // Runs the scheduler. 2640 RunScheduler(""); 2641 Costs c = scheduler_->Summary(); 2642 2643 // The costs for Merge is accumulated twice for execution_count times, but 2644 // since Merge's cost is minimal, we keep this behavior here. 2645 EXPECT_EQ(178, c.execution_time.asMicroSeconds().count()); 2646 // Both while/Merge and while/Merge_1 are scheduled twice. 2647 EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total); 2648 EXPECT_FALSE(c.inaccurate); 2649 EXPECT_EQ(0, c.num_ops_with_unknown_shapes); 2650 } 2651 } 2652 2653 TEST_F(VirtualSchedulerTest, Condition) { 2654 // Without annotation. 2655 { 2656 // Inits. 2657 CreateGrapplerItemWithCondition(); 2658 InitScheduler(); 2659 2660 // Runs the scheduler. 2661 RunScheduler(""); 2662 RunMetadata metadata; 2663 Costs c = scheduler_->Summary(&metadata); 2664 2665 // Nodes in topological order: a/Less, Switch, First/Second, Merge. 2666 int num_a = 0; 2667 int num_less = 0; 2668 int num_switch = 0; 2669 int num_first = 0; 2670 int num_second = 0; 2671 int num_merge = 0; 2672 2673 for (const auto& device_step_stats : metadata.step_stats().dev_stats()) { 2674 for (const auto& stats : device_step_stats.node_stats()) { 2675 if (stats.node_name() == "a") { 2676 ++num_a; 2677 } else if (stats.node_name() == "Less") { 2678 ++num_less; 2679 } else if (stats.node_name() == "Switch") { 2680 ++num_switch; 2681 } else if (stats.node_name() == "First") { 2682 ++num_first; 2683 } else if (stats.node_name() == "Second") { 2684 ++num_second; 2685 } else if (stats.node_name() == "Merge") { 2686 ++num_merge; 2687 } 2688 } 2689 } 2690 2691 EXPECT_EQ(1, num_a); 2692 EXPECT_EQ(1, num_less); 2693 EXPECT_EQ(1, num_switch); 2694 EXPECT_EQ(1, num_first); 2695 EXPECT_EQ(1, num_second); 2696 EXPECT_EQ(2, num_merge); 2697 2698 EXPECT_EQ(7, c.execution_time.asMicroSeconds().count()); 2699 // Merge is executed twice. 2700 EXPECT_EQ(grappler_item_->graph.node_size() + 1, c.num_ops_total); 2701 EXPECT_FALSE(c.inaccurate); 2702 EXPECT_EQ(0, c.num_ops_with_unknown_shapes); 2703 } 2704 2705 // With annotation. 2706 { 2707 // Inits. 2708 CreateGrapplerItemWithCondition(); 2709 2710 // Annotates the Switch node. 2711 for (auto& node : *grappler_item_->graph.mutable_node()) { 2712 if (node.name() == "Switch") { 2713 AttrValue attr_output_info; 2714 // Adds one output slot 0 so that Second shouldn't be executed. 2715 (*attr_output_info.mutable_list()).add_i(0); 2716 AddNodeAttr(kOutputSlots, attr_output_info, &node); 2717 } 2718 } 2719 2720 InitScheduler(); 2721 2722 // Runs the scheduler. 2723 RunScheduler(""); 2724 RunMetadata metadata; 2725 Costs c = scheduler_->Summary(&metadata); 2726 2727 // Nodes in topological order: a/Less, Switch, Merge 2728 int num_a = 0; 2729 int num_less = 0; 2730 int num_switch = 0; 2731 int num_first = 0; 2732 int num_second = 0; 2733 int num_merge = 0; 2734 2735 for (const auto& device_step_stats : metadata.step_stats().dev_stats()) { 2736 for (const auto& stats : device_step_stats.node_stats()) { 2737 if (stats.node_name() == "a") { 2738 ++num_a; 2739 } else if (stats.node_name() == "Less") { 2740 ++num_less; 2741 } else if (stats.node_name() == "Switch") { 2742 ++num_switch; 2743 } else if (stats.node_name() == "First") { 2744 ++num_first; 2745 } else if (stats.node_name() == "Second") { 2746 ++num_second; 2747 } else if (stats.node_name() == "Merge") { 2748 ++num_merge; 2749 } 2750 } 2751 } 2752 2753 EXPECT_EQ(1, num_a); 2754 EXPECT_EQ(1, num_less); 2755 EXPECT_EQ(1, num_switch); 2756 EXPECT_EQ(1, num_first); 2757 EXPECT_EQ(0, num_second); 2758 EXPECT_EQ(1, num_merge); 2759 2760 EXPECT_EQ(5, c.execution_time.asMicroSeconds().count()); 2761 // Second is not executed. 2762 EXPECT_EQ(grappler_item_->graph.node_size() - 1, c.num_ops_total); 2763 EXPECT_FALSE(c.inaccurate); 2764 EXPECT_EQ(0, c.num_ops_with_unknown_shapes); 2765 } 2766 } 2767 2768 TEST_F(VirtualSchedulerTest, InterDeviceTransfer) { 2769 // Init. 2770 CreateGrapplerItemWithInterDeviceTransfers(); 2771 InitScheduler(); 2772 2773 // Run the scheduler. 2774 auto ops_executed = RunScheduler(""); 2775 2776 // Helper lambda to extract port num from _Send and _Recv op name. 2777 auto get_port_num = [](const string& name) -> int { 2778 if (name.find("bn_0") != string::npos) { 2779 return 0; 2780 } else if (name.find("bn_1") != string::npos) { 2781 return 1; 2782 } else if (name.find("bn_2") != string::npos) { 2783 return 2; 2784 } else if (name.find("bn_minus1") != string::npos) { 2785 return -1; 2786 } 2787 return -999; 2788 }; 2789 2790 // Reorganize ops_executed for further testing. 2791 std::unordered_map<string, int> op_count; 2792 std::unordered_map<int, string> recv_op_names; 2793 std::unordered_map<int, string> send_op_names; 2794 for (const auto& x : ops_executed) { 2795 const auto& name = x.first; 2796 const auto& node_info = x.second; 2797 const auto& op = node_info.op_info.op(); 2798 if (op == kRecv) { 2799 recv_op_names[get_port_num(name)] = name; 2800 } else if (op == kSend) { 2801 send_op_names[get_port_num(name)] = name; 2802 } 2803 op_count[op]++; 2804 } 2805 2806 // Same number of _Send and _Recv. 2807 EXPECT_EQ(op_count.at(kSend), op_count.at(kRecv)); 2808 2809 // Expect 4 Send and Recvs each: port 0, 1, and, 2, and control dependency. 2810 EXPECT_EQ(op_count.at(kRecv), 4); 2811 EXPECT_EQ(op_count.at(kSend), 4); 2812 2813 // Helper lambda for extracting output Tensor size. 2814 auto get_output_size = [this, ops_executed](const string& name) -> int64 { 2815 const auto& output_properties_ = ops_executed.at(name).op_info.outputs(); 2816 std::vector<OpInfo::TensorProperties> output_properties; 2817 for (const auto& output_property : output_properties_) { 2818 output_properties.push_back(output_property); 2819 } 2820 return CalculateOutputSize(output_properties, 0); 2821 }; 2822 2823 // Validate transfer size. 2824 // Batchnorm output y is 4D vector: batch x width x width x depth. 2825 int input_size = 4 * batch_size_ * width_ * height_ * depth_in_; 2826 EXPECT_EQ(get_output_size(recv_op_names[0]), input_size); 2827 EXPECT_EQ(get_output_size(send_op_names[0]), input_size); 2828 // Mean and vars are 1-D vector with size depth_in_. 2829 EXPECT_EQ(get_output_size(recv_op_names[1]), 4 * depth_in_); 2830 EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_); 2831 EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_); 2832 EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_); 2833 // Control dependency size is 4B. 2834 EXPECT_EQ(get_output_size(recv_op_names[-1]), 4); 2835 EXPECT_EQ(get_output_size(send_op_names[-1]), 4); 2836 } 2837 2838 TEST_F(VirtualSchedulerTest, GraphWithSendRecv) { 2839 // Init. 2840 CreateGrapplerItemWithSendRecv(); 2841 InitScheduler(); 2842 2843 // Run the scheduler. 2844 auto ops_executed = RunScheduler(""); 2845 2846 EXPECT_GT(ops_executed.count("Const"), 0); 2847 EXPECT_GT(ops_executed.count("Send"), 0); 2848 EXPECT_GT(ops_executed.count("Recv"), 0); 2849 } 2850 2851 TEST_F(VirtualSchedulerTest, GraphWithSendRecvDifferentDevice) { 2852 // Init. 2853 CreateGrapplerItemWithSendRecv(); 2854 // Change Recv node's device so that Send and Recv are placed on different 2855 // devices. 2856 auto& graph = grappler_item_->graph; 2857 const string recv_device = kCPU1; 2858 for (int i = 0; i < graph.node_size(); i++) { 2859 auto* node = graph.mutable_node(i); 2860 if (node->name() == "Recv") { 2861 node->set_device(recv_device); 2862 auto* attr = node->mutable_attr(); 2863 (*attr)["recv_device"].set_s(recv_device); 2864 } else if (node->name() == "Send") { 2865 auto* attr = node->mutable_attr(); 2866 (*attr)["recv_device"].set_s(recv_device); 2867 } 2868 } 2869 InitScheduler(); 2870 2871 // Run the scheduler. 2872 auto ops_executed = RunScheduler(""); 2873 2874 // Expect Const, Send, Recv, and VirtualScheduler created Send and Recv ops. 2875 EXPECT_GT(ops_executed.count("Const"), 0); 2876 EXPECT_GT(ops_executed.count("Send"), 0); 2877 EXPECT_GT(ops_executed.count("Send_Send_0_from_/job_localhost/replica_0/" 2878 "task_0/cpu_0_to_/job_localhost" 2879 "/replica_0/task_0/cpu_1"), 2880 0); 2881 EXPECT_GT(ops_executed.count( 2882 "Recv_Send_0_on_/job_localhost/replica_0/task_0/cpu_1"), 2883 0); 2884 EXPECT_GT(ops_executed.count("Recv"), 0); 2885 } 2886 2887 TEST_F(VirtualSchedulerTest, GraphWihtOnlyRecv) { 2888 // Init. 2889 CreateGrapplerItemWithRecvWithoutSend(); 2890 InitScheduler(); 2891 2892 // Run the scheduler. 2893 auto ops_executed = RunScheduler(""); 2894 2895 // Recv without Send will be treated as initially ready node. 2896 EXPECT_GT(ops_executed.count("Recv"), 0); 2897 } 2898 2899 } // end namespace grappler 2900 } // end namespace tensorflow 2901