Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMMON_RUNTIME_DIRECT_SESSION_H_
     17 #define TENSORFLOW_COMMON_RUNTIME_DIRECT_SESSION_H_
     18 
     19 #include <atomic>
     20 #include <memory>
     21 #include <string>
     22 #include <unordered_map>
     23 #include <unordered_set>
     24 #include <vector>
     25 
     26 #include "tensorflow/core/common_runtime/costmodel_manager.h"
     27 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
     28 #include "tensorflow/core/common_runtime/device_mgr.h"
     29 #include "tensorflow/core/common_runtime/device_set.h"
     30 #include "tensorflow/core/common_runtime/executor.h"
     31 #include "tensorflow/core/common_runtime/graph_execution_state.h"
     32 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
     33 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
     34 #include "tensorflow/core/common_runtime/session_factory.h"
     35 #include "tensorflow/core/framework/cancellation.h"
     36 #include "tensorflow/core/framework/graph.pb.h"
     37 #include "tensorflow/core/framework/session_state.h"
     38 #include "tensorflow/core/framework/tensor.h"
     39 #include "tensorflow/core/lib/core/errors.h"
     40 #include "tensorflow/core/lib/core/status.h"
     41 #include "tensorflow/core/platform/macros.h"
     42 #include "tensorflow/core/platform/mutex.h"
     43 #include "tensorflow/core/platform/thread_annotations.h"
     44 #include "tensorflow/core/platform/types.h"
     45 #include "tensorflow/core/public/session.h"
     46 
     47 namespace tensorflow {
     48 
     49 class CostModel;
     50 class DebugGateway;
     51 class Device;
     52 class DirectSessionFactory;
     53 
     54 class DirectSession : public Session {
     55  public:
     56   typedef std::function<void(Session*)> CloseCallback;
     57 
     58   // Takes ownership of 'device_mgr'.
     59   // 'factory' is used to unregister the DirectSession with 'factory' when its
     60   // closed. This ensures that Reset requests from the 'factory' don't get sent
     61   // to sessions that are already closed.
     62   DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr,
     63                 DirectSessionFactory* factory);
     64   ~DirectSession() override;
     65 
     66   typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
     67   typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameNodeMap;
     68 
     69   ::tensorflow::Status Create(const GraphDef& graph) override;
     70   ::tensorflow::Status Extend(const GraphDef& graph) override;
     71   ::tensorflow::Status Run(const NamedTensorList& inputs,
     72                            const std::vector<string>& output_names,
     73                            const std::vector<string>& target_nodes,
     74                            std::vector<Tensor>* outputs) override;
     75 
     76   // NOTE: Experimental and subject to change.
     77   ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options,
     78                            const NamedTensorList& inputs,
     79                            const std::vector<string>& output_names,
     80                            const std::vector<string>& target_nodes,
     81                            std::vector<Tensor>* outputs,
     82                            RunMetadata* run_metadata) override;
     83 
     84   // NOTE: PRunSetup and PRun are added to support partial execution. This
     85   // feature is experimental and subject to change.
     86   ::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
     87                                  const std::vector<string>& output_names,
     88                                  const std::vector<string>& target_nodes,
     89                                  string* handle) override;
     90   ::tensorflow::Status PRun(const string& handle, const NamedTensorList& inputs,
     91                             const std::vector<string>& output_names,
     92                             std::vector<Tensor>* outputs) override;
     93 
     94   // Reset clears 'containers' from the device_mgr of the DirectSession.
     95   // If 'containers' is empty, then Reset clears the default container.
     96   ::tensorflow::Status Reset(const std::vector<string>& containers);
     97 
     98   ::tensorflow::Status ListDevices(
     99       std::vector<DeviceAttributes>* response) override;
    100   ::tensorflow::Status Close() override;
    101   ::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override {
    102     *output = device_mgr_.get();
    103     return ::tensorflow::Status::OK();
    104   }
    105 
    106   void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
    107     cost_model_manager_.ExportCostModels(cost_models);
    108   }
    109 
    110  private:
    111   // We create one executor and its dependent library runtime for
    112   // every partition.
    113   struct PerPartitionExecutorsAndLib {
    114     Graph* graph = nullptr;                  // not owned.
    115     Device* device = nullptr;                // not owned.
    116     FunctionLibraryRuntime* flib = nullptr;  // not owned.
    117     std::unique_ptr<Executor> executor;
    118   };
    119 
    120   // An ExecutorsAndKeys is created for a given set of feeds/fetches.
    121   // 'step_count' is the number of times this graph is executed.
    122   // 'graph' is the entire graph being executed. 'name_to_node'
    123   // maps node name to node. We keep 'graph' and 'name_to_node' only in
    124   // the case of partial runs. Each item in 'items' is the executor for
    125   // a partition of the graph bundled with its dependent library runtime.
    126   // 'input_keys' are the rendezvous keys for the feeds and 'output_keys'
    127   // are rendezvous keys for the fetches.
    128   struct ExecutorsAndKeys {
    129     ExecutorsAndKeys() : step_count(0) {}
    130 
    131     std::atomic_int_fast64_t step_count;
    132     std::unique_ptr<Graph> graph;
    133     NameNodeMap name_to_node;
    134     std::vector<PerPartitionExecutorsAndLib> items;
    135     std::unordered_map<string, size_t> input_name_to_index;
    136     std::unordered_map<string, string> input_name_to_rendezvous_key;
    137     std::unordered_map<string, size_t> output_name_to_index;
    138     std::unordered_map<string, string> output_name_to_rendezvous_key;
    139 
    140     DataTypeVector input_types;
    141     DataTypeVector output_types;
    142   };
    143 
    144   // A FunctionInfo object is created for every unique set of feeds/fetches.
    145   // This info could be folded into the ExecutorsAndKeys object but we would
    146   // like to maintain a deletion order in which the OpKernels (owned by the
    147   // executor) should be destroyed first, followed by the resources in the
    148   // device and then followed by the function stuff.
    149   // TODO(rohanj): Consolidate function library definitions so that we can
    150   // instantiate only one ProcFLR and lib_def and make this just a member
    151   // variable and not a vector.
    152   // 'flib_def' is the function library used.
    153   // 'proc_flr' is the collection of FunctionLibraryRuntime objects, one per
    154   // device.
    155   struct FunctionInfo {
    156     std::unique_ptr<FunctionLibraryDefinition> flib_def;
    157     std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr;
    158   };
    159 
    160   // For each live partial execution, the session maintains a RunState.
    161   // 'status' is the current status of this partial execution. 'executor_done'
    162   // is "notified" when all executors are done. 'pending_inputs' are the set
    163   // of pending feeds and 'pending_outputs' are the set of pending fetches.
    164   struct RunState {
    165     mutex mu_;
    166     Status status GUARDED_BY(mu_);
    167     IntraProcessRendezvous* rendez = nullptr;
    168     std::unique_ptr<StepStatsCollector> collector;
    169     Notification executors_done;
    170     std::unordered_map<string, bool> pending_inputs;   // true if fed
    171     std::unordered_map<string, bool> pending_outputs;  // true if fetched
    172     TensorStore tensor_store;
    173     ScopedStepContainer step_container;
    174 
    175     RunState(int64 step_id, const std::vector<Device*>* devices);
    176 
    177     RunState(const std::vector<string>& pending_input_names,
    178              const std::vector<string>& pending_output_names, int64 step_id,
    179              const std::vector<Device*>* devices);
    180 
    181     // Returns true if all pending inputs and outputs have been completed.
    182     bool PendingDone() const;
    183 
    184     ~RunState();
    185   };
    186 
    187   struct RunStateArgs {
    188     RunStateArgs(const DebugOptions& options) : debug_options(options) {}
    189 
    190     bool is_partial_run = false;
    191     string handle;
    192     std::unique_ptr<Graph> graph;
    193     const DebugOptions& debug_options;
    194   };
    195 
    196   // Initializes the base execution state given the 'graph',
    197   // if not already initialized.
    198   Status MaybeInitializeExecutionState(const GraphDef& graph,
    199                                        bool* out_already_initialized)
    200       EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
    201 
    202   // Retrieves an already existing set of executors to run 'inputs' and
    203   // 'outputs', or creates and caches them for future use.
    204   ::tensorflow::Status GetOrCreateExecutors(
    205       gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
    206       gtl::ArraySlice<string> target_nodes,
    207       ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args);
    208 
    209   // Creates several graphs given the existing graph_def_ and the
    210   // input feeds and fetches, given 'devices'. The graphs share a common
    211   // function library 'flib_def'.
    212   ::tensorflow::Status CreateGraphs(
    213       const BuildGraphOptions& options,
    214       std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
    215       std::unique_ptr<FunctionLibraryDefinition>* flib_def,
    216       RunStateArgs* run_state_args, DataTypeVector* input_types,
    217       DataTypeVector* output_types);
    218 
    219   ::tensorflow::Status ExtendLocked(const GraphDef& graph)
    220       EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
    221 
    222   ::tensorflow::Status ResourceHandleToInputTensor(
    223       const Tensor& resource_tensor, Tensor* retrieved_tensor);
    224 
    225   // Feeds more inputs to the executors, triggering further execution.
    226   ::tensorflow::Status SendPRunInputs(
    227       const std::vector<std::pair<string, Tensor>>& inputs,
    228       const ExecutorsAndKeys* executors_and_keys,
    229       IntraProcessRendezvous* rendez);
    230 
    231   // Fetches more outputs from the executors. It waits until the output
    232   // tensors are computed.
    233   ::tensorflow::Status RecvPRunOutputs(
    234       const std::vector<string>& output_names,
    235       const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
    236       std::vector<Tensor>* outputs);
    237 
    238   // Check if the specified fetches can be computed from the feeds
    239   // that we have already provided.
    240   ::tensorflow::Status CheckFetch(
    241       const std::vector<std::pair<string, Tensor>>& feeds,
    242       const std::vector<string>& fetches,
    243       const ExecutorsAndKeys* executors_and_keys, const RunState* run_state);
    244 
    245   // Use the appropriate WaitForNotification function based on whether
    246   // operation_timeout_in_ms is greater than 0.
    247   //
    248   // If the timeout expires, the `cm->StartCancel()` will be called.
    249   ::tensorflow::Status WaitForNotification(Notification* n,
    250                                            int64 timeout_in_ms);
    251   void WaitForNotification(RunState* run_state, CancellationManager* cm,
    252                            int64 timeout_in_ms);
    253 
    254   ::tensorflow::Status CheckNotClosed() {
    255     mutex_lock l(closed_lock_);
    256     if (closed_) return errors::Cancelled("Session has been closed.");
    257     return ::tensorflow::Status::OK();
    258   }
    259 
    260   ::tensorflow::Status CreateDebuggerState(
    261       const DebugOptions& debug_options, int64 session_run_index,
    262       int64 executor_step_index, const std::vector<string>& input_names,
    263       const std::vector<string>& output_names,
    264       const std::vector<string>& target_names,
    265       std::unique_ptr<DebuggerStateInterface>* debugger_state);
    266 
    267   ::tensorflow::Status DecorateAndPublishGraphForDebug(
    268       const DebugOptions& debug_options, Graph* graph, Device* device);
    269 
    270   const SessionOptions options_;
    271 
    272   // Device structures.
    273   const std::unique_ptr<const DeviceMgr> device_mgr_;
    274   std::vector<Device*> devices_;  // not owned
    275   DeviceSet device_set_;
    276 
    277   string session_handle_;
    278   bool graph_created_ GUARDED_BY(graph_def_lock_) = false;
    279 
    280   mutex graph_def_lock_;
    281   GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
    282 
    283   // The thread-pools to use for running ops, with a bool indicating if the pool
    284   // is owned.
    285   std::vector<std::pair<thread::ThreadPool*, bool>> thread_pools_;
    286 
    287   Status init_error_;  // Set to an error if construction failed.
    288 
    289   // If true, blocks until device has finished all queued operations in a step.
    290   bool sync_on_finish_ = true;
    291   // Schedules 'c' for execution on pool.
    292   void SchedClosure(thread::ThreadPool* pool, std::function<void()> c);
    293 
    294   std::vector<std::unique_ptr<FunctionInfo>> functions_
    295       GUARDED_BY(executor_lock_);
    296 
    297   mutex executor_lock_;  // protects executors_
    298   // Holds mappings from signature to the executors that process
    299   // it. The reason for a level of indirection around mapped_type is
    300   // to guarantee address stability.
    301   // The map value is a shared_ptr since multiple map keys can point to the
    302   // same ExecutorsAndKey object.
    303   std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
    304       GUARDED_BY(executor_lock_);
    305 
    306   // Holds mappings from handle to partial run state.
    307   std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
    308       GUARDED_BY(executor_lock_);
    309 
    310   // This holds all the tensors that are currently alive in the session.
    311   SessionState session_state_;
    312 
    313   DirectSessionFactory* const factory_;  // not owned
    314   CancellationManager* cancellation_manager_;
    315 
    316   // Map of placed stateful nodes, i.e. nodes for which is_stateful()
    317   // is true, such as "params" and "queue" nodes.  Once placed these
    318   // nodes can not be moved to a different device.  Maps node names to
    319   // device names.
    320   std::unordered_map<string, string> stateful_placements_
    321       GUARDED_BY(graph_def_lock_);
    322 
    323   // Execution_state; used when placing the entire graph.
    324   std::unique_ptr<GraphExecutionState> execution_state_
    325       GUARDED_BY(graph_def_lock_);
    326 
    327   // The function library, before any rewrites or optimizations have been
    328   // performed. In particular, CreateGraphs() may need to modify the function
    329   // library; it copies and modifies the function library.
    330   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
    331 
    332   // true if the Session has been Closed.
    333   mutex closed_lock_;
    334   bool closed_ GUARDED_BY(closed_lock_) = false;
    335 
    336   // For generating unique names for this session instance.
    337   std::atomic<int64> edge_name_counter_ = {0};
    338   std::atomic<int64> handle_name_counter_ = {0};
    339 
    340   // For generating step ids that are unique across all sessions.
    341   static std::atomic_int_fast64_t step_id_counter_;
    342 
    343   // Global timeout for all blocking operations in this session.
    344   const int64 operation_timeout_in_ms_ = 0;
    345 
    346   // Manages all the cost models for the graphs executed in this session.
    347   CostModelManager cost_model_manager_;
    348 
    349   Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
    350 
    351   TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
    352 
    353   // EXPERIMENTAL: debugger (tfdbg) related
    354   friend class DebugGateway;
    355 };
    356 
    357 }  // end namespace tensorflow
    358 
    359 #endif  // TENSORFLOW_COMMON_RUNTIME_DIRECT_SESSION_H_
    360