1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ 18 19 #include <unordered_map> 20 #include <vector> 21 22 #include "tensorflow/core/common_runtime/costmodel_manager.h" 23 #include "tensorflow/core/common_runtime/executor.h" 24 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 25 #include "tensorflow/core/distributed_runtime/message_wrappers.h" 26 #include "tensorflow/core/distributed_runtime/worker_env.h" 27 #include "tensorflow/core/framework/cancellation.h" 28 #include "tensorflow/core/framework/cost_graph.pb.h" 29 #include "tensorflow/core/framework/function.h" 30 #include "tensorflow/core/lib/core/refcount.h" 31 #include "tensorflow/core/platform/logging.h" 32 #include "tensorflow/core/platform/macros.h" 33 #include "tensorflow/core/platform/mutex.h" 34 #include "tensorflow/core/platform/types.h" 35 #include "tensorflow/core/protobuf/config.pb.h" 36 #include "tensorflow/core/protobuf/debug.pb.h" 37 #include "tensorflow/core/protobuf/worker.pb.h" 38 39 namespace tensorflow { 40 41 class ExecutorOpts; 42 class StepStatsCollector; 43 class RendezvousMgrInterface; 44 class DeviceMgr; 45 struct WorkerSession; 46 47 // GraphMgr keeps track of a set of graphs that are registered with a 48 // TensorFlow worker. Each registered graph is identified by a handle 49 // that is generated by GraphMgr and returned to the caller. 50 // 51 // After a successful registration, the caller executes a graph using 52 // the graph handle. Each execution is distinguished from others by a 53 // caller generated global unique id "step_id". Multiple executions 54 // can use the same graph concurrently and independently as long as 55 // "step_id" used are different. 56 // 57 // Multiple threads can call GraphMgr methods concurrently. 58 // 59 // E.g., 60 // GraphMgr gmgr(worker_env); 61 // string handle; 62 // TF_CHECK_OK(gmgr.Register("session", { graph computes c = a + b }, 63 // &handle)); 64 // GraphMgr::NamedTensors in = { { "a", Tensor({1, 2}) }, 65 // { "b", Tensor({3, 4}) } }; 66 // GraphMgr::NamedTensors out = { { "c", Tensor() } }; 67 // TF_CHECK_OK(gmgr.Execute(handle, 0x0001, in, &out)); 68 // EXPECT_EQ(out["c"], Tensor({4, 6})); 69 class GraphMgr { 70 public: 71 explicit GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr); 72 ~GraphMgr(); 73 74 // Registers a graph. Fills in "handle". The registered graph retains a 75 // reference to cluster_flr to do cross process function calls. 76 Status Register(const string& session, const GraphDef& gdef, 77 const GraphOptions& graph_options, 78 const DebugOptions& debug_options, 79 DistributedFunctionLibraryRuntime* cluster_flr, 80 string* handle); 81 82 // Executes one step of a registered graph "handle". 83 // 84 // If "out" is not nullptr, "out" specifies all keys the execution 85 // should receive upon finish. 86 typedef std::map<string, Tensor> NamedTensors; 87 typedef std::function<void(const Status&)> StatusCallback; 88 void ExecuteAsync(const string& handle, const int64 step_id, 89 WorkerSession* session, const ExecutorOpts& opts, 90 StepStatsCollector* collector, 91 MutableRunGraphResponseWrapper* response, 92 CancellationManager* cancellation_manager, 93 const NamedTensors& in, StatusCallback done); 94 95 Status SendInputs(const int64 step_id, const NamedTensors& in); 96 Status RecvOutputs(const int64 step_id, NamedTensors* out); 97 void RecvOutputsAsync(const int64 step_id, NamedTensors* out, 98 StatusCallback done); 99 100 // Deregisters a graph. 101 Status Deregister(const string& handle); 102 103 // Deregister all graphs. 104 Status DeregisterAll(); 105 106 private: 107 typedef GraphMgr ME; 108 109 struct ExecutionUnit { 110 Graph* graph = nullptr; // not owned. 111 Device* device = nullptr; // not owned. 112 Executor* root = nullptr; // not owned. 113 FunctionLibraryRuntime* lib = nullptr; // not owned. 114 // Build the cost model if this value is strictly positive. 115 int64 build_cost_model = 0; 116 }; 117 118 struct Item : public core::RefCounted { 119 // TODO(zhifengc): Keeps a copy of the original graph if the need arises. 120 // TODO(zhifengc): Stats, updated by multiple runs potentially. 121 // TODO(zhifengc): Dup-detection. Ensure step_id only run once. 122 ~Item() override; 123 124 // Session handle. 125 string session; 126 127 // Graph handle. 128 string handle; 129 130 std::unique_ptr<FunctionLibraryDefinition> lib_def; 131 // Owns the FunctionLibraryRuntime objects needed to execute functions, one 132 // per device. 133 std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr; 134 // A graph is partitioned over multiple devices. Each partition 135 // has a root executor which may call into the runtime library. 136 std::vector<ExecutionUnit> units; 137 138 // Used to deregister a cost model when cost model is required in graph 139 // manager. 140 GraphMgr* graph_mgr; 141 }; 142 143 const WorkerEnv* worker_env_; // Not owned. 144 DeviceMgr* device_mgr_; 145 146 CostModelManager cost_model_manager_; 147 148 // Owned. 149 mutex mu_; 150 int64 next_id_ GUARDED_BY(mu_) = 0; 151 152 // If true, blocks until device has finished all queued operations in a step. 153 bool sync_on_finish_ = true; 154 155 // Table mapping graph handles to registered graphs. 156 // 157 // TODO(zhifengc): If the client does not call Deregister, we'll 158 // lose memory over time. We should implement a timeout-based 159 // mechanism to gc these graphs. 160 std::unordered_map<string, Item*> table_; 161 162 void StartParallelExecutors(const string& handle, int64 step_id, Item* item, 163 Rendezvous* rendezvous, 164 StepStatsCollector* collector, 165 CostGraphDef* cost_graph, 166 CancellationManager* cancellation_manager, 167 StatusCallback done); 168 169 // Don't attempt to process cost models unless explicitly requested for at 170 // least one of the items. 171 bool skip_cost_models_ = true; 172 173 void BuildCostModel(Item* item, StepStatsCollector* collector, 174 CostGraphDef* cost_graph); 175 176 Status InitItem(const string& session, const GraphDef& gdef, 177 const GraphOptions& graph_options, 178 const DebugOptions& debug_options, 179 DistributedFunctionLibraryRuntime* cluster_flr, Item* item); 180 181 Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options, 182 Graph* graph, Device* device); 183 184 TF_DISALLOW_COPY_AND_ASSIGN(GraphMgr); 185 }; 186 187 } // end namespace tensorflow 188 189 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ 190