1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ 18 19 #include <atomic> 20 #include <vector> 21 22 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 23 #include "tensorflow/core/common_runtime/device_set.h" 24 #include "tensorflow/core/common_runtime/graph_execution_state.h" 25 #include "tensorflow/core/common_runtime/stats_publisher_interface.h" 26 #include "tensorflow/core/distributed_runtime/call_options.h" 27 #include "tensorflow/core/distributed_runtime/master_env.h" 28 #include "tensorflow/core/distributed_runtime/message_wrappers.h" 29 #include "tensorflow/core/distributed_runtime/worker_cache.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/types.h" 32 #include "tensorflow/core/protobuf/master.pb.h" 33 #include "tensorflow/core/public/session_options.h" 34 35 namespace tensorflow { 36 37 class Device; 38 struct MasterEnv; 39 40 // A session encapsulates a graph computation (resource allocation, 41 // placement, execution, etc.). 42 class MasterSession : public core::RefCounted { 43 public: 44 // This session encapsulates the graph computation for a graph. 45 // 46 // The session places nodes on devices in "remote_devs" and executes 47 // operations on these devices. 48 // 49 // The caller takes ownership of all remote devices. 50 MasterSession( 51 const SessionOptions& options, const MasterEnv* env, 52 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs, 53 std::unique_ptr<WorkerCacheInterface> worker_cache, 54 std::unique_ptr<DeviceSet> device_set, 55 StatsPublisherFactory stats_publisher_factory); 56 57 // Initialize the MasterSession for "def". Must be called before Extend(), 58 // Run(), or Close(). 59 // 60 // After this method returns, `def` will no longer be valid. 61 Status Create(GraphDef* def, const WorkerCacheFactoryOptions& options); 62 63 // Returns the session handle. 64 const string& handle() const { return handle_; } 65 66 // Returns the last access time (the number of micro-seconds since 67 // some fixed point in time) of this session. 68 uint64 last_access_time_usec() const { return last_access_time_usec_.load(); } 69 70 // Attempt to extend the graph according to the given "req". 71 // (See master.proto for details of valid extensions.) 72 // 73 // PRECONDITION: The current version of this session's graph 74 // is "req->current_graph_version". 75 // 76 // POSTCONDITION: The current version of this session's graph 77 // is "resp->new_graph_version". 78 // 79 // Extend() may block the caller thread for a long time. 80 Status Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp); 81 82 // Setup a partial run call. 83 Status PartialRunSetup(const PartialRunSetupRequest* req, 84 PartialRunSetupResponse* resp); 85 86 // Run one step. 87 Status Run(CallOptions* opts, const RunStepRequestWrapper& req, 88 MutableRunStepResponseWrapper* resp); 89 90 Status ListDevices(ListDevicesResponse* resp) const; 91 92 // Close this session and delete "*this". Returns OK if all known 93 // states are cleanup successfully. 94 // 95 // Close() may block the caller thread for a long time. 96 Status Close(); 97 98 // Close this session and release a reference on "*this". 99 // 100 // Note that, unlike Close(), this method does not block on the 101 // completion of all work. 102 void GarbageCollect(); 103 104 private: 105 SessionOptions session_opts_; 106 107 // Not owned. 108 const MasterEnv* env_; 109 110 // The opaque session handle. 111 const string handle_; 112 113 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_; 114 115 // The optional session-specific worker cluster. 116 // TODO(saeta): Convert to std::optional when available. 117 const std::unique_ptr<WorkerCacheInterface> worker_cache_; 118 // Retrieves either worker_cache_ or the env_->worker_cache as appropriate. 119 WorkerCacheInterface* get_worker_cache() const; 120 121 // The device set used by this session. 122 std::unique_ptr<DeviceSet> devices_; 123 124 StatsPublisherFactory stats_publisher_factory_; 125 126 std::atomic_ulong last_access_time_usec_; 127 128 std::atomic<int64> partial_run_handle_counter_ = {0}; 129 130 mutex mu_; 131 std::unique_ptr<GraphExecutionState> execution_state_ GUARDED_BY(mu_); 132 int64 graph_version_; 133 134 // We keep a map from a signature of a run request to the 135 // ReffedClientGraph the can execute it. We keep up to one old copy 136 // of each ReffedClientGraph around because if it gets deallocated 137 // before a new substitute has been created, Variables can go out of 138 // scope and lose their state. 139 class ReffedClientGraph; 140 typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap; 141 RCGMap run_graphs_ GUARDED_BY(mu_); 142 RCGMap partial_run_graphs_ GUARDED_BY(mu_); 143 144 struct PerStepState { 145 bool collect_costs = false; 146 bool collect_timeline = false; 147 bool collect_rpcs = false; 148 bool collect_partition_graphs = false; 149 bool report_tensor_allocations_upon_oom = false; 150 Microseconds start_micros = Microseconds(0); 151 Microseconds end_micros = Microseconds(0); 152 std::vector<StepStats> step_stats; // per partition 153 StepStats rpc_stats; // for RPC layer 154 CostGraphDef cost_graph; 155 }; 156 157 struct RunState { 158 std::unordered_map<string, bool> pending_inputs; // true if fed 159 std::unordered_map<string, bool> pending_outputs; // true if fetched 160 ReffedClientGraph* rcg = nullptr; 161 uint64 step_id; 162 int64 count = 0; 163 PerStepState pss; 164 std::unique_ptr<ProfileHandler> ph; 165 bool step_started = false; 166 167 RunState(const std::vector<string>& input_names, 168 const std::vector<string>& output_names, ReffedClientGraph* rcg, 169 const uint64 step_id, const int64 count); 170 171 bool PendingDone() const; 172 173 ~RunState(); 174 }; 175 std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_ 176 GUARDED_BY(mu_); 177 178 // Active RunStep calls. 179 condition_variable num_running_is_zero_; 180 int32 num_running_ GUARDED_BY(mu_) = 0; 181 182 bool closed_ GUARDED_BY(mu_) = false; 183 bool garbage_collected_ GUARDED_BY(mu_) = false; 184 185 std::unordered_map<uint64, int64> subgraph_execution_counts_ GUARDED_BY(mu_); 186 187 // We need to ensure that certain nodes added (e.g., send and recv 188 // nodes) are unique across all sub-graphs within this session. 189 int64 next_node_id_ GUARDED_BY(mu_) = 0; 190 191 // Used to cancel running steps on Close(). 192 CancellationManager cancellation_manager_; 193 194 // Private dtor. The client must call Close(). 195 virtual ~MasterSession(); 196 197 // Creates sessions on all workers. 198 // 199 // If this session is operating using the new ClusterSpec propagation behavior 200 // call this method in order to propagate the cluster membership to all 201 // workers. 202 Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def); 203 204 // TODO(b/36574172): Always use Create/DeleteWorkerSession. 205 bool should_delete_worker_sessions_ = false; 206 Status DeleteWorkerSessions(); 207 208 Status StartStep(const BuildGraphOptions& opts, int64* count, 209 ReffedClientGraph** graph, bool is_partial); 210 void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref, 211 RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_); 212 Status DoRunWithLocalExecution(CallOptions* opts, 213 const RunStepRequestWrapper& req, 214 MutableRunStepResponseWrapper* resp); 215 Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, 216 MutableRunStepResponseWrapper* resp); 217 void MarkRunCompletion(); 218 void UpdateLastAccessTime(); 219 220 Status BuildAndRegisterPartitions(ReffedClientGraph* rcg); 221 222 Status CreateDebuggerState( 223 const DebugOptions& debug_options, const RunStepRequestWrapper& req, 224 int64 rcg_execution_count, 225 std::unique_ptr<DebuggerStateInterface>* debugger_state); 226 227 TF_DISALLOW_COPY_AND_ASSIGN(MasterSession); 228 }; 229 230 } // end namespace tensorflow 231 232 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ 233