Home | History | Annotate | Download | only in costs
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
     17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
     18 
     19 #include <list>
     20 #include <memory>
     21 #include <unordered_map>
     22 #include <unordered_set>
     23 
     24 #include "tensorflow/core/framework/node_def.pb.h"
     25 #include "tensorflow/core/framework/step_stats.pb.h"
     26 #include "tensorflow/core/grappler/costs/cost_estimator.h"
     27 #include "tensorflow/core/grappler/costs/graph_properties.h"
     28 #include "tensorflow/core/grappler/costs/op_context.h"
     29 #include "tensorflow/core/grappler/costs/virtual_placer.h"
     30 #include "tensorflow/core/grappler/grappler_item.h"
     31 
     32 namespace tensorflow {
     33 namespace grappler {
     34 
     35 struct NodeState {
     36   // A node (i.e., an op) takes a set of input:port pairs and produces
     37   // a set of output ports.
     38 
     39   // Cross references to input and output nodes from graphdef.
     40   std::vector<std::pair<const NodeDef*, int>> inputs;  // Input, port pairs.
     41   // List of output nodes (a list of nodes that takes this output port as input)
     42   // keyed by port_num. Note that port_num -1 is used for control dependency.
     43   std::unordered_map<int, std::vector<const NodeDef*>> outputs;
     44 
     45   // Info from GraphProperties.
     46   std::vector<OpInfo::TensorProperties> input_properties;
     47   std::vector<OpInfo::TensorProperties> output_properties;
     48 
     49   // Canonical device name used within VirtualScheduler.
     50   string device_name;
     51 
     52   // States updated as scheduling nodes.
     53   int num_inputs_ready;
     54   std::unordered_map<int, int> num_outputs_executed;
     55   Costs::Duration time_ready;
     56   Costs::Duration time_scheduled;
     57   Costs::Duration time_finished;
     58   // Time that all the consumers are executed (hence, no need to keep this
     59   // output in memory), keyed by port_num.
     60   std::unordered_map<int, Costs::Duration> time_no_references;
     61 
     62   // Note that a node may have multiple output ports. The length of outputs,
     63   // num_outputs_executed, and time_no_references should be
     64   // identical when a NodeState is fully initialized.
     65   // They should be 1 + output_properties.size() as we add [-1] for control
     66   // dependency.
     67 
     68   // Node will be ready to be executed at time_ready, scheduled at
     69   // time_scheduled, and finishes execution at time_finished.
     70   // Each output port uses up memory space from time_scheduled to its
     71   // time_no_references.
     72 
     73   // How many times this node has been executed, e.g. in a while loop.
     74   int execution_count;
     75 
     76   NodeState() {
     77     num_inputs_ready = 0;
     78     time_ready = Costs::Duration::max();
     79     time_scheduled = Costs::Duration::max();
     80     time_finished = Costs::Duration::max();
     81     execution_count = 0;
     82     // Note that num_outputs_executed and time_no_references are not initialized
     83     // here, since we don't know the size (i.e., # outputs for this node).
     84   }
     85 };
     86 
     87 struct DeviceState {
     88   // Nodes executed on this device in execution order.
     89   std::vector<const NodeDef*> nodes_executed;
     90 
     91   struct NodePairHash {
     92    public:
     93     const std::size_t operator()(
     94         const std::pair<const NodeDef*, int>& element) const {
     95       return std::hash<const NodeDef*>()(element.first);
     96     }
     97   };
     98 
     99   // Nodes currently allocated in memory: set of NodeDef* and port_num pairs
    100   // so that we can track which output of the node is in memory.
    101   std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash>
    102       nodes_in_memory;
    103 
    104   // Nodes allocated in memory persistently: e.g., Variables.
    105   std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash>
    106       persistent_nodes;
    107 
    108   // Snapshot of nodes_in_memory, when memory usage is at peak.
    109   // Same to nodes_in_memory, it's a set of NodeDef* and port_num pairs.
    110   std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash>
    111       mem_usage_snapshot_at_peak;
    112 
    113   Costs device_costs;
    114   std::map<string, Costs> op_to_cost;  // Per-op cost.
    115 
    116   int64 memory_usage;      // Current temporary memory usage
    117   int64 max_memory_usage;  // Max temporary memory usage
    118 
    119   DeviceState() {
    120     device_costs = Costs::ZeroCosts();
    121     device_costs.num_ops_total = 0;
    122     memory_usage = 0;
    123     max_memory_usage = 0;
    124   }
    125 
    126   Costs::Duration GetCurrTime() const { return device_costs.execution_time; }
    127 };
    128 
    129 // ReadyNodeManager (abstract class):
    130 // Keeps ready nodes and picks the best one to be scheduled.
    131 class ReadyNodeManager {
    132  public:
    133   ReadyNodeManager() {}
    134   virtual ~ReadyNodeManager() {}
    135   virtual void Init(
    136       const std::unordered_map<const NodeDef*, NodeState>* node_state) {}
    137   virtual void AddNode(const NodeDef* node) = 0;
    138   virtual const NodeDef* GetCurrNode() = 0;
    139   virtual void RemoveCurrNode() = 0;
    140   virtual bool Empty() const = 0;
    141 };
    142 
    143 class FIFOManager : public ReadyNodeManager {
    144  public:
    145   FIFOManager() : ReadyNodeManager() {}
    146   ~FIFOManager() override {}
    147   void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state)
    148       override {}
    149   void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
    150   const NodeDef* GetCurrNode() override {
    151     CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
    152     return nodes_.front();
    153   }
    154   void RemoveCurrNode() override { nodes_.pop_front(); }
    155   bool Empty() const override { return nodes_.empty(); }
    156 
    157  private:
    158   std::list<const NodeDef*> nodes_;
    159 };
    160 
    161 // The LIFOManager schedules nodes by returning the last one added to the
    162 // scheduler. A node is executed and then its ready outputs are newly added to
    163 // the scheduler, so the LIFOManager will return outputs to a node following
    164 // that node's execution.
    165 class LIFOManager : public ReadyNodeManager {
    166  public:
    167   LIFOManager() : ReadyNodeManager() {}
    168   ~LIFOManager() override {}
    169   void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state)
    170       override {}
    171   void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
    172   const NodeDef* GetCurrNode() override;
    173   void RemoveCurrNode() override;
    174   bool Empty() const override { return nodes_.empty(); }
    175 
    176  private:
    177   std::list<const NodeDef*> nodes_;
    178   // Keep track of the current node being executed by saving its position.
    179   // Necessary because nodes may be added to the end of the list while a node is
    180   // executing, and we want to remove the correct node (the one that is
    181   // executing) rather than the new ones being added.
    182   std::list<const NodeDef*>::iterator curr_pos_ = nodes_.end();
    183 };
    184 
    185 // FirstReadyManager picks a node with the minimum time_ready value.
    186 // Behavior is unknown if there are more than one nodes with the minimum
    187 // time_ready value (it depends on C++ STL push_heap and pop_heap).
    188 class FirstReadyManager : public ReadyNodeManager {
    189  public:
    190   FirstReadyManager();
    191   void Init(
    192       const std::unordered_map<const NodeDef*, NodeState>* node_state) override;
    193   ~FirstReadyManager() override {}
    194   void AddNode(const NodeDef* node) override { waiting_queue_.push_back(node); }
    195   const NodeDef* GetCurrNode() override;
    196   void RemoveCurrNode() override;
    197   bool Empty() const override;
    198 
    199  private:
    200   // Move all the nodes in the waiting_queue_ to nodes_.
    201   void DrainWaitingQueue();
    202 
    203   // nodes_ is the main queue, where we construct heap, and the front is the
    204   // current node.
    205   std::vector<const NodeDef*> nodes_;
    206   // Newly added nodes are added to waiting_queue_. That way, GetCurrNode(),
    207   // which returns the front of the nodes_, always returns the same node,
    208   // even if any of new nodes has time_ready smaller than the current node's.
    209   std::vector<const NodeDef*> waiting_queue_;
    210   // Comparator functor for heap; stl heap is max heap, so we use "greater than"
    211   // functor for keeping the smallest time_ready node at the front of heap.
    212   std::function<bool(const NodeDef*, const NodeDef*)> greater_;
    213 
    214   // NodeState structure from VirtualScheduler to get time_ready of ready nodes.
    215   // Not owned by FirstReadyManager.
    216   const std::unordered_map<const NodeDef*, NodeState>* node_state_;
    217 };
    218 
    219 // CompositeNodeManager has a few other NodeManagers: per-device LIFO for normal
    220 // ops (neither _Send nor _Recv) and FirstReadyManagers for _Send ops and _Recv
    221 // ops, and then it chooses FirstReady among the ops chosen from each
    222 // internal NodeManagers. The objective is to maximize producer-consumer
    223 // locality within device, while processing nodes across devices, including
    224 // _Send and _Recv, fairly, in terms of their time_ready.
    225 class CompositeNodeManager : public ReadyNodeManager {
    226  public:
    227   CompositeNodeManager();
    228   ~CompositeNodeManager() override {}
    229 
    230   void Init(
    231       const std::unordered_map<const NodeDef*, NodeState>* node_state) override;
    232   void AddNode(const NodeDef* node) override;
    233   const NodeDef* GetCurrNode() override;
    234   void RemoveCurrNode() override;
    235   bool Empty() const override;
    236 
    237  private:
    238   // Internal ready node managers:
    239   // LIFO for normal ops to maximize producer consumer locality.
    240   // One LIFO per device.
    241   std::unordered_map<string, LIFOManager> ops_lifo_map_;
    242   // FirstReady for send and recv. Handle send and recv separately ensures that
    243   // send and recv do not block previously read ops with LIFO schedule.
    244   FirstReadyManager send_manager_;
    245   FirstReadyManager recv_manager_;
    246 
    247   // NodeState structure from VirtualScheduler to get time_ready of ready nodes.
    248   // Not owned by FirstReadyManager.
    249   const std::unordered_map<const NodeDef*, NodeState>* node_state_;
    250 
    251   // Cached curr node. Set back to nullptr from RemoveCurrNode().
    252   const NodeDef* curr_node_;
    253 };
    254 
    255 // Constructs a ready node manager from the given string.
    256 std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
    257     const string& ready_node_manager);
    258 
    259 // The virtual scheduler emulates execution of nodes in a graph, considering
    260 // dependencies, device, etc.
    261 class VirtualScheduler {
    262  public:
    263   // Does not take ownership of cluster or ready_nodes.
    264   VirtualScheduler(const bool use_static_shapes,
    265                    const bool use_aggressive_shape_inference, Cluster* cluster,
    266                    ReadyNodeManager* ready_nodes);
    267   // Initializes the scheduler for the specific grappler item.
    268   // Should be called immediately after the c'tor or when the scheduler will be
    269   // reused for a new grappler item. All internal states of the scheduler
    270   // related to the previous grappler item will be reset/cleared.
    271   //
    272   // This function should be called at least once after the scheduler is
    273   // constructed. An uninitialized or failed-to-initialize scheduler will cause
    274   // undefined behavior.
    275   Status Init(const GrapplerItem* item);
    276 
    277   OpContext GetCurrNode() const;
    278 
    279   // Returns true if there is any node to be scheduled.
    280   bool MarkCurrNodeExecuted(const Costs& node_costs);
    281 
    282   // Prints out summary of execution (timing, memory usage, etc.)
    283   Costs Summary() const;
    284   // Like the above, but writes detailed stats to RunMetadata.
    285   // If metadata is nullptr, then just calls and return Summary().
    286   Costs Summary(RunMetadata* metadata);
    287   // Generate RunMetadata's step_stats and partition_graphs fields from results
    288   // of the virtual execution of the graph.
    289   void GenerateRunMetadata(RunMetadata* metadata);
    290 
    291   // Return per device peak memory usage.
    292   const std::unordered_map<string, int64> GetPeakMemoryUsage() const;
    293 
    294   const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
    295     return &device_;
    296   }
    297   const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
    298     return &node_map_;
    299   }
    300 
    301   void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; }
    302 
    303  private:
    304   // Constants.
    305   const string kAttrInputSrc = "input_source_";
    306   const string kAttrSrcDevice = "send_device";
    307   const string kAttrDstDevice = "recv_device";
    308   const string kAttrTensorName = "tensor_name";
    309   const string kChannelDevice = "Channel";
    310 
    311   // Methods called from Init(). Fails if initialize_ is set.
    312   void MaybeUpdateInputOutput(const NodeDef* node);
    313   NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
    314   std::pair<const NodeDef*, const NodeDef*> CreateSendRecv(
    315       const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
    316       const string& input_name);
    317   string DeviceName(const NodeDef* node) const;
    318   string SanitizedDeviceName(const NodeDef* node) const;
    319   string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const;
    320 
    321   // Helper methods.
    322   Costs& FindOrCreateZero(const string& op_name,
    323                           std::map<string, Costs>* op_cost);
    324   float Round2(const float x) const;
    325   bool IsPersistentNode(const NodeDef* node) const;
    326   void AddOutputNodesToReadyQueue(const NodeDef* node,
    327                                   const Costs::Duration& curr_time);
    328 
    329   // Scheduler states:
    330   ReadyNodeManager* ready_nodes_;  // Not owned.
    331   std::unordered_map<const NodeDef*, NodeState> node_map_;
    332   std::unordered_map<string, DeviceState> device_;
    333 
    334   // Pool of NodeDefs for SendRecv and Identity ops created.
    335   std::vector<std::unique_ptr<NodeDef>> additional_nodes_;
    336 
    337   // Stats:
    338   // Op counts with key with input shape.
    339   // Example key: "[Op=AssignSub, input_shapes=[[7,1,160,160][7,1,160,160]]"
    340   std::map<string, int> op_counts_;
    341   // Individual op costs with key with input shape.
    342   // Integer field for execution time in micro seconds.
    343   // Boolean field for whether the cost is accurate.
    344   std::map<string, std::pair<int, bool>> op_costs_;
    345 
    346   Costs graph_costs_;                   // Graph cost.
    347   std::map<string, Costs> op_to_cost_;  // Per-op cost.
    348 
    349   // Auxiliary data structures for constructing NodeState and DeviceState.
    350   std::unique_ptr<GraphProperties> graph_properties_;  // Initialized in Init().
    351   Cluster* cluster_;                                   // Not owned.
    352 
    353   const GrapplerItem* grappler_item_;  // Not owned.
    354   bool use_static_shapes_;
    355   bool initialized_;
    356   bool track_mem_usage_snapshot_;
    357   const bool use_aggressive_shape_inference_;
    358 
    359   VirtualPlacer placer_;  // owned.
    360 };
    361 
    362 }  // namespace grappler
    363 }  // end namespace tensorflow
    364 
    365 #endif  // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
    366