Home | History | Annotate | Download | only in costs
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
     17 
     18 #include <math.h>
     19 
     20 #include "tensorflow/core/framework/allocation_description.pb.h"
     21 #include "tensorflow/core/framework/attr_value.pb.h"
     22 #include "tensorflow/core/framework/node_def.pb.h"
     23 #include "tensorflow/core/framework/tensor.pb.h"
     24 #include "tensorflow/core/framework/tensor_description.pb.h"
     25 #include "tensorflow/core/framework/tensor_shape.pb.h"
     26 #include "tensorflow/core/grappler/clusters/utils.h"
     27 #include "tensorflow/core/grappler/costs/utils.h"
     28 #include "tensorflow/core/grappler/op_types.h"
     29 #include "tensorflow/core/grappler/utils.h"
     30 #include "tensorflow/core/lib/core/errors.h"
     31 #include "tensorflow/core/lib/strings/numbers.h"
     32 #include "tensorflow/core/lib/strings/str_util.h"
     33 #include "tensorflow/core/platform/logging.h"
     34 #include "tensorflow/core/util/device_name_utils.h"
     35 
     36 namespace tensorflow {
     37 namespace grappler {
     38 namespace {
     39 
     40 Costs CombineCosts(const Costs& left, const Costs& right) {
     41   CHECK_NE(left.max_memory, kMemoryUnknown);
     42   CHECK_NE(left.max_per_op_buffers, kMemoryUnknown);
     43   CHECK_NE(left.max_per_op_streaming, kMemoryUnknown);
     44 
     45   Costs result = left;
     46   result.execution_time += right.execution_time;
     47   if (right.inaccurate) {
     48     result.inaccurate = true;
     49   }
     50   if (right.max_memory != kMemoryUnknown) {
     51     result.max_memory += right.max_memory;
     52   }
     53   if (right.max_per_op_buffers != kMemoryUnknown) {
     54     result.max_per_op_buffers =
     55         std::max(left.max_per_op_buffers, right.max_per_op_buffers);
     56   }
     57   if (right.max_per_op_streaming != kMemoryUnknown) {
     58     result.max_per_op_streaming =
     59         std::max(left.max_per_op_streaming, right.max_per_op_streaming);
     60   }
     61   VLOG(4) << "costs execution_time=" << result.execution_time.count()
     62           << " max_memory=" << result.max_memory
     63           << " max_per_op_buffers=" << result.max_per_op_buffers
     64           << " max_per_op_streaming=" << result.max_per_op_streaming;
     65   return result;
     66 }
     67 
     68 // Key to the cached _Recv ops map, and its hash and predicate structures.
     69 struct RecvNodeDescriptor {
     70   const NodeDef* node;
     71   const int port_num;
     72   const string device;
     73 
     74   RecvNodeDescriptor(const NodeDef* node_, const int port_num_,
     75                      const string& device_)
     76       : node(node_), port_num(port_num_), device(device_) {}
     77 };
     78 
     79 struct RecvNodeDescriptorHash {
     80   std::size_t operator()(const RecvNodeDescriptor& recv_node) const {
     81     return std::hash<const NodeDef*>()(recv_node.node) ^
     82            std::hash<int>()(recv_node.port_num) ^
     83            std::hash<string>()(recv_node.device);
     84   }
     85 };
     86 
     87 struct RecvNodeDescriptorEqual {
     88   bool operator()(const RecvNodeDescriptor& a,
     89                   const RecvNodeDescriptor& b) const {
     90     return a.node == b.node && a.port_num == b.port_num && a.device == b.device;
     91   }
     92 };
     93 }  // namespace
     94 
     95 // ReadyNodeManager
     96 const NodeDef* LIFOManager::GetCurrNode() {
     97   CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
     98   if (curr_pos_ == nodes_.end()) {
     99     curr_pos_ = --(nodes_.rbegin().base());  // Last one in the list.
    100   }
    101   // Once curr_pos_ is set to a valid entry in the list, we keep using the
    102   // cached curr_pos_ until RemoveCurrNode() is called. AddNode() will not
    103   // change the GetCurrNode() return value.
    104   return *curr_pos_;
    105 }
    106 
    107 void LIFOManager::RemoveCurrNode() {
    108   // Make sure we have curr_pos_ ready to be removed.
    109   GetCurrNode();
    110   // Note curr_pos_ may not be pointing the last element if some nodes are
    111   // added.
    112   nodes_.erase(curr_pos_);
    113 
    114   curr_pos_ = nodes_.end();  // Reset curr_pos_.
    115 }
    116 
    117 FirstReadyManager::FirstReadyManager() : ReadyNodeManager() {
    118   std::make_heap(nodes_.begin(), nodes_.end());
    119 }
    120 
    121 void FirstReadyManager::Init(
    122     const std::unordered_map<const NodeDef*, NodeState>* node_state) {
    123   // Reset the node state since different instances of the scheduler can reuse
    124   // the same node_manager.
    125   node_state_ = node_state;
    126   nodes_.clear();
    127   waiting_queue_.clear();
    128   greater_ = [this](const NodeDef* a, const NodeDef* b) -> bool {
    129     if (node_state_->at(a).time_ready == node_state_->at(b).time_ready) {
    130       // Use Node name as tie-breaker for deterministic node scheduling.
    131       return a->name().compare(b->name()) > 0;
    132     } else {
    133       // Note: we need a node with minimum time_ready, not
    134       // maximum; hence, using a > b for comparison function.
    135       return node_state_->at(a).time_ready > node_state_->at(b).time_ready;
    136     }
    137   };
    138 }
    139 
    140 const NodeDef* FirstReadyManager::GetCurrNode() {
    141   if (nodes_.empty()) {
    142     // Nothing in the node_; probably, the very first call. Move
    143     // waiting_queue_ to node_.
    144     DrainWaitingQueue();
    145     CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
    146   }
    147   return nodes_.front();
    148 }
    149 
    150 void FirstReadyManager::RemoveCurrNode() {
    151   if (nodes_.empty()) {
    152     // Make sure that there is a node to be removed at the front of nodes_.
    153     GetCurrNode();
    154   }
    155   std::pop_heap(nodes_.begin(), nodes_.end(), greater_);
    156   nodes_.pop_back();
    157   DrainWaitingQueue();
    158 }
    159 
    160 bool FirstReadyManager::Empty() const {
    161   return nodes_.empty() && waiting_queue_.empty();
    162 }
    163 
    164 void FirstReadyManager::DrainWaitingQueue() {
    165   for (const auto* node : waiting_queue_) {
    166     // push_heap in AddNode() and pop_heap in RemoveCurrNode() guarantees that
    167     // the first element is the node with minimum time_ready.
    168     nodes_.push_back(node);
    169     std::push_heap(nodes_.begin(), nodes_.end(), greater_);
    170   }
    171   waiting_queue_.clear();
    172 }
    173 
    174 CompositeNodeManager::CompositeNodeManager()
    175     : ReadyNodeManager(), send_manager_(), recv_manager_() {}
    176 
    177 void CompositeNodeManager::Init(
    178     const std::unordered_map<const NodeDef*, NodeState>* node_state) {
    179   node_state_ = node_state;
    180   send_manager_.Init(node_state);
    181   recv_manager_.Init(node_state);
    182   curr_node_ = nullptr;
    183 }
    184 
    185 void CompositeNodeManager::AddNode(const NodeDef* node) {
    186   if (IsSend(*node)) {
    187     send_manager_.AddNode(node);
    188   } else if (IsRecv(*node)) {
    189     recv_manager_.AddNode(node);
    190   } else {
    191     const auto& device = node_state_->at(node).device_name;
    192     ops_lifo_map_[device].AddNode(node);
    193   }
    194 }
    195 
    196 const NodeDef* CompositeNodeManager::GetCurrNode() {
    197   if (curr_node_) return curr_node_;
    198 
    199   // Per-device LIFO for normal ops (not _Send / _Recv),
    200   // FirstReady for _Send and _Recv (separately),
    201   // Globally (among the LIFO-selected ops from each device and _Send and
    202   // _Recv) FirstReady,
    203   // Priorty order: _Send, _Recv, and then the rest, if time_ready is equal.
    204   std::vector<std::pair<const NodeDef*, Costs::Duration>> candidates;
    205   for (auto& ops_lifo : ops_lifo_map_) {
    206     if (!ops_lifo.second.Empty()) {
    207       const auto* op = ops_lifo.second.GetCurrNode();
    208       candidates.emplace_back(op, node_state_->at(op).time_ready);
    209     }
    210   }
    211   if (!send_manager_.Empty()) {
    212     const auto* send = send_manager_.GetCurrNode();
    213     candidates.emplace_back(send, node_state_->at(send).time_ready);
    214   }
    215   if (!recv_manager_.Empty()) {
    216     const auto* recv = recv_manager_.GetCurrNode();
    217     candidates.emplace_back(recv, node_state_->at(recv).time_ready);
    218   }
    219   CHECK(!candidates.empty());
    220   auto first_ready = std::min_element(
    221       candidates.begin(), candidates.end(),
    222       [](const std::pair<const NodeDef*, Costs::Duration>& a,
    223          const std::pair<const NodeDef*, Costs::Duration>& b) {
    224         if (a.second == b.second) {
    225           // Note that there can be only 1 Send and only 1 Recv in candidates,
    226           // at most; hence, score is 2 for Send, 1 for Recv, and 0 for a
    227           // normap op, and a_score and b_score are equal only if both are
    228           // normal ops.
    229           int a_score = 2 * IsSend(*a.first) + IsRecv(*a.first);
    230           int b_score = 2 * IsSend(*b.first) + IsRecv(*b.first);
    231           if (a_score == b_score) {
    232             // Both are normal ops; use node name as tie breaker.
    233             return a.first->name().compare(b.first->name()) < 0;
    234           } else {
    235             // Priortize by op type: _Send, _Recv, and normap ops.
    236             return a_score > b_score;
    237           }
    238         } else {
    239           return a.second < b.second;
    240         }
    241       });
    242   // Next time we call GetCurrNode(), it just returns the cached one,
    243   // curr_node_ until we call RemovCurrNode().
    244   curr_node_ = first_ready->first;
    245 
    246   return curr_node_;
    247 }
    248 
    249 void CompositeNodeManager::RemoveCurrNode() {
    250   const auto* node = GetCurrNode();
    251   if (IsSend(*node)) {
    252     send_manager_.RemoveCurrNode();
    253   } else if (IsRecv(*node)) {
    254     recv_manager_.RemoveCurrNode();
    255   } else {
    256     const auto device = node_state_->at(node).device_name;
    257     ops_lifo_map_[device].RemoveCurrNode();
    258   }
    259   // Reset curr_node_ so that GetCurrNode() finds another node.
    260   curr_node_ = nullptr;
    261 }
    262 
    263 bool CompositeNodeManager::Empty() const {
    264   // Empty if all the ready managers are empty.
    265   bool empty = true;
    266   for (const auto& ops_lifo : ops_lifo_map_) {
    267     empty &= ops_lifo.second.Empty();
    268   }
    269   return empty && send_manager_.Empty() && recv_manager_.Empty();
    270 }
    271 
    272 VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item,
    273                                    const bool use_static_shapes,
    274                                    Cluster* cluster,
    275                                    ReadyNodeManager* ready_nodes)
    276     : ready_nodes_(ready_nodes),
    277       graph_costs_(Costs::ZeroCosts()),
    278       graph_properties_(*grappler_item),
    279       cluster_(cluster),
    280       grappler_item_(grappler_item),
    281       use_static_shapes_(use_static_shapes),
    282       placer_(cluster) {
    283   initialized_ = false;
    284 }
    285 
    286 ReadyNodeManager* VirtualScheduler::ReadyNodeManagerFactory(
    287     const string& ready_node_manager) {
    288   if (ready_node_manager == "FIFO") {
    289     return new FIFOManager();
    290   } else if (ready_node_manager == "LIFO") {
    291     return new LIFOManager();
    292   } else if (ready_node_manager == "FirstReady") {
    293     return new FirstReadyManager();
    294   } else if (ready_node_manager == "Composite") {
    295     return new CompositeNodeManager();
    296   }
    297   LOG(FATAL) << "Not a valid ready node manager: " << ready_node_manager;
    298 }
    299 
    300 Status VirtualScheduler::Init() {
    301   // Init() preprocesses the input grappler_item and graph_properties to extract
    302   // necessary information for emulating tensorflow op scheduling and
    303   // construct internal data structures (NodeState and DeviceState) for virtual
    304   // scheduling.
    305   ready_nodes_->Init(GetNodeStates());
    306   // Construct graph properties.
    307   Status status;
    308   if (use_static_shapes_) {
    309     status = graph_properties_.InferStatically(true);
    310   } else {
    311     status = graph_properties_.InferDynamically(cluster_);
    312   }
    313   if (!status.ok()) {
    314     return status;
    315   }
    316 
    317   const auto& graph = grappler_item_->graph;
    318   const auto& fetch_nodes = grappler_item_->fetch;
    319   std::set<string> feed_nodes;
    320   for (const auto& f : grappler_item_->feed) {
    321     auto iter_and_inserted_flag = feed_nodes.insert(f.first);
    322     QCHECK(iter_and_inserted_flag.second)
    323         << "Duplicate feed node found: " << f.first;
    324   }
    325 
    326   // Get the nodes that would run to output fetch_nodes.
    327   bool ill_formed = false;
    328   std::vector<const NodeDef*> nodes =
    329       ComputeTransitiveFanin(graph, fetch_nodes, &ill_formed);
    330   if (ill_formed) {
    331     return errors::InvalidArgument(
    332         "Ill formed graph or invalid set of fetch nodes specified");
    333   }
    334 
    335   // TODO(dyoon): this is a bit inefficient as name_to_node is already built in
    336   // ComputeTransitiveFanin().
    337   // Once ComputeTransitiveFanin is complete, only the nodes that can be reached
    338   // from the fetch nodes are scheduled. So the scheduled nodes should be
    339   // exactly the same as those executed for real. One possible discrepancy could
    340   // be the control flow nodes, where tf only executes one path.
    341   std::unordered_map<string, const NodeDef*> name_to_node;
    342   for (const auto& node : nodes) {
    343     name_to_node[node->name()] = node;
    344   }
    345 
    346   // TODO(dyoon): Instead of identifying _Send node here manually, add _Send
    347   // to _Recv as control dependency when creating GrapplerItem.
    348   std::unordered_map<string, const NodeDef*> name_to_send;
    349   for (const auto& node : graph.node()) {
    350     if (IsSend(node)) {
    351       const auto& attr = node.attr();
    352       name_to_send[attr.at("tensor_name").s()] = &node;
    353     }
    354   }
    355 
    356   // To reuse _Recv ops.
    357   std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescriptorHash,
    358                      RecvNodeDescriptorEqual>
    359       cached_recv_nodes;
    360 
    361   // Build node_map; for each node, create its NodeState and connect its inputs
    362   // and outputs.
    363   for (const auto* curr_node : nodes) {
    364     auto& curr_node_state = GetNodeStateOrCreateIt(curr_node);
    365     const string curr_node_device = DeviceName(curr_node);
    366     std::vector<string> inputs;
    367     if (IsRecv(*curr_node)) {
    368       const auto& attr = curr_node->attr();
    369       const NodeDef* send = name_to_send[attr.at("tensor_name").s()];
    370       inputs = {send->name()};
    371     } else {
    372       for (const string& input : curr_node->input()) {
    373         inputs.push_back(input);
    374       }
    375     }
    376     for (const string& input_node_name : inputs) {
    377       // Note that input_node_name may be in <prefix><node_name>:<port_num>
    378       // format, where <prefix> (e.g., "^" for control dependency) and
    379       // ":<port_num>" may be omitted. NodeName() extracts only the node_name.
    380       const NodeDef* input_node = name_to_node[NodeName(input_node_name)];
    381 
    382       CHECK(input_node);
    383       const string in_device = DeviceName(input_node);
    384       const auto input_node_port_num = NodePosition(input_node_name);
    385 
    386       if (curr_node_device == in_device) {
    387         // Same device: connect input_node and curr_node directly.
    388         curr_node_state.inputs.push_back(
    389             std::make_pair(input_node, input_node_port_num));
    390         auto& input_node_state = GetNodeStateOrCreateIt(input_node);
    391         input_node_state.outputs[input_node_port_num].push_back(curr_node);
    392       } else {
    393         RecvNodeDescriptor recv_node(input_node, input_node_port_num,
    394                                      curr_node_device);
    395         auto it = cached_recv_nodes.find(recv_node);
    396         if (it != cached_recv_nodes.end()) {
    397           // Different device, but found an already-cached copy (a _Recv op);
    398           // connect the _Recv to curr_node.
    399           const NodeDef* recv_op = it->second;
    400           // recv_op's output port is hard-coded to zero.
    401           curr_node_state.inputs.push_back(std::make_pair(recv_op, 0));
    402           auto& input_node_state = node_map_.at(recv_op);
    403           input_node_state.outputs[0].push_back(curr_node);
    404         } else {
    405           // Different device, no cached copy; transfer input_node to the
    406           // curr_node's device.
    407           auto send_and_recv =
    408               CreateSendRecv(input_node, curr_node, input_node_name);
    409           // Note that CreateSendRecv() already connected input/output between
    410           // _Send and _Recv ops.
    411           const auto* send = send_and_recv.first;
    412           const auto* recv = send_and_recv.second;
    413           // recv_op's output port is hard-coded to zero.
    414           curr_node_state.inputs.push_back(std::make_pair(recv, 0));
    415           auto& input_node_state = GetNodeStateOrCreateIt(input_node);
    416           input_node_state.outputs[input_node_port_num].push_back(send);
    417 
    418           // Cache the _Recv op for future use.
    419           cached_recv_nodes[recv_node] = recv;
    420         }
    421       }
    422     }
    423 
    424     // Special case: given feed nodes are ready at time 0.
    425     const bool given_as_feed =
    426         feed_nodes.find(curr_node->name()) != feed_nodes.end();
    427 
    428     // Default case: node without inputs are ready at time 0.
    429     const bool has_no_inputs = curr_node->input().empty();
    430 
    431     if (!IsRecv(*curr_node) && (given_as_feed || has_no_inputs)) {
    432       curr_node_state.time_ready = Costs::Duration();
    433       ready_nodes_->AddNode(curr_node);
    434       VLOG(3) << "Added ready node: " << curr_node->name();
    435     }
    436 
    437     feed_nodes.erase(curr_node->name());
    438 
    439     if (IsPersistentNode(curr_node)) {
    440       auto& device_state = device_[curr_node_device];
    441       for (int port_num = 0;
    442            port_num < curr_node_state.output_properties.size(); ++port_num) {
    443         device_state.persistent_nodes.insert(
    444             std::make_pair(curr_node, port_num));
    445       }
    446     }
    447   }
    448 
    449   if (ready_nodes_->Empty()) {
    450     return errors::InvalidArgument("No ready nodes in the graph.");
    451   }
    452 
    453   if (!feed_nodes.empty()) {
    454     return errors::InvalidArgument(
    455         strings::StrCat("Some feed nodes were not found in the graph: ",
    456                         str_util::Join(feed_nodes, ",")));
    457   }
    458   initialized_ = true;
    459   return Status::OK();
    460 }
    461 
    462 void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) {
    463   CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init().";
    464   // This method is called when NodeState is created and adds input and output
    465   // properties for a few exceptional cases that GraphProperties cannot provide
    466   // input/output properties.
    467   if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) {
    468     // _Send and _Recv ops created from VirtualScheduler have kAttrInputSrc
    469     // attr; normal _Send and _Recv ops (from the input graph) do not have that
    470     // attr.
    471     auto& node_state = node_map_[node];
    472     auto& inputs = node_state.input_properties;
    473     auto& outputs = node_state.output_properties;
    474 
    475     // _Send and _Recv ops are created from VirtualScheduler, so
    476     // there should be no inputs TensorProperties.
    477     CHECK(inputs.empty());
    478     CHECK(outputs.empty());
    479     const auto& attr = node->attr();
    480     // This is the original input source to the _Send and _Recv, and this
    481     // string includes "^" if it was control dependency, and output port
    482     /// (e.g., ":2") if the input source had multiple outputs.
    483     const auto& input_source_name = attr.at(kAttrInputSrc).s();
    484     if (IsControlInput(input_source_name)) {
    485       // Control dependency; regardless of the input source tensor size,
    486       // send 4B.
    487       OpInfo::TensorProperties control_message;
    488       control_message.set_dtype(DT_FLOAT);
    489       control_message.mutable_shape()->add_dim()->set_size(1);
    490       auto* value = control_message.mutable_value();
    491       value->add_float_val(1);
    492       inputs.push_back(control_message);
    493       outputs.push_back(control_message);
    494     } else {
    495       auto output_properties =
    496           graph_properties_.GetOutputProperties(NodeName(input_source_name));
    497       // Like with HasInputProperties, if a node does not have output
    498       // properties, it's likely it was pruned during the shape inference run.
    499       if (!output_properties.empty()) {
    500         const auto input_node_port_num = NodePosition(input_source_name);
    501         // Use the input source's output property as _Send and _Recv's input
    502         // property.
    503         CHECK_GT(output_properties.size(), input_node_port_num);
    504         inputs.push_back(output_properties[input_node_port_num]);
    505         outputs.push_back(output_properties[input_node_port_num]);
    506       }
    507     }
    508   }
    509 }
    510 
    511 float VirtualScheduler::Round2(const float x) const {
    512   // Not using std::round from <cmath> here because not all platforms seem to
    513   // support that (specifically Android).
    514   return ::round(100.0 * x) / 100.0;
    515 }
    516 
    517 bool VirtualScheduler::IsPersistentNode(const NodeDef* node) const {
    518   // Variables are persistent nodes.
    519   return IsVariable(*node);
    520 }
    521 
    522 string VirtualScheduler::DeviceName(const NodeDef* node) const {
    523   return placer_.get_canonical_device_name(*node);
    524 }
    525 
    526 string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const {
    527   // Replace the ":" characters that may be present in the device name with "_".
    528   // This makes it possible to then use the resulting string in a node name.
    529   return str_util::StringReplace(placer_.get_canonical_device_name(*node), ":",
    530                                  "_", true);
    531 }
    532 
    533 string VirtualScheduler::ChannelDeviceName(const NodeDef* from,
    534                                            const NodeDef* to) const {
    535   CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
    536   return kChannelDevice + "_from_" + SanitizedDeviceName(from) + "_to_" +
    537          SanitizedDeviceName(to);
    538 }
    539 
    540 std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
    541     const NodeDef* from, const NodeDef* to, const string& input_name) {
    542   CHECK(!initialized_) << "CreateSendRecv is called after Init().";
    543 
    544   // Connect "from" node to "to" node with _Send and _Recv such that
    545   // from -> _Send -> _Recv -> to.
    546   // _Send is placed on "Channel" device, and _Recv is on the same device
    547   // as "to" node.
    548   // input_node_name is the string from the "to" node to identify which output
    549   // we get from the "from" node.
    550 
    551   // Note that we use NodeState for scheduling, so _Send and _Recv
    552   // NodeDefs created here need not be correct: in terms of name,
    553   // input names, attrs, etc.
    554 
    555   auto input_node_port_num = NodePosition(input_name);
    556   string src_name;
    557   if (input_node_port_num >= 0) {
    558     src_name = strings::StrCat(from->name(), "_", input_node_port_num);
    559   } else {
    560     src_name = strings::StrCat(from->name(), "_minus1");
    561   }
    562 
    563   // _Send op.
    564   auto* send = new NodeDef();
    565   send->set_name("Send_" + src_name + "_from_" + SanitizedDeviceName(from) +
    566                  "_to_" + SanitizedDeviceName(to));
    567   send->set_op("_Send");
    568   send->add_input(from->name());
    569   send->set_device(ChannelDeviceName(from, to));
    570   auto& send_attr = *(send->mutable_attr());
    571   send_attr[kAttrInputSrc].set_s(input_name);
    572   send_attr[kAttrSrcDevice].set_s(DeviceName(from));
    573   send_attr[kAttrDstDevice].set_s(DeviceName(to));
    574 
    575   // _Recv op.
    576   auto* recv = new NodeDef();
    577   recv->set_name("Recv_" + src_name + "_on_" + SanitizedDeviceName(to));
    578   recv->set_op("_Recv");
    579   recv->add_input(send->name());
    580   recv->set_device(DeviceName(to));
    581   auto& recv_attr = *(recv->mutable_attr());
    582   recv_attr[kAttrInputSrc].set_s(input_name);
    583 
    584   // NodeState for _Send op.
    585   auto& send_node_state = GetNodeStateOrCreateIt(send);
    586   send_node_state.device_name = send->device();  // Set Channel device.
    587   send_node_state.inputs.push_back(std::make_pair(from, input_node_port_num));
    588   send_node_state.outputs[0].push_back(recv);
    589 
    590   // NodeState for _Recv op.
    591   auto& recv_node_state = GetNodeStateOrCreateIt(recv);
    592   recv_node_state.inputs.push_back(std::make_pair(send, 0));
    593   recv_node_state.outputs[0].push_back(to);
    594 
    595   // Keep the created nodes.
    596   additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(send));
    597   additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(recv));
    598 
    599   // Return _Send and _Recv.
    600   return std::make_pair(send, recv);
    601 }
    602 
    603 OpContext VirtualScheduler::GetCurrNode() const {
    604   const NodeDef* node = ready_nodes_->GetCurrNode();
    605 
    606   // Get the device from the placer.
    607   DeviceProperties device;
    608   device = placer_.get_device(*node);
    609 
    610   // Special case for _Send op.
    611   if (IsSend(*node)) {
    612     device.set_type(kChannelDevice);
    613   }
    614 
    615   // Construct OpContext.
    616   OpContext op_context;
    617   const auto& node_state = node_map_.at(node);
    618   op_context.name = node->name();
    619   op_context.device_name = node_state.device_name;
    620   auto& op_info = op_context.op_info;
    621   op_info.set_op(node->op());
    622   *op_info.mutable_attr() = node->attr();
    623   for (auto& input : node_state.input_properties) {
    624     *op_info.add_inputs() = input;
    625   }
    626   for (auto& output : node_state.output_properties) {
    627     *op_info.add_outputs() = output;
    628   }
    629   op_info.mutable_device()->Swap(&device);
    630 
    631   if (grappler_item_->graph.has_library()) {
    632     op_context.function_library = &grappler_item_->graph.library();
    633   }
    634   return op_context;
    635 }
    636 
    637 NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
    638   CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
    639 
    640   auto it = node_map_.find(node);
    641   if (it == node_map_.end()) {
    642     // Not found; create a NodeState for this node.
    643     it = node_map_.emplace(node, NodeState()).first;
    644     auto& node_state = it->second;
    645     node_state.input_properties =
    646         graph_properties_.GetInputProperties(node->name());
    647     node_state.output_properties =
    648         graph_properties_.GetOutputProperties(node->name());
    649 
    650     // Some ops may need further processing to the input / output properties:
    651     // _Send and _Recv.
    652     MaybeUpdateInputOutput(node);
    653 
    654     if (!IsSend(*node)) {
    655       node_state.device_name = DeviceName(node);
    656       // For _Send op, device_name will be set to Channel in CreateSendRecv().
    657     }
    658 
    659     // Initialize output port related data:
    660     // Assume the size of OutputProperties represents the number of output ports
    661     // of this node.
    662     for (size_t i = 0; i < node_state.output_properties.size(); ++i) {
    663       node_state.time_no_references[i] = Costs::Duration::max();
    664       node_state.num_outputs_executed[i] = 0;
    665       // Populate an empty vector for each port. The caller will add nodes
    666       // that use this port as input.
    667       node_state.outputs[i] = {};
    668     }
    669     // Port_num -1 is for control dependency.
    670     node_state.time_no_references[-1] = Costs::Duration::max();
    671     node_state.num_outputs_executed[-1] = 0;
    672     node_state.outputs[-1] = {};
    673   }
    674   return it->second;
    675 }
    676 
    677 int64 VirtualScheduler::CalculateOutputSize(
    678     const std::vector<OpInfo::TensorProperties>& output_properties,
    679     const int port_num) const {
    680   if (port_num < 0) {
    681     return 4;  // 4B for control dependency.
    682   }
    683 
    684   if (port_num >= output_properties.size()) {
    685     VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- "
    686             << "port_num: " << port_num
    687             << " >= output_properties.size(): " << output_properties.size();
    688     return 0;
    689   }
    690 
    691   const auto& output = output_properties[port_num];
    692   int64 output_size = DataTypeSize(BaseType(output.dtype()));
    693 
    694   for (const auto& dim : output.shape().dim()) {
    695     auto dim_size = dim.size();
    696     if (dim_size < 0) {
    697       // Zero output size if there's any unknown dim.
    698       output_size = 0;
    699       VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- "
    700               << "unknown dim: " << output_size;
    701       break;
    702     }
    703     output_size *= dim_size;
    704   }
    705 
    706   return output_size;
    707 }
    708 
    709 Costs& VirtualScheduler::FindOrCreateZero(const string& op_name,
    710                                           std::map<string, Costs>* op_cost) {
    711   auto it = op_cost->find(op_name);
    712   if (it == op_cost->end()) {
    713     // Note that default constructor of Costs sets some memory related fields
    714     // to unknown values so we should explicitly initialize it with ZeroCosts.
    715     it = op_cost->emplace(op_name, Costs::ZeroCosts()).first;
    716   }
    717   return it->second;
    718 }
    719 
    720 bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
    721   // Update graph_costs_ and per-op costs.
    722   graph_costs_ = CombineCosts(graph_costs_, node_costs);
    723   const NodeDef* node = ready_nodes_->GetCurrNode();
    724   const string& op_name = node->op();
    725 
    726   // Also keep track of op counts and times per op (with their shapes).
    727   OpContext op_context = GetCurrNode();
    728   string node_description = GetOpDescription(op_context.op_info);
    729   op_counts_[node_description] += 1;
    730   op_costs_[node_description] =
    731       std::make_pair(node_costs.execution_time.asMicroSeconds().count(),
    732                      !node_costs.inaccurate);
    733 
    734   auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_);
    735   op_cost = CombineCosts(op_cost, node_costs);
    736 
    737   // Update node and device states.
    738   auto& node_state = node_map_[node];
    739   auto& device = device_[node_state.device_name];
    740   device.nodes_executed.push_back(node);
    741   // Node is scheduled when the device is available AND all the inputs are
    742   // ready; hence, time_scheduled is time_ready if time_ready > device curr
    743   // time.
    744   node_state.time_scheduled =
    745       std::max(device.GetCurrTime(), node_state.time_ready);
    746   // Override device curr time with the time_scheduled.
    747   device.device_costs.execution_time = node_state.time_scheduled;
    748   device.device_costs = CombineCosts(device.device_costs, node_costs);
    749   auto curr_time = device.GetCurrTime();
    750   node_state.time_finished = curr_time;
    751 
    752   // Update device memory usage.
    753   if (!IsPersistentNode(node)) {
    754     for (const auto& port_num_output_pair : node_state.outputs) {
    755       int port_num = port_num_output_pair.first;
    756       // There's a chance that a specific output is not used at all.
    757       if (node_state.outputs[port_num].empty()) {
    758         node_state.time_no_references[port_num] = curr_time;
    759       } else {
    760         device.memory_usage +=
    761             CalculateOutputSize(node_state.output_properties, port_num);
    762         device.nodes_in_memory.insert(std::make_pair(node, port_num));
    763       }
    764     }
    765   }
    766 
    767   // Update device's per-op cost.
    768   auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost);
    769   device_op_cost = CombineCosts(device_op_cost, node_costs);
    770 
    771   VLOG(3) << "Op scheduled -- name: " << node->name() << ", op: " << node->op()
    772           << ", device: " << node->device()
    773           << ", ready: " << node_state.time_ready.count()
    774           << ", scheduled: " << node_state.time_scheduled.count()
    775           << ", finished: " << node_state.time_finished.count();
    776 
    777   // Increment num_inputs_ready of the output nodes
    778   for (const auto& port_num_output_pair : node_state.outputs) {
    779     for (auto* output_node : port_num_output_pair.second) {
    780       auto& output_state = node_map_[output_node];
    781       output_state.num_inputs_ready++;
    782       // Execute a node as soon as all its inputs are ready. Merge nodes are
    783       // special since they run as soon as one of their inputs becomes
    784       // available.
    785       if (output_state.num_inputs_ready == output_state.inputs.size() ||
    786           IsMerge(*output_node)) {
    787         // This output node is now ready.
    788         output_state.time_ready = curr_time;
    789         ready_nodes_->AddNode(output_node);
    790       }
    791     }
    792   }
    793 
    794   // Increment num_outputs_executed of the input nodes.
    795   for (const auto& input_port : node_state.inputs) {
    796     auto* input = input_port.first;
    797     auto port = input_port.second;
    798     auto& input_state = node_map_[input];
    799     input_state.num_outputs_executed[port]++;
    800     if (input_state.num_outputs_executed[port] ==
    801             input_state.outputs[port].size() &&
    802         !IsPersistentNode(input)) {
    803       // All the outputs are executed; no reference to this output port of
    804       // input node.
    805       input_state.time_no_references[port] = curr_time;
    806       auto& input_device = device_[input_state.device_name];
    807       input_device.memory_usage -=
    808           CalculateOutputSize(input_state.output_properties, port);
    809 
    810       input_device.nodes_in_memory.erase(std::make_pair(input, port));
    811     }
    812   }
    813 
    814   if (!IsPersistentNode(node)) {
    815     // Now that output memory is added and used up nodes are deallocated,
    816     // check max memory usage.
    817     if (device.memory_usage > device.max_memory_usage) {
    818       device.max_memory_usage = device.memory_usage;
    819       device.mem_usage_snapshot_at_peak = device.nodes_in_memory;
    820     }
    821   }
    822 
    823   // Remove the current node; assume FIFO.
    824   ready_nodes_->RemoveCurrNode();
    825 
    826   return !ready_nodes_->Empty();
    827 }
    828 
    829 Costs VirtualScheduler::Summary() const {
    830   // Print out basic execution summary.
    831   VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count();
    832   VLOG(1) << "Expected max memory: " << graph_costs_.max_memory;
    833   VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers;
    834   VLOG(1) << "Expected max per-op streaming buffers: "
    835           << graph_costs_.max_per_op_streaming;
    836 
    837   VLOG(1) << "Per-op execution time:";
    838   for (const auto& op_cost_pair : op_to_cost_) {
    839     const auto& op = op_cost_pair.first;
    840     const auto& cost = op_cost_pair.second.execution_time.count();
    841     const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
    842     if (cost) {  // Skip printing out zero-cost ops.
    843       VLOG(1) << " + " << op << " : " << (is_op_cost_accurate ? "" : "~")
    844               << cost;
    845     }
    846   }
    847 
    848   // Print per device summary
    849   VLOG(1) << "Devices:";
    850   Costs critical_path_costs = Costs::ZeroCosts();
    851 
    852   for (const auto& device : device_) {
    853     const auto& name = device.first;
    854     const auto& state = device.second;
    855 
    856     std::map<string, int64> op_to_memory;
    857     // First profile only persistent memory usage.
    858     int64 persistent_memory_usage = 0;
    859     std::set<string> persisent_ops;
    860     for (const auto& node_port : state.persistent_nodes) {
    861       const auto* node = node_port.first;
    862       const auto port = node_port.second;
    863       const auto output_size =
    864           CalculateOutputSize(node_map_.at(node).output_properties, port);
    865       persistent_memory_usage += output_size;
    866       op_to_memory[node->op()] += output_size;
    867       persisent_ops.insert(node->op());
    868     }
    869     int64 max_memory_usage = persistent_memory_usage + state.max_memory_usage;
    870     critical_path_costs.estimated_max_memory_per_device[name] =
    871         max_memory_usage;
    872 
    873     const Costs::NanoSeconds wall_time_ns = state.GetCurrTime();
    874     VLOG(1) << "Device = " << name
    875             << ", num_nodes = " << state.nodes_executed.size()
    876             << ", wall_time_ns = " << wall_time_ns.count() << ", memory usage: "
    877             << "persistent = "
    878             << strings::HumanReadableNumBytes(persistent_memory_usage)
    879             << ", peak = "
    880             << strings::HumanReadableNumBytes(state.max_memory_usage)
    881             << ", total = " << strings::HumanReadableNumBytes(max_memory_usage)
    882             << ", at the end: "
    883             << strings::HumanReadableNumBytes(state.memory_usage);
    884 
    885     VLOG(1) << "Per-op execution time (and memory usage at peak memory usage):";
    886 
    887     // Profile non-persistent op memory usage.
    888     for (const auto& node_port : state.mem_usage_snapshot_at_peak) {
    889       const auto* node = node_port.first;
    890       const auto port = node_port.second;
    891       op_to_memory[node->op()] +=
    892           CalculateOutputSize(node_map_.at(node).output_properties, port);
    893     }
    894     Costs::NanoSeconds total_compute_time_ns;
    895     bool is_total_cost_accurate = true;
    896     for (const auto& op_cost_pair : state.op_to_cost) {
    897       const auto& op = op_cost_pair.first;
    898       const auto& cost = op_cost_pair.second.execution_time.count();
    899       total_compute_time_ns += op_cost_pair.second.execution_time;
    900       const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
    901       if (!is_op_cost_accurate) {
    902         is_total_cost_accurate = false;
    903       }
    904 
    905       int64 op_mem_usage = 0;
    906       auto it = op_to_memory.find(op);
    907       if (it != op_to_memory.end()) {
    908         op_mem_usage = it->second;
    909       }
    910 
    911       const float mem_usage_percent =
    912           max_memory_usage > 0 ? Round2(100.0 * op_mem_usage / max_memory_usage)
    913                                : 0.0;
    914       if (cost || mem_usage_percent > 1.0) {
    915         // Print out only non-zero cost ops or ops with > 1% memory usage.
    916         VLOG(1) << " + " << op << " : " << (is_op_cost_accurate ? "" : "~")
    917                 << cost << " (" << strings::HumanReadableNumBytes(op_mem_usage)
    918                 << " [" << mem_usage_percent << "%] "
    919                 << (persisent_ops.count(op) > 0 ? ": persistent op)" : ")");
    920       }
    921     }
    922 
    923     int utilization = 0;
    924     if (wall_time_ns.count() > 0) {
    925       utilization = total_compute_time_ns.count() * 100 / wall_time_ns.count();
    926     }
    927     VLOG(1) << "Device = " << name << ", total_compute_time_ns = "
    928             << (is_total_cost_accurate ? "" : "~")
    929             << total_compute_time_ns.count()
    930             << ", utilization = " << utilization << "%";
    931 
    932     if (critical_path_costs.execution_time <= state.GetCurrTime()) {
    933       critical_path_costs = state.device_costs;
    934     }
    935   }
    936 
    937   if (VLOG_IS_ON(2)) {
    938     // Also log the op description and their corresponding counts.
    939     VLOG(2) << "Node description, counts, cost:";
    940     for (const auto& item : op_counts_) {
    941       int cost;
    942       bool is_cost_accurate;
    943       std::tie(cost, is_cost_accurate) = op_costs_.at(item.first);
    944       VLOG(2) << "Node: " << item.first << ", Count: " << item.second
    945               << ", Individual Cost: " << (is_cost_accurate ? "" : "~") << cost;
    946     }
    947   }
    948 
    949   VLOG(1) << "Critical path execution time: "
    950           << critical_path_costs.execution_time.count();
    951   return critical_path_costs;
    952 }
    953 
    954 Costs VirtualScheduler::Summary(RunMetadata* metadata) {
    955   if (metadata != nullptr) {
    956     StepStats* stepstats = metadata->mutable_step_stats();
    957     for (const auto& device : device_) {
    958       GraphDef* device_partition_graph = metadata->add_partition_graphs();
    959       DeviceStepStats* device_stepstats = stepstats->add_dev_stats();
    960       device_stepstats->set_device(device.first);
    961       for (const auto& node_def : device.second.nodes_executed) {
    962         const NodeState& nodestate = node_map_.at(node_def);
    963         NodeExecStats* node_stats = device_stepstats->add_node_stats();
    964         uint64 total_output_size = 0;
    965         for (int slot = 0; slot < nodestate.output_properties.size(); slot++) {
    966           const auto& properties = nodestate.output_properties[slot];
    967           NodeOutput* no = node_stats->add_output();
    968           no->set_slot(slot);
    969           TensorDescription* tensor_descr = no->mutable_tensor_description();
    970           tensor_descr->set_dtype(properties.dtype());
    971           *tensor_descr->mutable_shape() = properties.shape();
    972           // Optional allocation description.
    973           const auto tensor_size =
    974               CalculateOutputSize(nodestate.output_properties, slot);
    975           total_output_size += tensor_size;
    976           tensor_descr->mutable_allocation_description()->set_requested_bytes(
    977               tensor_size);
    978           tensor_descr->mutable_allocation_description()->set_allocated_bytes(
    979               tensor_size);
    980         }
    981         node_stats->set_timeline_label(node_def->op());
    982         node_stats->set_node_name(node_def->name());
    983         node_stats->set_op_start_rel_micros(0);
    984         node_stats->set_all_start_micros(
    985             nodestate.time_scheduled.asMicroSeconds().count());
    986         node_stats->set_op_end_rel_micros(
    987             nodestate.time_finished.asMicroSeconds().count() -
    988             nodestate.time_scheduled.asMicroSeconds().count());
    989         node_stats->set_all_end_rel_micros(
    990             nodestate.time_finished.asMicroSeconds().count() -
    991             nodestate.time_scheduled.asMicroSeconds().count());
    992         auto* mem_stats = node_stats->mutable_memory_stats();
    993         // VirtualScheduler does not specify scratch pad memory usage.
    994         mem_stats->set_temp_memory_size(0);
    995         int64 persistent_memory_size = 0;
    996         if (IsPersistentNode(node_def)) {
    997           persistent_memory_size = total_output_size;
    998         }
    999         mem_stats->set_persistent_memory_size(persistent_memory_size);
   1000         *device_partition_graph->add_node() = *node_def;
   1001       }
   1002     }
   1003   }
   1004   return Summary();
   1005 }
   1006 
   1007 const std::unordered_map<string, int64> VirtualScheduler::GetPeakMemoryUsage()
   1008     const {
   1009   std::unordered_map<string, int64> result;
   1010   for (const auto& device : device_) {
   1011     const string& name = device.first;
   1012     const DeviceState& state = device.second;
   1013     result[name] = state.max_memory_usage;
   1014   }
   1015   return result;
   1016 }
   1017 
   1018 }  // end namespace grappler
   1019 }  // end namespace tensorflow
   1020