1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/common_runtime/eager/context.h" 17 18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h" 19 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h" 20 #include "tensorflow/core/common_runtime/device_resolver_local.h" 21 #include "tensorflow/core/common_runtime/device_set.h" 22 #include "tensorflow/core/common_runtime/process_util.h" 23 #include "tensorflow/core/lib/core/errors.h" 24 #ifndef __ANDROID__ 25 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" 26 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" 27 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" 28 #endif 29 #include "tensorflow/core/framework/resource_mgr.h" 30 #include "tensorflow/core/lib/core/blocking_counter.h" 31 #include "tensorflow/core/util/env_var.h" 32 33 namespace tensorflow { 34 namespace { 35 36 bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) { 37 bool val; 38 if (tensorflow::ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) { 39 return val; 40 } 41 return default_val; 42 } 43 44 } // namespace 45 46 EagerContext::EagerContext(const SessionOptions& opts, 47 ContextDevicePlacementPolicy default_policy, 48 bool async, 49 std::unique_ptr<const DeviceMgr> device_mgr, 50 Rendezvous* rendezvous) 51 : EagerContext(opts, default_policy, async, device_mgr.release(), 52 /*device_mgr_owned*/ true, rendezvous) {} 53 54 EagerContext::EagerContext(const SessionOptions& opts, 55 ContextDevicePlacementPolicy default_policy, 56 bool async, const DeviceMgr* device_mgr, 57 bool device_mgr_owned, Rendezvous* rendezvous) 58 : policy_(default_policy), 59 devices_(device_mgr->ListDevices()), 60 rendezvous_(rendezvous), 61 thread_pool_(NewThreadPoolFromSessionOptions(opts)), 62 pflr_(new ProcessFunctionLibraryRuntime( 63 device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, 64 opts.config.graph_options().optimizer_options(), thread_pool_.get())), 65 log_device_placement_(opts.config.log_device_placement()), 66 num_active_steps_(0), 67 async_default_(async), 68 log_memory_(LogMemory::IsEnabled()), 69 env_(opts.env), 70 use_send_tensor_rpc_(false), 71 pin_small_ops_to_cpu_(ReadBoolFromEnvVar( 72 "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)) { 73 if (device_mgr_owned) { 74 local_device_manager_.reset(device_mgr); 75 local_unowned_device_manager_ = nullptr; 76 } else { 77 local_unowned_device_manager_ = device_mgr; 78 } 79 InitDeviceMapAndAsync(); 80 runner_ = [this](std::function<void()> closure) { 81 this->thread_pool_->Schedule(std::move(closure)); 82 }; 83 84 std::unique_ptr<DeviceResolverInterface> drl( 85 new DeviceResolverLocal(local_device_mgr())); 86 std::unique_ptr<ParamResolverInterface> cprl(new CollectiveParamResolverLocal( 87 opts.config, local_device_mgr(), drl.get(), 88 "/job:localhost/replica:0/task:0")); 89 collective_executor_mgr_.reset(new CollectiveExecutorMgr( 90 opts.config, local_device_mgr(), std::move(drl), std::move(cprl))); 91 } 92 93 void EagerContext::InitDeviceMapAndAsync() { 94 if (async_default_) { 95 executor_.EnableAsync(); 96 } 97 98 for (auto* device : devices_) { 99 devices_map_[device->name()] = device; 100 } 101 102 if (remote_device_manager_ != nullptr) { 103 for (auto* device : remote_device_manager_->ListDevices()) { 104 if (devices_map_.find(device->name()) == devices_map_.end()) { 105 devices_map_[device->name()] = device; 106 devices_.push_back(device); 107 } 108 } 109 } 110 111 DeviceSet ds; 112 for (Device* d : devices_) { 113 ds.AddDevice(d); 114 } 115 prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList(); 116 } 117 118 bool EagerContext::Async() const { 119 mutex_lock l(async_map_mu_); 120 return gtl::FindWithDefault(thread_local_async_, std::this_thread::get_id(), 121 async_default_); 122 } 123 124 Status EagerContext::SetAsyncForThread(bool async) { 125 { 126 tensorflow::mutex_lock l(async_map_mu_); 127 thread_local_async_[std::this_thread::get_id()] = async; 128 } 129 if (async) { 130 executor_.EnableAsync(); 131 } else { 132 // TODO(agarwal): Currently we add a wait here to handle cases where a 133 // sync op has a control dependency on an async op, and the latter has not 134 // executed yet. This wait can be removed by storing all the control 135 // inputs and waiting for them when executing ops. 136 return executor_.WaitForAllPendingNodes(); 137 } 138 return Status::OK(); 139 } 140 141 Status EagerContext::ClearCaches() { 142 // The executor stores pointers to kernels, so we need to make sure that no 143 // async eager ops are still executing. We lock the cache during this time as 144 // well. 145 mutex_lock ml(cache_mu_); 146 TF_RETURN_IF_ERROR(executor_.WaitForAllPendingNodes()); 147 gtl::STLDeleteValues(&kernel_cache_); 148 149 return Status::OK(); 150 } 151 152 void EagerContext::SetThreadLocalDevicePlacementPolicy( 153 ContextDevicePlacementPolicy policy) { 154 mutex_lock ml(policy_map_mu_); 155 thread_local_policies_[std::this_thread::get_id()] = policy; 156 } 157 158 ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() { 159 mutex_lock ml(policy_map_mu_); 160 auto policy_map_it = thread_local_policies_.find(std::this_thread::get_id()); 161 if (policy_map_it != thread_local_policies_.end()) { 162 return policy_map_it->second; 163 } 164 return policy_; 165 } 166 167 #ifndef __ANDROID__ 168 void EagerContext::CloseRemoteContexts() { 169 // Close all remote contexts. 170 std::vector<eager::CloseContextRequest> requests(remote_contexts_.size()); 171 std::vector<eager::CloseContextResponse> responses(remote_contexts_.size()); 172 BlockingCounter counter(static_cast<int>(remote_contexts_.size())); 173 174 int i = 0; 175 for (const auto& worker_and_context_id : remote_contexts_) { 176 auto* client = 177 remote_eager_workers_->GetClient(worker_and_context_id.first); 178 179 requests[i].set_context_id(worker_and_context_id.second); 180 client->CloseContextAsync( 181 &requests[i], &responses[i], 182 [&worker_and_context_id, &counter](const Status& s) { 183 if (!s.ok()) { 184 LOG(ERROR) << "Unable to close remote context with ID " 185 << worker_and_context_id.second 186 << " for worker: " << worker_and_context_id.first 187 << " due to " << s.error_message(); 188 } 189 counter.DecrementCount(); 190 }); 191 i++; 192 } 193 194 counter.Wait(); 195 } 196 #endif 197 198 EagerContext::~EagerContext() { 199 #ifndef __ANDROID__ 200 if (server_) { 201 // TODO(nareshmodi): Fix this. 202 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. " 203 "Servers don't support clean shutdown."; 204 server_.release(); 205 } 206 207 { 208 mutex_lock l(keep_alive_thread_shutdown_mu_); 209 shutting_down_ = true; 210 keep_alive_thread_cv_.notify_all(); 211 } 212 keep_alive_thread_.reset(); 213 214 CloseRemoteContexts(); 215 #endif 216 217 executor_.WaitForAllPendingNodes().IgnoreError(); 218 ClearCaches().IgnoreError(); 219 rendezvous_->Unref(); 220 221 for (auto& thread : child_threads_) { 222 thread.reset(); 223 } 224 } 225 226 void EagerContext::AddChildThread(std::unique_ptr<Thread> thread) { 227 child_threads_.push_back(std::move(thread)); 228 } 229 230 bool EagerContext::FindFunctionByName(const string& name) { 231 mutex_lock l(functions_mu_); 232 return func_lib_def_.Find(name) != nullptr; 233 } 234 235 Status EagerContext::FindFunctionOpData( 236 const string& name, const tensorflow::OpRegistrationData** op_data) { 237 mutex_lock l(functions_mu_); 238 return func_lib_def_.LookUp(name, op_data); 239 } 240 241 const FunctionDef* EagerContext::FindFunctionDef(const string& name) { 242 mutex_lock l(functions_mu_); 243 return func_lib_def_.Find(name); 244 } 245 246 Status EagerContext::FindDeviceByName(const string& name, Device** result) { 247 auto it = devices_map_.find(name); 248 if (it == devices_map_.end()) { 249 return errors::InvalidArgument(name, " unknown device."); 250 } 251 *result = it->second; 252 return Status::OK(); 253 } 254 255 void EagerContext::ClearRunMetadata() { 256 if (metadata_listener_ != nullptr) { 257 metadata_listener_->BeforeClearRunMetadata(); 258 } 259 run_metadata_.Clear(); 260 } 261 262 Status EagerContext::RegisterRunMetadataListener( 263 RunMetadataListener* listener) { 264 mutex_lock l(metadata_mu_); 265 if (metadata_listener_ != nullptr) { 266 return Status(error::Code::INVALID_ARGUMENT, 267 "Cannot run two eager profiler at the same time"); 268 } 269 metadata_listener_ = listener; 270 return Status::OK(); 271 } 272 273 void EagerContext::ClearRunMetadataListener() { 274 mutex_lock l(metadata_mu_); 275 metadata_listener_ = nullptr; 276 } 277 278 void EagerContext::StartStep() { 279 mutex_lock ml(metadata_mu_); 280 num_active_steps_++; 281 if (step_container_ == nullptr) { 282 step_container_.reset( 283 new ScopedStepContainer(0, [this](const string& name) { 284 for (Device* device : devices_) { 285 device->resource_manager()->Cleanup(name).IgnoreError(); 286 } 287 })); 288 } 289 } 290 291 void EagerContext::EndStep() { 292 mutex_lock ml(metadata_mu_); 293 num_active_steps_--; 294 if (num_active_steps_ == 0) { 295 step_container_.reset(); 296 } 297 } 298 299 ScopedStepContainer* EagerContext::StepContainer() { 300 if (num_active_steps_.load() == 0) { 301 return nullptr; 302 } 303 mutex_lock ml(metadata_mu_); 304 return step_container_.get(); 305 } 306 307 Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) { 308 if (remote_device_manager_ == nullptr) return Status::OK(); 309 #ifndef __ANDROID__ 310 BlockingCounter blocking_counter(static_cast<int>(remote_contexts_.size())); 311 312 std::vector<eager::RegisterFunctionRequest> requests(remote_contexts_.size()); 313 std::vector<eager::RegisterFunctionResponse> responses( 314 remote_contexts_.size()); 315 std::vector<Status> statuses(remote_contexts_.size()); 316 317 int i = 0; 318 for (const auto& target_and_context_id : remote_contexts_) { 319 requests[i].set_context_id(target_and_context_id.second); 320 *requests[i].mutable_function_def() = fdef; 321 322 auto* eager_client = 323 remote_eager_workers_->GetClient(target_and_context_id.first); 324 325 eager_client->RegisterFunctionAsync( 326 &requests[i], &responses[i], 327 [i, &statuses, &blocking_counter](const Status& status) { 328 statuses[i] = status; 329 blocking_counter.DecrementCount(); 330 }); 331 332 i++; 333 } 334 blocking_counter.Wait(); 335 336 for (int i = 0; i < remote_contexts_.size(); i++) { 337 TF_RETURN_IF_ERROR(statuses[i]); 338 } 339 #endif 340 return Status::OK(); 341 } 342 343 Status EagerContext::AddFunctionDef(const FunctionDef& fdef) { 344 mutex_lock l(functions_mu_); 345 TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef)); 346 347 return MaybeRegisterFunctionRemotely(fdef); 348 } 349 350 KernelAndDevice* EagerContext::GetCachedKernel(Fprint128 cache_key) { 351 tf_shared_lock l(cache_mu_); 352 return gtl::FindPtrOrNull(kernel_cache_, cache_key); 353 } 354 355 void EagerContext::AddKernelToCache(Fprint128 cache_key, 356 KernelAndDevice* kernel) { 357 mutex_lock ml(cache_mu_); 358 gtl::InsertOrUpdate(&kernel_cache_, cache_key, kernel); 359 } 360 361 bool EagerContext::ShouldStoreGraphs() { 362 mutex_lock ml(metadata_mu_); 363 return should_store_graphs_.load() || metadata_listener_ != nullptr; 364 } 365 366 bool EagerContext::ShouldStoreStepStats() { 367 mutex_lock ml(metadata_mu_); 368 return should_store_step_stats_.load() || metadata_listener_ != nullptr; 369 } 370 371 void EagerContext::SetShouldStoreGraphs(bool value) { 372 mutex_lock ml(metadata_mu_); 373 should_store_graphs_.store(value); 374 if (!value || metadata_listener_ != nullptr) { 375 run_metadata_.Clear(); 376 } 377 } 378 379 void EagerContext::SetShouldStoreStepStats(bool value) { 380 mutex_lock ml(metadata_mu_); 381 should_store_step_stats_.store(value); 382 if (!value || metadata_listener_ != nullptr) { 383 run_metadata_.Clear(); 384 } 385 } 386 387 namespace { 388 Status GetTaskName(Device* d, string* task_name) { 389 string ignored; 390 if (!DeviceNameUtils::SplitDeviceName(d->name(), task_name, &ignored)) { 391 return errors::InvalidArgument("Unable to parse device name: ", d->name()); 392 } 393 394 return Status::OK(); 395 } 396 } // namespace 397 398 #ifndef __ANDROID__ 399 Status EagerContext::GetClientAndContextID(Device* device, 400 eager::EagerClient** client, 401 uint64* context_id) { 402 auto it = device_to_client_cache_.find(device); 403 if (it != device_to_client_cache_.end()) { 404 *client = it->second.first; 405 *context_id = it->second.second; 406 } 407 string device_task_name; 408 TF_RETURN_IF_ERROR(GetTaskName(device, &device_task_name)); 409 410 *client = remote_eager_workers_->GetClient(device_task_name); 411 412 if (*client == nullptr) { 413 return errors::InvalidArgument( 414 "Unable to find eager client corresponding to device ", device->name()); 415 } 416 417 auto context_iterator = remote_contexts_.find(device_task_name); 418 if (context_iterator == remote_contexts_.end()) { 419 return errors::Internal("Unable to find a context for handle on task: ", 420 device_task_name, ". This should not be possible"); 421 } 422 *context_id = context_iterator->second; 423 424 device_to_client_cache_.insert({device, {*client, *context_id}}); 425 426 return Status::OK(); 427 } 428 429 Status EagerContext::StoreCollectiveOpsServer( 430 std::unique_ptr<ServerInterface> server, DeviceMgr* device_mgr, 431 CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) { 432 collective_executor_mgr_.reset(nullptr); 433 unowned_collective_executor_mgr_ = rpc_collective_executor_mgr; 434 435 local_device_manager_.reset(nullptr); 436 local_unowned_device_manager_ = device_mgr; 437 438 devices_ = local_unowned_device_manager_->ListDevices(); 439 devices_map_.clear(); 440 441 InitDeviceMapAndAsync(); 442 TF_RETURN_IF_ERROR(ClearCaches()); 443 444 pflr_.reset(new ProcessFunctionLibraryRuntime( 445 local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_, 446 {}, thread_pool_.get())); 447 448 // Memory leak! 449 if (server_ != nullptr) { 450 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. " 451 "Servers don't support clean shutdown."; 452 server_.release(); 453 } 454 server_ = std::move(server); 455 456 return Status::OK(); 457 } 458 459 Status EagerContext::InitializeRemote( 460 std::unique_ptr<ServerInterface> server, 461 std::unique_ptr<eager::EagerClientCache> remote_eager_workers, 462 std::unique_ptr<DeviceMgr> remote_device_manager, 463 const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r, 464 DeviceMgr* local_device_mgr, int keep_alive_secs) { 465 mutex_lock l(remote_state_mu_); 466 467 if (!remote_contexts_.empty()) { 468 CloseRemoteContexts(); 469 } 470 remote_contexts_ = remote_contexts; 471 472 use_send_tensor_rpc_ = 473 ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false); 474 475 local_unowned_device_manager_ = local_device_mgr; 476 local_device_manager_ = nullptr; 477 pflr_.reset(new ProcessFunctionLibraryRuntime( 478 local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_, 479 {}, thread_pool_.get())); 480 481 devices_ = local_unowned_device_manager_->ListDevices(); 482 devices_map_.clear(); 483 484 if (rendezvous_ != nullptr) rendezvous_->Unref(); 485 rendezvous_ = r; 486 487 // Memory leak! 488 if (server_ != nullptr) { 489 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. " 490 "Servers don't support clean shutdown."; 491 server_.release(); 492 } 493 494 server_ = std::move(server); 495 remote_eager_workers_ = std::move(remote_eager_workers); 496 497 active_remote_contexts_.clear(); 498 for (const auto& remote_context : remote_contexts_) { 499 active_remote_contexts_.insert(remote_context.second); 500 } 501 502 device_to_client_cache_.clear(); 503 remote_device_manager_ = std::move(remote_device_manager); 504 505 InitDeviceMapAndAsync(); 506 507 TF_RETURN_IF_ERROR(ClearCaches()); 508 509 keep_alive_secs_ = keep_alive_secs; 510 511 sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2); 512 513 // Only schedule a single closure. 514 if (keep_alive_thread_ == nullptr) { 515 keep_alive_thread_.reset( 516 env_->StartThread({}, "EagerKeepAliveThread", [this]() { 517 while (true) { 518 { 519 { 520 mutex_lock l(keep_alive_thread_shutdown_mu_); 521 keep_alive_thread_cv_.wait_for( 522 l, std::chrono::seconds(sleep_for_secs_)); 523 524 if (shutting_down_) { 525 return; 526 } 527 } 528 { 529 mutex_lock l(remote_state_mu_); 530 if (keep_alive_secs_ > 0) { 531 { 532 for (const auto& worker_and_context_id : remote_contexts_) { 533 auto* client = remote_eager_workers_->GetClient( 534 worker_and_context_id.first); 535 536 eager::KeepAliveRequest* request = 537 new eager::KeepAliveRequest; 538 eager::KeepAliveResponse* response = 539 new eager::KeepAliveResponse; 540 541 request->set_context_id(worker_and_context_id.second); 542 client->KeepAliveAsync( 543 request, response, 544 [request, response](const Status& s) { 545 delete request; 546 delete response; 547 }); 548 } 549 } 550 } 551 } 552 } 553 } 554 })); 555 } 556 return Status::OK(); 557 } 558 #endif 559 560 } // namespace tensorflow 561