Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
     17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
     18 
     19 #include <unordered_map>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/common_runtime/costmodel_manager.h"
     23 #include "tensorflow/core/common_runtime/executor.h"
     24 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
     25 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
     26 #include "tensorflow/core/distributed_runtime/worker_env.h"
     27 #include "tensorflow/core/framework/cancellation.h"
     28 #include "tensorflow/core/framework/cost_graph.pb.h"
     29 #include "tensorflow/core/framework/function.h"
     30 #include "tensorflow/core/lib/core/refcount.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 #include "tensorflow/core/platform/macros.h"
     33 #include "tensorflow/core/platform/mutex.h"
     34 #include "tensorflow/core/platform/types.h"
     35 #include "tensorflow/core/protobuf/config.pb.h"
     36 #include "tensorflow/core/protobuf/debug.pb.h"
     37 #include "tensorflow/core/protobuf/worker.pb.h"
     38 
     39 namespace tensorflow {
     40 
     41 class ExecutorOpts;
     42 class StepStatsCollector;
     43 class RendezvousMgrInterface;
     44 class DeviceMgr;
     45 struct WorkerSession;
     46 
     47 // GraphMgr keeps track of a set of graphs that are registered with a
     48 // TensorFlow worker. Each registered graph is identified by a handle
     49 // that is generated by GraphMgr and returned to the caller.
     50 //
     51 // After a successful registration, the caller executes a graph using
     52 // the graph handle. Each execution is distinguished from others by a
     53 // caller generated global unique id "step_id". Multiple executions
     54 // can use the same graph concurrently and independently as long as
     55 // "step_id" used are different.
     56 //
     57 // Multiple threads can call GraphMgr methods concurrently.
     58 //
     59 // E.g.,
     60 //   GraphMgr gmgr(worker_env);
     61 //   string handle;
     62 //   TF_CHECK_OK(gmgr.Register("session", { graph computes c = a + b },
     63 //   &handle));
     64 //   GraphMgr::NamedTensors in = { { "a", Tensor({1, 2}) },
     65 //                                { "b", Tensor({3, 4}) } };
     66 //   GraphMgr::NamedTensors out = { { "c", Tensor() } };
     67 //   TF_CHECK_OK(gmgr.Execute(handle, 0x0001, in, &out));
     68 //   EXPECT_EQ(out["c"], Tensor({4, 6}));
     69 class GraphMgr {
     70  public:
     71   explicit GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr);
     72   ~GraphMgr();
     73 
     74   // Registers a graph. Fills in "handle". The registered graph retains a
     75   // reference to cluster_flr to do cross process function calls.
     76   Status Register(const string& session, const GraphDef& gdef,
     77                   const GraphOptions& graph_options,
     78                   const DebugOptions& debug_options,
     79                   DistributedFunctionLibraryRuntime* cluster_flr,
     80                   string* handle);
     81 
     82   // Executes one step of a registered graph "handle".
     83   //
     84   // If "out" is not nullptr, "out" specifies all keys the execution
     85   // should receive upon finish.
     86   typedef std::map<string, Tensor> NamedTensors;
     87   typedef std::function<void(const Status&)> StatusCallback;
     88   void ExecuteAsync(const string& handle, const int64 step_id,
     89                     WorkerSession* session, const ExecutorOpts& opts,
     90                     StepStatsCollector* collector,
     91                     MutableRunGraphResponseWrapper* response,
     92                     CancellationManager* cancellation_manager,
     93                     const NamedTensors& in, StatusCallback done);
     94 
     95   Status SendInputs(const int64 step_id, const NamedTensors& in);
     96   Status RecvOutputs(const int64 step_id, NamedTensors* out);
     97   void RecvOutputsAsync(const int64 step_id, NamedTensors* out,
     98                         StatusCallback done);
     99 
    100   // Deregisters a graph.
    101   Status Deregister(const string& handle);
    102 
    103   // Deregister all graphs.
    104   Status DeregisterAll();
    105 
    106  private:
    107   typedef GraphMgr ME;
    108 
    109   struct ExecutionUnit {
    110     Graph* graph = nullptr;                 // not owned.
    111     Device* device = nullptr;               // not owned.
    112     Executor* root = nullptr;               // not owned.
    113     FunctionLibraryRuntime* lib = nullptr;  // not owned.
    114     // Build the cost model if this value is strictly positive.
    115     int64 build_cost_model = 0;
    116   };
    117 
    118   struct Item : public core::RefCounted {
    119     // TODO(zhifengc): Keeps a copy of the original graph if the need arises.
    120     // TODO(zhifengc): Stats, updated by multiple runs potentially.
    121     // TODO(zhifengc): Dup-detection. Ensure step_id only run once.
    122     ~Item() override;
    123 
    124     // Session handle.
    125     string session;
    126 
    127     // Graph handle.
    128     string handle;
    129 
    130     std::unique_ptr<FunctionLibraryDefinition> lib_def;
    131     // Owns the FunctionLibraryRuntime objects needed to execute functions, one
    132     // per device.
    133     std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr;
    134     // A graph is partitioned over multiple devices.  Each partition
    135     // has a root executor which may call into the runtime library.
    136     std::vector<ExecutionUnit> units;
    137 
    138     // Used to deregister a cost model when cost model is required in graph
    139     // manager.
    140     GraphMgr* graph_mgr;
    141   };
    142 
    143   const WorkerEnv* worker_env_;  // Not owned.
    144   DeviceMgr* device_mgr_;
    145 
    146   CostModelManager cost_model_manager_;
    147 
    148   // Owned.
    149   mutex mu_;
    150   int64 next_id_ GUARDED_BY(mu_) = 0;
    151 
    152   // If true, blocks until device has finished all queued operations in a step.
    153   bool sync_on_finish_ = true;
    154 
    155   // Table mapping graph handles to registered graphs.
    156   //
    157   // TODO(zhifengc): If the client does not call Deregister, we'll
    158   // lose memory over time. We should implement a timeout-based
    159   // mechanism to gc these graphs.
    160   std::unordered_map<string, Item*> table_;
    161 
    162   void StartParallelExecutors(const string& handle, int64 step_id, Item* item,
    163                               Rendezvous* rendezvous,
    164                               StepStatsCollector* collector,
    165                               CostGraphDef* cost_graph,
    166                               CancellationManager* cancellation_manager,
    167                               StatusCallback done);
    168 
    169   // Don't attempt to process cost models unless explicitly requested for at
    170   // least one of the items.
    171   bool skip_cost_models_ = true;
    172 
    173   void BuildCostModel(Item* item, StepStatsCollector* collector,
    174                       CostGraphDef* cost_graph);
    175 
    176   Status InitItem(const string& session, const GraphDef& gdef,
    177                   const GraphOptions& graph_options,
    178                   const DebugOptions& debug_options,
    179                   DistributedFunctionLibraryRuntime* cluster_flr, Item* item);
    180 
    181   Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options,
    182                                          Graph* graph, Device* device);
    183 
    184   TF_DISALLOW_COPY_AND_ASSIGN(GraphMgr);
    185 };
    186 
    187 }  // end namespace tensorflow
    188 
    189 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
    190