Home | History | Annotate | Download | only in eager
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
     16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
     17 
     18 #include <algorithm>
     19 #include <cstddef>
     20 #include <map>
     21 #include <memory>
     22 #include <queue>
     23 #include <string>
     24 #include <vector>
     25 
     26 #include "tensorflow/core/common_runtime/device_factory.h"
     27 #include "tensorflow/core/common_runtime/device_mgr.h"
     28 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
     29 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
     30 #include "tensorflow/core/common_runtime/function.h"
     31 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
     32 #include "tensorflow/core/example/example.pb.h"
     33 #include "tensorflow/core/platform/env.h"
     34 #ifndef __ANDROID__
     35 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
     36 #include "tensorflow/core/distributed_runtime/server_lib.h"
     37 #include "tensorflow/core/distributed_runtime/worker_cache.h"
     38 #endif
     39 #include "tensorflow/core/framework/collective.h"
     40 #include "tensorflow/core/framework/log_memory.h"
     41 #include "tensorflow/core/framework/rendezvous.h"
     42 #include "tensorflow/core/lib/core/stringpiece.h"
     43 #include "tensorflow/core/lib/core/threadpool.h"
     44 #include "tensorflow/core/lib/gtl/flatmap.h"
     45 #include "tensorflow/core/lib/gtl/flatset.h"
     46 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     47 #include "tensorflow/core/lib/gtl/map_util.h"
     48 #include "tensorflow/core/lib/gtl/stl_util.h"
     49 #include "tensorflow/core/platform/fingerprint.h"
     50 #include "tensorflow/core/platform/mutex.h"
     51 #include "tensorflow/core/platform/thread_annotations.h"
     52 #include "tensorflow/core/public/session_options.h"
     53 #include "tensorflow/core/public/version.h"
     54 
     55 namespace tensorflow {
     56 
     57 // Note: there's a copy enum in eager/c_api.h. It should be kept in sync.
     58 enum ContextDevicePlacementPolicy {
     59   // Running operations with input tensors on the wrong device will fail.
     60   DEVICE_PLACEMENT_EXPLICIT = 0,
     61   // Copy the tensor to the right device but log a warning.
     62   DEVICE_PLACEMENT_WARN = 1,
     63   // Silently copy the tensor, which has a performance cost since the operation
     64   // will be blocked till the copy completes. This is the default policy.
     65   DEVICE_PLACEMENT_SILENT = 2,
     66   // Placement policy which silently copies int32 tensors but not other dtypes.
     67   DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
     68 };
     69 
     70 class RunMetadataListener {
     71  public:
     72   virtual ~RunMetadataListener() {}
     73   virtual void BeforeClearRunMetadata() = 0;
     74 };
     75 
     76 class EagerContext {
     77  public:
     78   // TODO: remove this constructor once we migrate all callers to the next one.
     79   EagerContext(const SessionOptions& opts,
     80                ContextDevicePlacementPolicy default_policy, bool async,
     81                std::unique_ptr<const DeviceMgr> device_mgr,
     82                Rendezvous* rendezvous);
     83 
     84   EagerContext(const SessionOptions& opts,
     85                ContextDevicePlacementPolicy default_policy, bool async,
     86                const DeviceMgr* device_mgr, bool device_mgr_owned,
     87                Rendezvous* rendezvous);
     88 
     89   ~EagerContext();
     90 
     91   // Returns the function library runtime for the given device.
     92   FunctionLibraryRuntime* func_lib(Device* d) const {
     93     return pflr_->GetFLR(d->name());
     94   }
     95 
     96   ProcessFunctionLibraryRuntime* pflr() const { return pflr_.get(); }
     97 
     98   // True if running in asynchronous mode.
     99   bool Async() const;
    100 
    101   EagerExecutor* Executor() { return &executor_; }
    102 
    103   std::function<void(std::function<void()>)>* runner() { return &runner_; }
    104 
    105   // Sets whether this thread should run in synchronous or asynchronous mode.
    106   Status SetAsyncForThread(bool async);
    107 
    108   // TODO(apassos) make this return a constant reference
    109   gtl::FlatMap<string, Device*, StringPieceHasher>* device_map() {
    110     return &devices_map_;
    111   }
    112 
    113   // TODO(apassos) make this return a constant reference
    114   std::vector<Device*>* devices() { return &devices_; }
    115   const std::vector<DeviceType>& prioritized_device_type_list() {
    116     return prioritized_device_type_list_;
    117   }
    118 
    119   // Clears the kernel caches.
    120   Status ClearCaches();
    121 
    122   // Sets the device placement policy for the current thread.
    123   void SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy);
    124 
    125   // Returns the device placement policy for the current thread.
    126   ContextDevicePlacementPolicy GetDevicePlacementPolicy();
    127 
    128   Status AsyncWait() { return executor_.WaitForAllPendingNodes(); }
    129 
    130   Status GetStatus() { return executor_.status(); }
    131 
    132   void ClearAsyncError() { executor_.ClearError(); }
    133 
    134   bool FindFunctionByName(const string& name);
    135 
    136   Status FindFunctionOpData(const string& name,
    137                             const tensorflow::OpRegistrationData** op_data);
    138 
    139   const FunctionDef* FindFunctionDef(const string& name);
    140 
    141   Status FindDeviceByName(const string& name, Device** result);
    142 
    143   Device* HostCPU() const { return devices_[0]; }
    144 
    145   GraphCollector* GetGraphCollector() { return &graph_collector_; }
    146 
    147   uint64 NextId() { return executor_.NextId(); }
    148 
    149   void ExecutorAdd(EagerNode* node) { executor_.Add(node); }
    150 
    151   Status AddFunctionDef(const FunctionDef& fdef);
    152 
    153   KernelAndDevice* GetCachedKernel(Fprint128 cache_key);
    154 
    155   void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
    156 
    157   bool LogDevicePlacement() const { return log_device_placement_; }
    158   bool LogMemory() const { return log_memory_; }
    159 
    160   Rendezvous* GetRendezvous() const { return rendezvous_; }
    161   CollectiveExecutorMgrInterface* collective_executor_mgr() {
    162     return (collective_executor_mgr_ != nullptr)
    163                ? collective_executor_mgr_.get()
    164                : unowned_collective_executor_mgr_;
    165   }
    166   std::unique_ptr<CollectiveExecutor::Handle> GetCollectiveExecutorHandle() {
    167     return std::unique_ptr<CollectiveExecutor::Handle>(
    168         new CollectiveExecutor::Handle(
    169             collective_executor_mgr()->FindOrCreate(0), true /*inherit_ref*/));
    170   }
    171 
    172   const tensorflow::DeviceMgr* local_device_mgr() const {
    173     return (local_device_manager_ != nullptr) ? local_device_manager_.get()
    174                                               : local_unowned_device_manager_;
    175   }
    176   const tensorflow::DeviceMgr* remote_device_mgr() const {
    177     return remote_device_manager_.get();
    178   }
    179 
    180   // TODO(apassos) remove the need for this
    181   void ReleaseDeviceMgr() { local_device_manager_.release(); }
    182 
    183   // TODO(apassos) clean up RunMetadata storage.
    184   mutex* MetadataMu() LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
    185   bool ShouldStoreStepStats() LOCKS_EXCLUDED(metadata_mu_);
    186   void SetShouldStoreStepStats(bool value);
    187   bool ShouldStoreGraphs() LOCKS_EXCLUDED(metadata_mu_);
    188   void SetShouldStoreGraphs(bool value);
    189   RunMetadata* RunMetadataProto() { return &run_metadata_; }
    190   void ClearRunMetadata() EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_);
    191 
    192   Status RegisterRunMetadataListener(RunMetadataListener* listener)
    193       LOCKS_EXCLUDED(metadata_mu_);
    194   void ClearRunMetadataListener() LOCKS_EXCLUDED(metadata_mu_);
    195 
    196   void StartStep();
    197   void EndStep();
    198   ScopedStepContainer* StepContainer();
    199 
    200   FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
    201 
    202 #ifndef __ANDROID__
    203   Status GetClientAndContextID(Device* device, eager::EagerClient** client,
    204                                uint64* context_id);
    205 
    206   // TODO(nareshmodi): Encapsulate remote state into a separate
    207   // class/struct.
    208   //
    209   // Enables the eager context to communicate with remote devices.
    210   //
    211   // - server: A ServerInterface that exports the tensorflow.WorkerService.
    212   // Note that this class expects the server to already have been started.
    213   // - remote_eager_workers: A cache from which we can get "EagerClient"s to
    214   // communicate with remote eager services.
    215   // - remote_device_mgr: A DeviceMgr* which contains all remote devices
    216   // (should contain no local devices).
    217   // - remote_contexts: A map containing task name to remote context ID.
    218   Status InitializeRemote(
    219       std::unique_ptr<ServerInterface> server,
    220       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
    221       std::unique_ptr<DeviceMgr> remote_device_manager,
    222       const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r,
    223       DeviceMgr* local_device_mgr, int keep_alive_secs);
    224 
    225   bool HasActiveRemoteContext(uint64 context_id) {
    226     return active_remote_contexts_.find(context_id) !=
    227            active_remote_contexts_.end();
    228   }
    229 
    230   Status StoreCollectiveOpsServer(
    231       std::unique_ptr<ServerInterface> server, DeviceMgr* device_mgr,
    232       CollectiveExecutorMgrInterface* rpc_collective_executor_mgr);
    233 #endif
    234 
    235   // If true, then tensors should be shipped across processes via the
    236   // EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used
    237   // instead (which in-turn use WorkerService.RecvTensor RPCs).
    238   bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
    239   bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; }
    240 
    241   tensorflow::Env* TFEnv() const { return env_; }
    242 
    243   // All child threads will be reset() when destructing EagerContext.
    244   void AddChildThread(std::unique_ptr<Thread> thread);
    245 
    246  private:
    247   void InitDeviceMapAndAsync();
    248   Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
    249 
    250   const ContextDevicePlacementPolicy policy_;
    251 
    252   // Note: we cannot use C++11 thread_local here as there is no concept of a
    253   // thread-local-object-local variable in C++11.
    254   mutex policy_map_mu_;
    255   std::unordered_map<std::thread::id, ContextDevicePlacementPolicy>
    256       thread_local_policies_ GUARDED_BY(policy_map_mu_);
    257 
    258   // Only one of the below is set.
    259   std::unique_ptr<const DeviceMgr> local_device_manager_;
    260   const DeviceMgr* local_unowned_device_manager_;
    261   std::unique_ptr<DeviceMgr> remote_device_manager_;
    262 
    263   // Devices owned by device_manager
    264   std::vector<Device*> devices_;
    265   std::vector<DeviceType> prioritized_device_type_list_;
    266   // All devices are not owned.
    267   gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
    268   Rendezvous* rendezvous_;
    269 
    270   mutex functions_mu_;
    271   FunctionLibraryDefinition func_lib_def_ GUARDED_BY(functions_mu_){
    272       OpRegistry::Global(), {}};
    273 
    274   std::unique_ptr<thread::ThreadPool> thread_pool_;
    275 
    276   // One FunctionLibraryRuntime per device.
    277   // func_libs[i] is the FunctionLibraryRuntime corresponding to
    278   // session->devices[i].
    279   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
    280 
    281   std::function<void(std::function<void()>)> runner_;
    282 
    283   mutex cache_mu_;
    284   std::unordered_map<Fprint128, KernelAndDevice*, Fprint128Hasher> kernel_cache_
    285       GUARDED_BY(cache_mu_);
    286 
    287   // Whether we should compute RunMetadata.
    288   std::atomic<bool> should_store_step_stats_{false};
    289   std::atomic<bool> should_store_graphs_{false};
    290   mutex metadata_mu_;
    291   RunMetadata run_metadata_ GUARDED_BY(metadata_mu_);
    292   RunMetadataListener* metadata_listener_ GUARDED_BY(metadata_mu_) = nullptr;
    293   GraphCollector graph_collector_;
    294   const bool log_device_placement_;
    295   // EagerExecutor for async execution.
    296   EagerExecutor executor_;
    297 
    298   // Information related to step containers.
    299   std::atomic<int> num_active_steps_;
    300   std::unique_ptr<ScopedStepContainer> step_container_ GUARDED_BY(metadata_mu_);
    301 
    302   // True if the default value for execution mode is async. Note that this value
    303   // can be overridden per thread based on `thread_local_async` overrides.
    304   const bool async_default_;
    305   mutable mutex async_map_mu_;
    306   std::unordered_map<std::thread::id, bool> thread_local_async_
    307       GUARDED_BY(async_map_mu_);
    308 
    309   const bool log_memory_;
    310 
    311   Env* const env_;
    312 
    313   std::unique_ptr<CollectiveExecutorMgrInterface> collective_executor_mgr_;
    314   CollectiveExecutorMgrInterface* unowned_collective_executor_mgr_ = nullptr;
    315 
    316 #ifndef __ANDROID__
    317   void CloseRemoteContexts();
    318 
    319   // The server_ is not const since we release it when the context is destroyed.
    320   // Therefore the server_ object is not marked as const (even though it should
    321   // be).
    322   std::unique_ptr<ServerInterface> server_;
    323   std::unique_ptr<eager::EagerClientCache> remote_eager_workers_;
    324 
    325   mutex remote_state_mu_;
    326 
    327   gtl::FlatMap<string, uint64> remote_contexts_;
    328   gtl::FlatSet<uint64> active_remote_contexts_;
    329   gtl::FlatMap<Device*, std::pair<eager::EagerClient*, uint64>>
    330       device_to_client_cache_;
    331 
    332   int keep_alive_secs_ GUARDED_BY(remote_state_mu_);
    333   std::atomic<int> sleep_for_secs_;
    334 
    335   std::unique_ptr<Thread> keep_alive_thread_;
    336   mutex keep_alive_thread_shutdown_mu_;
    337   condition_variable keep_alive_thread_cv_;
    338   bool shutting_down_ GUARDED_BY(keep_alive_thread_shutdown_mu_) = false;
    339 #endif
    340 
    341   bool use_send_tensor_rpc_;
    342   const bool pin_small_ops_to_cpu_;
    343   std::vector<std::unique_ptr<tensorflow::Thread>> child_threads_;
    344 };
    345 
    346 }  // namespace tensorflow
    347 
    348 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
    349