Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
     17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
     18 
     19 #include <atomic>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
     23 #include "tensorflow/core/common_runtime/device_set.h"
     24 #include "tensorflow/core/common_runtime/graph_execution_state.h"
     25 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
     26 #include "tensorflow/core/distributed_runtime/call_options.h"
     27 #include "tensorflow/core/distributed_runtime/master_env.h"
     28 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
     29 #include "tensorflow/core/distributed_runtime/worker_cache.h"
     30 #include "tensorflow/core/lib/core/status.h"
     31 #include "tensorflow/core/platform/types.h"
     32 #include "tensorflow/core/protobuf/master.pb.h"
     33 #include "tensorflow/core/public/session_options.h"
     34 
     35 namespace tensorflow {
     36 
     37 class Device;
     38 struct MasterEnv;
     39 
     40 // A session encapsulates a graph computation (resource allocation,
     41 // placement, execution, etc.).
     42 class MasterSession : public core::RefCounted {
     43  public:
     44   // This session encapsulates the graph computation for a graph.
     45   //
     46   // The session places nodes on devices in "remote_devs" and executes
     47   // operations on these devices.
     48   //
     49   // The caller takes ownership of all remote devices.
     50   MasterSession(
     51       const SessionOptions& options, const MasterEnv* env,
     52       std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
     53       std::unique_ptr<WorkerCacheInterface> worker_cache,
     54       std::unique_ptr<DeviceSet> device_set,
     55       StatsPublisherFactory stats_publisher_factory);
     56 
     57   // Initialize the MasterSession for "def".  Must be called before Extend(),
     58   // Run(), or Close().
     59   //
     60   // After this method returns, `def` will no longer be valid.
     61   Status Create(GraphDef* def, const WorkerCacheFactoryOptions& options);
     62 
     63   // Returns the session handle.
     64   const string& handle() const { return handle_; }
     65 
     66   // Returns the last access time (the number of micro-seconds since
     67   // some fixed point in time) of this session.
     68   uint64 last_access_time_usec() const { return last_access_time_usec_.load(); }
     69 
     70   // Attempt to extend the graph according to the given "req".
     71   // (See master.proto for details of valid extensions.)
     72   //
     73   // PRECONDITION: The current version of this session's graph
     74   //   is "req->current_graph_version".
     75   //
     76   // POSTCONDITION: The current version of this session's graph
     77   //   is "resp->new_graph_version".
     78   //
     79   // Extend() may block the caller thread for a long time.
     80   Status Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp);
     81 
     82   // Setup a partial run call.
     83   Status PartialRunSetup(const PartialRunSetupRequest* req,
     84                          PartialRunSetupResponse* resp);
     85 
     86   // Run one step.
     87   Status Run(CallOptions* opts, const RunStepRequestWrapper& req,
     88              MutableRunStepResponseWrapper* resp);
     89 
     90   Status ListDevices(ListDevicesResponse* resp) const;
     91 
     92   // Close this session and delete "*this". Returns OK if all known
     93   // states are cleanup successfully.
     94   //
     95   // Close() may block the caller thread for a long time.
     96   Status Close();
     97 
     98   // Close this session and release a reference on "*this".
     99   //
    100   // Note that, unlike Close(), this method does not block on the
    101   // completion of all work.
    102   void GarbageCollect();
    103 
    104  private:
    105   SessionOptions session_opts_;
    106 
    107   // Not owned.
    108   const MasterEnv* env_;
    109 
    110   // The opaque session handle.
    111   const string handle_;
    112 
    113   std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_;
    114 
    115   // The optional session-specific worker cluster.
    116   // TODO(saeta): Convert to std::optional when available.
    117   const std::unique_ptr<WorkerCacheInterface> worker_cache_;
    118   // Retrieves either worker_cache_ or the env_->worker_cache as appropriate.
    119   WorkerCacheInterface* get_worker_cache() const;
    120 
    121   // The device set used by this session.
    122   std::unique_ptr<DeviceSet> devices_;
    123 
    124   StatsPublisherFactory stats_publisher_factory_;
    125 
    126   std::atomic_ulong last_access_time_usec_;
    127 
    128   std::atomic<int64> partial_run_handle_counter_ = {0};
    129 
    130   mutex mu_;
    131   std::unique_ptr<GraphExecutionState> execution_state_ GUARDED_BY(mu_);
    132   int64 graph_version_;
    133 
    134   // We keep a map from a signature of a run request to the
    135   // ReffedClientGraph the can execute it.  We keep up to one old copy
    136   // of each ReffedClientGraph around because if it gets deallocated
    137   // before a new substitute has been created, Variables can go out of
    138   // scope and lose their state.
    139   class ReffedClientGraph;
    140   typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap;
    141   RCGMap run_graphs_ GUARDED_BY(mu_);
    142   RCGMap partial_run_graphs_ GUARDED_BY(mu_);
    143 
    144   struct PerStepState {
    145     bool collect_costs = false;
    146     bool collect_timeline = false;
    147     bool collect_rpcs = false;
    148     bool collect_partition_graphs = false;
    149     bool report_tensor_allocations_upon_oom = false;
    150     Microseconds start_micros = Microseconds(0);
    151     Microseconds end_micros = Microseconds(0);
    152     std::vector<StepStats> step_stats;  // per partition
    153     StepStats rpc_stats;                // for RPC layer
    154     CostGraphDef cost_graph;
    155   };
    156 
    157   struct RunState {
    158     std::unordered_map<string, bool> pending_inputs;   // true if fed
    159     std::unordered_map<string, bool> pending_outputs;  // true if fetched
    160     ReffedClientGraph* rcg = nullptr;
    161     uint64 step_id;
    162     int64 count = 0;
    163     PerStepState pss;
    164     std::unique_ptr<ProfileHandler> ph;
    165     bool step_started = false;
    166 
    167     RunState(const std::vector<string>& input_names,
    168              const std::vector<string>& output_names, ReffedClientGraph* rcg,
    169              const uint64 step_id, const int64 count);
    170 
    171     bool PendingDone() const;
    172 
    173     ~RunState();
    174   };
    175   std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
    176       GUARDED_BY(mu_);
    177 
    178   // Active RunStep calls.
    179   condition_variable num_running_is_zero_;
    180   int32 num_running_ GUARDED_BY(mu_) = 0;
    181 
    182   bool closed_ GUARDED_BY(mu_) = false;
    183   bool garbage_collected_ GUARDED_BY(mu_) = false;
    184 
    185   std::unordered_map<uint64, int64> subgraph_execution_counts_ GUARDED_BY(mu_);
    186 
    187   // We need to ensure that certain nodes added (e.g., send and recv
    188   // nodes) are unique across all sub-graphs within this session.
    189   int64 next_node_id_ GUARDED_BY(mu_) = 0;
    190 
    191   // Used to cancel running steps on Close().
    192   CancellationManager cancellation_manager_;
    193 
    194   // Private dtor. The client must call Close().
    195   virtual ~MasterSession();
    196 
    197   // Creates sessions on all workers.
    198   //
    199   // If this session is operating using the new ClusterSpec propagation behavior
    200   // call this method in order to propagate the cluster membership to all
    201   // workers.
    202   Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def);
    203 
    204   // TODO(b/36574172): Always use Create/DeleteWorkerSession.
    205   bool should_delete_worker_sessions_ = false;
    206   Status DeleteWorkerSessions();
    207 
    208   Status StartStep(const BuildGraphOptions& opts, int64* count,
    209                    ReffedClientGraph** graph, bool is_partial);
    210   void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
    211                       RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_);
    212   Status DoRunWithLocalExecution(CallOptions* opts,
    213                                  const RunStepRequestWrapper& req,
    214                                  MutableRunStepResponseWrapper* resp);
    215   Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req,
    216                       MutableRunStepResponseWrapper* resp);
    217   void MarkRunCompletion();
    218   void UpdateLastAccessTime();
    219 
    220   Status BuildAndRegisterPartitions(ReffedClientGraph* rcg);
    221 
    222   Status CreateDebuggerState(
    223       const DebugOptions& debug_options, const RunStepRequestWrapper& req,
    224       int64 rcg_execution_count,
    225       std::unique_ptr<DebuggerStateInterface>* debugger_state);
    226 
    227   TF_DISALLOW_COPY_AND_ASSIGN(MasterSession);
    228 };
    229 
    230 }  // end namespace tensorflow
    231 
    232 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
    233