Home | History | Annotate | Download | only in costs
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
     17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
     18 
     19 #include <list>
     20 #include <memory>
     21 #include <unordered_map>
     22 #include <unordered_set>
     23 
     24 #include "tensorflow/core/framework/node_def.pb.h"
     25 #include "tensorflow/core/framework/step_stats.pb.h"
     26 #include "tensorflow/core/grappler/costs/cost_estimator.h"
     27 #include "tensorflow/core/grappler/costs/graph_properties.h"
     28 #include "tensorflow/core/grappler/costs/op_context.h"
     29 #include "tensorflow/core/grappler/costs/virtual_placer.h"
     30 #include "tensorflow/core/grappler/grappler_item.h"
     31 
     32 namespace tensorflow {
     33 namespace grappler {
     34 
     35 struct NodeState {
     36   // A node (i.e., an op) takes a set of input:port pairs and produces
     37   // a set of output ports.
     38 
     39   // Cross references to input and output nodes from graphdef.
     40   std::vector<std::pair<const NodeDef*, int>> inputs;  // Input, port pairs.
     41   // List of output nodes (a list of nodes that takes this output port as input)
     42   // keyed by port_num. Note that port_num -1 is used for control dependency.
     43   std::unordered_map<int, std::vector<const NodeDef*>> outputs;
     44 
     45   // Info from GraphProperties.
     46   std::vector<OpInfo::TensorProperties> input_properties;
     47   std::vector<OpInfo::TensorProperties> output_properties;
     48 
     49   // Canonical device name used within VirtualScheduler.
     50   string device_name;
     51 
     52   // States updated as scheduling nodes.
     53   int num_inputs_ready;
     54   std::unordered_map<int, int> num_outputs_executed;
     55   Costs::Duration time_ready;
     56   Costs::Duration time_scheduled;
     57   Costs::Duration time_finished;
     58   // Time that all the consumers are executed (hence, no need to keep this
     59   // output in memory), keyed by port_num.
     60   std::unordered_map<int, Costs::Duration> time_no_references;
     61 
     62   // Note that a node may have multiple output ports. The length of outputs,
     63   // num_outputs_executed, and time_no_references should be
     64   // identical when a NodeState is fully initialized.
     65   // They should be 1 + output_properties.size() as we add [-1] for control
     66   // dependency.
     67 
     68   // Node will be ready to be executed at time_ready, scheduled at
     69   // time_scheduled, and finishes execution at time_finished.
     70   // Each output port uses up memory space from time_scheduled to its
     71   // time_no_references.
     72 
     73   NodeState() {
     74     num_inputs_ready = 0;
     75     time_ready = Costs::Duration::max();
     76     time_scheduled = Costs::Duration::max();
     77     time_finished = Costs::Duration::max();
     78     // Note that num_outputs_executed and time_no_references are not initialized
     79     // here, since we don't know the size (i.e., # outputs for this node).
     80   }
     81 };
     82 
     83 struct DeviceState {
     84   // Nodes executed on this device in execution order.
     85   std::vector<const NodeDef*> nodes_executed;
     86 
     87   struct NodePairHash {
     88    public:
     89     const std::size_t operator()(
     90         const std::pair<const NodeDef*, int>& element) const {
     91       return std::hash<const NodeDef*>()(element.first);
     92     }
     93   };
     94 
     95   // Nodes currently allocated in memory: set of NodeDef* and port_num pairs
     96   // so that we can track which output of the node is in memory.
     97   std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash>
     98       nodes_in_memory;
     99 
    100   // Nodes allocated in memory persistently: e.g., Variables.
    101   std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash>
    102       persistent_nodes;
    103 
    104   // Snapshot of nodes_in_memory, when memory usage is at peak.
    105   // Same to nodes_in_memory, it's a set of NodeDef* and port_num pairs.
    106   std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash>
    107       mem_usage_snapshot_at_peak;
    108 
    109   Costs device_costs;
    110   std::map<string, Costs> op_to_cost;    // Per-op cost.
    111   std::map<string, int64> op_to_memory;  // Per-op memory usage at peak usage.
    112   int64 memory_usage;
    113   int64 max_memory_usage;
    114 
    115   DeviceState() {
    116     device_costs = Costs::ZeroCosts();
    117     memory_usage = 0;
    118     max_memory_usage = 0;
    119   }
    120 
    121   Costs::Duration GetCurrTime() const { return device_costs.execution_time; }
    122 };
    123 
    124 // ReadyNodeManager (abstract class):
    125 // Keeps ready nodes and picks the best one to be scheduled.
    126 class ReadyNodeManager {
    127  public:
    128   ReadyNodeManager() {}
    129   virtual ~ReadyNodeManager() {}
    130   virtual void Init(
    131       const std::unordered_map<const NodeDef*, NodeState>* node_state) {}
    132   virtual void AddNode(const NodeDef* node) = 0;
    133   virtual const NodeDef* GetCurrNode() = 0;
    134   virtual void RemoveCurrNode() = 0;
    135   virtual bool Empty() const = 0;
    136 };
    137 
    138 class FIFOManager : public ReadyNodeManager {
    139  public:
    140   FIFOManager() : ReadyNodeManager() {}
    141   ~FIFOManager() override {}
    142   void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state)
    143       override {}
    144   void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
    145   const NodeDef* GetCurrNode() override {
    146     CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
    147     return nodes_.front();
    148   }
    149   void RemoveCurrNode() override { nodes_.pop_front(); }
    150   bool Empty() const override { return nodes_.empty(); }
    151 
    152  private:
    153   std::list<const NodeDef*> nodes_;
    154 };
    155 
    156 // The LIFOManager schedules nodes by returning the last one added to the
    157 // scheduler. A node is executed and then its ready outputs are newly added to
    158 // the scheduler, so the LIFOManager will return outputs to a node following
    159 // that node's execution.
    160 class LIFOManager : public ReadyNodeManager {
    161  public:
    162   LIFOManager() : ReadyNodeManager() {}
    163   ~LIFOManager() override {}
    164   void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state)
    165       override {}
    166   void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
    167   const NodeDef* GetCurrNode() override;
    168   void RemoveCurrNode() override;
    169   bool Empty() const override { return nodes_.empty(); }
    170 
    171  private:
    172   std::list<const NodeDef*> nodes_;
    173   // Keep track of the current node being executed by saving its position.
    174   // Necessary because nodes may be added to the end of the list while a node is
    175   // executing, and we want to remove the correct node (the one that is
    176   // executing) rather than the new ones being added.
    177   std::list<const NodeDef*>::iterator curr_pos_ = nodes_.end();
    178 };
    179 
    180 // FirstReadyManager picks a node with the minimum time_ready value.
    181 // Behavior is unknown if there are more than one nodes with the minimum
    182 // time_ready value (it depends on C++ STL push_heap and pop_heap).
    183 class FirstReadyManager : public ReadyNodeManager {
    184  public:
    185   FirstReadyManager();
    186   void Init(
    187       const std::unordered_map<const NodeDef*, NodeState>* node_state) override;
    188   ~FirstReadyManager() override {}
    189   void AddNode(const NodeDef* node) override { waiting_queue_.push_back(node); }
    190   const NodeDef* GetCurrNode() override;
    191   void RemoveCurrNode() override;
    192   bool Empty() const override;
    193 
    194  private:
    195   // Move all the nodes in the waiting_queue_ to nodes_.
    196   void DrainWaitingQueue();
    197 
    198   // nodes_ is the main queue, where we construct heap, and the front is the
    199   // current node.
    200   std::vector<const NodeDef*> nodes_;
    201   // Newly added nodes are added to waiting_queue_. That way, GetCurrNode(),
    202   // wihch returns the front of the nodes_, always returns the same node,
    203   // even if any of new nodes has time_ready smaller than the current node's.
    204   std::vector<const NodeDef*> waiting_queue_;
    205   // Comparator functor for heap; stl heap is max heap, so we use "greater than"
    206   // functor for keeping the smallest time_ready node at the front of heap.
    207   std::function<bool(const NodeDef*, const NodeDef*)> greater_;
    208 
    209   // NodeState structure from VirtualScheduler to get time_ready of ready nodes.
    210   // Not owned by FirstReadyManager.
    211   const std::unordered_map<const NodeDef*, NodeState>* node_state_;
    212 };
    213 
    214 // CompositeNodeManager has a few other NodeManagers: per-device LIFO for normal
    215 // ops (neither _Send nor _Recv) and FirstyReadyManagers for _Send ops and _Recv
    216 // ops, and then it chooses FirstReady among the ops chosen from each
    217 // internal NodeManagers. The objective is to maximize producer-consumer
    218 // locality within device, while processing nodes across devices, including
    219 // _Send and _Recv, fairly, in terms of their time_ready.
    220 class CompositeNodeManager : public ReadyNodeManager {
    221  public:
    222   CompositeNodeManager();
    223   ~CompositeNodeManager() override {}
    224 
    225   void Init(
    226       const std::unordered_map<const NodeDef*, NodeState>* node_state) override;
    227   void AddNode(const NodeDef* node) override;
    228   const NodeDef* GetCurrNode() override;
    229   void RemoveCurrNode() override;
    230   bool Empty() const override;
    231 
    232  private:
    233   // Internal ready node managers:
    234   // LIFO for normal ops to maximize producer consumer locality.
    235   // One LIFO per device.
    236   std::unordered_map<string, LIFOManager> ops_lifo_map_;
    237   // FirstReady for send and recv. Handle send and recv separately ensures that
    238   // send and recv do not block previously read ops with LIFO schedule.
    239   FirstReadyManager send_manager_;
    240   FirstReadyManager recv_manager_;
    241 
    242   // NodeState structure from VirtualScheduler to get time_ready of ready nodes.
    243   // Not owned by FirstReadyManager.
    244   const std::unordered_map<const NodeDef*, NodeState>* node_state_;
    245 
    246   // Cached curr node. Set back to nullptr from RemoveCurrNode().
    247   const NodeDef* curr_node_;
    248 };
    249 
    250 // The virtual scheduler emulates execution of nodes in a graph, considering
    251 // dependencies, device, etc.
    252 class VirtualScheduler {
    253  public:
    254   VirtualScheduler(const GrapplerItem* grappler_item,
    255                    const bool use_static_shapes, Cluster* cluster,
    256                    ReadyNodeManager* ready_nodes);
    257   // Initializes NodeState and DeviceState from grappler_item_ and
    258   // graph_properties_.
    259   Status Init();
    260 
    261   OpContext GetCurrNode() const;
    262 
    263   // Returns true if there is any node to be scheduled.
    264   bool MarkCurrNodeExecuted(const Costs& node_costs);
    265 
    266   // Prints out summary of execution (timing, memory usage, etc.)
    267   Costs Summary() const;
    268   // Like the above, but writes detailed stats to RunMetadata.
    269   // If metadata is nullptr, then just calls and return Summary().
    270   Costs Summary(RunMetadata* metadata);
    271   // Methods called from constructor.
    272   static ReadyNodeManager* ReadyNodeManagerFactory(
    273       const string& ready_node_manager);
    274 
    275   // Return per device peak memory usage.
    276   const std::unordered_map<string, int64> GetPeakMemoryUsage() const;
    277 
    278  protected:
    279   const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
    280     return &device_;
    281   }
    282   const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
    283     return &node_map_;
    284   }
    285 
    286   // Returns the size of output at port_num (unit: bytes). A special case is
    287   // port_num -1, which is for control dependency and assumed to be 4 bytes.
    288   int64 CalculateOutputSize(
    289       const std::vector<OpInfo::TensorProperties>& output_properties,
    290       const int port_num) const;
    291 
    292  private:
    293   // Constants.
    294   const string kAttrInputSrc = "input_source_";
    295   const string kAttrSrcDevice = "src_device_";
    296   const string kAttrDstDevice = "dst_device_";
    297   const string kChannelDevice = "Channel";
    298 
    299   // Methods called from Init(). Fails if initialize_ is set.
    300   void MaybeUpdateInputOutput(const NodeDef* node);
    301   NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
    302   std::pair<const NodeDef*, const NodeDef*> CreateSendRecv(
    303       const NodeDef* from, const NodeDef* to, const string& input_name);
    304   string DeviceName(const NodeDef* node) const;
    305   string SanitizedDeviceName(const NodeDef* node) const;
    306   string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const;
    307 
    308   // Helper methods.
    309   Costs& FindOrCreateZero(const string& op_name,
    310                           std::map<string, Costs>* op_cost);
    311   float Round2(const float x) const;
    312   bool IsPersistentNode(const NodeDef* node) const;
    313 
    314   // Scheduler states:
    315   ReadyNodeManager* ready_nodes_;  // Not owned.
    316   std::unordered_map<const NodeDef*, NodeState> node_map_;
    317   std::unordered_map<string, DeviceState> device_;
    318 
    319   // Pool of NodeDefs for SendRecv and Identity ops created.
    320   std::vector<std::unique_ptr<NodeDef>> additional_nodes_;
    321 
    322   // Stats:
    323   std::map<string, int> op_counts_;  // Op counts with key with input shape.
    324   // Individual op costs (with input shapes).
    325   // Boolean field for whether the cost is accurate.
    326   std::map<string, std::pair<int, bool>> op_costs_;
    327 
    328   Costs graph_costs_;                   // Graph cost.
    329   std::map<string, Costs> op_to_cost_;  // Per-op cost.
    330 
    331   // Auxilliary data structures for constructing NodeState and DeviceState.
    332   GraphProperties graph_properties_;
    333   Cluster* cluster_;  // Not owned.
    334 
    335   const GrapplerItem* grappler_item_;  // Not owned.
    336   bool use_static_shapes_;
    337   bool initialized_;
    338 
    339   VirtualPlacer placer_;  // owned.
    340 };
    341 
    342 }  // namespace grappler
    343 }  // end namespace tensorflow
    344 
    345 #endif  // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
    346