1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMMON_RUNTIME_DIRECT_SESSION_H_ 17 #define TENSORFLOW_COMMON_RUNTIME_DIRECT_SESSION_H_ 18 19 #include <atomic> 20 #include <memory> 21 #include <string> 22 #include <unordered_map> 23 #include <unordered_set> 24 #include <vector> 25 26 #include "tensorflow/core/common_runtime/costmodel_manager.h" 27 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 28 #include "tensorflow/core/common_runtime/device_mgr.h" 29 #include "tensorflow/core/common_runtime/device_set.h" 30 #include "tensorflow/core/common_runtime/executor.h" 31 #include "tensorflow/core/common_runtime/graph_execution_state.h" 32 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 33 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 34 #include "tensorflow/core/common_runtime/session_factory.h" 35 #include "tensorflow/core/framework/cancellation.h" 36 #include "tensorflow/core/framework/graph.pb.h" 37 #include "tensorflow/core/framework/session_state.h" 38 #include "tensorflow/core/framework/tensor.h" 39 #include "tensorflow/core/lib/core/errors.h" 40 #include "tensorflow/core/lib/core/status.h" 41 #include "tensorflow/core/platform/macros.h" 42 #include "tensorflow/core/platform/mutex.h" 43 #include "tensorflow/core/platform/thread_annotations.h" 44 #include "tensorflow/core/platform/types.h" 45 #include "tensorflow/core/public/session.h" 46 47 namespace tensorflow { 48 49 class CostModel; 50 class DebugGateway; 51 class Device; 52 class DirectSessionFactory; 53 54 class DirectSession : public Session { 55 public: 56 typedef std::function<void(Session*)> CloseCallback; 57 58 // Takes ownership of 'device_mgr'. 59 // 'factory' is used to unregister the DirectSession with 'factory' when its 60 // closed. This ensures that Reset requests from the 'factory' don't get sent 61 // to sessions that are already closed. 62 DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr, 63 DirectSessionFactory* factory); 64 ~DirectSession() override; 65 66 typedef std::vector<std::pair<string, Tensor>> NamedTensorList; 67 typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameNodeMap; 68 69 ::tensorflow::Status Create(const GraphDef& graph) override; 70 ::tensorflow::Status Extend(const GraphDef& graph) override; 71 ::tensorflow::Status Run(const NamedTensorList& inputs, 72 const std::vector<string>& output_names, 73 const std::vector<string>& target_nodes, 74 std::vector<Tensor>* outputs) override; 75 76 // NOTE: Experimental and subject to change. 77 ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options, 78 const NamedTensorList& inputs, 79 const std::vector<string>& output_names, 80 const std::vector<string>& target_nodes, 81 std::vector<Tensor>* outputs, 82 RunMetadata* run_metadata) override; 83 84 // NOTE: PRunSetup and PRun are added to support partial execution. This 85 // feature is experimental and subject to change. 86 ::tensorflow::Status PRunSetup(const std::vector<string>& input_names, 87 const std::vector<string>& output_names, 88 const std::vector<string>& target_nodes, 89 string* handle) override; 90 ::tensorflow::Status PRun(const string& handle, const NamedTensorList& inputs, 91 const std::vector<string>& output_names, 92 std::vector<Tensor>* outputs) override; 93 94 // Reset clears 'containers' from the device_mgr of the DirectSession. 95 // If 'containers' is empty, then Reset clears the default container. 96 ::tensorflow::Status Reset(const std::vector<string>& containers); 97 98 ::tensorflow::Status ListDevices( 99 std::vector<DeviceAttributes>* response) override; 100 ::tensorflow::Status Close() override; 101 ::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override { 102 *output = device_mgr_.get(); 103 return ::tensorflow::Status::OK(); 104 } 105 106 void ExportCostModels(CostModelManager::CostModelMap* cost_models) { 107 cost_model_manager_.ExportCostModels(cost_models); 108 } 109 110 private: 111 // We create one executor and its dependent library runtime for 112 // every partition. 113 struct PerPartitionExecutorsAndLib { 114 Graph* graph = nullptr; // not owned. 115 Device* device = nullptr; // not owned. 116 FunctionLibraryRuntime* flib = nullptr; // not owned. 117 std::unique_ptr<Executor> executor; 118 }; 119 120 // An ExecutorsAndKeys is created for a given set of feeds/fetches. 121 // 'step_count' is the number of times this graph is executed. 122 // 'graph' is the entire graph being executed. 'name_to_node' 123 // maps node name to node. We keep 'graph' and 'name_to_node' only in 124 // the case of partial runs. Each item in 'items' is the executor for 125 // a partition of the graph bundled with its dependent library runtime. 126 // 'input_keys' are the rendezvous keys for the feeds and 'output_keys' 127 // are rendezvous keys for the fetches. 128 struct ExecutorsAndKeys { 129 ExecutorsAndKeys() : step_count(0) {} 130 131 std::atomic_int_fast64_t step_count; 132 std::unique_ptr<Graph> graph; 133 NameNodeMap name_to_node; 134 std::vector<PerPartitionExecutorsAndLib> items; 135 std::unordered_map<string, size_t> input_name_to_index; 136 std::unordered_map<string, string> input_name_to_rendezvous_key; 137 std::unordered_map<string, size_t> output_name_to_index; 138 std::unordered_map<string, string> output_name_to_rendezvous_key; 139 140 DataTypeVector input_types; 141 DataTypeVector output_types; 142 }; 143 144 // A FunctionInfo object is created for every unique set of feeds/fetches. 145 // This info could be folded into the ExecutorsAndKeys object but we would 146 // like to maintain a deletion order in which the OpKernels (owned by the 147 // executor) should be destroyed first, followed by the resources in the 148 // device and then followed by the function stuff. 149 // TODO(rohanj): Consolidate function library definitions so that we can 150 // instantiate only one ProcFLR and lib_def and make this just a member 151 // variable and not a vector. 152 // 'flib_def' is the function library used. 153 // 'proc_flr' is the collection of FunctionLibraryRuntime objects, one per 154 // device. 155 struct FunctionInfo { 156 std::unique_ptr<FunctionLibraryDefinition> flib_def; 157 std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr; 158 }; 159 160 // For each live partial execution, the session maintains a RunState. 161 // 'status' is the current status of this partial execution. 'executor_done' 162 // is "notified" when all executors are done. 'pending_inputs' are the set 163 // of pending feeds and 'pending_outputs' are the set of pending fetches. 164 struct RunState { 165 mutex mu_; 166 Status status GUARDED_BY(mu_); 167 IntraProcessRendezvous* rendez = nullptr; 168 std::unique_ptr<StepStatsCollector> collector; 169 Notification executors_done; 170 std::unordered_map<string, bool> pending_inputs; // true if fed 171 std::unordered_map<string, bool> pending_outputs; // true if fetched 172 TensorStore tensor_store; 173 ScopedStepContainer step_container; 174 175 RunState(int64 step_id, const std::vector<Device*>* devices); 176 177 RunState(const std::vector<string>& pending_input_names, 178 const std::vector<string>& pending_output_names, int64 step_id, 179 const std::vector<Device*>* devices); 180 181 // Returns true if all pending inputs and outputs have been completed. 182 bool PendingDone() const; 183 184 ~RunState(); 185 }; 186 187 struct RunStateArgs { 188 RunStateArgs(const DebugOptions& options) : debug_options(options) {} 189 190 bool is_partial_run = false; 191 string handle; 192 std::unique_ptr<Graph> graph; 193 const DebugOptions& debug_options; 194 }; 195 196 // Initializes the base execution state given the 'graph', 197 // if not already initialized. 198 Status MaybeInitializeExecutionState(const GraphDef& graph, 199 bool* out_already_initialized) 200 EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_); 201 202 // Retrieves an already existing set of executors to run 'inputs' and 203 // 'outputs', or creates and caches them for future use. 204 ::tensorflow::Status GetOrCreateExecutors( 205 gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs, 206 gtl::ArraySlice<string> target_nodes, 207 ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args); 208 209 // Creates several graphs given the existing graph_def_ and the 210 // input feeds and fetches, given 'devices'. The graphs share a common 211 // function library 'flib_def'. 212 ::tensorflow::Status CreateGraphs( 213 const BuildGraphOptions& options, 214 std::unordered_map<string, std::unique_ptr<Graph>>* outputs, 215 std::unique_ptr<FunctionLibraryDefinition>* flib_def, 216 RunStateArgs* run_state_args, DataTypeVector* input_types, 217 DataTypeVector* output_types); 218 219 ::tensorflow::Status ExtendLocked(const GraphDef& graph) 220 EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_); 221 222 ::tensorflow::Status ResourceHandleToInputTensor( 223 const Tensor& resource_tensor, Tensor* retrieved_tensor); 224 225 // Feeds more inputs to the executors, triggering further execution. 226 ::tensorflow::Status SendPRunInputs( 227 const std::vector<std::pair<string, Tensor>>& inputs, 228 const ExecutorsAndKeys* executors_and_keys, 229 IntraProcessRendezvous* rendez); 230 231 // Fetches more outputs from the executors. It waits until the output 232 // tensors are computed. 233 ::tensorflow::Status RecvPRunOutputs( 234 const std::vector<string>& output_names, 235 const ExecutorsAndKeys* executors_and_keys, RunState* run_state, 236 std::vector<Tensor>* outputs); 237 238 // Check if the specified fetches can be computed from the feeds 239 // that we have already provided. 240 ::tensorflow::Status CheckFetch( 241 const std::vector<std::pair<string, Tensor>>& feeds, 242 const std::vector<string>& fetches, 243 const ExecutorsAndKeys* executors_and_keys, const RunState* run_state); 244 245 // Use the appropriate WaitForNotification function based on whether 246 // operation_timeout_in_ms is greater than 0. 247 // 248 // If the timeout expires, the `cm->StartCancel()` will be called. 249 ::tensorflow::Status WaitForNotification(Notification* n, 250 int64 timeout_in_ms); 251 void WaitForNotification(RunState* run_state, CancellationManager* cm, 252 int64 timeout_in_ms); 253 254 ::tensorflow::Status CheckNotClosed() { 255 mutex_lock l(closed_lock_); 256 if (closed_) return errors::Cancelled("Session has been closed."); 257 return ::tensorflow::Status::OK(); 258 } 259 260 ::tensorflow::Status CreateDebuggerState( 261 const DebugOptions& debug_options, int64 session_run_index, 262 int64 executor_step_index, const std::vector<string>& input_names, 263 const std::vector<string>& output_names, 264 const std::vector<string>& target_names, 265 std::unique_ptr<DebuggerStateInterface>* debugger_state); 266 267 ::tensorflow::Status DecorateAndPublishGraphForDebug( 268 const DebugOptions& debug_options, Graph* graph, Device* device); 269 270 const SessionOptions options_; 271 272 // Device structures. 273 const std::unique_ptr<const DeviceMgr> device_mgr_; 274 std::vector<Device*> devices_; // not owned 275 DeviceSet device_set_; 276 277 string session_handle_; 278 bool graph_created_ GUARDED_BY(graph_def_lock_) = false; 279 280 mutex graph_def_lock_; 281 GraphDef graph_def_ GUARDED_BY(graph_def_lock_); 282 283 // The thread-pools to use for running ops, with a bool indicating if the pool 284 // is owned. 285 std::vector<std::pair<thread::ThreadPool*, bool>> thread_pools_; 286 287 Status init_error_; // Set to an error if construction failed. 288 289 // If true, blocks until device has finished all queued operations in a step. 290 bool sync_on_finish_ = true; 291 // Schedules 'c' for execution on pool. 292 void SchedClosure(thread::ThreadPool* pool, std::function<void()> c); 293 294 std::vector<std::unique_ptr<FunctionInfo>> functions_ 295 GUARDED_BY(executor_lock_); 296 297 mutex executor_lock_; // protects executors_ 298 // Holds mappings from signature to the executors that process 299 // it. The reason for a level of indirection around mapped_type is 300 // to guarantee address stability. 301 // The map value is a shared_ptr since multiple map keys can point to the 302 // same ExecutorsAndKey object. 303 std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_ 304 GUARDED_BY(executor_lock_); 305 306 // Holds mappings from handle to partial run state. 307 std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_ 308 GUARDED_BY(executor_lock_); 309 310 // This holds all the tensors that are currently alive in the session. 311 SessionState session_state_; 312 313 DirectSessionFactory* const factory_; // not owned 314 CancellationManager* cancellation_manager_; 315 316 // Map of placed stateful nodes, i.e. nodes for which is_stateful() 317 // is true, such as "params" and "queue" nodes. Once placed these 318 // nodes can not be moved to a different device. Maps node names to 319 // device names. 320 std::unordered_map<string, string> stateful_placements_ 321 GUARDED_BY(graph_def_lock_); 322 323 // Execution_state; used when placing the entire graph. 324 std::unique_ptr<GraphExecutionState> execution_state_ 325 GUARDED_BY(graph_def_lock_); 326 327 // The function library, before any rewrites or optimizations have been 328 // performed. In particular, CreateGraphs() may need to modify the function 329 // library; it copies and modifies the function library. 330 std::unique_ptr<FunctionLibraryDefinition> flib_def_; 331 332 // true if the Session has been Closed. 333 mutex closed_lock_; 334 bool closed_ GUARDED_BY(closed_lock_) = false; 335 336 // For generating unique names for this session instance. 337 std::atomic<int64> edge_name_counter_ = {0}; 338 std::atomic<int64> handle_name_counter_ = {0}; 339 340 // For generating step ids that are unique across all sessions. 341 static std::atomic_int_fast64_t step_id_counter_; 342 343 // Global timeout for all blocking operations in this session. 344 const int64 operation_timeout_in_ms_ = 0; 345 346 // Manages all the cost models for the graphs executed in this session. 347 CostModelManager cost_model_manager_; 348 349 Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr; 350 351 TF_DISALLOW_COPY_AND_ASSIGN(DirectSession); 352 353 // EXPERIMENTAL: debugger (tfdbg) related 354 friend class DebugGateway; 355 }; 356 357 } // end namespace tensorflow 358 359 #endif // TENSORFLOW_COMMON_RUNTIME_DIRECT_SESSION_H_ 360