Home | History | Annotate | Download | only in eager
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/common_runtime/eager/context.h"
     17 
     18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
     19 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
     20 #include "tensorflow/core/common_runtime/device_resolver_local.h"
     21 #include "tensorflow/core/common_runtime/device_set.h"
     22 #include "tensorflow/core/common_runtime/process_util.h"
     23 #include "tensorflow/core/lib/core/errors.h"
     24 #ifndef __ANDROID__
     25 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
     26 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
     27 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
     28 #endif
     29 #include "tensorflow/core/framework/resource_mgr.h"
     30 #include "tensorflow/core/lib/core/blocking_counter.h"
     31 #include "tensorflow/core/util/env_var.h"
     32 
     33 namespace tensorflow {
     34 namespace {
     35 
     36 bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
     37   bool val;
     38   if (tensorflow::ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
     39     return val;
     40   }
     41   return default_val;
     42 }
     43 
     44 }  // namespace
     45 
     46 EagerContext::EagerContext(const SessionOptions& opts,
     47                            ContextDevicePlacementPolicy default_policy,
     48                            bool async,
     49                            std::unique_ptr<const DeviceMgr> device_mgr,
     50                            Rendezvous* rendezvous)
     51     : EagerContext(opts, default_policy, async, device_mgr.release(),
     52                    /*device_mgr_owned*/ true, rendezvous) {}
     53 
     54 EagerContext::EagerContext(const SessionOptions& opts,
     55                            ContextDevicePlacementPolicy default_policy,
     56                            bool async, const DeviceMgr* device_mgr,
     57                            bool device_mgr_owned, Rendezvous* rendezvous)
     58     : policy_(default_policy),
     59       devices_(device_mgr->ListDevices()),
     60       rendezvous_(rendezvous),
     61       thread_pool_(NewThreadPoolFromSessionOptions(opts)),
     62       pflr_(new ProcessFunctionLibraryRuntime(
     63           device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_,
     64           opts.config.graph_options().optimizer_options(), thread_pool_.get())),
     65       log_device_placement_(opts.config.log_device_placement()),
     66       num_active_steps_(0),
     67       async_default_(async),
     68       log_memory_(LogMemory::IsEnabled()),
     69       env_(opts.env),
     70       use_send_tensor_rpc_(false),
     71       pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
     72           "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)) {
     73   if (device_mgr_owned) {
     74     local_device_manager_.reset(device_mgr);
     75     local_unowned_device_manager_ = nullptr;
     76   } else {
     77     local_unowned_device_manager_ = device_mgr;
     78   }
     79   InitDeviceMapAndAsync();
     80   runner_ = [this](std::function<void()> closure) {
     81     this->thread_pool_->Schedule(std::move(closure));
     82   };
     83 
     84   std::unique_ptr<DeviceResolverInterface> drl(
     85       new DeviceResolverLocal(local_device_mgr()));
     86   std::unique_ptr<ParamResolverInterface> cprl(new CollectiveParamResolverLocal(
     87       opts.config, local_device_mgr(), drl.get(),
     88       "/job:localhost/replica:0/task:0"));
     89   collective_executor_mgr_.reset(new CollectiveExecutorMgr(
     90       opts.config, local_device_mgr(), std::move(drl), std::move(cprl)));
     91 }
     92 
     93 void EagerContext::InitDeviceMapAndAsync() {
     94   if (async_default_) {
     95     executor_.EnableAsync();
     96   }
     97 
     98   for (auto* device : devices_) {
     99     devices_map_[device->name()] = device;
    100   }
    101 
    102   if (remote_device_manager_ != nullptr) {
    103     for (auto* device : remote_device_manager_->ListDevices()) {
    104       if (devices_map_.find(device->name()) == devices_map_.end()) {
    105         devices_map_[device->name()] = device;
    106         devices_.push_back(device);
    107       }
    108     }
    109   }
    110 
    111   DeviceSet ds;
    112   for (Device* d : devices_) {
    113     ds.AddDevice(d);
    114   }
    115   prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
    116 }
    117 
    118 bool EagerContext::Async() const {
    119   mutex_lock l(async_map_mu_);
    120   return gtl::FindWithDefault(thread_local_async_, std::this_thread::get_id(),
    121                               async_default_);
    122 }
    123 
    124 Status EagerContext::SetAsyncForThread(bool async) {
    125   {
    126     tensorflow::mutex_lock l(async_map_mu_);
    127     thread_local_async_[std::this_thread::get_id()] = async;
    128   }
    129   if (async) {
    130     executor_.EnableAsync();
    131   } else {
    132     // TODO(agarwal): Currently we add a wait here to handle cases where a
    133     // sync op has a control dependency on an async op, and the latter has not
    134     // executed yet. This wait can be removed by storing all the control
    135     // inputs and waiting for them when executing ops.
    136     return executor_.WaitForAllPendingNodes();
    137   }
    138   return Status::OK();
    139 }
    140 
    141 Status EagerContext::ClearCaches() {
    142   // The executor stores pointers to kernels, so we need to make sure that no
    143   // async eager ops are still executing. We lock the cache during this time as
    144   // well.
    145   mutex_lock ml(cache_mu_);
    146   TF_RETURN_IF_ERROR(executor_.WaitForAllPendingNodes());
    147   gtl::STLDeleteValues(&kernel_cache_);
    148 
    149   return Status::OK();
    150 }
    151 
    152 void EagerContext::SetThreadLocalDevicePlacementPolicy(
    153     ContextDevicePlacementPolicy policy) {
    154   mutex_lock ml(policy_map_mu_);
    155   thread_local_policies_[std::this_thread::get_id()] = policy;
    156 }
    157 
    158 ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() {
    159   mutex_lock ml(policy_map_mu_);
    160   auto policy_map_it = thread_local_policies_.find(std::this_thread::get_id());
    161   if (policy_map_it != thread_local_policies_.end()) {
    162     return policy_map_it->second;
    163   }
    164   return policy_;
    165 }
    166 
    167 #ifndef __ANDROID__
    168 void EagerContext::CloseRemoteContexts() {
    169   // Close all remote contexts.
    170   std::vector<eager::CloseContextRequest> requests(remote_contexts_.size());
    171   std::vector<eager::CloseContextResponse> responses(remote_contexts_.size());
    172   BlockingCounter counter(static_cast<int>(remote_contexts_.size()));
    173 
    174   int i = 0;
    175   for (const auto& worker_and_context_id : remote_contexts_) {
    176     auto* client =
    177         remote_eager_workers_->GetClient(worker_and_context_id.first);
    178 
    179     requests[i].set_context_id(worker_and_context_id.second);
    180     client->CloseContextAsync(
    181         &requests[i], &responses[i],
    182         [&worker_and_context_id, &counter](const Status& s) {
    183           if (!s.ok()) {
    184             LOG(ERROR) << "Unable to close remote context with ID "
    185                        << worker_and_context_id.second
    186                        << " for worker: " << worker_and_context_id.first
    187                        << " due to " << s.error_message();
    188           }
    189           counter.DecrementCount();
    190         });
    191     i++;
    192   }
    193 
    194   counter.Wait();
    195 }
    196 #endif
    197 
    198 EagerContext::~EagerContext() {
    199 #ifndef __ANDROID__
    200   if (server_) {
    201     // TODO(nareshmodi): Fix this.
    202     LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
    203                     "Servers don't support clean shutdown.";
    204     server_.release();
    205   }
    206 
    207   {
    208     mutex_lock l(keep_alive_thread_shutdown_mu_);
    209     shutting_down_ = true;
    210     keep_alive_thread_cv_.notify_all();
    211   }
    212   keep_alive_thread_.reset();
    213 
    214   CloseRemoteContexts();
    215 #endif
    216 
    217   executor_.WaitForAllPendingNodes().IgnoreError();
    218   ClearCaches().IgnoreError();
    219   rendezvous_->Unref();
    220 
    221   for (auto& thread : child_threads_) {
    222     thread.reset();
    223   }
    224 }
    225 
    226 void EagerContext::AddChildThread(std::unique_ptr<Thread> thread) {
    227   child_threads_.push_back(std::move(thread));
    228 }
    229 
    230 bool EagerContext::FindFunctionByName(const string& name) {
    231   mutex_lock l(functions_mu_);
    232   return func_lib_def_.Find(name) != nullptr;
    233 }
    234 
    235 Status EagerContext::FindFunctionOpData(
    236     const string& name, const tensorflow::OpRegistrationData** op_data) {
    237   mutex_lock l(functions_mu_);
    238   return func_lib_def_.LookUp(name, op_data);
    239 }
    240 
    241 const FunctionDef* EagerContext::FindFunctionDef(const string& name) {
    242   mutex_lock l(functions_mu_);
    243   return func_lib_def_.Find(name);
    244 }
    245 
    246 Status EagerContext::FindDeviceByName(const string& name, Device** result) {
    247   auto it = devices_map_.find(name);
    248   if (it == devices_map_.end()) {
    249     return errors::InvalidArgument(name, " unknown device.");
    250   }
    251   *result = it->second;
    252   return Status::OK();
    253 }
    254 
    255 void EagerContext::ClearRunMetadata() {
    256   if (metadata_listener_ != nullptr) {
    257     metadata_listener_->BeforeClearRunMetadata();
    258   }
    259   run_metadata_.Clear();
    260 }
    261 
    262 Status EagerContext::RegisterRunMetadataListener(
    263     RunMetadataListener* listener) {
    264   mutex_lock l(metadata_mu_);
    265   if (metadata_listener_ != nullptr) {
    266     return Status(error::Code::INVALID_ARGUMENT,
    267                   "Cannot run two eager profiler at the same time");
    268   }
    269   metadata_listener_ = listener;
    270   return Status::OK();
    271 }
    272 
    273 void EagerContext::ClearRunMetadataListener() {
    274   mutex_lock l(metadata_mu_);
    275   metadata_listener_ = nullptr;
    276 }
    277 
    278 void EagerContext::StartStep() {
    279   mutex_lock ml(metadata_mu_);
    280   num_active_steps_++;
    281   if (step_container_ == nullptr) {
    282     step_container_.reset(
    283         new ScopedStepContainer(0, [this](const string& name) {
    284           for (Device* device : devices_) {
    285             device->resource_manager()->Cleanup(name).IgnoreError();
    286           }
    287         }));
    288   }
    289 }
    290 
    291 void EagerContext::EndStep() {
    292   mutex_lock ml(metadata_mu_);
    293   num_active_steps_--;
    294   if (num_active_steps_ == 0) {
    295     step_container_.reset();
    296   }
    297 }
    298 
    299 ScopedStepContainer* EagerContext::StepContainer() {
    300   if (num_active_steps_.load() == 0) {
    301     return nullptr;
    302   }
    303   mutex_lock ml(metadata_mu_);
    304   return step_container_.get();
    305 }
    306 
    307 Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
    308   if (remote_device_manager_ == nullptr) return Status::OK();
    309 #ifndef __ANDROID__
    310   BlockingCounter blocking_counter(static_cast<int>(remote_contexts_.size()));
    311 
    312   std::vector<eager::RegisterFunctionRequest> requests(remote_contexts_.size());
    313   std::vector<eager::RegisterFunctionResponse> responses(
    314       remote_contexts_.size());
    315   std::vector<Status> statuses(remote_contexts_.size());
    316 
    317   int i = 0;
    318   for (const auto& target_and_context_id : remote_contexts_) {
    319     requests[i].set_context_id(target_and_context_id.second);
    320     *requests[i].mutable_function_def() = fdef;
    321 
    322     auto* eager_client =
    323         remote_eager_workers_->GetClient(target_and_context_id.first);
    324 
    325     eager_client->RegisterFunctionAsync(
    326         &requests[i], &responses[i],
    327         [i, &statuses, &blocking_counter](const Status& status) {
    328           statuses[i] = status;
    329           blocking_counter.DecrementCount();
    330         });
    331 
    332     i++;
    333   }
    334   blocking_counter.Wait();
    335 
    336   for (int i = 0; i < remote_contexts_.size(); i++) {
    337     TF_RETURN_IF_ERROR(statuses[i]);
    338   }
    339 #endif
    340   return Status::OK();
    341 }
    342 
    343 Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
    344   mutex_lock l(functions_mu_);
    345   TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef));
    346 
    347   return MaybeRegisterFunctionRemotely(fdef);
    348 }
    349 
    350 KernelAndDevice* EagerContext::GetCachedKernel(Fprint128 cache_key) {
    351   tf_shared_lock l(cache_mu_);
    352   return gtl::FindPtrOrNull(kernel_cache_, cache_key);
    353 }
    354 
    355 void EagerContext::AddKernelToCache(Fprint128 cache_key,
    356                                     KernelAndDevice* kernel) {
    357   mutex_lock ml(cache_mu_);
    358   gtl::InsertOrUpdate(&kernel_cache_, cache_key, kernel);
    359 }
    360 
    361 bool EagerContext::ShouldStoreGraphs() {
    362   mutex_lock ml(metadata_mu_);
    363   return should_store_graphs_.load() || metadata_listener_ != nullptr;
    364 }
    365 
    366 bool EagerContext::ShouldStoreStepStats() {
    367   mutex_lock ml(metadata_mu_);
    368   return should_store_step_stats_.load() || metadata_listener_ != nullptr;
    369 }
    370 
    371 void EagerContext::SetShouldStoreGraphs(bool value) {
    372   mutex_lock ml(metadata_mu_);
    373   should_store_graphs_.store(value);
    374   if (!value || metadata_listener_ != nullptr) {
    375     run_metadata_.Clear();
    376   }
    377 }
    378 
    379 void EagerContext::SetShouldStoreStepStats(bool value) {
    380   mutex_lock ml(metadata_mu_);
    381   should_store_step_stats_.store(value);
    382   if (!value || metadata_listener_ != nullptr) {
    383     run_metadata_.Clear();
    384   }
    385 }
    386 
    387 namespace {
    388 Status GetTaskName(Device* d, string* task_name) {
    389   string ignored;
    390   if (!DeviceNameUtils::SplitDeviceName(d->name(), task_name, &ignored)) {
    391     return errors::InvalidArgument("Unable to parse device name: ", d->name());
    392   }
    393 
    394   return Status::OK();
    395 }
    396 }  // namespace
    397 
    398 #ifndef __ANDROID__
    399 Status EagerContext::GetClientAndContextID(Device* device,
    400                                            eager::EagerClient** client,
    401                                            uint64* context_id) {
    402   auto it = device_to_client_cache_.find(device);
    403   if (it != device_to_client_cache_.end()) {
    404     *client = it->second.first;
    405     *context_id = it->second.second;
    406   }
    407   string device_task_name;
    408   TF_RETURN_IF_ERROR(GetTaskName(device, &device_task_name));
    409 
    410   *client = remote_eager_workers_->GetClient(device_task_name);
    411 
    412   if (*client == nullptr) {
    413     return errors::InvalidArgument(
    414         "Unable to find eager client corresponding to device ", device->name());
    415   }
    416 
    417   auto context_iterator = remote_contexts_.find(device_task_name);
    418   if (context_iterator == remote_contexts_.end()) {
    419     return errors::Internal("Unable to find a context for handle on task: ",
    420                             device_task_name, ". This should not be possible");
    421   }
    422   *context_id = context_iterator->second;
    423 
    424   device_to_client_cache_.insert({device, {*client, *context_id}});
    425 
    426   return Status::OK();
    427 }
    428 
    429 Status EagerContext::StoreCollectiveOpsServer(
    430     std::unique_ptr<ServerInterface> server, DeviceMgr* device_mgr,
    431     CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) {
    432   collective_executor_mgr_.reset(nullptr);
    433   unowned_collective_executor_mgr_ = rpc_collective_executor_mgr;
    434 
    435   local_device_manager_.reset(nullptr);
    436   local_unowned_device_manager_ = device_mgr;
    437 
    438   devices_ = local_unowned_device_manager_->ListDevices();
    439   devices_map_.clear();
    440 
    441   InitDeviceMapAndAsync();
    442   TF_RETURN_IF_ERROR(ClearCaches());
    443 
    444   pflr_.reset(new ProcessFunctionLibraryRuntime(
    445       local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_,
    446       {}, thread_pool_.get()));
    447 
    448   // Memory leak!
    449   if (server_ != nullptr) {
    450     LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
    451                     "Servers don't support clean shutdown.";
    452     server_.release();
    453   }
    454   server_ = std::move(server);
    455 
    456   return Status::OK();
    457 }
    458 
    459 Status EagerContext::InitializeRemote(
    460     std::unique_ptr<ServerInterface> server,
    461     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
    462     std::unique_ptr<DeviceMgr> remote_device_manager,
    463     const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r,
    464     DeviceMgr* local_device_mgr, int keep_alive_secs) {
    465   mutex_lock l(remote_state_mu_);
    466 
    467   if (!remote_contexts_.empty()) {
    468     CloseRemoteContexts();
    469   }
    470   remote_contexts_ = remote_contexts;
    471 
    472   use_send_tensor_rpc_ =
    473       ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false);
    474 
    475   local_unowned_device_manager_ = local_device_mgr;
    476   local_device_manager_ = nullptr;
    477   pflr_.reset(new ProcessFunctionLibraryRuntime(
    478       local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_,
    479       {}, thread_pool_.get()));
    480 
    481   devices_ = local_unowned_device_manager_->ListDevices();
    482   devices_map_.clear();
    483 
    484   if (rendezvous_ != nullptr) rendezvous_->Unref();
    485   rendezvous_ = r;
    486 
    487   // Memory leak!
    488   if (server_ != nullptr) {
    489     LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
    490                     "Servers don't support clean shutdown.";
    491     server_.release();
    492   }
    493 
    494   server_ = std::move(server);
    495   remote_eager_workers_ = std::move(remote_eager_workers);
    496 
    497   active_remote_contexts_.clear();
    498   for (const auto& remote_context : remote_contexts_) {
    499     active_remote_contexts_.insert(remote_context.second);
    500   }
    501 
    502   device_to_client_cache_.clear();
    503   remote_device_manager_ = std::move(remote_device_manager);
    504 
    505   InitDeviceMapAndAsync();
    506 
    507   TF_RETURN_IF_ERROR(ClearCaches());
    508 
    509   keep_alive_secs_ = keep_alive_secs;
    510 
    511   sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
    512 
    513   // Only schedule a single closure.
    514   if (keep_alive_thread_ == nullptr) {
    515     keep_alive_thread_.reset(
    516         env_->StartThread({}, "EagerKeepAliveThread", [this]() {
    517           while (true) {
    518             {
    519               {
    520                 mutex_lock l(keep_alive_thread_shutdown_mu_);
    521                 keep_alive_thread_cv_.wait_for(
    522                     l, std::chrono::seconds(sleep_for_secs_));
    523 
    524                 if (shutting_down_) {
    525                   return;
    526                 }
    527               }
    528               {
    529                 mutex_lock l(remote_state_mu_);
    530                 if (keep_alive_secs_ > 0) {
    531                   {
    532                     for (const auto& worker_and_context_id : remote_contexts_) {
    533                       auto* client = remote_eager_workers_->GetClient(
    534                           worker_and_context_id.first);
    535 
    536                       eager::KeepAliveRequest* request =
    537                           new eager::KeepAliveRequest;
    538                       eager::KeepAliveResponse* response =
    539                           new eager::KeepAliveResponse;
    540 
    541                       request->set_context_id(worker_and_context_id.second);
    542                       client->KeepAliveAsync(
    543                           request, response,
    544                           [request, response](const Status& s) {
    545                             delete request;
    546                             delete response;
    547                           });
    548                     }
    549                   }
    550                 }
    551               }
    552             }
    553           }
    554         }));
    555   }
    556   return Status::OK();
    557 }
    558 #endif
    559 
    560 }  // namespace tensorflow
    561