Home | History | Annotate | Download | only in costs
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
     17 
     18 #include "tensorflow/core/framework/allocation_description.pb.h"
     19 #include "tensorflow/core/framework/attr_value.pb.h"
     20 #include "tensorflow/core/framework/node_def.pb.h"
     21 #include "tensorflow/core/framework/tensor.pb.h"
     22 #include "tensorflow/core/framework/tensor_description.pb.h"
     23 #include "tensorflow/core/framework/tensor_shape.pb.h"
     24 #include "tensorflow/core/grappler/clusters/utils.h"
     25 #include "tensorflow/core/grappler/costs/utils.h"
     26 #include "tensorflow/core/grappler/op_types.h"
     27 #include "tensorflow/core/grappler/utils.h"
     28 #include "tensorflow/core/lib/core/errors.h"
     29 #include "tensorflow/core/lib/strings/numbers.h"
     30 #include "tensorflow/core/lib/strings/str_util.h"
     31 #include "tensorflow/core/lib/strings/stringprintf.h"
     32 #include "tensorflow/core/platform/logging.h"
     33 #include "tensorflow/core/util/device_name_utils.h"
     34 
     35 namespace tensorflow {
     36 namespace grappler {
     37 
     38 namespace {
     39 
     40 // Key to the cached _Recv ops map, and its hash and predicate structures.
     41 struct RecvNodeDescriptor {
     42   const NodeDef* node;
     43   const int port_num;
     44   const string device;
     45 
     46   RecvNodeDescriptor(const NodeDef* node_, const int port_num_,
     47                      const string& device_)
     48       : node(node_), port_num(port_num_), device(device_) {}
     49 };
     50 
     51 struct RecvNodeDescriptorHash {
     52   std::size_t operator()(const RecvNodeDescriptor& recv_node) const {
     53     return std::hash<const NodeDef*>()(recv_node.node) ^
     54            std::hash<int>()(recv_node.port_num) ^
     55            std::hash<string>()(recv_node.device);
     56   }
     57 };
     58 
     59 struct RecvNodeDescriptorEqual {
     60   bool operator()(const RecvNodeDescriptor& a,
     61                   const RecvNodeDescriptor& b) const {
     62     return a.node == b.node && a.port_num == b.port_num && a.device == b.device;
     63   }
     64 };
     65 }  // namespace
     66 
     67 // ReadyNodeManager
     68 const NodeDef* LIFOManager::GetCurrNode() {
     69   CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
     70   if (curr_pos_ == nodes_.end()) {
     71     curr_pos_ = --(nodes_.rbegin().base());  // Last one in the list.
     72   }
     73   // Once curr_pos_ is set to a valid entry in the list, we keep using the
     74   // cached curr_pos_ until RemoveCurrNode() is called. AddNode() will not
     75   // change the GetCurrNode() return value.
     76   return *curr_pos_;
     77 }
     78 
     79 void LIFOManager::RemoveCurrNode() {
     80   // Make sure we have curr_pos_ ready to be removed.
     81   GetCurrNode();
     82   // Note curr_pos_ may not be pointing the last element if some nodes are
     83   // added.
     84   nodes_.erase(curr_pos_);
     85 
     86   curr_pos_ = nodes_.end();  // Reset curr_pos_.
     87 }
     88 
     89 FirstReadyManager::FirstReadyManager() : ReadyNodeManager() {
     90   std::make_heap(nodes_.begin(), nodes_.end());
     91 }
     92 
     93 void FirstReadyManager::Init(
     94     const std::unordered_map<const NodeDef*, NodeState>* node_state) {
     95   // Reset the node state since different instances of the scheduler can reuse
     96   // the same node_manager.
     97   node_state_ = node_state;
     98   nodes_.clear();
     99   waiting_queue_.clear();
    100   greater_ = [this](const NodeDef* a, const NodeDef* b) -> bool {
    101     if (node_state_->at(a).time_ready == node_state_->at(b).time_ready) {
    102       // Use Node name as tie-breaker for deterministic node scheduling.
    103       return a->name().compare(b->name()) > 0;
    104     } else {
    105       // Note: we need a node with minimum time_ready, not
    106       // maximum; hence, using a > b for comparison function.
    107       return node_state_->at(a).time_ready > node_state_->at(b).time_ready;
    108     }
    109   };
    110 }
    111 
    112 const NodeDef* FirstReadyManager::GetCurrNode() {
    113   if (nodes_.empty()) {
    114     // Nothing in the node_; probably, the very first call. Move
    115     // waiting_queue_ to node_.
    116     DrainWaitingQueue();
    117     CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
    118   }
    119   return nodes_.front();
    120 }
    121 
    122 void FirstReadyManager::RemoveCurrNode() {
    123   if (nodes_.empty()) {
    124     // Make sure that there is a node to be removed at the front of nodes_.
    125     GetCurrNode();
    126   }
    127   std::pop_heap(nodes_.begin(), nodes_.end(), greater_);
    128   nodes_.pop_back();
    129   DrainWaitingQueue();
    130 }
    131 
    132 bool FirstReadyManager::Empty() const {
    133   return nodes_.empty() && waiting_queue_.empty();
    134 }
    135 
    136 void FirstReadyManager::DrainWaitingQueue() {
    137   for (const auto* node : waiting_queue_) {
    138     // push_heap in AddNode() and pop_heap in RemoveCurrNode() guarantees that
    139     // the first element is the node with minimum time_ready.
    140     nodes_.push_back(node);
    141     std::push_heap(nodes_.begin(), nodes_.end(), greater_);
    142   }
    143   waiting_queue_.clear();
    144 }
    145 
    146 CompositeNodeManager::CompositeNodeManager()
    147     : ReadyNodeManager(), send_manager_(), recv_manager_() {}
    148 
    149 void CompositeNodeManager::Init(
    150     const std::unordered_map<const NodeDef*, NodeState>* node_state) {
    151   node_state_ = node_state;
    152   send_manager_.Init(node_state);
    153   recv_manager_.Init(node_state);
    154   curr_node_ = nullptr;
    155 }
    156 
    157 void CompositeNodeManager::AddNode(const NodeDef* node) {
    158   if (IsSend(*node)) {
    159     send_manager_.AddNode(node);
    160   } else if (IsRecv(*node)) {
    161     recv_manager_.AddNode(node);
    162   } else {
    163     const auto& device = node_state_->at(node).device_name;
    164     ops_lifo_map_[device].AddNode(node);
    165   }
    166 }
    167 
    168 const NodeDef* CompositeNodeManager::GetCurrNode() {
    169   if (curr_node_) return curr_node_;
    170 
    171   // Per-device LIFO for normal ops (not _Send / _Recv),
    172   // FirstReady for _Send and _Recv (separately),
    173   // Globally (among the LIFO-selected ops from each device and _Send and
    174   // _Recv) FirstReady,
    175   // Priorty order: _Send, _Recv, and then the rest, if time_ready is equal.
    176   std::vector<std::pair<const NodeDef*, Costs::Duration>> candidates;
    177   for (auto& ops_lifo : ops_lifo_map_) {
    178     if (!ops_lifo.second.Empty()) {
    179       const auto* op = ops_lifo.second.GetCurrNode();
    180       candidates.emplace_back(op, node_state_->at(op).time_ready);
    181     }
    182   }
    183   if (!send_manager_.Empty()) {
    184     const auto* send = send_manager_.GetCurrNode();
    185     candidates.emplace_back(send, node_state_->at(send).time_ready);
    186   }
    187   if (!recv_manager_.Empty()) {
    188     const auto* recv = recv_manager_.GetCurrNode();
    189     candidates.emplace_back(recv, node_state_->at(recv).time_ready);
    190   }
    191   CHECK(!candidates.empty());
    192   auto first_ready = std::min_element(
    193       candidates.begin(), candidates.end(),
    194       [](const std::pair<const NodeDef*, Costs::Duration>& a,
    195          const std::pair<const NodeDef*, Costs::Duration>& b) {
    196         if (a.second == b.second) {
    197           // Note that there can be only 1 Send and only 1 Recv in candidates,
    198           // at most; hence, score is 2 for Send, 1 for Recv, and 0 for a
    199           // normap op, and a_score and b_score are equal only if both are
    200           // normal ops.
    201           int a_score = 2 * IsSend(*a.first) + IsRecv(*a.first);
    202           int b_score = 2 * IsSend(*b.first) + IsRecv(*b.first);
    203           if (a_score == b_score) {
    204             // Both are normal ops; use node name as tie breaker.
    205             return a.first->name().compare(b.first->name()) < 0;
    206           } else {
    207             // Priortize by op type: _Send, _Recv, and normap ops.
    208             return a_score > b_score;
    209           }
    210         } else {
    211           return a.second < b.second;
    212         }
    213       });
    214   // Next time we call GetCurrNode(), it just returns the cached one,
    215   // curr_node_ until we call RemovCurrNode().
    216   curr_node_ = first_ready->first;
    217 
    218   return curr_node_;
    219 }
    220 
    221 void CompositeNodeManager::RemoveCurrNode() {
    222   const auto* node = GetCurrNode();
    223   if (IsSend(*node)) {
    224     send_manager_.RemoveCurrNode();
    225   } else if (IsRecv(*node)) {
    226     recv_manager_.RemoveCurrNode();
    227   } else {
    228     const auto device = node_state_->at(node).device_name;
    229     ops_lifo_map_[device].RemoveCurrNode();
    230   }
    231   // Reset curr_node_ so that GetCurrNode() finds another node.
    232   curr_node_ = nullptr;
    233 }
    234 
    235 bool CompositeNodeManager::Empty() const {
    236   // Empty if all the ready managers are empty.
    237   bool empty = true;
    238   for (const auto& ops_lifo : ops_lifo_map_) {
    239     empty &= ops_lifo.second.Empty();
    240   }
    241   return empty && send_manager_.Empty() && recv_manager_.Empty();
    242 }
    243 
    244 std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
    245     const string& ready_node_manager) {
    246   if (ready_node_manager == "FIFO") {
    247     return absl::make_unique<FIFOManager>();
    248   } else if (ready_node_manager == "LIFO") {
    249     return absl::make_unique<LIFOManager>();
    250   } else if (ready_node_manager == "FirstReady") {
    251     return absl::make_unique<FirstReadyManager>();
    252   } else if (ready_node_manager == "Composite") {
    253     return absl::make_unique<CompositeNodeManager>();
    254   }
    255   LOG(FATAL) << "Not a valid ready node manager: " << ready_node_manager;
    256   return nullptr;
    257 }
    258 
    259 VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
    260                                    const bool use_aggressive_shape_inference,
    261                                    Cluster* cluster,
    262                                    ReadyNodeManager* ready_nodes)
    263     : ready_nodes_(ready_nodes),
    264       graph_costs_(Costs::ZeroCosts()),
    265       cluster_(cluster),
    266       use_static_shapes_(use_static_shapes),
    267       use_aggressive_shape_inference_(use_aggressive_shape_inference),
    268       placer_(cluster) {
    269   graph_costs_.num_ops_total = 0;
    270   initialized_ = false;
    271   track_mem_usage_snapshot_ = VLOG_IS_ON(1);
    272 }
    273 
    274 Status VirtualScheduler::Init(const GrapplerItem* item) {
    275   grappler_item_ = item;
    276   graph_properties_ = absl::make_unique<GraphProperties>(*item);
    277 
    278   initialized_ = false;
    279 
    280   // Clear all internal states so that the VirtualScheduler is reusable for
    281   // different GrapplerItems
    282   node_map_.clear();
    283   device_.clear();
    284   additional_nodes_.clear();
    285 
    286   graph_costs_ = Costs::ZeroCosts();
    287   graph_costs_.num_ops_total = 0;
    288   op_to_cost_.clear();
    289 
    290   op_counts_.clear();
    291   op_costs_.clear();
    292 
    293   // Init() preprocesses the input grappler_item and graph_properties to extract
    294   // necessary information for emulating tensorflow op scheduling and
    295   // construct internal data structures (NodeState and DeviceState) for virtual
    296   // scheduling.
    297   ready_nodes_->Init(GetNodeStates());
    298 
    299   // Construct graph properties.
    300   if (use_static_shapes_) {
    301     TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
    302         true, use_aggressive_shape_inference_));
    303   } else {
    304     TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_));
    305   }
    306 
    307   const auto& graph = grappler_item_->graph;
    308   const auto& fetch_nodes = grappler_item_->fetch;
    309   std::set<string> feed_nodes;
    310   for (const auto& f : grappler_item_->feed) {
    311     auto iter_and_inserted_flag = feed_nodes.insert(f.first);
    312     QCHECK(iter_and_inserted_flag.second)
    313         << "Duplicate feed node found: " << f.first;
    314   }
    315 
    316   // Get the nodes that would run to output fetch_nodes.
    317   bool ill_formed = false;
    318   const std::vector<const NodeDef*> fetch_fanin_nodes =
    319       ComputeTransitiveFanin(graph, fetch_nodes, &ill_formed);
    320   if (ill_formed) {
    321     return errors::InvalidArgument(
    322         "Ill formed graph or invalid set of fetch nodes specified");
    323   }
    324 
    325   // TODO(dyoon): this is a bit inefficient as name_to_node is already built in
    326   // ComputeTransitiveFanin().
    327   // Once ComputeTransitiveFanin is complete, only the nodes that can be reached
    328   // from the fetch nodes are scheduled. So the scheduled nodes should be
    329   // exactly the same as those executed for real. One possible discrepancy could
    330   // be the control flow nodes, where tf only executes one path.
    331   std::unordered_map<string, const NodeDef*> name_to_node;
    332   for (const auto& node : fetch_fanin_nodes) {
    333     name_to_node[node->name()] = node;
    334   }
    335 
    336   // Traverses the graph to record _Send nodes.
    337   // TODO(dyoon): Instead of identifying _Send node here manually, add _Send
    338   // to _Recv as control dependency when creating GrapplerItem.
    339   std::unordered_map<string, const NodeDef*> name_to_send;
    340   for (const auto& node : graph.node()) {
    341     if (IsSend(node)) {
    342       const auto& attr = node.attr();
    343       name_to_send[attr.at("tensor_name").s()] = &node;
    344     }
    345   }
    346 
    347   // To reuse _Recv ops.
    348   std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescriptorHash,
    349                      RecvNodeDescriptorEqual>
    350       cached_recv_nodes;
    351 
    352   // Build node_map; for each node, create its NodeState and connect its inputs
    353   // and outputs.
    354   for (const auto* curr_node : fetch_fanin_nodes) {
    355     auto& curr_node_state = GetNodeStateOrCreateIt(curr_node);
    356     const string curr_node_device = DeviceName(curr_node);
    357     std::vector<string> inputs;
    358     if (IsRecv(*curr_node)) {
    359       const auto& attr = curr_node->attr();
    360       if (attr.count("tensor_name")) {
    361         const auto& send_node_name = attr.at("tensor_name").s();
    362         auto it = name_to_send.find(send_node_name);
    363         // If there is a _Send associated with the curr_node (_Recv), add it as
    364         // input.
    365         if (it != name_to_send.end()) {
    366           const NodeDef* send = it->second;
    367           inputs = {send->name()};
    368         }
    369       }
    370     } else {
    371       for (const string& input : curr_node->input()) {
    372         inputs.push_back(input);
    373       }
    374     }
    375     for (const string& input_node_name : inputs) {
    376       // Note that input_node_name may be in <prefix><node_name>:<port_num>
    377       // format, where <prefix> (e.g., "^" for control dependency) and
    378       // ":<port_num>" may be omitted. NodeName() extracts only the node_name.
    379       const NodeDef* input_node = name_to_node[NodeName(input_node_name)];
    380 
    381       CHECK(input_node);
    382       const string in_device = DeviceName(input_node);
    383       const auto input_node_port_num = NodePosition(input_node_name);
    384 
    385       if (curr_node_device == in_device) {
    386         // Same device: connect input_node and curr_node directly.
    387         curr_node_state.inputs.push_back(
    388             std::make_pair(input_node, input_node_port_num));
    389         auto& input_node_state = GetNodeStateOrCreateIt(input_node);
    390         input_node_state.outputs[input_node_port_num].push_back(curr_node);
    391       } else {
    392         RecvNodeDescriptor recv_node(input_node, input_node_port_num,
    393                                      curr_node_device);
    394         auto it = cached_recv_nodes.find(recv_node);
    395         if (it != cached_recv_nodes.end()) {
    396           // Different device, but found an already-cached copy (a _Recv op);
    397           // connect the _Recv to curr_node.
    398           const NodeDef* recv_op = it->second;
    399           // recv_op's output port is hard-coded to zero.
    400           curr_node_state.inputs.push_back(std::make_pair(recv_op, 0));
    401           auto& input_node_state = node_map_.at(recv_op);
    402           input_node_state.outputs[0].push_back(curr_node);
    403         } else {
    404           // Different device, no cached copy; transfer input_node to the
    405           // curr_node's device.
    406           auto send_and_recv = CreateSendRecv(input_node, curr_node, input_node,
    407                                               input_node_name);
    408           // Note that CreateSendRecv() already connected input/output between
    409           // _Send and _Recv ops.
    410           const auto* send = send_and_recv.first;
    411           const auto* recv = send_and_recv.second;
    412           // recv_op's output port is hard-coded to zero.
    413           curr_node_state.inputs.push_back(std::make_pair(recv, 0));
    414           auto& input_node_state = GetNodeStateOrCreateIt(input_node);
    415           input_node_state.outputs[input_node_port_num].push_back(send);
    416 
    417           // Cache the _Recv op for future use.
    418           cached_recv_nodes[recv_node] = recv;
    419         }
    420       }
    421     }
    422 
    423     // Special case: given feed nodes are ready at time 0.
    424     const bool given_as_feed =
    425         feed_nodes.find(curr_node->name()) != feed_nodes.end();
    426 
    427     // Default case: node without inputs are ready at time 0.
    428     // Note that we check inputs vector which may be different to
    429     // curr_node->input(); e.g., we add Send as input to Recv.
    430     const bool has_no_inputs = inputs.empty();
    431 
    432     if (given_as_feed || has_no_inputs) {
    433       curr_node_state.time_ready = Costs::Duration();
    434       ready_nodes_->AddNode(curr_node);
    435       VLOG(3) << "Added ready node: " << curr_node->name();
    436     }
    437 
    438     feed_nodes.erase(curr_node->name());
    439 
    440     if (IsPersistentNode(curr_node)) {
    441       auto& device_state = device_[curr_node_device];
    442       for (int port_num = 0;
    443            port_num < curr_node_state.output_properties.size(); ++port_num) {
    444         device_state.persistent_nodes.insert(
    445             std::make_pair(curr_node, port_num));
    446       }
    447     }
    448   }
    449 
    450   if (ready_nodes_->Empty()) {
    451     return errors::InvalidArgument("No ready nodes in the graph.");
    452   }
    453 
    454   if (!feed_nodes.empty()) {
    455     // This isn't always a bug: when the caller hasn't specified the exact list
    456     // of feed and fetch nodes, by default we consider all placeholders as feed
    457     // nodes, but some of them may not be needed for the default fetch node.
    458     VLOG(1) << "Some feed nodes were not consumed by the fetch fanin: "
    459             << str_util::Join(feed_nodes, ",");
    460   }
    461 
    462   initialized_ = true;
    463   return Status::OK();
    464 }
    465 
    466 void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) {
    467   CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init().";
    468   // This method is called when NodeState is created and adds input and output
    469   // properties for a few exceptional cases that GraphProperties cannot provide
    470   // input/output properties.
    471   if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) {
    472     // _Send and _Recv ops created from VirtualScheduler have kAttrInputSrc
    473     // attr; normal _Send and _Recv ops (from the input graph) do not have that
    474     // attr.
    475     auto& node_state = node_map_[node];
    476     auto& inputs = node_state.input_properties;
    477     auto& outputs = node_state.output_properties;
    478 
    479     // _Send and _Recv ops are created from VirtualScheduler, so
    480     // there should be no inputs TensorProperties.
    481     CHECK(inputs.empty());
    482     CHECK(outputs.empty());
    483     const auto& attr = node->attr();
    484     // This is the original input source to the _Send and _Recv, and this
    485     // string includes "^" if it was control dependency, and output port
    486     /// (e.g., ":2") if the input source had multiple outputs.
    487     const auto& input_source_name = attr.at(kAttrInputSrc).s();
    488     if (IsControlInput(input_source_name)) {
    489       // Control dependency; regardless of the input source tensor size,
    490       // send 4B.
    491       OpInfo::TensorProperties control_message;
    492       control_message.set_dtype(DT_FLOAT);
    493       control_message.mutable_shape()->add_dim()->set_size(1);
    494       auto* value = control_message.mutable_value();
    495       value->add_float_val(1);
    496       inputs.push_back(control_message);
    497       outputs.push_back(control_message);
    498     } else {
    499       const auto& output_properties =
    500           graph_properties_->GetOutputProperties(NodeName(input_source_name));
    501       // Like with HasInputProperties, if a node does not have output
    502       // properties, it's likely it was pruned during the shape inference run.
    503       if (!output_properties.empty()) {
    504         const auto input_node_port_num = NodePosition(input_source_name);
    505         // Use the input source's output property as _Send and _Recv's input
    506         // property.
    507         CHECK_GT(output_properties.size(), input_node_port_num);
    508         inputs.push_back(output_properties[input_node_port_num]);
    509         outputs.push_back(output_properties[input_node_port_num]);
    510       }
    511     }
    512   }
    513 }
    514 
    515 float VirtualScheduler::Round2(const float x) const {
    516   // Not using std::round from <cmath> here because not all platforms seem to
    517   // support that (specifically Android).
    518   return ::round(100.0 * x) / 100.0;
    519 }
    520 
    521 bool VirtualScheduler::IsPersistentNode(const NodeDef* node) const {
    522   // Variables are persistent nodes.
    523   return IsVariable(*node);
    524 }
    525 
    526 string VirtualScheduler::DeviceName(const NodeDef* node) const {
    527   return placer_.get_canonical_device_name(*node);
    528 }
    529 
    530 string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const {
    531   // Replace the ":" characters that may be present in the device name with "_".
    532   // This makes it possible to then use the resulting string in a node name.
    533   return str_util::StringReplace(placer_.get_canonical_device_name(*node), ":",
    534                                  "_", true);
    535 }
    536 
    537 string VirtualScheduler::ChannelDeviceName(const NodeDef* from,
    538                                            const NodeDef* to) const {
    539   CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
    540   return kChannelDevice + "_from_" + SanitizedDeviceName(from) + "_to_" +
    541          SanitizedDeviceName(to);
    542 }
    543 
    544 std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
    545     const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
    546     const string& input_name) {
    547   CHECK(!initialized_) << "CreateSendRecv is called after Init().";
    548 
    549   // Connect "from" node to "to" node with _Send and _Recv such that
    550   // from -> _Send -> _Recv -> to.
    551   // _Send is placed on "Channel" device, and _Recv is on the same device
    552   // as "to" node.
    553   // input_node_name is the string from the "to" node to identify which output
    554   // we get from the "from" node.
    555 
    556   // Note that we use NodeState for scheduling, so _Send and _Recv
    557   // NodeDefs created here need not be correct: in terms of name,
    558   // input names, attrs, etc.
    559 
    560   auto input_node_port_num = NodePosition(input_name);
    561   string src_name;
    562   if (input_node_port_num >= 0) {
    563     src_name = strings::StrCat(from->name(), "_", input_node_port_num);
    564   } else {
    565     src_name = strings::StrCat(from->name(), "_minus1");
    566   }
    567 
    568   // _Send op.
    569   auto* send = new NodeDef();
    570   send->set_name("Send_" + src_name + "_from_" + SanitizedDeviceName(from) +
    571                  "_to_" + SanitizedDeviceName(to));
    572   send->set_op("_Send");
    573   send->add_input(from->name());
    574   send->set_device(ChannelDeviceName(from, to));
    575   auto& send_attr = *(send->mutable_attr());
    576   send_attr[kAttrInputSrc].set_s(input_name);
    577   send_attr[kAttrSrcDevice].set_s(DeviceName(from));
    578   send_attr[kAttrDstDevice].set_s(DeviceName(to));
    579   // GraphDef generated by AutoGrappler has tensor_name field when removing
    580   // _Send/_Recv nodes.
    581   if (input_node->attr().count(kAttrTensorName)) {
    582     send_attr[kAttrTensorName].set_s(
    583         input_node->attr().at(kAttrTensorName).s());
    584   }
    585 
    586   // _Recv op.
    587   auto* recv = new NodeDef();
    588   recv->set_name("Recv_" + src_name + "_on_" + SanitizedDeviceName(to));
    589   recv->set_op("_Recv");
    590   recv->add_input(send->name());
    591   recv->set_device(DeviceName(to));
    592   auto& recv_attr = *(recv->mutable_attr());
    593   recv_attr[kAttrInputSrc].set_s(input_name);
    594   if (input_node->attr().count(kAttrTensorName)) {
    595     recv_attr[kAttrTensorName].set_s(
    596         input_node->attr().at(kAttrTensorName).s());
    597   }
    598 
    599   // NodeState for _Send op.
    600   auto& send_node_state = GetNodeStateOrCreateIt(send);
    601   send_node_state.device_name = send->device();  // Set Channel device.
    602   send_node_state.inputs.push_back(std::make_pair(from, input_node_port_num));
    603   send_node_state.outputs[0].push_back(recv);
    604 
    605   // NodeState for _Recv op.
    606   auto& recv_node_state = GetNodeStateOrCreateIt(recv);
    607   recv_node_state.inputs.push_back(std::make_pair(send, 0));
    608   recv_node_state.outputs[0].push_back(to);
    609 
    610   // Keep the created nodes.
    611   additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(send));
    612   additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(recv));
    613 
    614   // Return _Send and _Recv.
    615   return std::make_pair(send, recv);
    616 }
    617 
    618 OpContext VirtualScheduler::GetCurrNode() const {
    619   const NodeDef* node = ready_nodes_->GetCurrNode();
    620 
    621   // Get the device from the placer.
    622   DeviceProperties device;
    623   device = placer_.get_device(*node);
    624 
    625   // Special case for _Send op.
    626   if (IsSend(*node)) {
    627     device.set_type(kChannelDevice);
    628   }
    629 
    630   // Construct OpContext.
    631   OpContext op_context;
    632   const auto& node_state = node_map_.at(node);
    633   op_context.name = node->name();
    634   op_context.device_name = node_state.device_name;
    635   auto& op_info = op_context.op_info;
    636   op_info.set_op(node->op());
    637   *op_info.mutable_attr() = node->attr();
    638   for (auto& input : node_state.input_properties) {
    639     *op_info.add_inputs() = input;
    640   }
    641   for (auto& output : node_state.output_properties) {
    642     *op_info.add_outputs() = output;
    643   }
    644   op_info.mutable_device()->Swap(&device);
    645 
    646   if (grappler_item_->graph.has_library()) {
    647     op_context.function_library = &grappler_item_->graph.library();
    648   }
    649   return op_context;
    650 }
    651 
    652 NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
    653   CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
    654 
    655   auto it = node_map_.find(node);
    656   if (it != node_map_.end()) {
    657     return it->second;
    658   }
    659 
    660   // Not found; create a NodeState for this node.
    661   it = node_map_.emplace(node, NodeState()).first;
    662   auto& node_state = it->second;
    663   node_state.input_properties =
    664       graph_properties_->GetInputProperties(node->name());
    665   node_state.output_properties =
    666       graph_properties_->GetOutputProperties(node->name());
    667 
    668   // Some ops may need further processing to the input / output properties:
    669   // _Send and _Recv.
    670   MaybeUpdateInputOutput(node);
    671 
    672   if (!IsSend(*node)) {
    673     node_state.device_name = DeviceName(node);
    674     // For _Send op, device_name will be set to Channel in CreateSendRecv().
    675   }
    676 
    677   // Initialize output port related data:
    678   // Assume the size of OutputProperties represents the number of output ports
    679   // of this node.
    680   for (size_t i = 0; i < node_state.output_properties.size(); ++i) {
    681     node_state.time_no_references[i] = Costs::Duration::max();
    682     node_state.num_outputs_executed[i] = 0;
    683     // Populate an empty vector for each port. The caller will add nodes
    684     // that use this port as input.
    685     node_state.outputs[i] = {};
    686   }
    687   // Port_num -1 is for control dependency.
    688   node_state.time_no_references[-1] = Costs::Duration::max();
    689   node_state.num_outputs_executed[-1] = 0;
    690   node_state.outputs[-1] = {};
    691 
    692   return it->second;
    693 }
    694 
    695 Costs& VirtualScheduler::FindOrCreateZero(const string& op_name,
    696                                           std::map<string, Costs>* op_cost) {
    697   auto it = op_cost->find(op_name);
    698   if (it == op_cost->end()) {
    699     // Note that default constructor of Costs sets some memory related fields
    700     // to unknown values so we should explicitly initialize it with ZeroCosts.
    701     it = op_cost->emplace(op_name, Costs::ZeroCosts()).first;
    702   }
    703   return it->second;
    704 }
    705 
    706 void VirtualScheduler::AddOutputNodesToReadyQueue(
    707     const NodeDef* node, const Costs::Duration& curr_time) {
    708   // Checks whether the Switch's output slots change over iterations.
    709   int slot = -1;
    710   if (IsSwitch(*node) && node->attr().count(kOutputSlots) > 0 &&
    711       node->attr().at(kOutputSlots).list().i_size() > 0) {
    712     slot = node->attr().at(kOutputSlots).list().i(0);
    713     for (int i = 1; i < node->attr().at(kOutputSlots).list().i_size(); ++i) {
    714       if (slot != node->attr().at(kOutputSlots).list().i(i)) {
    715         slot = -1;
    716         break;
    717       }
    718     }
    719   }
    720 
    721   // Increment num_inputs_ready of the output nodes and maybe add to ready
    722   // nodes.
    723   auto& node_state = node_map_[node];
    724   for (const auto& port_num_output_pair : node_state.outputs) {
    725     // If Switch is annotated and its output slots are always the same, we only
    726     // schedule the slot that was executed. Otherwise, scheduler both slots.
    727     if (slot >= 0 && port_num_output_pair.first != slot) continue;
    728 
    729     for (auto* output_node : port_num_output_pair.second) {
    730       auto& output_state = node_map_[output_node];
    731       output_state.num_inputs_ready++;
    732       // Execute a node as soon as all its inputs are ready. Merge nodes are
    733       // special since they run as soon as one of their inputs becomes
    734       // available.
    735       if (output_state.num_inputs_ready == output_state.inputs.size() ||
    736           IsMerge(*output_node)) {
    737         // This output node is now ready.
    738         output_state.time_ready = curr_time;
    739         ready_nodes_->AddNode(output_node);
    740         VLOG(3) << "  Add output: " << output_node->name();
    741       }
    742     }
    743   }
    744 }
    745 
    746 bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
    747   // Update graph_costs_ and per-op costs.
    748   const NodeDef* node = ready_nodes_->GetCurrNode();
    749   auto& node_state = node_map_[node];
    750   // If there is annotation in the graph about execution times, we use that
    751   // number, otherwise, we assume the node is executed once.
    752   node_state.execution_count = node->attr().count(kExecutionCount) == 0
    753                                    ? 1
    754                                    : node->attr().at(kExecutionCount).i();
    755   Costs total_node_costs =
    756       MultiplyCosts(node_costs, node_state.execution_count);
    757   graph_costs_ = CombineCosts(graph_costs_, total_node_costs);
    758   const string& op_name = node->op();
    759 
    760   auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_);
    761   op_cost = CombineCosts(op_cost, total_node_costs);
    762 
    763   if (VLOG_IS_ON(2)) {
    764     // Also keep track of op counts and costs per op (with their shapes).
    765     OpContext op_context = GetCurrNode();
    766 
    767     string node_description = GetOpDescription(op_context.op_info);
    768     op_counts_[node_description] += 1;
    769     op_costs_[node_description] =
    770         std::make_pair(node_costs.execution_time.asMicroSeconds().count(),
    771                        !node_costs.inaccurate);
    772   }
    773 
    774   // Update node and device states.
    775   auto& device = device_[node_state.device_name];
    776   device.nodes_executed.push_back(node);
    777   // Node is scheduled when the device is available AND all the inputs are
    778   // ready; hence, time_scheduled is time_ready if time_ready > device curr
    779   // time.
    780   node_state.time_scheduled =
    781       std::max(device.GetCurrTime(), node_state.time_ready);
    782   // Override device curr time with the time_scheduled.
    783   device.device_costs.execution_time = node_state.time_scheduled;
    784   device.device_costs = CombineCosts(device.device_costs, total_node_costs);
    785   auto curr_time = device.GetCurrTime();
    786   node_state.time_finished = curr_time;
    787 
    788   // Update device memory usage.
    789   if (!IsPersistentNode(node)) {
    790     for (const auto& port_num_output_pair : node_state.outputs) {
    791       int port_num = port_num_output_pair.first;
    792       // There's a chance that a specific output is not used at all.
    793       if (node_state.outputs[port_num].empty()) {
    794         node_state.time_no_references[port_num] = curr_time;
    795       } else {
    796         device.memory_usage +=
    797             CalculateOutputSize(node_state.output_properties, port_num) *
    798             node_state.execution_count;
    799         device.nodes_in_memory.insert(std::make_pair(node, port_num));
    800       }
    801     }
    802   }
    803 
    804   // Update device's per-op cost.
    805   auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost);
    806   device_op_cost = CombineCosts(device_op_cost, total_node_costs);
    807 
    808   VLOG(3) << "Op scheduled -- name: " << node->name() << ", op: " << node->op()
    809           << ", device: " << node->device()
    810           << ", execution_count: " << node_state.execution_count
    811           << ", ready: " << node_state.time_ready.count()
    812           << ", scheduled: " << node_state.time_scheduled.count()
    813           << ", finished: " << node_state.time_finished.count();
    814 
    815   // Checks outputs, and adds ready nodes to queue.
    816   AddOutputNodesToReadyQueue(node, curr_time);
    817 
    818   // Increment num_outputs_executed of the input nodes and maybe update memory.
    819   for (const auto& input_port : node_state.inputs) {
    820     auto* input = input_port.first;
    821     auto port = input_port.second;
    822     auto& input_state = node_map_[input];
    823     input_state.num_outputs_executed[port]++;
    824     if (input_state.num_outputs_executed[port] ==
    825             input_state.outputs[port].size() &&
    826         !IsPersistentNode(input)) {
    827       // All the outputs are executed; no reference to this output port of
    828       // input node.
    829       input_state.time_no_references[port] = curr_time;
    830       auto& input_device = device_[input_state.device_name];
    831       input_device.memory_usage -=
    832           CalculateOutputSize(input_state.output_properties, port) *
    833           node_state.execution_count;
    834 
    835       input_device.nodes_in_memory.erase(std::make_pair(input, port));
    836     }
    837   }
    838 
    839   if (!IsPersistentNode(node)) {
    840     // Now that output memory is added and used up nodes are deallocated,
    841     // check max memory usage.
    842     if (device.memory_usage > device.max_memory_usage) {
    843       device.max_memory_usage = device.memory_usage;
    844 
    845       if (track_mem_usage_snapshot_) {
    846         device.mem_usage_snapshot_at_peak = device.nodes_in_memory;
    847       }
    848     }
    849   }
    850 
    851   ready_nodes_->RemoveCurrNode();
    852 
    853   return !ready_nodes_->Empty();
    854 }
    855 
    856 Costs VirtualScheduler::Summary() const {
    857   // Overall statement about accuracy
    858   VLOG(1) << graph_costs_.num_ops_total << " ops processed in total, with "
    859           << graph_costs_.num_ops_with_unknown_shapes
    860           << " having unknown shapes";
    861 
    862   // Print out basic execution summary.
    863   VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count();
    864   VLOG(1) << "Expected compute time: " << graph_costs_.compute_time.count();
    865   VLOG(1) << "Expected memory time: " << graph_costs_.memory_time.count();
    866   VLOG(1) << "Expected intermediate memory time: "
    867           << graph_costs_.intermediate_memory_time.count();
    868   VLOG(1) << "Expected max memory: " << graph_costs_.max_memory;
    869   VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers;
    870   VLOG(1) << "Expected max per-op streaming buffers: "
    871           << graph_costs_.max_per_op_streaming;
    872 
    873   VLOG(1) << "Per-op execution time / compute time / memory time"
    874           << " / intermediate memory time:";
    875   for (const auto& op_cost_pair : op_to_cost_) {
    876     const auto& op = op_cost_pair.first;
    877     const auto& cost = op_cost_pair.second.execution_time.count();
    878     const auto& compute_cost = op_cost_pair.second.compute_time.count();
    879     const auto& memory_cost = op_cost_pair.second.memory_time.count();
    880     const auto& intermediate_memory_cost =
    881         op_cost_pair.second.intermediate_memory_time.count();
    882     const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
    883     if (cost) {  // Skip printing out zero-cost ops.
    884       VLOG(1) << strings::Printf(
    885           " + %30s : %c %10lld / %10lld / %10lld / %10lld", op.c_str(),
    886           (is_op_cost_accurate ? ' ' : '~'), static_cast<int64>(cost),
    887           static_cast<int64>(compute_cost), static_cast<int64>(memory_cost),
    888           static_cast<int64>(intermediate_memory_cost));
    889     }
    890   }
    891 
    892   // Print per device summary
    893   VLOG(1) << "Devices:";
    894   Costs critical_path_costs = Costs::ZeroCosts();
    895   std::vector<string> device_names;
    896   device_names.reserve(device_.size());
    897   for (auto& it : device_) {
    898     device_names.push_back(it.first);
    899   }
    900   std::sort(device_names.begin(), device_names.end());
    901 
    902   for (const auto& name : device_names) {
    903     const auto& state = device_.at(name);
    904 
    905     std::map<string, int64> op_to_memory;
    906     // First profile only persistent memory usage.
    907     int64 persistent_memory_usage = 0;
    908     std::set<string> persisent_ops;
    909     for (const auto& node_port : state.persistent_nodes) {
    910       const auto* node = node_port.first;
    911       const auto port = node_port.second;
    912       const auto output_size =
    913           CalculateOutputSize(node_map_.at(node).output_properties, port);
    914       persistent_memory_usage += output_size;
    915       op_to_memory[node->op()] += output_size;
    916       persisent_ops.insert(node->op());
    917     }
    918     int64 max_memory_usage = persistent_memory_usage + state.max_memory_usage;
    919     critical_path_costs.estimated_max_memory_per_device[name] =
    920         max_memory_usage;
    921 
    922     const Costs::NanoSeconds wall_time_ns = state.GetCurrTime();
    923     VLOG(1) << "Device = " << name
    924             << ", num_nodes = " << state.nodes_executed.size()
    925             << ", wall_time_ns = " << wall_time_ns.count() << ", memory usage: "
    926             << "persistent = "
    927             << strings::HumanReadableNumBytes(persistent_memory_usage)
    928             << ", peak = "
    929             << strings::HumanReadableNumBytes(state.max_memory_usage)
    930             << ", total = " << strings::HumanReadableNumBytes(max_memory_usage)
    931             << ", at the end: "
    932             << strings::HumanReadableNumBytes(state.memory_usage);
    933 
    934     // Overall statement about accuracy
    935     VLOG(1) << state.device_costs.num_ops_total
    936             << " ops processed in total, with "
    937             << state.device_costs.num_ops_with_unknown_shapes
    938             << " having unknown shapes";
    939 
    940     VLOG(1) << "Per-op execution time / compute time / memory time "
    941             << " / intermediate memory time"
    942             << " (and memory usage at peak memory usage):";
    943 
    944     // Profile non-persistent op memory usage.
    945     for (const auto& node_port : state.mem_usage_snapshot_at_peak) {
    946       const auto* node = node_port.first;
    947       const auto port = node_port.second;
    948       op_to_memory[node->op()] +=
    949           CalculateOutputSize(node_map_.at(node).output_properties, port);
    950     }
    951     Costs::NanoSeconds total_compute_time_ns;
    952     bool is_total_cost_accurate = true;
    953     for (const auto& op_cost_pair : state.op_to_cost) {
    954       const auto& op = op_cost_pair.first;
    955       const auto& cost = op_cost_pair.second.execution_time.count();
    956       const auto& compute_cost = op_cost_pair.second.compute_time.count();
    957       const auto& memory_cost = op_cost_pair.second.memory_time.count();
    958       const auto& intermediate_memory_cost =
    959           op_cost_pair.second.intermediate_memory_time.count();
    960       total_compute_time_ns += op_cost_pair.second.execution_time;
    961       const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
    962       if (!is_op_cost_accurate) {
    963         is_total_cost_accurate = false;
    964       }
    965 
    966       int64 op_mem_usage = 0;
    967       auto it = op_to_memory.find(op);
    968       if (it != op_to_memory.end()) {
    969         op_mem_usage = it->second;
    970       }
    971 
    972       const float mem_usage_percent =
    973           max_memory_usage > 0 ? Round2(100.0 * op_mem_usage / max_memory_usage)
    974                                : 0.0;
    975       if (cost || mem_usage_percent > 1.0) {
    976         // Print out only non-zero cost ops or ops with > 1% memory usage.
    977         VLOG(1) << strings::Printf(
    978                        " + %30s : %c %10lld / %10lld / %10lld / %10lld",
    979                        op.c_str(), (is_op_cost_accurate ? ' ' : '~'),
    980                        static_cast<int64>(cost),
    981                        static_cast<int64>(compute_cost),
    982                        static_cast<int64>(memory_cost),
    983                        static_cast<int64>(intermediate_memory_cost))
    984                 << " (" << strings::HumanReadableNumBytes(op_mem_usage) << " ["
    985                 << mem_usage_percent << "%] "
    986                 << (persisent_ops.count(op) > 0 ? ": persistent op)" : ")");
    987       }
    988     }
    989 
    990     int utilization = 0;
    991     if (wall_time_ns.count() > 0) {
    992       utilization = total_compute_time_ns.count() * 100 / wall_time_ns.count();
    993     }
    994     VLOG(1) << "Device = " << name << ", total_compute_time_ns = "
    995             << (is_total_cost_accurate ? "" : "~")
    996             << total_compute_time_ns.count()
    997             << ", utilization = " << utilization << "%";
    998 
    999     if (critical_path_costs.execution_time <= state.GetCurrTime()) {
   1000       critical_path_costs = state.device_costs;
   1001     }
   1002   }
   1003 
   1004   if (VLOG_IS_ON(2)) {
   1005     // Also log the op description and their corresponding counts.
   1006     VLOG(2) << "Node description, counts, cost:";
   1007     for (const auto& item : op_counts_) {
   1008       int cost;
   1009       bool is_cost_accurate;
   1010       std::tie(cost, is_cost_accurate) = op_costs_.at(item.first);
   1011       VLOG(2) << "Node: " << item.first << ", Count: " << item.second
   1012               << ", Individual Cost: " << (is_cost_accurate ? "" : "~") << cost
   1013               << " us";
   1014     }
   1015   }
   1016 
   1017   VLOG(1) << "Critical path execution time: "
   1018           << critical_path_costs.execution_time.count();
   1019   return critical_path_costs;
   1020 }
   1021 
   1022 Costs VirtualScheduler::Summary(RunMetadata* metadata) {
   1023   if (metadata) GenerateRunMetadata(metadata);
   1024   return Summary();
   1025 }
   1026 
   1027 void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) {
   1028   // Fill RunMetadata's step_stats and partition_graphs fields.
   1029   StepStats* stepstats = metadata->mutable_step_stats();
   1030   for (const auto& device : device_) {
   1031     GraphDef* device_partition_graph = metadata->add_partition_graphs();
   1032     DeviceStepStats* device_stepstats = stepstats->add_dev_stats();
   1033     device_stepstats->set_device(device.first);
   1034     for (const auto& node_def : device.second.nodes_executed) {
   1035       const NodeState& nodestate = node_map_.at(node_def);
   1036       NodeExecStats* node_stats = device_stepstats->add_node_stats();
   1037       uint64 total_output_size = 0;
   1038       for (int slot = 0; slot < nodestate.output_properties.size(); slot++) {
   1039         const auto& properties = nodestate.output_properties[slot];
   1040         NodeOutput* no = node_stats->add_output();
   1041         no->set_slot(slot);
   1042         TensorDescription* tensor_descr = no->mutable_tensor_description();
   1043         tensor_descr->set_dtype(properties.dtype());
   1044         *tensor_descr->mutable_shape() = properties.shape();
   1045         // Optional allocation description.
   1046         const auto tensor_size =
   1047             CalculateOutputSize(nodestate.output_properties, slot);
   1048         total_output_size += tensor_size;
   1049         tensor_descr->mutable_allocation_description()->set_requested_bytes(
   1050             tensor_size);
   1051         tensor_descr->mutable_allocation_description()->set_allocated_bytes(
   1052             tensor_size);
   1053       }
   1054       node_stats->set_timeline_label(node_def->op());
   1055       node_stats->set_node_name(node_def->name());
   1056       node_stats->set_op_start_rel_micros(0);
   1057       node_stats->set_all_start_micros(
   1058           nodestate.time_scheduled.asMicroSeconds().count());
   1059       node_stats->set_op_end_rel_micros(
   1060           nodestate.time_finished.asMicroSeconds().count() -
   1061           nodestate.time_scheduled.asMicroSeconds().count());
   1062       node_stats->set_all_end_rel_micros(
   1063           nodestate.time_finished.asMicroSeconds().count() -
   1064           nodestate.time_scheduled.asMicroSeconds().count());
   1065       auto* mem_stats = node_stats->mutable_memory_stats();
   1066       // VirtualScheduler does not specify scratch pad memory usage.
   1067       mem_stats->set_temp_memory_size(0);
   1068       int64 persistent_memory_size = 0;
   1069       if (IsPersistentNode(node_def)) {
   1070         persistent_memory_size = total_output_size;
   1071       }
   1072       mem_stats->set_persistent_memory_size(persistent_memory_size);
   1073       *device_partition_graph->add_node() = *node_def;
   1074     }
   1075   }
   1076 }
   1077 
   1078 const std::unordered_map<string, int64> VirtualScheduler::GetPeakMemoryUsage()
   1079     const {
   1080   std::unordered_map<string, int64> result;
   1081   for (const auto& device : device_) {
   1082     const string& name = device.first;
   1083     const DeviceState& state = device.second;
   1084     result[name] = state.max_memory_usage;
   1085   }
   1086   return result;
   1087 }
   1088 
   1089 }  // end namespace grappler
   1090 }  // end namespace tensorflow
   1091