Home | History | Annotate | Download | only in costs
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
     17 #include "tensorflow/cc/ops/standard_ops.h"
     18 #include "tensorflow/core/framework/tensor_description.pb.h"
     19 #include "tensorflow/core/framework/tensor_shape.pb.h"
     20 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
     21 #include "tensorflow/core/grappler/costs/virtual_placer.h"
     22 #include "tensorflow/core/lib/core/status_test_util.h"
     23 #include "tensorflow/core/platform/test.h"
     24 
     25 namespace tensorflow {
     26 namespace grappler {
     27 // Class for testing virtual scheduler.
     28 class TestVirtualScheduler : public VirtualScheduler {
     29  public:
     30   TestVirtualScheduler(const GrapplerItem* grappler_item,
     31                        const bool use_static_shapes, Cluster* cluster)
     32       : VirtualScheduler(grappler_item, use_static_shapes, cluster,
     33                          &ready_node_manager_) {}
     34 
     35   FRIEND_TEST(VirtualSchedulerTest, CalculateOutputSize);
     36   FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
     37   FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
     38   FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
     39   FRIEND_TEST(VirtualSchedulerTest, Variable);
     40   FRIEND_TEST(VirtualSchedulerTest, InterDeviceTransfer);
     41 
     42  protected:
     43   FirstReadyManager ready_node_manager_;
     44 };
     45 
     46 class VirtualSchedulerTest : public ::testing::Test {
     47  protected:
     48   NodeDef node1_, node2_, node3_, node4_, node5_, node6_;
     49   std::unordered_map<const NodeDef*, NodeState> node_states_;
     50 
     51   // Device names:
     52   const string kCPU0 = "/job:localhost/replica:0/task:0/cpu:0";
     53   const string kCPU1 = "/job:localhost/replica:0/task:0/cpu:1";
     54   const string kChannelFrom0To1 = "Channel from CPU0 to CPU1";
     55   const string kChannelFrom1To0 = "Channel from CPU1 to CPU0";
     56   // Op names:
     57   const string kSend = "_Send";
     58   const string kRecv = "_Recv";
     59   const string kConv2D = "Conv2D";
     60 
     61   DeviceProperties GetDummyCPUDevice() {
     62     // Create CPU with 2 cores, 4 Ghz freq, 2 GB/s mem bandwidth.
     63     // - 8 Gflops
     64     // - 2 GB/s
     65     DeviceProperties cpu_device;
     66     cpu_device.set_type("CPU");
     67     cpu_device.set_frequency(4000);
     68     cpu_device.set_num_cores(2);
     69     cpu_device.set_bandwidth(2000000);
     70     return cpu_device;
     71   }
     72 
     73   void NodeSetUp(const string& name, const string& op_name,
     74                  const string& device_name, const uint64 time_ready,
     75                  NodeDef* node) {
     76     node->set_name(name);
     77     node->set_op(op_name);
     78     node->set_device(device_name);
     79 
     80     node_states_[node] = NodeState();
     81     node_states_[node].time_ready = time_ready;
     82     node_states_[node].device_name = device_name;
     83   }
     84 
     85   void SetUp() override {
     86     // node1_ to node6_ on kCPU0, with time_ready in reverse_order.
     87     NodeSetUp("Node1", kConv2D, kCPU0, 6000, &node1_);
     88     NodeSetUp("Node2", kConv2D, kCPU0, 5000, &node2_);
     89     NodeSetUp("Node3", kConv2D, kCPU0, 4000, &node3_);
     90     NodeSetUp("Node4", kConv2D, kCPU0, 3000, &node4_);
     91     NodeSetUp("Node5", kConv2D, kCPU0, 2000, &node5_);
     92     NodeSetUp("Node6", kConv2D, kCPU0, 1000, &node6_);
     93 
     94     // Initializes cluster_ and placer_.
     95     std::unordered_map<string, DeviceProperties> devices;
     96 
     97     // Set some dummy CPU properties
     98     DeviceProperties cpu_device = GetDummyCPUDevice();
     99 
    100     // IMPORTANT: Device is not actually ever used in the test case since
    101     // force_cpu_type is defaulted to "Haswell"
    102     devices[kCPU0] = cpu_device;
    103     devices[kCPU1] = cpu_device;
    104     cluster_.reset(new VirtualCluster(devices));
    105     placer_.reset(new VirtualPlacer(cluster_.get()));
    106   }
    107 
    108   // Three Conv2Ds with only two in fetch nodes.
    109   void CreateGrapplerItemWithConv2Ds() {
    110     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
    111     auto x = ops::RandomUniform(
    112         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
    113     auto y = ops::RandomUniform(
    114         s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
    115     auto z = ops::RandomUniform(
    116         s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
    117     auto f = ops::RandomUniform(
    118         s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
    119     std::vector<int> strides = {1, 1, 1, 1};
    120     auto c0 = ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME");
    121     auto c1 = ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME");
    122     auto c2 = ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME");
    123     GraphDef def;
    124     TF_CHECK_OK(s.ToGraphDef(&def));
    125 
    126     grappler_item_.reset(new GrapplerItem);
    127     grappler_item_->id = "test_conv2d_graph";
    128     grappler_item_->graph = def;
    129     grappler_item_->fetch = {"c0", "c1"};
    130 
    131     dependency_["c0"] = {"x", "f"};
    132     dependency_["c1"] = {"y", "f"};
    133   }
    134 
    135   // A Conv2D with a variable.
    136   void CreateGrapplerItemWithConv2DAndVariable() {
    137     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
    138     auto x = ops::RandomUniform(
    139         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
    140     auto f = ops::Variable(s.WithOpName("f"),
    141                            {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
    142     std::vector<int> strides = {1, 1, 1, 1};
    143     auto y = ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME");
    144     GraphDef def;
    145     TF_CHECK_OK(s.ToGraphDef(&def));
    146 
    147     grappler_item_.reset(new GrapplerItem);
    148     grappler_item_->id = "test_conv2d_var_graph";
    149     grappler_item_->graph = def;
    150     grappler_item_->fetch = {"y"};
    151 
    152     dependency_["y"] = {"x", "f"};
    153   }
    154 
    155   void CreateGrapplerItemWithMatmulChain() {
    156     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
    157     // Add control dependencies to ensure tests do not rely on specific
    158     // manager and the order remains consistent for the test.
    159     auto a = ops::RandomUniform(s.WithOpName("a"), {3200, 3200}, DT_FLOAT);
    160     auto b = ops::RandomUniform(s.WithOpName("b").WithControlDependencies(a),
    161                                 {3200, 3200}, DT_FLOAT);
    162     auto c = ops::RandomUniform(s.WithOpName("c").WithControlDependencies(b),
    163                                 {3200, 3200}, DT_FLOAT);
    164     auto d = ops::RandomUniform(s.WithOpName("d").WithControlDependencies(c),
    165                                 {3200, 3200}, DT_FLOAT);
    166     auto e = ops::RandomUniform(s.WithOpName("e").WithControlDependencies(d),
    167                                 {3200, 3200}, DT_FLOAT);
    168 
    169     auto ab = ops::MatMul(s.WithOpName("ab").WithControlDependencies(e), a, b);
    170     auto abc = ops::MatMul(s.WithOpName("abc"), ab, c);
    171     auto abcd = ops::MatMul(s.WithOpName("abcd"), abc, d);
    172     auto abcde = ops::MatMul(s.WithOpName("abcde"), abcd, e);
    173 
    174     GraphDef def;
    175     TF_CHECK_OK(s.ToGraphDef(&def));
    176 
    177     grappler_item_.reset(new GrapplerItem);
    178     grappler_item_->id = "test_matmul_sequence_graph";
    179     grappler_item_->graph = def;
    180     grappler_item_->fetch = {"abcde"};
    181 
    182     dependency_["ab"] = {"a", "b"};
    183     dependency_["abc"] = {"ab", "c"};
    184     dependency_["abcd"] = {"abc", "d"};
    185     dependency_["abcde"] = {"abcd", "e"};
    186   }
    187 
    188   // AddN that takes 4 tensors with 10x10x10x10.
    189   void CreateGrapplerItemWithAddN() {
    190     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
    191     auto x = ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10}, DT_FLOAT);
    192     auto y = ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10}, DT_FLOAT);
    193     auto z = ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10}, DT_FLOAT);
    194     auto w = ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10}, DT_FLOAT);
    195     OutputList input_tensors = {x, y, z, w};
    196     auto out = ops::AddN(s.WithOpName("out"), input_tensors);
    197     GraphDef def;
    198     TF_CHECK_OK(s.ToGraphDef(&def));
    199 
    200     grappler_item_.reset(new GrapplerItem);
    201     grappler_item_->id = "test_addn_graph";
    202     grappler_item_->graph = def;
    203     grappler_item_->fetch = {"out"};
    204 
    205     dependency_["out"] = {"x", "y", "z", "w"};
    206   }
    207 
    208   // NoOp that takes 7 NoOps as control dependency.
    209   void CreateGrapplerItemWithControlDependency() {
    210     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
    211     std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"};
    212     std::vector<Operation> input_tensors;
    213     for (const auto& input : input_noop_names) {
    214       auto x = ops::NoOp(s.WithOpName(input));
    215       input_tensors.push_back(x.operation);
    216     }
    217     auto out =
    218         ops::NoOp(s.WithControlDependencies(input_tensors).WithOpName("out"));
    219     GraphDef def;
    220     TF_CHECK_OK(s.ToGraphDef(&def));
    221 
    222     grappler_item_.reset(new GrapplerItem);
    223     grappler_item_->id = "test_control_dependency_graph";
    224     grappler_item_->graph = def;
    225     grappler_item_->fetch = {"out"};
    226 
    227     dependency_["out"] = input_noop_names;
    228   }
    229 
    230   // FusedBN [an op with multiple outputs] with multiple consumers (including
    231   // control dependency).
    232   void CreateGrapplerItemWithBatchNorm() {
    233     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
    234     auto x = ops::RandomUniform(
    235         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
    236     auto scale =
    237         ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
    238     auto offset =
    239         ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
    240     auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
    241     auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
    242 
    243     auto batch_norm = ops::FusedBatchNorm(
    244         s.WithOpName("bn"), x, scale, offset, mean, var,
    245         ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
    246     auto y = batch_norm.y;
    247     auto batch_mean = batch_norm.batch_mean;
    248     auto batch_var = batch_norm.batch_variance;
    249 
    250     auto z1 = ops::Add(s.WithOpName("z1"), x, y);
    251     auto z2 = ops::Add(s.WithOpName("z2"), batch_var, batch_var);
    252     auto z3 = ops::Add(s.WithOpName("z3"), batch_var, batch_var);
    253     std::vector<Operation> input_tensors = {
    254         batch_mean.op(),
    255         z1.z.op(),
    256         z2.z.op(),
    257         z3.z.op(),
    258     };
    259     auto z4 = ops::NoOp(s.WithControlDependencies(batch_var).WithOpName("z4"));
    260 
    261     GraphDef def;
    262     TF_CHECK_OK(s.ToGraphDef(&def));
    263 
    264     grappler_item_.reset(new GrapplerItem);
    265     grappler_item_->id = "test_complex_dependency_graph";
    266     grappler_item_->graph = def;
    267     grappler_item_->fetch = {"z1", "z2", "z3", "z4"};
    268 
    269     dependency_["bn"] = {"x", "scale", "offset", "mean", "var"};
    270     dependency_["z1"] = {"x", "bn"};
    271     dependency_["z2"] = {"bn"};
    272     dependency_["z3"] = {"bn"};
    273     dependency_["z4"] = {"bn"};
    274   }
    275 
    276   void CreateGrapplerItemWithSendRecv() {
    277     const string gdef_ascii = R"EOF(
    278 node {
    279   name: "Const"
    280   op: "Const"
    281   device: "/job:localhost/replica:0/task:0/device:CPU:0"
    282   attr {
    283     key: "dtype"
    284     value {
    285       type: DT_FLOAT
    286     }
    287   }
    288   attr {
    289     key: "value"
    290     value {
    291       tensor {
    292         dtype: DT_FLOAT
    293         tensor_shape {
    294         }
    295         float_val: 3.1415
    296       }
    297     }
    298   }
    299 }
    300 node {
    301   name: "Send"
    302   op: "_Send"
    303   input: "Const"
    304   device: "/job:localhost/replica:0/task:0/device:CPU:0"
    305   attr {
    306     key: "T"
    307     value {
    308       type: DT_FLOAT
    309     }
    310   }
    311   attr {
    312     key: "client_terminated"
    313     value {
    314       b: false
    315     }
    316   }
    317   attr {
    318     key: "recv_device"
    319     value {
    320       s: "/job:localhost/replica:0/task:0/device:CPU:0"
    321     }
    322   }
    323   attr {
    324     key: "send_device"
    325     value {
    326       s: "/job:localhost/replica:0/task:0/device:CPU:0"
    327     }
    328   }
    329   attr {
    330     key: "send_device_incarnation"
    331     value {
    332       i: 0
    333     }
    334   }
    335   attr {
    336     key: "tensor_name"
    337     value {
    338       s: "test"
    339     }
    340   }
    341 }
    342 node {
    343   name: "Recv"
    344   op: "_Recv"
    345   device: "/job:localhost/replica:0/task:0/device:CPU:0"
    346   attr {
    347     key: "client_terminated"
    348     value {
    349       b: false
    350     }
    351   }
    352   attr {
    353     key: "recv_device"
    354     value {
    355       s: "/job:localhost/replica:0/task:0/device:CPU:0"
    356     }
    357   }
    358   attr {
    359     key: "send_device"
    360     value {
    361       s: "/job:localhost/replica:0/task:0/device:CPU:0"
    362     }
    363   }
    364   attr {
    365     key: "send_device_incarnation"
    366     value {
    367       i: 0
    368     }
    369   }
    370   attr {
    371     key: "tensor_name"
    372     value {
    373       s: "test"
    374     }
    375   }
    376   attr {
    377     key: "tensor_type"
    378     value {
    379       type: DT_FLOAT
    380     }
    381   }
    382 }
    383 library {
    384 }
    385 versions {
    386   producer: 24
    387 }
    388     )EOF";
    389 
    390     grappler_item_.reset(new GrapplerItem);
    391     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
    392                                                 &grappler_item_->graph));
    393     grappler_item_->id = "test_graph";
    394     grappler_item_->fetch = {"Recv"};
    395   }
    396 
    397   // A simple while loop
    398   void CreateGrapplerItemWithLoop() {
    399     // Test graph produced in python using:
    400     /*
    401       with tf.Graph().as_default():
    402       i0 = tf.constant(0)
    403       m0 = tf.ones([2, 2])
    404       c = lambda i, m: i < 10
    405       b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
    406       r = tf.while_loop(
    407       c, b, loop_vars=[i0, m0],
    408       shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
    409       with open('/tmp/graph.pbtxt', 'w') as f:
    410       f.write(str(tf.get_default_graph().as_graph_def()))
    411     */
    412     const string gdef_ascii = R"EOF(
    413 node {
    414   name: "Const"
    415   op: "Const"
    416   attr {
    417     key: "dtype"
    418     value {
    419       type: DT_INT32
    420     }
    421   }
    422   attr {
    423     key: "value"
    424     value {
    425       tensor {
    426         dtype: DT_INT32
    427         tensor_shape {
    428         }
    429         int_val: 0
    430       }
    431     }
    432   }
    433 }
    434 node {
    435   name: "ones"
    436   op: "Const"
    437   attr {
    438     key: "dtype"
    439     value {
    440       type: DT_FLOAT
    441     }
    442   }
    443   attr {
    444     key: "value"
    445     value {
    446       tensor {
    447         dtype: DT_FLOAT
    448         tensor_shape {
    449           dim {
    450             size: 2
    451           }
    452           dim {
    453             size: 2
    454           }
    455         }
    456         float_val: 1.0
    457       }
    458     }
    459   }
    460 }
    461 node {
    462   name: "while/Enter"
    463   op: "Enter"
    464   input: "Const"
    465   attr {
    466     key: "T"
    467     value {
    468       type: DT_INT32
    469     }
    470   }
    471   attr {
    472     key: "frame_name"
    473     value {
    474       s: "while/while/"
    475     }
    476   }
    477   attr {
    478     key: "is_constant"
    479     value {
    480       b: false
    481     }
    482   }
    483   attr {
    484     key: "parallel_iterations"
    485     value {
    486       i: 10
    487     }
    488   }
    489 }
    490 node {
    491   name: "while/Enter_1"
    492   op: "Enter"
    493   input: "ones"
    494   attr {
    495     key: "T"
    496     value {
    497       type: DT_FLOAT
    498     }
    499   }
    500   attr {
    501     key: "frame_name"
    502     value {
    503       s: "while/while/"
    504     }
    505   }
    506   attr {
    507     key: "is_constant"
    508     value {
    509       b: false
    510     }
    511   }
    512   attr {
    513     key: "parallel_iterations"
    514     value {
    515       i: 10
    516     }
    517   }
    518 }
    519 node {
    520   name: "while/Merge"
    521   op: "Merge"
    522   input: "while/Enter"
    523   input: "while/NextIteration"
    524   attr {
    525     key: "N"
    526     value {
    527       i: 2
    528     }
    529   }
    530   attr {
    531     key: "T"
    532     value {
    533       type: DT_INT32
    534     }
    535   }
    536 }
    537 node {
    538   name: "while/Merge_1"
    539   op: "Merge"
    540   input: "while/Enter_1"
    541   input: "while/NextIteration_1"
    542   attr {
    543     key: "N"
    544     value {
    545       i: 2
    546     }
    547   }
    548   attr {
    549     key: "T"
    550     value {
    551       type: DT_FLOAT
    552     }
    553   }
    554 }
    555 node {
    556   name: "while/Less/y"
    557   op: "Const"
    558   input: "^while/Merge"
    559   attr {
    560     key: "dtype"
    561     value {
    562       type: DT_INT32
    563     }
    564   }
    565   attr {
    566     key: "value"
    567     value {
    568       tensor {
    569         dtype: DT_INT32
    570         tensor_shape {
    571         }
    572         int_val: 10
    573       }
    574     }
    575   }
    576 }
    577 node {
    578   name: "while/Less"
    579   op: "Less"
    580   input: "while/Merge"
    581   input: "while/Less/y"
    582   attr {
    583     key: "T"
    584     value {
    585       type: DT_INT32
    586     }
    587   }
    588 }
    589 node {
    590   name: "while/LoopCond"
    591   op: "LoopCond"
    592   input: "while/Less"
    593 }
    594 node {
    595   name: "while/Switch"
    596   op: "Switch"
    597   input: "while/Merge"
    598   input: "while/LoopCond"
    599   attr {
    600     key: "T"
    601     value {
    602       type: DT_INT32
    603     }
    604   }
    605   attr {
    606     key: "_class"
    607     value {
    608       list {
    609         s: "loc:@while/Merge"
    610       }
    611     }
    612   }
    613 }
    614 node {
    615   name: "while/Switch_1"
    616   op: "Switch"
    617   input: "while/Merge_1"
    618   input: "while/LoopCond"
    619   attr {
    620     key: "T"
    621     value {
    622       type: DT_FLOAT
    623     }
    624   }
    625   attr {
    626     key: "_class"
    627     value {
    628       list {
    629         s: "loc:@while/Merge_1"
    630       }
    631     }
    632   }
    633 }
    634 node {
    635   name: "while/Identity"
    636   op: "Identity"
    637   input: "while/Switch:1"
    638   attr {
    639     key: "T"
    640     value {
    641       type: DT_INT32
    642     }
    643   }
    644 }
    645 node {
    646   name: "while/Identity_1"
    647   op: "Identity"
    648   input: "while/Switch_1:1"
    649   attr {
    650     key: "T"
    651     value {
    652       type: DT_FLOAT
    653     }
    654   }
    655 }
    656 node {
    657   name: "while/add/y"
    658   op: "Const"
    659   input: "^while/Identity"
    660   attr {
    661     key: "dtype"
    662     value {
    663       type: DT_INT32
    664     }
    665   }
    666   attr {
    667     key: "value"
    668     value {
    669       tensor {
    670         dtype: DT_INT32
    671         tensor_shape {
    672         }
    673         int_val: 1
    674       }
    675     }
    676   }
    677 }
    678 node {
    679   name: "while/add"
    680   op: "Add"
    681   input: "while/Identity"
    682   input: "while/add/y"
    683   attr {
    684     key: "T"
    685     value {
    686       type: DT_INT32
    687     }
    688   }
    689 }
    690 node {
    691   name: "while/concat/axis"
    692   op: "Const"
    693   input: "^while/Identity"
    694   attr {
    695     key: "dtype"
    696     value {
    697       type: DT_INT32
    698     }
    699   }
    700   attr {
    701     key: "value"
    702     value {
    703       tensor {
    704         dtype: DT_INT32
    705         tensor_shape {
    706         }
    707         int_val: 0
    708       }
    709     }
    710   }
    711 }
    712 node {
    713   name: "while/concat"
    714   op: "ConcatV2"
    715   input: "while/Identity_1"
    716   input: "while/Identity_1"
    717   input: "while/concat/axis"
    718   attr {
    719     key: "N"
    720     value {
    721       i: 2
    722     }
    723   }
    724   attr {
    725     key: "T"
    726     value {
    727       type: DT_FLOAT
    728     }
    729   }
    730   attr {
    731     key: "Tidx"
    732     value {
    733       type: DT_INT32
    734     }
    735   }
    736 }
    737 node {
    738   name: "while/NextIteration"
    739   op: "NextIteration"
    740   input: "while/add"
    741   attr {
    742     key: "T"
    743     value {
    744       type: DT_INT32
    745     }
    746   }
    747 }
    748 node {
    749   name: "while/NextIteration_1"
    750   op: "NextIteration"
    751   input: "while/concat"
    752   attr {
    753     key: "T"
    754     value {
    755       type: DT_FLOAT
    756     }
    757   }
    758 }
    759 node {
    760   name: "while/Exit"
    761   op: "Exit"
    762   input: "while/Switch"
    763   attr {
    764     key: "T"
    765     value {
    766       type: DT_INT32
    767     }
    768   }
    769 }
    770 node {
    771   name: "while/Exit_1"
    772   op: "Exit"
    773   input: "while/Switch_1"
    774   attr {
    775     key: "T"
    776     value {
    777       type: DT_FLOAT
    778     }
    779   }
    780 }
    781 versions {
    782   producer: 21
    783 }
    784   )EOF";
    785 
    786     grappler_item_.reset(new GrapplerItem);
    787     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
    788                                                 &grappler_item_->graph));
    789     grappler_item_->id = "test_graph";
    790     grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
    791   }
    792 
    793   void CreateGrapplerItemWithInterDeviceTransfers() {
    794     tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
    795 
    796     // Create a FusedBatchNorm op that has multiple output ports.
    797     auto x = ops::RandomUniform(
    798         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
    799     auto scale =
    800         ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
    801     auto offset =
    802         ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
    803     auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
    804     auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
    805 
    806     auto batch_norm = ops::FusedBatchNorm(
    807         s.WithOpName("bn"), x, scale, offset, mean, var,
    808         ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
    809     auto y = batch_norm.y;
    810     auto batch_mean = batch_norm.batch_mean;
    811     auto batch_var = batch_norm.batch_variance;
    812     // y1 and y2 take the same tensor, so there should be only 1 Send and Recv.
    813     auto y1 = ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y);
    814     auto y2 = ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y);
    815     // batch_mean1 and batch_var1 take different output ports, so each will
    816     // initiate Send/Recv.
    817     auto batch_mean1 = ops::Identity(
    818         s.WithOpName("batch_mean1").WithDevice(kCPU1), batch_mean);
    819     auto batch_var1 =
    820         ops::Identity(s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var);
    821     // This is control dependency.
    822     auto control_dep = ops::NoOp(s.WithOpName("control_dep")
    823                                      .WithControlDependencies(y)
    824                                      .WithDevice(kCPU1));
    825 
    826     GraphDef def;
    827     TF_CHECK_OK(s.ToGraphDef(&def));
    828 
    829     grappler_item_.reset(new GrapplerItem);
    830     grappler_item_->id = "test_conv2d_graph";
    831     grappler_item_->graph = def;
    832     grappler_item_->fetch = {"y1", "y2", "batch_mean1", "batch_var1",
    833                              "control_dep"};
    834 
    835     dependency_["bn"] = {"x", "mean", "var"};
    836     dependency_["y1"] = {"bn"};
    837     dependency_["y2"] = {"bn"};
    838     dependency_["batch_mean1"] = {"bn"};
    839     dependency_["batch_var1"] = {"bn"};
    840     dependency_["control_dep"] = {"bn"};
    841   }
    842 
    843   // Call this after creating grappler_item_ and setting up dependency_.
    844   void InitScheduler() {
    845     scheduler_.reset(new TestVirtualScheduler(
    846         grappler_item_.get(), true /* use_static_shapes */, cluster_.get()));
    847     TF_CHECK_OK(scheduler_->Init());
    848   }
    849 
    850   // Returns cost based on op.
    851   Costs SimplePredictCosts(const OpContext& op_context) const {
    852     Costs c;
    853     int64 exec_cost = 0;
    854     if (op_context.op_info.op() == "MatMul") {
    855       exec_cost = 2000000000;
    856     } else if (op_context.op_info.op() == "RandomUniform") {
    857       exec_cost = 1000000000;
    858     } else {
    859       exec_cost = 1000;
    860     }
    861     c.execution_time = Costs::NanoSeconds(exec_cost);
    862     return c;
    863   }
    864 
    865   // Call this after init scheduler_. Scheduler stops after executing
    866   // target_node.
    867   std::unordered_map<string, OpContext> RunScheduler(
    868       const string& target_node) {
    869     Costs zero_costs = Costs::ZeroCosts();
    870     std::unordered_map<string, OpContext> ops_executed;
    871     bool more_nodes = true;
    872     do {
    873       OpContext op_context = scheduler_->GetCurrNode();
    874       ops_executed[op_context.name] = op_context;
    875       std::cout << op_context.name << std::endl;
    876 
    877       Costs node_costs = SimplePredictCosts(op_context);
    878 
    879       // Check scheduling order.
    880       auto it = dependency_.find(op_context.name);
    881       if (it != dependency_.end()) {
    882         for (const auto& preceding_node : it->second) {
    883           EXPECT_GT(ops_executed.count(preceding_node), 0);
    884         }
    885       }
    886       more_nodes = scheduler_->MarkCurrNodeExecuted(node_costs);
    887 
    888       if (op_context.name == target_node) {
    889         // Scheduler has the state after executing the target node.
    890         break;
    891       }
    892     } while (more_nodes);
    893     return ops_executed;
    894   }
    895 
    896   // Helper method for validating a vector.
    897   template <typename T>
    898   void ExpectVectorEq(const std::vector<T>& expected,
    899                       const std::vector<T>& test_elements) {
    900     // Set of expected elements for an easy comparison.
    901     std::set<T> expected_set(expected.begin(), expected.end());
    902     for (const auto& element : test_elements) {
    903       EXPECT_GT(expected_set.count(element), 0);
    904     }
    905     EXPECT_EQ(expected.size(), test_elements.size());
    906   }
    907 
    908   // Helper method that checks the name of nodes.
    909   void ValidateNodeDefs(const std::vector<string>& expected,
    910                         const std::vector<const NodeDef*>& node_defs) {
    911     std::vector<string> node_names;
    912     std::transform(node_defs.begin(), node_defs.end(),
    913                    std::back_inserter(node_names),
    914                    [](const NodeDef* node) { return node->name(); });
    915     ExpectVectorEq(expected, node_names);
    916   }
    917 
    918   // Helper method for validating a set.
    919   template <typename T>
    920   void ExpectSetEq(const std::set<T>& expected,
    921                    const std::set<T>& test_elements) {
    922     for (const auto& element : test_elements) {
    923       EXPECT_GT(expected.count(element), 0);
    924     }
    925     EXPECT_EQ(expected.size(), test_elements.size());
    926   }
    927 
    928   // Helper method tthat checks name - port pairs.
    929   void ValidateMemoryUsageSnapshot(
    930       const std::vector<string>& expected_names, const int port_num_expected,
    931       const std::unordered_set<std::pair<const NodeDef*, int>,
    932                                DeviceState::NodePairHash>& mem_usage_snapshot) {
    933     std::set<std::pair<string, int>> nodes_at_peak_mem_usage;
    934     std::transform(
    935         mem_usage_snapshot.begin(), mem_usage_snapshot.end(),
    936         std::inserter(nodes_at_peak_mem_usage, nodes_at_peak_mem_usage.begin()),
    937         [](const std::pair<const NodeDef*, int>& node_port) {
    938           return std::make_pair(node_port.first->name(), node_port.second);
    939         });
    940     std::set<std::pair<string, int>> expected;
    941     std::transform(expected_names.begin(), expected_names.end(),
    942                    std::inserter(expected, expected.begin()),
    943                    [port_num_expected](const string& name) {
    944                      return std::make_pair(name, port_num_expected);
    945                    });
    946     ExpectSetEq(expected, nodes_at_peak_mem_usage);
    947   }
    948 
    949   // Helper method for checking nodes dependency.
    950   void ValidateDependencyChain(
    951       const std::unordered_map<string, int64>& start_times,
    952       const std::vector<string>& nodes_in_dependency_order) {
    953     int64 prev_node_time = -1;
    954     for (const auto& node : nodes_in_dependency_order) {
    955       int64 curr_node_time = start_times.at(node);
    956       EXPECT_GE(curr_node_time, prev_node_time);
    957       prev_node_time = curr_node_time;
    958     }
    959   }
    960 
    961   // Helper method for converting shape vector to TensorProperty.
    962   OpInfo::TensorProperties ShapeToTensorProperty(
    963       const std::vector<int> shape, const DataType& data_type) const {
    964     OpInfo::TensorProperties tensor_property;
    965     tensor_property.set_dtype(data_type);
    966     for (const auto& x : shape) {
    967       tensor_property.mutable_shape()->add_dim()->set_size(x);
    968     }
    969     return tensor_property;
    970   }
    971 
    972   // SetUp() inits cluster_ and placer_.
    973   std::unique_ptr<VirtualCluster> cluster_;
    974   std::unique_ptr<VirtualPlacer> placer_;
    975 
    976   // grappler_item_ and scheduler_ will be initialized differently for each test
    977   // case.
    978   std::unique_ptr<GrapplerItem> grappler_item_;
    979   std::unique_ptr<TestVirtualScheduler> scheduler_;
    980   // Node name -> its preceding nodes map for testing scheduling order.
    981   std::unordered_map<string, std::vector<string>> dependency_;
    982 
    983   // Shared params for Conv2D related graphs:
    984   const int batch_size_ = 4;
    985   const int width_ = 10;
    986   const int height_ = 10;
    987   const int depth_in_ = 8;
    988   const int kernel_ = 3;
    989   const int depth_out_ = 16;
    990 };
    991 
    992 // Test that FIFOManager correctly returns the current node with only 1 node.
    993 TEST_F(VirtualSchedulerTest, GetSingleNodeFIFOManager) {
    994   // Init.
    995   FIFOManager manager = FIFOManager();
    996 
    997   // Add the node to FIFOManager.
    998   manager.AddNode(&node1_);
    999   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
   1000 }
   1001 
   1002 // Test that FIFOManager removes the only node contained within.
   1003 TEST_F(VirtualSchedulerTest, RemoveSingleNodeFIFOManager) {
   1004   // Init.
   1005   FIFOManager manager = FIFOManager();
   1006 
   1007   // Add the node to FIFOManager.
   1008   manager.AddNode(&node1_);
   1009 
   1010   // Remove the only node in FIFOManager.
   1011   manager.RemoveCurrNode();
   1012   EXPECT_TRUE(manager.Empty());
   1013 }
   1014 
   1015 // Test that FIFOManager can remove multiple nodes and returns the current node
   1016 // in the right order
   1017 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleFIFOManager) {
   1018   // Init.
   1019   FIFOManager manager = FIFOManager();
   1020 
   1021   // Add the nodes to FIFOManager.
   1022   manager.AddNode(&node1_);
   1023   manager.AddNode(&node2_);
   1024   manager.AddNode(&node3_);
   1025   manager.AddNode(&node4_);
   1026 
   1027   // Keep checking current node while removing nodes from manager.
   1028   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
   1029   manager.RemoveCurrNode();
   1030   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
   1031   manager.RemoveCurrNode();
   1032   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
   1033   manager.RemoveCurrNode();
   1034   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
   1035   manager.RemoveCurrNode();
   1036   EXPECT_TRUE(manager.Empty());
   1037 }
   1038 
   1039 // Test that FIFOManager can remove multiple nodes and add more nodes, still
   1040 // returning the current node in the right order
   1041 TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleFIFOManager) {
   1042   // Init.
   1043   FIFOManager manager = FIFOManager();
   1044 
   1045   // Add the nodes to FIFOManager.
   1046   manager.AddNode(&node1_);
   1047   manager.AddNode(&node2_);
   1048   manager.AddNode(&node3_);
   1049   manager.AddNode(&node4_);
   1050 
   1051   // Keep checking current node as nodes are removed and added.
   1052   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
   1053   manager.RemoveCurrNode();
   1054   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
   1055   manager.AddNode(&node5_);
   1056   // GetCurrNode()  should return the same node even if some nodes are added,
   1057   // until RemoveCurrNode() is called.
   1058   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
   1059   manager.RemoveCurrNode();
   1060   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
   1061   manager.RemoveCurrNode();
   1062   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
   1063   manager.AddNode(&node6_);
   1064   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
   1065   manager.RemoveCurrNode();
   1066   EXPECT_EQ("Node5", manager.GetCurrNode()->name());
   1067   manager.RemoveCurrNode();
   1068   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
   1069   manager.RemoveCurrNode();
   1070   EXPECT_TRUE(manager.Empty());
   1071 }
   1072 
   1073 // Test that LIFOManager correctly returns the current node with only 1 node.
   1074 TEST_F(VirtualSchedulerTest, GetSingleNodeLIFOManager) {
   1075   // Init.
   1076   LIFOManager manager = LIFOManager();
   1077 
   1078   // Add the node to LIFOManager.
   1079   manager.AddNode(&node1_);
   1080   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
   1081 }
   1082 
   1083 // Test that LIFOManager removes the only node contained within.
   1084 TEST_F(VirtualSchedulerTest, RemoveSingleNodeLIFOManager) {
   1085   // Init.
   1086   LIFOManager manager = LIFOManager();
   1087 
   1088   // Add the node to LIFOManager.
   1089   manager.AddNode(&node1_);
   1090 
   1091   // Remove the only node in LIFOManager.
   1092   manager.RemoveCurrNode();
   1093   EXPECT_TRUE(manager.Empty());
   1094 }
   1095 
   1096 // Test that LIFOManager can remove multiple nodes and returns the current node
   1097 // in the right order
   1098 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleLIFOManager) {
   1099   // Init.
   1100   LIFOManager manager = LIFOManager();
   1101 
   1102   // Add the nodes to LIFOManager.
   1103   manager.AddNode(&node1_);
   1104   manager.AddNode(&node2_);
   1105   manager.AddNode(&node3_);
   1106   manager.AddNode(&node4_);
   1107 
   1108   // Keep checking current node while removing nodes from manager.
   1109   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
   1110   manager.RemoveCurrNode();
   1111   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
   1112   manager.RemoveCurrNode();
   1113   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
   1114   manager.RemoveCurrNode();
   1115   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
   1116   manager.RemoveCurrNode();
   1117   EXPECT_TRUE(manager.Empty());
   1118 }
   1119 
   1120 // Test that LIFOManager can remove multiple nodes (must be removing the current
   1121 // node) and add more nodes, still returning the current node in the right order
   1122 TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleLIFOManager) {
   1123   // Init.
   1124   LIFOManager manager = LIFOManager();
   1125 
   1126   // Add the nodes to LIFOManager.
   1127   manager.AddNode(&node1_);
   1128   manager.AddNode(&node2_);
   1129   manager.AddNode(&node3_);
   1130   manager.AddNode(&node4_);
   1131 
   1132   // Keep checking current node as nodes are removed and added.
   1133   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
   1134   manager.RemoveCurrNode();
   1135   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
   1136   manager.AddNode(&node5_);
   1137   // GetCurrNode()  should return the same node even if some nodes are added,
   1138   // until RemoveCurrNode() is called.
   1139   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
   1140   manager.RemoveCurrNode();
   1141   EXPECT_EQ("Node5", manager.GetCurrNode()->name());
   1142   manager.RemoveCurrNode();
   1143   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
   1144   manager.AddNode(&node6_);
   1145   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
   1146   manager.RemoveCurrNode();
   1147   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
   1148   manager.RemoveCurrNode();
   1149   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
   1150   manager.RemoveCurrNode();
   1151   EXPECT_TRUE(manager.Empty());
   1152 }
   1153 
   1154 TEST_F(VirtualSchedulerTest, GetSingleNodeFirstReadyManager) {
   1155   FirstReadyManager manager;
   1156   manager.Init(&node_states_);
   1157 
   1158   manager.AddNode(&node1_);
   1159   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
   1160 }
   1161 
   1162 TEST_F(VirtualSchedulerTest, RemoveSingleNodeFirstReadyManager) {
   1163   FirstReadyManager manager;
   1164   manager.Init(&node_states_);
   1165   manager.AddNode(&node1_);
   1166   manager.RemoveCurrNode();
   1167   EXPECT_TRUE(manager.Empty());
   1168 }
   1169 
   1170 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleFirstReadyManager) {
   1171   FirstReadyManager manager;
   1172   manager.Init(&node_states_);
   1173   // Insert nodes in some random order.
   1174   manager.AddNode(&node2_);
   1175   manager.AddNode(&node1_);
   1176   manager.AddNode(&node4_);
   1177   manager.AddNode(&node5_);
   1178   manager.AddNode(&node3_);
   1179   manager.AddNode(&node6_);
   1180 
   1181   // In whatever order we insert nodes, we get the same order based on nodes'
   1182   // time_ready.
   1183   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
   1184   manager.RemoveCurrNode();
   1185   EXPECT_EQ("Node5", manager.GetCurrNode()->name());
   1186   manager.RemoveCurrNode();
   1187   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
   1188   manager.RemoveCurrNode();
   1189   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
   1190   manager.RemoveCurrNode();
   1191   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
   1192   manager.RemoveCurrNode();
   1193   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
   1194   manager.RemoveCurrNode();
   1195   EXPECT_TRUE(manager.Empty());
   1196 }
   1197 
   1198 TEST_F(VirtualSchedulerTest, GetCurrNodeFirstReadyManager) {
   1199   FirstReadyManager manager;
   1200   manager.Init(&node_states_);
   1201   // Insert nodes in some random order.
   1202   manager.AddNode(&node2_);
   1203   manager.AddNode(&node1_);
   1204   manager.AddNode(&node4_);
   1205   manager.AddNode(&node5_);
   1206   manager.AddNode(&node3_);
   1207   manager.AddNode(&node6_);
   1208 
   1209   // Among these nodes, node6 has the smallest time_ready, hence, GetCurrNode()
   1210   // should return it.
   1211   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
   1212   // Now insret a few other nodes, but their time_ready's are even smaller than
   1213   // that of Node6. Before calling RemoveCurrNode(), GetCurrNode() should return
   1214   // the same node, Node6, in this case.
   1215 
   1216   NodeDef node7;
   1217   NodeDef node8;
   1218   NodeDef node9;
   1219   NodeSetUp("Node7", kConv2D, kCPU0, 5, &node7);
   1220   NodeSetUp("Node8", kConv2D, kCPU0, 4, &node8);
   1221   NodeSetUp("Node9", kConv2D, kCPU0, 3, &node9);
   1222 
   1223   manager.AddNode(&node7);
   1224   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
   1225 
   1226   manager.AddNode(&node8);
   1227   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
   1228 
   1229   manager.RemoveCurrNode();
   1230   // Now Node6 is removed, and GetCurrNode() will return Node8.
   1231   EXPECT_EQ("Node8", manager.GetCurrNode()->name());
   1232 
   1233   // Again, AddNode shouldn't change GetCurrNode().
   1234   manager.AddNode(&node9);
   1235   EXPECT_EQ("Node8", manager.GetCurrNode()->name());
   1236 
   1237   manager.RemoveCurrNode();
   1238   EXPECT_EQ("Node9", manager.GetCurrNode()->name());
   1239   manager.RemoveCurrNode();
   1240   EXPECT_EQ("Node7", manager.GetCurrNode()->name());
   1241   manager.RemoveCurrNode();
   1242   EXPECT_EQ("Node5", manager.GetCurrNode()->name());
   1243   manager.RemoveCurrNode();
   1244   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
   1245   manager.RemoveCurrNode();
   1246   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
   1247   manager.RemoveCurrNode();
   1248   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
   1249   manager.RemoveCurrNode();
   1250   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
   1251   manager.RemoveCurrNode();
   1252   EXPECT_TRUE(manager.Empty());
   1253 }
   1254 
   1255 TEST_F(VirtualSchedulerTest, DeterminismInFirstReadyManager) {
   1256   FirstReadyManager manager1;
   1257   manager1.Init(&node_states_);
   1258   FirstReadyManager manager2;
   1259   manager2.Init(&node_states_);
   1260 
   1261   // 6 nodes with same time_ready.
   1262   NodeDef node7;
   1263   NodeDef node8;
   1264   NodeDef node9;
   1265   NodeDef node10;
   1266   NodeDef node11;
   1267   NodeDef node12;
   1268   NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
   1269   NodeSetUp("Node8", kConv2D, kCPU0, 1000, &node8);
   1270   NodeSetUp("Node9", kConv2D, kCPU0, 1000, &node9);
   1271   NodeSetUp("Node10", kConv2D, kCPU0, 1000, &node10);
   1272   NodeSetUp("Node11", kConv2D, kCPU0, 1000, &node11);
   1273   NodeSetUp("Node12", kConv2D, kCPU0, 1000, &node12);
   1274 
   1275   // Add the above 6 nodes to manager1.
   1276   manager1.AddNode(&node7);
   1277   manager1.AddNode(&node8);
   1278   manager1.AddNode(&node9);
   1279   manager1.AddNode(&node10);
   1280   manager1.AddNode(&node11);
   1281   manager1.AddNode(&node12);
   1282 
   1283   // Add the above 6 nodes to manager2, but in a different order.
   1284   manager2.AddNode(&node8);
   1285   manager2.AddNode(&node11);
   1286   manager2.AddNode(&node9);
   1287   manager2.AddNode(&node10);
   1288   manager2.AddNode(&node7);
   1289   manager2.AddNode(&node12);
   1290 
   1291   // Expect both managers return the same nodes for deterministic node
   1292   // scheduling.
   1293   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
   1294   manager1.RemoveCurrNode();
   1295   manager2.RemoveCurrNode();
   1296 
   1297   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
   1298   manager1.RemoveCurrNode();
   1299   manager2.RemoveCurrNode();
   1300 
   1301   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
   1302   manager1.RemoveCurrNode();
   1303   manager2.RemoveCurrNode();
   1304 
   1305   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
   1306   manager1.RemoveCurrNode();
   1307   manager2.RemoveCurrNode();
   1308 
   1309   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
   1310   manager1.RemoveCurrNode();
   1311   manager2.RemoveCurrNode();
   1312 
   1313   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
   1314   manager1.RemoveCurrNode();
   1315   manager2.RemoveCurrNode();
   1316 
   1317   EXPECT_TRUE(manager1.Empty());
   1318   EXPECT_TRUE(manager2.Empty());
   1319 }
   1320 
   1321 TEST_F(VirtualSchedulerTest, RemoveSingleNodeCompositeNodeManager) {
   1322   CompositeNodeManager manager;
   1323   manager.Init(&node_states_);
   1324   manager.AddNode(&node1_);
   1325   manager.RemoveCurrNode();
   1326   EXPECT_TRUE(manager.Empty());
   1327 }
   1328 
   1329 TEST_F(VirtualSchedulerTest, RemoveSingleNodeComopsiteNodeManager) {
   1330   CompositeNodeManager manager;
   1331   manager.Init(&node_states_);
   1332 
   1333   manager.AddNode(&node1_);
   1334   manager.RemoveCurrNode();
   1335   EXPECT_TRUE(manager.Empty());
   1336 }
   1337 
   1338 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleComopsiteNodeManager) {
   1339   CompositeNodeManager manager;
   1340   manager.Init(&node_states_);
   1341 
   1342   // Add the nodes to LIFOManager.
   1343   manager.AddNode(&node1_);
   1344   manager.AddNode(&node2_);
   1345   manager.AddNode(&node3_);
   1346   manager.AddNode(&node4_);
   1347 
   1348   // Keep checking current node as nodes are removed and added.
   1349   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
   1350   manager.RemoveCurrNode();
   1351   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
   1352   manager.AddNode(&node5_);
   1353   // GetCurrNode()  should return the same node even if some nodes are added,
   1354   // until RemoveCurrNode() is called.
   1355   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
   1356   manager.RemoveCurrNode();
   1357   EXPECT_EQ("Node5", manager.GetCurrNode()->name());
   1358   manager.RemoveCurrNode();
   1359   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
   1360   manager.AddNode(&node6_);
   1361   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
   1362   manager.RemoveCurrNode();
   1363   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
   1364   manager.RemoveCurrNode();
   1365   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
   1366   manager.RemoveCurrNode();
   1367   EXPECT_TRUE(manager.Empty());
   1368 }
   1369 
   1370 TEST_F(VirtualSchedulerTest, MultiDeviceSendRecvComopsiteNodeManager) {
   1371   CompositeNodeManager manager;
   1372   manager.Init(&node_states_);
   1373   // Additional nodes on kCPU1
   1374   NodeDef node7;
   1375   NodeDef node8;
   1376   NodeDef node9;
   1377   NodeSetUp("Node7", kConv2D, kCPU1, 1001, &node7);
   1378   NodeSetUp("Node8", kConv2D, kCPU1, 2001, &node8);
   1379   NodeSetUp("Node9", kConv2D, kCPU1, 3001, &node9);
   1380 
   1381   // Send and Recv nodes.
   1382   NodeDef send1;
   1383   NodeDef send2;
   1384   NodeDef recv1;
   1385   NodeDef recv2;
   1386   NodeSetUp("Send1", kSend, kChannelFrom0To1, 2002, &send1);
   1387   NodeSetUp("Send2", kSend, kChannelFrom1To0, 2005, &send2);
   1388   NodeSetUp("Recv1", kRecv, kCPU0, 2003, &recv1);
   1389   NodeSetUp("Recv2", kRecv, kCPU1, 2004, &recv2);
   1390 
   1391   // Insert nodes.
   1392   manager.AddNode(&node1_);
   1393   manager.AddNode(&node2_);
   1394   manager.AddNode(&node3_);
   1395   manager.AddNode(&node4_);
   1396   manager.AddNode(&node5_);
   1397   manager.AddNode(&node6_);
   1398   manager.AddNode(&node7);
   1399   manager.AddNode(&node8);
   1400   manager.AddNode(&node9);
   1401   manager.AddNode(&send1);
   1402   manager.AddNode(&send2);
   1403   manager.AddNode(&recv1);
   1404   manager.AddNode(&recv2);
   1405 
   1406   // on kCPU0; last one is node6_, on kCPU1: last one is node9;
   1407   // so choose one that has earliest time_ready among node6_, node9,
   1408   // Send1, Send2, Recv1, and Recv2.
   1409   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
   1410   manager.RemoveCurrNode();
   1411   // Then, the next one on kCPU0 is node5_; choose the earliest time_ready node
   1412   // among node5_, node9, Send1, Send2, Recv1, and Recv2.
   1413   EXPECT_EQ("Node5", manager.GetCurrNode()->name());
   1414   manager.RemoveCurrNode();
   1415   // Next, choose among node4_, node9, Send1, Send2, Recv1, and Recv2.
   1416   EXPECT_EQ("Send1", manager.GetCurrNode()->name());
   1417   manager.RemoveCurrNode();
   1418   // Next, choose among node4_, node9, Sen2, Recv1, and Recv2.
   1419   EXPECT_EQ("Recv1", manager.GetCurrNode()->name());
   1420   manager.RemoveCurrNode();
   1421   // Next, choose among node4_, node9, Send2, and Recv2.
   1422   EXPECT_EQ("Recv2", manager.GetCurrNode()->name());
   1423   manager.RemoveCurrNode();
   1424   // Next, choose among node4_, node9, and Send2.
   1425   EXPECT_EQ("Send2", manager.GetCurrNode()->name());
   1426   manager.RemoveCurrNode();
   1427   // Next, choose between node4_, node9.
   1428   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
   1429   manager.RemoveCurrNode();
   1430   // Next, choose between node3_, node9.
   1431   EXPECT_EQ("Node9", manager.GetCurrNode()->name());
   1432   manager.RemoveCurrNode();
   1433   // Next, choose between node3_, node8.
   1434   EXPECT_EQ("Node8", manager.GetCurrNode()->name());
   1435   manager.RemoveCurrNode();
   1436   // Next, choose between node3_, node7.
   1437   EXPECT_EQ("Node7", manager.GetCurrNode()->name());
   1438   manager.RemoveCurrNode();
   1439   // Then, just the nodes on kCPU1 -- LIFO.
   1440   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
   1441   manager.RemoveCurrNode();
   1442   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
   1443   manager.RemoveCurrNode();
   1444   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
   1445   manager.RemoveCurrNode();
   1446   EXPECT_TRUE(manager.Empty());
   1447 }
   1448 
   1449 TEST_F(VirtualSchedulerTest, DeterminismInCompositeNodeManager) {
   1450   CompositeNodeManager manager;
   1451   manager.Init(&node_states_);
   1452   CompositeNodeManager manager2;
   1453   manager2.Init(&node_states_);
   1454 
   1455   // 6 nodes with same time_ready.
   1456   NodeDef node7;
   1457   NodeDef node8;
   1458   NodeDef node9;
   1459   NodeDef node10;
   1460   NodeDef node11;
   1461   NodeDef node12;
   1462   NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
   1463   NodeSetUp("Node8", kSend, kCPU0, 1000, &node8);
   1464   NodeSetUp("Node9", kRecv, kCPU0, 1000, &node9);
   1465   NodeSetUp("Node10", kConv2D, kCPU0, 999, &node10);
   1466   NodeSetUp("Node11", kRecv, kCPU0, 999, &node11);
   1467   NodeSetUp("Node12", kConv2D, kCPU1, 1000, &node12);
   1468 
   1469   // Add Nodes 7 to 9 to manager.
   1470   manager.AddNode(&node7);
   1471   manager.AddNode(&node8);
   1472   manager.AddNode(&node9);
   1473 
   1474   // It should return _Send, Recv, and the other op order, when the candidate
   1475   // nodes have same time_ready.
   1476   EXPECT_EQ("Node8", manager.GetCurrNode()->name());
   1477   EXPECT_EQ(kSend, manager.GetCurrNode()->op());
   1478   manager.RemoveCurrNode();
   1479   EXPECT_EQ("Node9", manager.GetCurrNode()->name());
   1480   EXPECT_EQ(kRecv, manager.GetCurrNode()->op());
   1481   manager.RemoveCurrNode();
   1482   EXPECT_EQ("Node7", manager.GetCurrNode()->name());
   1483   EXPECT_EQ(kConv2D, manager.GetCurrNode()->op());
   1484   manager.RemoveCurrNode();
   1485   EXPECT_TRUE(manager.Empty());
   1486 
   1487   // Add Nodes 7 to 9 to manager, but in a different order.
   1488   manager.AddNode(&node9);
   1489   manager.AddNode(&node8);
   1490   manager.AddNode(&node7);
   1491 
   1492   // Expect same order (_Send, _Recv, and the other op), regardless of Add
   1493   // order.
   1494   EXPECT_EQ("Node8", manager.GetCurrNode()->name());
   1495   EXPECT_EQ(kSend, manager.GetCurrNode()->op());
   1496   manager.RemoveCurrNode();
   1497   EXPECT_EQ("Node9", manager.GetCurrNode()->name());
   1498   EXPECT_EQ(kRecv, manager.GetCurrNode()->op());
   1499   manager.RemoveCurrNode();
   1500   EXPECT_EQ("Node7", manager.GetCurrNode()->name());
   1501   EXPECT_EQ(kConv2D, manager.GetCurrNode()->op());
   1502   manager.RemoveCurrNode();
   1503   EXPECT_TRUE(manager.Empty());
   1504 
   1505   // Conv2D's time_ready < Send's time_ready; Expect Conv2D first.
   1506   manager.AddNode(&node8);
   1507   manager.AddNode(&node10);
   1508   EXPECT_EQ("Node10", manager.GetCurrNode()->name());
   1509   EXPECT_EQ(kConv2D, manager.GetCurrNode()->op());
   1510   manager.RemoveCurrNode();
   1511   EXPECT_EQ("Node8", manager.GetCurrNode()->name());
   1512   EXPECT_EQ(kSend, manager.GetCurrNode()->op());
   1513   manager.RemoveCurrNode();
   1514   EXPECT_TRUE(manager.Empty());
   1515 
   1516   // Recv's time_ready < Send' time_ready; Expect Recv first.
   1517   manager.AddNode(&node11);
   1518   manager.AddNode(&node8);
   1519   EXPECT_EQ("Node11", manager.GetCurrNode()->name());
   1520   EXPECT_EQ(kRecv, manager.GetCurrNode()->op());
   1521   manager.RemoveCurrNode();
   1522   EXPECT_EQ("Node8", manager.GetCurrNode()->name());
   1523   EXPECT_EQ(kSend, manager.GetCurrNode()->op());
   1524   manager.RemoveCurrNode();
   1525   EXPECT_TRUE(manager.Empty());
   1526 
   1527   // Node7 and 12 are normal ops with the same time_ready, placed on different
   1528   // devices. These two nodes are added to manager and manager2, but in
   1529   // different orders; Expect GetCurrNode() returns the nodes in the same order.
   1530   manager.AddNode(&node7);
   1531   manager.AddNode(&node12);
   1532 
   1533   manager2.AddNode(&node12);
   1534   manager2.AddNode(&node7);
   1535 
   1536   EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
   1537   manager.RemoveCurrNode();
   1538   manager2.RemoveCurrNode();
   1539   EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
   1540   manager.RemoveCurrNode();
   1541   manager2.RemoveCurrNode();
   1542   EXPECT_TRUE(manager.Empty());
   1543 }
   1544 
   1545 // Create small graph, run predict costs on it, make sure the costs from the
   1546 // summary match the hand-calculated costs.
   1547 TEST_F(VirtualSchedulerTest, SummaryCostTest) {
   1548   // Run matmul test.
   1549   CreateGrapplerItemWithMatmulChain();
   1550   InitScheduler();
   1551   auto ops_executed = RunScheduler("");
   1552   Costs c = scheduler_->Summary();
   1553 
   1554   // RandomUniform - 5 * 1s
   1555   // Matmuls - 4 * 2s = 8
   1556   // Misc - 5 * 1us
   1557   // Total: 13000005
   1558   EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
   1559 }
   1560 
   1561 // Like the above SummaryCostTest, but makes sure the stepstats timeline is
   1562 // correct.
   1563 TEST_F(VirtualSchedulerTest, SummaryCostStepStatsTest) {
   1564   // Run matmul test.
   1565   CreateGrapplerItemWithMatmulChain();
   1566   InitScheduler();
   1567   auto ops_executed = RunScheduler("");
   1568   RunMetadata metadata;
   1569   Costs c = scheduler_->Summary(&metadata);
   1570   StepStats stepstats = metadata.step_stats();
   1571   EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
   1572 
   1573   // Should only be 1 device!
   1574   EXPECT_EQ(1, stepstats.dev_stats().size());
   1575 
   1576   // Create a map of op name -> start and end times (micros).
   1577   std::map<string, std::pair<int64, int64>> start_end_times;
   1578   for (const auto& device_step_stats : stepstats.dev_stats()) {
   1579     for (const auto& stats : device_step_stats.node_stats()) {
   1580       int64 start = stats.all_start_micros();
   1581       int64 end = start + stats.all_end_rel_micros();
   1582       start_end_times[stats.node_name()] = std::pair<int64, int64>(start, end);
   1583 
   1584       // Make sure that the output properties are correct for
   1585       // MatMul and RandomUniform operations.
   1586       // We only check for dtype, and shape (excluding alloc)
   1587       // since alloc is not set by the virtual scheduler.
   1588       if (stats.timeline_label() == "MatMul" ||
   1589           stats.timeline_label() == "RandomUniform") {
   1590         EXPECT_EQ(1, stats.output().size());
   1591         for (const auto& output : stats.output()) {
   1592           EXPECT_EQ(DT_FLOAT, output.tensor_description().dtype());
   1593           EXPECT_EQ(2, output.tensor_description().shape().dim().size());
   1594           for (const auto& dim : output.tensor_description().shape().dim()) {
   1595             EXPECT_EQ(3200, dim.size());
   1596           }
   1597         }
   1598       }
   1599     }
   1600   }
   1601 
   1602   // The base start_time is the time to compute RandomUniforms
   1603   int64 cur_time = static_cast<int64>(5000005);
   1604   // The increment is the execution time of one matmul. See
   1605   // CreateGrapplerItemWithMatmulChain for details.
   1606   int64 increment = static_cast<int64>(2000000);
   1607   auto op_names = {"ab", "abc", "abcd", "abcde"};
   1608   for (const auto& op_name : op_names) {
   1609     int64 actual_start = start_end_times[op_name].first;
   1610     int64 actual_end = start_end_times[op_name].second;
   1611     int64 expected_start = cur_time;
   1612     int64 expected_end = cur_time + increment;
   1613     EXPECT_EQ(expected_start, actual_start);
   1614     EXPECT_EQ(expected_end, actual_end);
   1615     cur_time += increment;
   1616   }
   1617 }
   1618 
   1619 TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
   1620   // Init.
   1621   CreateGrapplerItemWithConv2Ds();
   1622   InitScheduler();
   1623 
   1624   // Run the scheduler.
   1625   auto ops_executed = RunScheduler("");  // Run all the nodes.
   1626 
   1627   // [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be
   1628   // executed.
   1629   EXPECT_EQ(8, ops_executed.size());
   1630 
   1631   // x, y, f, c0, and c1 should be in the ops executed.
   1632   EXPECT_GT(ops_executed.count("x"), 0);
   1633   EXPECT_GT(ops_executed.count("y"), 0);
   1634   EXPECT_GT(ops_executed.count("f"), 0);
   1635   EXPECT_GT(ops_executed.count("c0"), 0);
   1636   EXPECT_GT(ops_executed.count("c1"), 0);
   1637 
   1638   // z and c2 shouldn't be part of it.
   1639   EXPECT_EQ(ops_executed.count("z"), 0);
   1640   EXPECT_EQ(ops_executed.count("c2"), 0);
   1641 
   1642   // Check input / output properties.
   1643   EXPECT_EQ(1, ops_executed["x"].op_info.outputs_size());
   1644   EXPECT_EQ(1, ops_executed["y"].op_info.outputs_size());
   1645   EXPECT_EQ(1, ops_executed["f"].op_info.outputs_size());
   1646   EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size());
   1647   EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
   1648 }
   1649 
   1650 TEST_F(VirtualSchedulerTest, CalculateOutputSize) {
   1651   // Init.
   1652   CreateGrapplerItemWithAddN();
   1653   InitScheduler();
   1654 
   1655   // Create a set of tensor properties.
   1656   std::vector<OpInfo::TensorProperties> output;
   1657   output.push_back(ShapeToTensorProperty({4, 4}, DT_FLOAT));           // 0
   1658   output.push_back(ShapeToTensorProperty({1}, DT_FLOAT));              // 1
   1659   output.push_back(ShapeToTensorProperty({10, 10, 10}, DT_HALF));      // 2
   1660   output.push_back(ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT));  // 3
   1661   output.push_back(ShapeToTensorProperty({-1, 7, 8, 99}, DT_FLOAT));   // 4
   1662   output.push_back(ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT));  // 4
   1663 
   1664   // port_num -1 is for control dependency: hard coded 4B.
   1665   EXPECT_EQ(4, scheduler_->CalculateOutputSize(output, -1));
   1666 
   1667   // Test valid outputs.
   1668   EXPECT_EQ(4 * 4 * 4, scheduler_->CalculateOutputSize(output, 0));
   1669   EXPECT_EQ(4 * 1, scheduler_->CalculateOutputSize(output, 1));
   1670   EXPECT_EQ(2 * 10 * 10 * 10, scheduler_->CalculateOutputSize(output, 2));
   1671   EXPECT_EQ(4 * 100 * 7 * 8 * 99, scheduler_->CalculateOutputSize(output, 3));
   1672 
   1673   // Any unknown shape (-1) shall yield zero output size.
   1674   EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 4));
   1675   EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 5));
   1676 
   1677   // Invalid port_num (though it may be an error) shall yield zero
   1678   // output size.
   1679   EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 6));
   1680 }
   1681 
   1682 TEST_F(VirtualSchedulerTest, MemoryUsage) {
   1683   // Init.
   1684   CreateGrapplerItemWithAddN();
   1685   InitScheduler();
   1686 
   1687   // Run the scheduler.
   1688   RunScheduler("");
   1689 
   1690   const auto* device_states = scheduler_->GetDeviceStates();
   1691   const auto& cpu_state = device_states->at(kCPU0);
   1692 
   1693   // out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage
   1694   // is 4 x the input tensor size while executing the out node.
   1695   int64 one_input_node_size = 4 * 10 * 10 * 10 * 10;
   1696   const std::vector<string> expected_names = {"x", "y", "z", "w"};
   1697   EXPECT_EQ(expected_names.size() * one_input_node_size,
   1698             cpu_state.max_memory_usage);
   1699   ValidateMemoryUsageSnapshot(expected_names, 0 /* port_num_expected */,
   1700                               cpu_state.mem_usage_snapshot_at_peak);
   1701 }
   1702 
   1703 TEST_F(VirtualSchedulerTest, ControlDependency) {
   1704   // Init.
   1705   CreateGrapplerItemWithControlDependency();
   1706   InitScheduler();
   1707 
   1708   // Run the scheduler.
   1709   RunScheduler("");
   1710 
   1711   const auto* device_states = scheduler_->GetDeviceStates();
   1712   const auto& cpu_state = device_states->at(kCPU0);
   1713 
   1714   // The graph has a NoOp that takes control dependency from 7 NoOps. The peak
   1715   // memory usage is when executing the final NoOp.
   1716   int64 one_input_node_size = 4;  // control dependency
   1717   const std::vector<string> expected_names = {"x", "y", "z", "w",
   1718                                               "u", "v", "t"};
   1719   EXPECT_EQ(expected_names.size() * one_input_node_size,
   1720             cpu_state.max_memory_usage);
   1721   ValidateMemoryUsageSnapshot(expected_names, -1 /* port_num_expected */,
   1722                               cpu_state.mem_usage_snapshot_at_peak);
   1723 }
   1724 
   1725 TEST_F(VirtualSchedulerTest, ComplexDependency) {
   1726   // Init.
   1727   CreateGrapplerItemWithBatchNorm();
   1728   InitScheduler();
   1729 
   1730   // Run the scheduler.
   1731   RunScheduler("bn");
   1732 
   1733   const auto& device_states = scheduler_->GetDeviceStates();
   1734   const auto& cpu_state = device_states->at(kCPU0);
   1735 
   1736   // The graph is
   1737   //  bn = FusedBatchNorm(x, scale, offset, mean, var)
   1738   //  z1 = bn.y + x
   1739   //  z2 = bn.var + bn.var
   1740   //  z3 = bn.var + bn.var
   1741   //  z4 = control dependency from bn.
   1742   //  Note that bn.mean doesn't have any consumer.
   1743   const int x_size = batch_size_ * width_ * height_ * depth_in_;
   1744   int64 expected_size =
   1745       4 * (2 * x_size /* x and bn.y */ + depth_in_ /* bn.var */ +
   1746            1 /* control dependency */);
   1747   EXPECT_EQ(expected_size, cpu_state.memory_usage);
   1748 
   1749   // Nodes currently in memory: bn's port -1, 0, and 2, and x's port 0.
   1750   std::set<std::pair<string, int>> nodes_in_memory;
   1751   std::transform(
   1752       cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(),
   1753       std::inserter(nodes_in_memory, nodes_in_memory.begin()),
   1754       [](const std::pair<const NodeDef*, int>& node_port) {
   1755         return std::make_pair(node_port.first->name(), node_port.second);
   1756       });
   1757   std::set<std::pair<string, int>> expected = {
   1758       std::make_pair("bn", -1),
   1759       std::make_pair("bn", 0),
   1760       std::make_pair("bn", 2),
   1761       std::make_pair("x", 0),
   1762   };
   1763   ExpectSetEq(expected, nodes_in_memory);
   1764 
   1765   const auto* node_states = scheduler_->GetNodeStates();
   1766   const NodeState* bn_node = nullptr;
   1767   const NodeState* x_node = nullptr;
   1768   for (const auto& nodedef_node_state : *node_states) {
   1769     const NodeDef* node = nodedef_node_state.first;
   1770     const NodeState& node_state = nodedef_node_state.second;
   1771     if (node->name() == "bn") {
   1772       bn_node = &node_state;
   1773     }
   1774     if (node->name() == "x") {
   1775       x_node = &node_state;
   1776     }
   1777   }
   1778   CHECK_NOTNULL(bn_node);
   1779   CHECK_NOTNULL(x_node);
   1780 
   1781   ValidateNodeDefs({"bn", "z1"}, x_node->outputs.at(0));
   1782   ValidateNodeDefs({"z4"}, bn_node->outputs.at(-1));
   1783   ValidateNodeDefs({"z1"}, bn_node->outputs.at(0));
   1784   // z2 and z3 are bn.var + bn.var, so they appear twice in bn's output port 2.
   1785   ValidateNodeDefs({"z2", "z3", "z2", "z3"}, bn_node->outputs.at(2));
   1786 }
   1787 
   1788 TEST_F(VirtualSchedulerTest, Variable) {
   1789   // Init.
   1790   CreateGrapplerItemWithConv2DAndVariable();
   1791   InitScheduler();
   1792 
   1793   // Run the scheduler.
   1794   RunScheduler("");
   1795 
   1796   const auto* device_states = scheduler_->GetDeviceStates();
   1797   const auto& cpu_state = device_states->at(kCPU0);
   1798 
   1799   // There is one Conv2D that takes x and f, but f is variable, so it should be
   1800   // in persistent nodes.
   1801   // f is variable.
   1802   ValidateMemoryUsageSnapshot({"f"}, 0 /* port_num_expected */,
   1803                               cpu_state.persistent_nodes);
   1804   // Only x in peak memory usage snapshot.
   1805   ValidateMemoryUsageSnapshot({"x"}, 0 /* port_num_expected */,
   1806                               cpu_state.mem_usage_snapshot_at_peak);
   1807 }
   1808 
   1809 TEST_F(VirtualSchedulerTest, WhileLoop) {
   1810   // Init.
   1811   CreateGrapplerItemWithLoop();
   1812   InitScheduler();
   1813 
   1814   // Run the scheduler.
   1815   RunScheduler("");
   1816 
   1817   // Check the timeline
   1818   RunMetadata metadata;
   1819   scheduler_->Summary(&metadata);
   1820 
   1821   // Nodes in topological order:
   1822   // * const, ones
   1823   // * while/Enter, while/Enter_1
   1824   // * while/Merge, while/Merge_1
   1825   // * while/Less/y
   1826   // * while/Less
   1827   // * while/LoopCond
   1828   // * while/Switch, while/Switch_1
   1829   // * while/Identity, while/Identity_1, while/Exit, while/Exit_1
   1830   // * while/add/y, while/concat/axis
   1831   // * while/add, while/concat
   1832   // * while/NextIteration, while/NextIteration_1
   1833 
   1834   int num_next_iteration = 0;
   1835   int num_next_iteration_1 = 0;
   1836   int num_exit = 0;
   1837   int num_exit_1 = 0;
   1838   int64 next_iter_start_micro;
   1839   int64 next_iter_1_start_micro;
   1840   int64 exit_start_micro;
   1841   int64 exit_1_start_micro;
   1842 
   1843   std::unordered_map<string, int64> start_times;
   1844   for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
   1845     for (const auto& stats : device_step_stats.node_stats()) {
   1846       start_times[stats.node_name()] = stats.all_start_micros();
   1847       if (stats.node_name() == "while/NextIteration") {
   1848         ++num_next_iteration;
   1849         next_iter_start_micro = stats.all_start_micros();
   1850       } else if (stats.node_name() == "while/NextIteration_1") {
   1851         ++num_next_iteration_1;
   1852         next_iter_1_start_micro = stats.all_start_micros();
   1853       } else if (stats.node_name() == "while/Exit") {
   1854         ++num_exit;
   1855         exit_start_micro = stats.all_start_micros();
   1856       } else if (stats.node_name() == "while/Exit_1") {
   1857         ++num_exit_1;
   1858         exit_1_start_micro = stats.all_start_micros();
   1859       }
   1860     }
   1861   }
   1862 
   1863   // Make sure we went though the body of the loop once, and that the output of
   1864   // the loop was scheduled as well.
   1865   EXPECT_EQ(1, num_next_iteration);
   1866   EXPECT_EQ(1, num_next_iteration_1);
   1867   EXPECT_EQ(1, num_exit);
   1868   EXPECT_EQ(1, num_exit_1);
   1869 
   1870   // Start times of while/NextIteration and while/NextIteration_1 should be
   1871   // different, so should be those of while/Exit and while/Exit_1.
   1872   EXPECT_NE(next_iter_start_micro, next_iter_1_start_micro);
   1873   EXPECT_NE(exit_start_micro, exit_1_start_micro);
   1874 
   1875   // Check dependency among the nodes; no matter what scheduling mechanism we
   1876   // use, the scheduled ops should follow these dependency chains.
   1877   // Note that currently, VirtualScheduler executes while/Merge twice; hence,
   1878   // we're not testing dependency chains related to while/Merge.
   1879   // TODO(dyoon): after fixing while loop behavior correctly (run nodes in the
   1880   // order of Enter, Merge, ...loop condition ..., ... loop body ...,
   1881   // NextIteration, Merge, ... loop condition ..., Exit), re-enable dependency
   1882   // chaing test w/ Merge nodes.
   1883   ValidateDependencyChain(
   1884       start_times,
   1885       {"Const", "while/Enter",  // "while/Merge",
   1886        "while/Less/y", "while/Less", "while/LoopCond", "while/Switch",
   1887        "while/Identity", "while/add/y", "while/add", "while/NextIteration"});
   1888   // ValidateDependencyChain(start_times, {"while/Merge", "while/Less"});
   1889   ValidateDependencyChain(start_times,
   1890                           {"ones", "while/Enter_1",  // "while/Merge_1",
   1891                            "while/Switch_1", "while/Identity_1", "while/concat",
   1892                            "while/NextIteration_1"});
   1893   ValidateDependencyChain(start_times, {"while/Switch", "while/Exit"});
   1894   ValidateDependencyChain(
   1895       start_times, {"while/Identity", "while/concat/axis", "while/concat"});
   1896   ValidateDependencyChain(start_times, {"while/Identity", "while/add"});
   1897   ValidateDependencyChain(start_times, {"while/Switch_1", "while/Exit_1"});
   1898 }
   1899 
   1900 TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
   1901   // Init.
   1902   CreateGrapplerItemWithInterDeviceTransfers();
   1903   InitScheduler();
   1904 
   1905   // Run the scheduler.
   1906   auto ops_executed = RunScheduler("");
   1907 
   1908   // Helper lambda to extract port num from _Send and _Recv op name.
   1909   auto get_port_num = [](const string& name) -> int {
   1910     if (name.find("bn_0") != std::string::npos) {
   1911       return 0;
   1912     } else if (name.find("bn_1") != std::string::npos) {
   1913       return 1;
   1914     } else if (name.find("bn_2") != std::string::npos) {
   1915       return 2;
   1916     } else if (name.find("bn_minus1") != std::string::npos) {
   1917       return -1;
   1918     }
   1919     return -999;
   1920   };
   1921 
   1922   // Reorganize ops_executed for further testing.
   1923   std::unordered_map<string, int> op_count;
   1924   std::unordered_map<int, string> recv_op_names;
   1925   std::unordered_map<int, string> send_op_names;
   1926   for (const auto& x : ops_executed) {
   1927     const auto& name = x.first;
   1928     const auto& node_info = x.second;
   1929     const auto& op = node_info.op_info.op();
   1930     if (op == kRecv) {
   1931       recv_op_names[get_port_num(name)] = name;
   1932     } else if (op == kSend) {
   1933       send_op_names[get_port_num(name)] = name;
   1934     }
   1935     op_count[op]++;
   1936   }
   1937 
   1938   // Same number of _Send and _Recv.
   1939   EXPECT_EQ(op_count.at(kSend), op_count.at(kRecv));
   1940 
   1941   // Expect 4 Send and Recvs each: port 0, 1, and, 2, and control dependency.
   1942   EXPECT_EQ(op_count.at(kRecv), 4);
   1943   EXPECT_EQ(op_count.at(kSend), 4);
   1944 
   1945   // Helper lambda for extracting output Tensor size.
   1946   auto get_output_size = [this, ops_executed](const string& name) -> int64 {
   1947     const auto& output_properties_ = ops_executed.at(name).op_info.outputs();
   1948     std::vector<OpInfo::TensorProperties> output_properties;
   1949     for (const auto& output_property : output_properties_) {
   1950       output_properties.push_back(output_property);
   1951     }
   1952     return scheduler_->CalculateOutputSize(output_properties, 0);
   1953   };
   1954 
   1955   // Validate transfer size.
   1956   // Batchnorm output y is 4D vector: batch x width x width x depth.
   1957   int input_size = 4 * batch_size_ * width_ * height_ * depth_in_;
   1958   EXPECT_EQ(get_output_size(recv_op_names[0]), input_size);
   1959   EXPECT_EQ(get_output_size(send_op_names[0]), input_size);
   1960   // Mean and vars are 1-D vector with size depth_in_.
   1961   EXPECT_EQ(get_output_size(recv_op_names[1]), 4 * depth_in_);
   1962   EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_);
   1963   EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_);
   1964   EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_);
   1965   // Control dependency size is 4B.
   1966   EXPECT_EQ(get_output_size(recv_op_names[-1]), 4);
   1967   EXPECT_EQ(get_output_size(send_op_names[-1]), 4);
   1968 }
   1969 
   1970 TEST_F(VirtualSchedulerTest, GraphWithSendRecv) {
   1971   // Init.
   1972   CreateGrapplerItemWithSendRecv();
   1973   InitScheduler();
   1974 
   1975   // Run the scheduler.
   1976   auto ops_executed = RunScheduler("");
   1977 
   1978   EXPECT_GT(ops_executed.count("Const"), 0);
   1979   EXPECT_GT(ops_executed.count("Send"), 0);
   1980   EXPECT_GT(ops_executed.count("Recv"), 0);
   1981 }
   1982 
   1983 TEST_F(VirtualSchedulerTest, GraphWithSendRecvDifferentDevice) {
   1984   // Init.
   1985   CreateGrapplerItemWithSendRecv();
   1986   // Change Recv node's device so that Send and Recv are placed on different
   1987   // devices.
   1988   auto& graph = grappler_item_->graph;
   1989   const string recv_device = kCPU1;
   1990   for (int i = 0; i < graph.node_size(); i++) {
   1991     auto* node = graph.mutable_node(i);
   1992     if (node->name() == "Recv") {
   1993       node->set_device(recv_device);
   1994       auto* attr = node->mutable_attr();
   1995       (*attr)["recv_device"].set_s(recv_device);
   1996     } else if (node->name() == "Send") {
   1997       auto* attr = node->mutable_attr();
   1998       (*attr)["recv_device"].set_s(recv_device);
   1999     }
   2000   }
   2001   InitScheduler();
   2002 
   2003   // Run the scheduler.
   2004   auto ops_executed = RunScheduler("");
   2005 
   2006   // Expect Const, Send, Recv, and VirtualScheduler created Send and Recv ops.
   2007   EXPECT_GT(ops_executed.count("Const"), 0);
   2008   EXPECT_GT(ops_executed.count("Send"), 0);
   2009   EXPECT_GT(ops_executed.count("Send_Send_0_from_/job_localhost/replica_0/"
   2010                                "task_0/cpu_0_to_/job_localhost"
   2011                                "/replica_0/task_0/cpu_1"),
   2012             0);
   2013   EXPECT_GT(ops_executed.count(
   2014                 "Recv_Send_0_on_/job_localhost/replica_0/task_0/cpu_1"),
   2015             0);
   2016   EXPECT_GT(ops_executed.count("Recv"), 0);
   2017 }
   2018 }  // end namespace grappler
   2019 }  // end namespace tensorflow
   2020