1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ 18 19 #include <list> 20 #include <memory> 21 #include <unordered_map> 22 #include <unordered_set> 23 24 #include "tensorflow/core/framework/node_def.pb.h" 25 #include "tensorflow/core/framework/step_stats.pb.h" 26 #include "tensorflow/core/grappler/costs/cost_estimator.h" 27 #include "tensorflow/core/grappler/costs/graph_properties.h" 28 #include "tensorflow/core/grappler/costs/op_context.h" 29 #include "tensorflow/core/grappler/costs/virtual_placer.h" 30 #include "tensorflow/core/grappler/grappler_item.h" 31 32 namespace tensorflow { 33 namespace grappler { 34 35 struct NodeState { 36 // A node (i.e., an op) takes a set of input:port pairs and produces 37 // a set of output ports. 38 39 // Cross references to input and output nodes from graphdef. 40 std::vector<std::pair<const NodeDef*, int>> inputs; // Input, port pairs. 41 // List of output nodes (a list of nodes that takes this output port as input) 42 // keyed by port_num. Note that port_num -1 is used for control dependency. 43 std::unordered_map<int, std::vector<const NodeDef*>> outputs; 44 45 // Info from GraphProperties. 46 std::vector<OpInfo::TensorProperties> input_properties; 47 std::vector<OpInfo::TensorProperties> output_properties; 48 49 // Canonical device name used within VirtualScheduler. 50 string device_name; 51 52 // States updated as scheduling nodes. 53 int num_inputs_ready; 54 std::unordered_map<int, int> num_outputs_executed; 55 Costs::Duration time_ready; 56 Costs::Duration time_scheduled; 57 Costs::Duration time_finished; 58 // Time that all the consumers are executed (hence, no need to keep this 59 // output in memory), keyed by port_num. 60 std::unordered_map<int, Costs::Duration> time_no_references; 61 62 // Note that a node may have multiple output ports. The length of outputs, 63 // num_outputs_executed, and time_no_references should be 64 // identical when a NodeState is fully initialized. 65 // They should be 1 + output_properties.size() as we add [-1] for control 66 // dependency. 67 68 // Node will be ready to be executed at time_ready, scheduled at 69 // time_scheduled, and finishes execution at time_finished. 70 // Each output port uses up memory space from time_scheduled to its 71 // time_no_references. 72 73 NodeState() { 74 num_inputs_ready = 0; 75 time_ready = Costs::Duration::max(); 76 time_scheduled = Costs::Duration::max(); 77 time_finished = Costs::Duration::max(); 78 // Note that num_outputs_executed and time_no_references are not initialized 79 // here, since we don't know the size (i.e., # outputs for this node). 80 } 81 }; 82 83 struct DeviceState { 84 // Nodes executed on this device in execution order. 85 std::vector<const NodeDef*> nodes_executed; 86 87 struct NodePairHash { 88 public: 89 const std::size_t operator()( 90 const std::pair<const NodeDef*, int>& element) const { 91 return std::hash<const NodeDef*>()(element.first); 92 } 93 }; 94 95 // Nodes currently allocated in memory: set of NodeDef* and port_num pairs 96 // so that we can track which output of the node is in memory. 97 std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash> 98 nodes_in_memory; 99 100 // Nodes allocated in memory persistently: e.g., Variables. 101 std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash> 102 persistent_nodes; 103 104 // Snapshot of nodes_in_memory, when memory usage is at peak. 105 // Same to nodes_in_memory, it's a set of NodeDef* and port_num pairs. 106 std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash> 107 mem_usage_snapshot_at_peak; 108 109 Costs device_costs; 110 std::map<string, Costs> op_to_cost; // Per-op cost. 111 std::map<string, int64> op_to_memory; // Per-op memory usage at peak usage. 112 int64 memory_usage; 113 int64 max_memory_usage; 114 115 DeviceState() { 116 device_costs = Costs::ZeroCosts(); 117 memory_usage = 0; 118 max_memory_usage = 0; 119 } 120 121 Costs::Duration GetCurrTime() const { return device_costs.execution_time; } 122 }; 123 124 // ReadyNodeManager (abstract class): 125 // Keeps ready nodes and picks the best one to be scheduled. 126 class ReadyNodeManager { 127 public: 128 ReadyNodeManager() {} 129 virtual ~ReadyNodeManager() {} 130 virtual void Init( 131 const std::unordered_map<const NodeDef*, NodeState>* node_state) {} 132 virtual void AddNode(const NodeDef* node) = 0; 133 virtual const NodeDef* GetCurrNode() = 0; 134 virtual void RemoveCurrNode() = 0; 135 virtual bool Empty() const = 0; 136 }; 137 138 class FIFOManager : public ReadyNodeManager { 139 public: 140 FIFOManager() : ReadyNodeManager() {} 141 ~FIFOManager() override {} 142 void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state) 143 override {} 144 void AddNode(const NodeDef* node) override { nodes_.push_back(node); } 145 const NodeDef* GetCurrNode() override { 146 CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node"; 147 return nodes_.front(); 148 } 149 void RemoveCurrNode() override { nodes_.pop_front(); } 150 bool Empty() const override { return nodes_.empty(); } 151 152 private: 153 std::list<const NodeDef*> nodes_; 154 }; 155 156 // The LIFOManager schedules nodes by returning the last one added to the 157 // scheduler. A node is executed and then its ready outputs are newly added to 158 // the scheduler, so the LIFOManager will return outputs to a node following 159 // that node's execution. 160 class LIFOManager : public ReadyNodeManager { 161 public: 162 LIFOManager() : ReadyNodeManager() {} 163 ~LIFOManager() override {} 164 void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state) 165 override {} 166 void AddNode(const NodeDef* node) override { nodes_.push_back(node); } 167 const NodeDef* GetCurrNode() override; 168 void RemoveCurrNode() override; 169 bool Empty() const override { return nodes_.empty(); } 170 171 private: 172 std::list<const NodeDef*> nodes_; 173 // Keep track of the current node being executed by saving its position. 174 // Necessary because nodes may be added to the end of the list while a node is 175 // executing, and we want to remove the correct node (the one that is 176 // executing) rather than the new ones being added. 177 std::list<const NodeDef*>::iterator curr_pos_ = nodes_.end(); 178 }; 179 180 // FirstReadyManager picks a node with the minimum time_ready value. 181 // Behavior is unknown if there are more than one nodes with the minimum 182 // time_ready value (it depends on C++ STL push_heap and pop_heap). 183 class FirstReadyManager : public ReadyNodeManager { 184 public: 185 FirstReadyManager(); 186 void Init( 187 const std::unordered_map<const NodeDef*, NodeState>* node_state) override; 188 ~FirstReadyManager() override {} 189 void AddNode(const NodeDef* node) override { waiting_queue_.push_back(node); } 190 const NodeDef* GetCurrNode() override; 191 void RemoveCurrNode() override; 192 bool Empty() const override; 193 194 private: 195 // Move all the nodes in the waiting_queue_ to nodes_. 196 void DrainWaitingQueue(); 197 198 // nodes_ is the main queue, where we construct heap, and the front is the 199 // current node. 200 std::vector<const NodeDef*> nodes_; 201 // Newly added nodes are added to waiting_queue_. That way, GetCurrNode(), 202 // wihch returns the front of the nodes_, always returns the same node, 203 // even if any of new nodes has time_ready smaller than the current node's. 204 std::vector<const NodeDef*> waiting_queue_; 205 // Comparator functor for heap; stl heap is max heap, so we use "greater than" 206 // functor for keeping the smallest time_ready node at the front of heap. 207 std::function<bool(const NodeDef*, const NodeDef*)> greater_; 208 209 // NodeState structure from VirtualScheduler to get time_ready of ready nodes. 210 // Not owned by FirstReadyManager. 211 const std::unordered_map<const NodeDef*, NodeState>* node_state_; 212 }; 213 214 // CompositeNodeManager has a few other NodeManagers: per-device LIFO for normal 215 // ops (neither _Send nor _Recv) and FirstyReadyManagers for _Send ops and _Recv 216 // ops, and then it chooses FirstReady among the ops chosen from each 217 // internal NodeManagers. The objective is to maximize producer-consumer 218 // locality within device, while processing nodes across devices, including 219 // _Send and _Recv, fairly, in terms of their time_ready. 220 class CompositeNodeManager : public ReadyNodeManager { 221 public: 222 CompositeNodeManager(); 223 ~CompositeNodeManager() override {} 224 225 void Init( 226 const std::unordered_map<const NodeDef*, NodeState>* node_state) override; 227 void AddNode(const NodeDef* node) override; 228 const NodeDef* GetCurrNode() override; 229 void RemoveCurrNode() override; 230 bool Empty() const override; 231 232 private: 233 // Internal ready node managers: 234 // LIFO for normal ops to maximize producer consumer locality. 235 // One LIFO per device. 236 std::unordered_map<string, LIFOManager> ops_lifo_map_; 237 // FirstReady for send and recv. Handle send and recv separately ensures that 238 // send and recv do not block previously read ops with LIFO schedule. 239 FirstReadyManager send_manager_; 240 FirstReadyManager recv_manager_; 241 242 // NodeState structure from VirtualScheduler to get time_ready of ready nodes. 243 // Not owned by FirstReadyManager. 244 const std::unordered_map<const NodeDef*, NodeState>* node_state_; 245 246 // Cached curr node. Set back to nullptr from RemoveCurrNode(). 247 const NodeDef* curr_node_; 248 }; 249 250 // The virtual scheduler emulates execution of nodes in a graph, considering 251 // dependencies, device, etc. 252 class VirtualScheduler { 253 public: 254 VirtualScheduler(const GrapplerItem* grappler_item, 255 const bool use_static_shapes, Cluster* cluster, 256 ReadyNodeManager* ready_nodes); 257 // Initializes NodeState and DeviceState from grappler_item_ and 258 // graph_properties_. 259 Status Init(); 260 261 OpContext GetCurrNode() const; 262 263 // Returns true if there is any node to be scheduled. 264 bool MarkCurrNodeExecuted(const Costs& node_costs); 265 266 // Prints out summary of execution (timing, memory usage, etc.) 267 Costs Summary() const; 268 // Like the above, but writes detailed stats to RunMetadata. 269 // If metadata is nullptr, then just calls and return Summary(). 270 Costs Summary(RunMetadata* metadata); 271 // Methods called from constructor. 272 static ReadyNodeManager* ReadyNodeManagerFactory( 273 const string& ready_node_manager); 274 275 // Return per device peak memory usage. 276 const std::unordered_map<string, int64> GetPeakMemoryUsage() const; 277 278 protected: 279 const std::unordered_map<string, DeviceState>* GetDeviceStates() const { 280 return &device_; 281 } 282 const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const { 283 return &node_map_; 284 } 285 286 // Returns the size of output at port_num (unit: bytes). A special case is 287 // port_num -1, which is for control dependency and assumed to be 4 bytes. 288 int64 CalculateOutputSize( 289 const std::vector<OpInfo::TensorProperties>& output_properties, 290 const int port_num) const; 291 292 private: 293 // Constants. 294 const string kAttrInputSrc = "input_source_"; 295 const string kAttrSrcDevice = "src_device_"; 296 const string kAttrDstDevice = "dst_device_"; 297 const string kChannelDevice = "Channel"; 298 299 // Methods called from Init(). Fails if initialize_ is set. 300 void MaybeUpdateInputOutput(const NodeDef* node); 301 NodeState& GetNodeStateOrCreateIt(const NodeDef* node); 302 std::pair<const NodeDef*, const NodeDef*> CreateSendRecv( 303 const NodeDef* from, const NodeDef* to, const string& input_name); 304 string DeviceName(const NodeDef* node) const; 305 string SanitizedDeviceName(const NodeDef* node) const; 306 string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const; 307 308 // Helper methods. 309 Costs& FindOrCreateZero(const string& op_name, 310 std::map<string, Costs>* op_cost); 311 float Round2(const float x) const; 312 bool IsPersistentNode(const NodeDef* node) const; 313 314 // Scheduler states: 315 ReadyNodeManager* ready_nodes_; // Not owned. 316 std::unordered_map<const NodeDef*, NodeState> node_map_; 317 std::unordered_map<string, DeviceState> device_; 318 319 // Pool of NodeDefs for SendRecv and Identity ops created. 320 std::vector<std::unique_ptr<NodeDef>> additional_nodes_; 321 322 // Stats: 323 std::map<string, int> op_counts_; // Op counts with key with input shape. 324 // Individual op costs (with input shapes). 325 // Boolean field for whether the cost is accurate. 326 std::map<string, std::pair<int, bool>> op_costs_; 327 328 Costs graph_costs_; // Graph cost. 329 std::map<string, Costs> op_to_cost_; // Per-op cost. 330 331 // Auxilliary data structures for constructing NodeState and DeviceState. 332 GraphProperties graph_properties_; 333 Cluster* cluster_; // Not owned. 334 335 const GrapplerItem* grappler_item_; // Not owned. 336 bool use_static_shapes_; 337 bool initialized_; 338 339 VirtualPlacer placer_; // owned. 340 }; 341 342 } // namespace grappler 343 } // end namespace tensorflow 344 345 #endif // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ 346