1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/common_runtime/direct_session.h" 17 18 #include <atomic> 19 #include <string> 20 #include <vector> 21 22 #include "tensorflow/core/common_runtime/constant_folding.h" 23 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 24 #include "tensorflow/core/common_runtime/device_factory.h" 25 #include "tensorflow/core/common_runtime/executor.h" 26 #include "tensorflow/core/common_runtime/function.h" 27 #include "tensorflow/core/common_runtime/graph_optimizer.h" 28 #include "tensorflow/core/common_runtime/memory_types.h" 29 #include "tensorflow/core/common_runtime/optimization_registry.h" 30 #include "tensorflow/core/common_runtime/step_stats_collector.h" 31 #include "tensorflow/core/framework/function.h" 32 #include "tensorflow/core/framework/graph.pb_text.h" 33 #include "tensorflow/core/framework/graph.pb.h" 34 #include "tensorflow/core/framework/graph_def_util.h" 35 #include "tensorflow/core/framework/log_memory.h" 36 #include "tensorflow/core/framework/node_def.pb.h" 37 #include "tensorflow/core/framework/tensor.h" 38 #include "tensorflow/core/framework/versions.pb.h" 39 #include "tensorflow/core/graph/algorithm.h" 40 #include "tensorflow/core/graph/graph.h" 41 #include "tensorflow/core/graph/graph_constructor.h" 42 #include "tensorflow/core/graph/graph_partition.h" 43 #include "tensorflow/core/graph/subgraph.h" 44 #include "tensorflow/core/graph/tensor_id.h" 45 #include "tensorflow/core/lib/core/errors.h" 46 #include "tensorflow/core/lib/core/notification.h" 47 #include "tensorflow/core/lib/core/refcount.h" 48 #include "tensorflow/core/lib/core/status.h" 49 #include "tensorflow/core/lib/core/threadpool.h" 50 #include "tensorflow/core/lib/gtl/array_slice.h" 51 #include "tensorflow/core/lib/gtl/stl_util.h" 52 #include "tensorflow/core/lib/monitoring/counter.h" 53 #include "tensorflow/core/lib/strings/numbers.h" 54 #include "tensorflow/core/lib/strings/str_util.h" 55 #include "tensorflow/core/lib/strings/strcat.h" 56 #include "tensorflow/core/platform/cpu_info.h" 57 #include "tensorflow/core/platform/device_tracer.h" 58 #include "tensorflow/core/platform/logging.h" 59 #include "tensorflow/core/platform/mutex.h" 60 #include "tensorflow/core/platform/types.h" 61 #include "tensorflow/core/util/device_name_utils.h" 62 #include "tensorflow/core/util/env_var.h" 63 64 namespace tensorflow { 65 66 namespace { 67 68 auto* direct_session_runs = monitoring::Counter<0>::New( 69 "/tensorflow/core/direct_session_runs", 70 "The number of times DirectSession::Run() has been called."); 71 72 int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) { 73 const int32 t = options.config.inter_op_parallelism_threads(); 74 if (t != 0) return t; 75 // Default to using the number of cores available in the process. 76 return port::NumSchedulableCPUs(); 77 } 78 79 thread::ThreadPool* NewThreadPoolFromSessionOptions( 80 const SessionOptions& options) { 81 const int32 num_threads = NumInterOpThreadsFromSessionOptions(options); 82 VLOG(1) << "Direct session inter op parallelism threads: " << num_threads; 83 return new thread::ThreadPool(options.env, "Compute", num_threads); 84 } 85 86 Status NewThreadPoolFromThreadPoolOptions( 87 const SessionOptions& options, 88 const ThreadPoolOptionProto& thread_pool_options, int pool_number, 89 thread::ThreadPool** pool, bool* owned) { 90 int32 num_threads = thread_pool_options.num_threads(); 91 if (num_threads == 0) { 92 num_threads = NumInterOpThreadsFromSessionOptions(options); 93 } 94 const string& name = thread_pool_options.global_name(); 95 if (name.empty()) { 96 // Session-local threadpool. 97 VLOG(1) << "Direct session inter op parallelism threads for pool " 98 << pool_number << ": " << num_threads; 99 *pool = new thread::ThreadPool( 100 options.env, strings::StrCat("Compute", pool_number), num_threads); 101 *owned = true; 102 return Status::OK(); 103 } 104 105 // Global, named threadpool. 106 typedef std::pair<int32, thread::ThreadPool*> MapValue; 107 static std::map<string, MapValue>* global_pool_map = 108 new std::map<string, MapValue>; 109 static mutex* mu = new mutex(); 110 mutex_lock l(*mu); 111 MapValue* mvalue = &(*global_pool_map)[name]; 112 if (mvalue->second == nullptr) { 113 mvalue->first = thread_pool_options.num_threads(); 114 mvalue->second = new thread::ThreadPool( 115 options.env, strings::StrCat("Compute", pool_number), num_threads); 116 } else { 117 if (mvalue->first != thread_pool_options.num_threads()) { 118 return errors::InvalidArgument( 119 "Pool ", name, 120 " configured previously with num_threads=", mvalue->first, 121 "; cannot re-configure with num_threads=", 122 thread_pool_options.num_threads()); 123 } 124 } 125 *owned = false; 126 *pool = mvalue->second; 127 return Status::OK(); 128 } 129 130 thread::ThreadPool* GlobalThreadPool(const SessionOptions& options) { 131 static thread::ThreadPool* const thread_pool = 132 NewThreadPoolFromSessionOptions(options); 133 return thread_pool; 134 } 135 136 // TODO(vrv): Figure out how to unify the many different functions 137 // that generate RendezvousKey, since many of them have to be 138 // consistent with each other. 139 string GetRendezvousKey(const string& tensor_name, 140 const DeviceAttributes& device_info, 141 const FrameAndIter& frame_iter) { 142 return strings::StrCat(device_info.name(), ";", 143 strings::FpToString(device_info.incarnation()), ";", 144 device_info.name(), ";", tensor_name, ";", 145 frame_iter.frame_id, ":", frame_iter.iter_id); 146 } 147 148 } // namespace 149 150 class DirectSessionFactory : public SessionFactory { 151 public: 152 DirectSessionFactory() {} 153 154 bool AcceptsOptions(const SessionOptions& options) override { 155 return options.target.empty(); 156 } 157 158 Session* NewSession(const SessionOptions& options) override { 159 // Must do this before the CPU allocator is created. 160 if (options.config.graph_options().build_cost_model() > 0) { 161 EnableCPUAllocatorFullStats(true); 162 } 163 std::vector<Device*> devices; 164 const Status s = DeviceFactory::AddDevices( 165 options, "/job:localhost/replica:0/task:0", &devices); 166 if (!s.ok()) { 167 LOG(ERROR) << s; 168 return nullptr; 169 } 170 171 DirectSession* session = 172 new DirectSession(options, new DeviceMgr(devices), this); 173 { 174 mutex_lock l(sessions_lock_); 175 sessions_.push_back(session); 176 } 177 return session; 178 } 179 180 Status Reset(const SessionOptions& options, 181 const std::vector<string>& containers) override { 182 std::vector<DirectSession*> sessions_to_reset; 183 { 184 mutex_lock l(sessions_lock_); 185 // We create a copy to ensure that we don't have a deadlock when 186 // session->Close calls the DirectSessionFactory.Deregister, which 187 // acquires sessions_lock_. 188 std::swap(sessions_to_reset, sessions_); 189 } 190 Status s; 191 for (auto session : sessions_to_reset) { 192 s.Update(session->Reset(containers)); 193 } 194 // TODO(suharshs): Change the Reset behavior of all SessionFactories so that 195 // it doesn't close the sessions? 196 for (auto session : sessions_to_reset) { 197 s.Update(session->Close()); 198 } 199 return s; 200 } 201 202 void Deregister(const DirectSession* session) { 203 mutex_lock l(sessions_lock_); 204 sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session), 205 sessions_.end()); 206 } 207 208 private: 209 mutex sessions_lock_; 210 std::vector<DirectSession*> sessions_ GUARDED_BY(sessions_lock_); 211 }; 212 213 class DirectSessionRegistrar { 214 public: 215 DirectSessionRegistrar() { 216 SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory()); 217 } 218 }; 219 static DirectSessionRegistrar registrar; 220 221 std::atomic_int_fast64_t DirectSession::step_id_counter_(1); 222 223 // NOTE: On Android with a single device, there is never 224 // a risk of an OpKernel blocking indefinitely: 225 // 226 // 1) No operations do I/O that depends on other simultaneous kernels, 227 // 228 // 2) Recv nodes always complete immediately: The inputs are sent into 229 // the local rendezvous before we start the executor, so the 230 // corresponding recvs will not block. 231 // 232 // Based on these assumptions, we can use the same thread pool for 233 // both "non-blocking" and "blocking" OpKernels on Android. 234 // 235 // This may change down the road when we add support for multiple 236 // devices that run concurrently, in which case we will need to 237 // revisit this decision. 238 void DirectSession::SchedClosure(thread::ThreadPool* pool, 239 std::function<void()> c) { 240 // TODO(sanjay): Get rid of __ANDROID__ path 241 #ifdef __ANDROID__ 242 // On Android, there is no implementation of ThreadPool that takes 243 // std::function, only Closure, which we cannot easily convert. 244 // 245 // Instead, we just run the function in-line, which is currently 246 // safe given the reasoning above. 247 c(); 248 #else 249 pool->Schedule(std::move(c)); 250 #endif // __ANDROID__ 251 } 252 253 DirectSession::DirectSession(const SessionOptions& options, 254 const DeviceMgr* device_mgr, 255 DirectSessionFactory* const factory) 256 : options_(options), 257 device_mgr_(device_mgr), 258 factory_(factory), 259 cancellation_manager_(new CancellationManager()), 260 operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) { 261 const int thread_pool_size = 262 options_.config.session_inter_op_thread_pool_size(); 263 if (thread_pool_size > 0) { 264 for (int i = 0; i < thread_pool_size; ++i) { 265 thread::ThreadPool* pool = nullptr; 266 bool owned = false; 267 init_error_.Update(NewThreadPoolFromThreadPoolOptions( 268 options_, options_.config.session_inter_op_thread_pool(i), i, &pool, 269 &owned)); 270 thread_pools_.emplace_back(pool, owned); 271 } 272 } else if (options_.config.use_per_session_threads()) { 273 thread_pools_.emplace_back(NewThreadPoolFromSessionOptions(options_), 274 true /* owned */); 275 } else { 276 thread_pools_.emplace_back(GlobalThreadPool(options), false /* owned */); 277 } 278 // The default value of sync_on_finish will be flipped soon and this 279 // environment variable will be removed as well. 280 const Status status = 281 ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_); 282 if (!status.ok()) { 283 LOG(ERROR) << status.error_message(); 284 } 285 // NOTE(mrry): We do not need to use a unique string for the session 286 // handle, because DirectSession owns its devices. This may change 287 // in future versions. 288 session_handle_ = "direct"; 289 int devices_added = 0; 290 if (options.config.log_device_placement()) { 291 const string mapping_str = device_mgr_->DeviceMappingString(); 292 if (mapping_str.empty()) { 293 printf("Device mapping: no known devices.\n"); 294 } else { 295 printf("Device mapping:\n%s", mapping_str.c_str()); 296 } 297 LOG(INFO) << "Device mapping:\n" << mapping_str; 298 } 299 for (auto d : device_mgr_->ListDevices()) { 300 devices_.push_back(d); 301 device_set_.AddDevice(d); 302 d->op_segment()->AddHold(session_handle_); 303 304 // The first device added is special: it is the 'client device' (a 305 // CPU device) from which we feed and fetch Tensors. 306 if (devices_added == 0) { 307 device_set_.set_client_device(d); 308 } 309 ++devices_added; 310 } 311 } 312 313 DirectSession::~DirectSession() { 314 if (!closed_) Close().IgnoreError(); 315 for (auto& it : partial_runs_) { 316 it.second.reset(nullptr); 317 } 318 for (auto& it : executors_) { 319 it.second.reset(); 320 } 321 for (auto d : device_mgr_->ListDevices()) { 322 d->op_segment()->RemoveHold(session_handle_); 323 } 324 for (auto d : device_mgr_->ListDevices()) { 325 d->ClearResourceMgr(); 326 } 327 functions_.clear(); 328 delete cancellation_manager_; 329 for (const auto& p_and_owned : thread_pools_) { 330 if (p_and_owned.second) delete p_and_owned.first; 331 } 332 333 execution_state_.reset(nullptr); 334 flib_def_.reset(nullptr); 335 } 336 337 Status DirectSession::MaybeInitializeExecutionState( 338 const GraphDef& graph, bool* out_already_initialized) { 339 // If already initialized, do nothing. 340 if (flib_def_ && execution_state_) { 341 *out_already_initialized = true; 342 return Status::OK(); 343 } 344 // Set up the per-session execution state. 345 // NOTE(mrry): The function library created here will be used for 346 // all subsequent extensions of the graph. 347 flib_def_.reset( 348 new FunctionLibraryDefinition(OpRegistry::Global(), graph.library())); 349 GraphExecutionStateOptions options; 350 options.device_set = &device_set_; 351 options.session_options = &options_; 352 // TODO(mrry,suharshs): We explicitly copy `graph` so that 353 // `MakeForBaseGraph()` can take ownership of its 354 // contents. Previously this happened implicitly in calls to the 355 // `GraphExecutionState`. Other sessions call 356 // `MakeForBaseGraph` in such a way that we can destructively read 357 // the passed-in `GraphDef`. In principle we could do the same here, 358 // with a wider refactoring; we might revise the direct session so 359 // that it copies the graph fewer times. 360 GraphDef temp(graph); 361 TF_RETURN_IF_ERROR( 362 GraphExecutionState::MakeForBaseGraph(&temp, options, &execution_state_)); 363 graph_created_ = true; 364 *out_already_initialized = false; 365 return Status::OK(); 366 } 367 368 Status DirectSession::Create(const GraphDef& graph) { 369 TF_RETURN_IF_ERROR(init_error_); 370 if (graph.node_size() > 0) { 371 mutex_lock l(graph_def_lock_); 372 if (graph_created_) { 373 return errors::AlreadyExists( 374 "A Graph has already been created for this session."); 375 } 376 return ExtendLocked(graph); 377 } 378 return Status::OK(); 379 } 380 381 Status DirectSession::Extend(const GraphDef& graph) { 382 TF_RETURN_IF_ERROR(CheckNotClosed()); 383 mutex_lock l(graph_def_lock_); 384 return ExtendLocked(graph); 385 } 386 387 Status DirectSession::ExtendLocked(const GraphDef& graph) { 388 bool already_initialized; 389 // If this is the first call, we can initialize the execution state 390 // with `graph` and do not need to call `Extend()`. 391 TF_RETURN_IF_ERROR( 392 MaybeInitializeExecutionState(graph, &already_initialized)); 393 if (already_initialized) { 394 TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library())); 395 std::unique_ptr<GraphExecutionState> state; 396 TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state)); 397 execution_state_.swap(state); 398 } 399 return Status::OK(); 400 } 401 402 Status DirectSession::Run(const NamedTensorList& inputs, 403 const std::vector<string>& output_names, 404 const std::vector<string>& target_nodes, 405 std::vector<Tensor>* outputs) { 406 RunMetadata run_metadata; 407 return Run(RunOptions(), inputs, output_names, target_nodes, outputs, 408 &run_metadata); 409 } 410 411 Status DirectSession::CreateDebuggerState( 412 const DebugOptions& debug_options, int64 session_run_index, 413 int64 executor_step_index, const std::vector<string>& input_names, 414 const std::vector<string>& output_names, 415 const std::vector<string>& target_names, 416 std::unique_ptr<DebuggerStateInterface>* debugger_state) { 417 TF_RETURN_IF_ERROR( 418 DebuggerStateRegistry::CreateState(debug_options, debugger_state)); 419 TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata( 420 debug_options.global_step(), session_run_index, executor_step_index, 421 input_names, output_names, target_names)); 422 return Status::OK(); 423 } 424 425 Status DirectSession::DecorateAndPublishGraphForDebug( 426 const DebugOptions& debug_options, Graph* graph, Device* device) { 427 std::unique_ptr<DebugGraphDecoratorInterface> decorator; 428 TF_RETURN_IF_ERROR( 429 DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator)); 430 431 TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device)); 432 TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name())); 433 return Status::OK(); 434 } 435 436 Status DirectSession::Run(const RunOptions& run_options, 437 const NamedTensorList& inputs, 438 const std::vector<string>& output_names, 439 const std::vector<string>& target_nodes, 440 std::vector<Tensor>* outputs, 441 RunMetadata* run_metadata) { 442 TF_RETURN_IF_ERROR(CheckNotClosed()); 443 direct_session_runs->GetCell()->IncrementBy(1); 444 { 445 mutex_lock l(graph_def_lock_); 446 if (!graph_created_) { 447 return errors::InvalidArgument( 448 "Session was not created with a graph before Run()!"); 449 } 450 } 451 452 // Extract the inputs names for this run of the session. 453 std::vector<string> input_tensor_names; 454 input_tensor_names.reserve(inputs.size()); 455 for (const auto& it : inputs) { 456 input_tensor_names.push_back(it.first); 457 } 458 459 if (run_options.inter_op_thread_pool() < 0 || 460 run_options.inter_op_thread_pool() >= thread_pools_.size()) { 461 return errors::InvalidArgument("Invalid inter_op_thread_pool: ", 462 run_options.inter_op_thread_pool()); 463 } 464 thread::ThreadPool* pool = 465 thread_pools_[run_options.inter_op_thread_pool()].first; 466 467 // Check if we already have an executor for these arguments. 468 ExecutorsAndKeys* executors_and_keys; 469 RunStateArgs run_state_args(run_options.debug_options()); 470 471 Executor::Args args; 472 args.step_id = step_id_counter_.fetch_add(1); 473 474 TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names, 475 target_nodes, &executors_and_keys, 476 &run_state_args)); 477 const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1); 478 479 std::unique_ptr<DebuggerStateInterface> debugger_state; 480 if (!run_options.debug_options().debug_tensor_watch_opts().empty()) { 481 TF_RETURN_IF_ERROR(CreateDebuggerState( 482 run_options.debug_options(), args.step_id, executor_step_count, 483 input_tensor_names, output_names, target_nodes, &debugger_state)); 484 } 485 486 // Configure a call frame for the step, which we use to feed and 487 // fetch values to and from the executors. 488 FunctionCallFrame call_frame(executors_and_keys->input_types, 489 executors_and_keys->output_types); 490 gtl::InlinedVector<Tensor, 4> feed_args(inputs.size()); 491 for (const auto& it : inputs) { 492 if (it.second.dtype() == DT_RESOURCE) { 493 Tensor tensor_from_handle; 494 TF_RETURN_IF_ERROR( 495 ResourceHandleToInputTensor(it.second, &tensor_from_handle)); 496 feed_args[executors_and_keys->input_name_to_index[it.first]] = 497 tensor_from_handle; 498 } else { 499 feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second; 500 } 501 } 502 const Status s = call_frame.SetArgs(feed_args); 503 if (errors::IsInternal(s)) { 504 return errors::InvalidArgument(s.error_message()); 505 } else if (!s.ok()) { 506 return s; 507 } 508 509 // Create a run state and start execution. 510 RunState run_state(args.step_id, &devices_); 511 run_state.rendez = new IntraProcessRendezvous(device_mgr_.get()); 512 CancellationManager step_cancellation_manager; 513 args.call_frame = &call_frame; 514 515 // Start parallel Executors. 516 const size_t num_executors = executors_and_keys->items.size(); 517 ExecutorBarrier* barrier = new ExecutorBarrier( 518 num_executors, run_state.rendez, [&run_state](const Status& ret) { 519 { 520 mutex_lock l(run_state.mu_); 521 run_state.status.Update(ret); 522 } 523 run_state.executors_done.Notify(); 524 }); 525 526 args.rendezvous = run_state.rendez; 527 args.cancellation_manager = &step_cancellation_manager; 528 529 args.session_state = &session_state_; 530 args.tensor_store = &run_state.tensor_store; 531 args.step_container = &run_state.step_container; 532 if (LogMemory::IsEnabled()) { 533 LogMemory::RecordStep(args.step_id, run_state_args.handle); 534 } 535 args.sync_on_finish = sync_on_finish_; 536 537 const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE); 538 539 bool update_cost_model = false; 540 if (options_.config.graph_options().build_cost_model() > 0) { 541 const int64 build_cost_model_every = 542 options_.config.graph_options().build_cost_model(); 543 const int64 build_cost_model_after = 544 options_.config.graph_options().build_cost_model_after(); 545 int64 measure_step_count = executor_step_count - build_cost_model_after; 546 if (measure_step_count >= 0) { 547 update_cost_model = 548 ((measure_step_count + 1) % build_cost_model_every == 0); 549 } 550 } 551 if (do_trace || update_cost_model || 552 run_options.report_tensor_allocations_upon_oom()) { 553 run_state.collector.reset( 554 new StepStatsCollector(run_metadata->mutable_step_stats())); 555 args.stats_collector = run_state.collector.get(); 556 } 557 558 std::unique_ptr<DeviceTracer> tracer; 559 if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) { 560 tracer = CreateDeviceTracer(); 561 // tracer may be NULL on platforms without accelerators. 562 if (tracer) { 563 Status s = tracer->Start(); 564 if (!s.ok()) { 565 run_state.executors_done.Notify(); 566 delete barrier; 567 return s; 568 } 569 } 570 } 571 572 // Register this step with session's cancellation manager, so that 573 // `Session::Close()` will cancel the step. 574 const CancellationToken cancellation_token = 575 cancellation_manager_->get_cancellation_token(); 576 const bool already_cancelled = !cancellation_manager_->RegisterCallback( 577 cancellation_token, [&step_cancellation_manager]() { 578 step_cancellation_manager.StartCancel(); 579 }); 580 if (already_cancelled) { 581 // NOTE(mrry): If we don't explicitly notify 582 // `run_state.executors_done`, the RunState destructor would 583 // block on this notification. 584 run_state.executors_done.Notify(); 585 delete barrier; 586 return errors::Cancelled("Run call was cancelled"); 587 } 588 589 Executor::Args::Runner default_runner = [this, 590 pool](Executor::Args::Closure c) { 591 SchedClosure(pool, std::move(c)); 592 }; 593 for (const auto& item : executors_and_keys->items) { 594 // TODO(zhengxq): support partial run. 595 // TODO(zhengxq): if the device picks its own threadpool, we need to assign 596 // less threads to the main compute pool by default. 597 thread::ThreadPool* device_thread_pool = 598 item.device->tensorflow_device_thread_pool(); 599 if (!device_thread_pool) { 600 args.runner = default_runner; 601 } else { 602 args.runner = [this, device_thread_pool](Executor::Args::Closure c) { 603 SchedClosure(device_thread_pool, std::move(c)); 604 }; 605 } 606 item.executor->RunAsync(args, barrier->Get()); 607 } 608 609 WaitForNotification(&run_state, &step_cancellation_manager, 610 run_options.timeout_in_ms() > 0 611 ? run_options.timeout_in_ms() 612 : operation_timeout_in_ms_); 613 614 if (!cancellation_manager_->DeregisterCallback(cancellation_token)) { 615 // The step has been cancelled: make sure we don't attempt to receive the 616 // outputs as this would make it block forever. 617 mutex_lock l(run_state.mu_); 618 run_state.status.Update(errors::Cancelled("Run call was cancelled")); 619 } 620 621 if (tracer) { 622 TF_RETURN_IF_ERROR(tracer->Stop()); 623 TF_RETURN_IF_ERROR(tracer->Collect(args.stats_collector)); 624 } 625 626 { 627 mutex_lock l(run_state.mu_); 628 TF_RETURN_IF_ERROR(run_state.status); 629 } 630 631 // Receive outputs. 632 if (outputs) { 633 std::vector<Tensor> sorted_outputs; 634 const Status s = call_frame.ConsumeRetvals(&sorted_outputs); 635 if (errors::IsInternal(s)) { 636 return errors::InvalidArgument(s.error_message()); 637 } else if (!s.ok()) { 638 return s; 639 } 640 const bool unique_outputs = 641 output_names.size() == executors_and_keys->output_name_to_index.size(); 642 // first_indices[i] = j implies that j is the smallest value for which 643 // output_names[i] == output_names[j]. 644 std::vector<int> first_indices; 645 if (!unique_outputs) { 646 first_indices.resize(output_names.size()); 647 for (int i = 0; i < output_names.size(); ++i) { 648 for (int j = 0; j <= i; ++j) { 649 if (output_names[i] == output_names[j]) { 650 first_indices[i] = j; 651 break; 652 } 653 } 654 } 655 } 656 outputs->clear(); 657 outputs->reserve(sorted_outputs.size()); 658 for (int i = 0; i < output_names.size(); ++i) { 659 const string& output_name = output_names[i]; 660 if (first_indices.empty() || first_indices[i] == i) { 661 outputs->emplace_back( 662 std::move(sorted_outputs[executors_and_keys 663 ->output_name_to_index[output_name]])); 664 } else { 665 outputs->push_back((*outputs)[first_indices[i]]); 666 } 667 } 668 } 669 670 // Save the output tensors of this run we choose to keep. 671 TF_RETURN_IF_ERROR( 672 run_state.tensor_store.SaveTensors(output_names, &session_state_)); 673 if (args.stats_collector) { 674 args.stats_collector->Finalize(); 675 } 676 677 // Build and return the cost model as instructed. 678 mutex_lock l(executor_lock_); 679 if (update_cost_model) { 680 // Build the cost model 681 std::unordered_map<string, const Graph*> device_to_graph; 682 for (const PerPartitionExecutorsAndLib& partition : 683 executors_and_keys->items) { 684 const Graph* graph = partition.graph; 685 const string device = partition.flib->device()->name(); 686 device_to_graph[device] = graph; 687 } 688 args.stats_collector->BuildCostModel(&cost_model_manager_, device_to_graph); 689 690 // annotate stats onto cost graph. 691 CostGraphDef* cost_graph = run_metadata->mutable_cost_graph(); 692 for (const auto& item : executors_and_keys->items) { 693 TF_RETURN_IF_ERROR( 694 cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph)); 695 } 696 } 697 698 // If requested via RunOptions, output the partition graphs. 699 if (run_options.output_partition_graphs()) { 700 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs = 701 run_metadata->mutable_partition_graphs(); 702 for (const PerPartitionExecutorsAndLib& exec_and_lib : 703 executors_and_keys->items) { 704 GraphDef* partition_graph_def = partition_graph_defs->Add(); 705 exec_and_lib.graph->ToGraphDef(partition_graph_def); 706 } 707 } 708 709 return Status::OK(); 710 } 711 712 Status DirectSession::PRunSetup(const std::vector<string>& input_names, 713 const std::vector<string>& output_names, 714 const std::vector<string>& target_nodes, 715 string* handle) { 716 TF_RETURN_IF_ERROR(CheckNotClosed()); 717 { 718 mutex_lock l(graph_def_lock_); 719 if (!graph_created_) { 720 return errors::InvalidArgument( 721 "Session was not created with a graph before PRunSetup()!"); 722 } 723 } 724 725 // RunOptions is not available in PRunSetup, so use thread pool 0. 726 thread::ThreadPool* pool = thread_pools_[0].first; 727 728 // Check if we already have an executor for these arguments. 729 ExecutorsAndKeys* executors_and_keys; 730 // TODO(cais): TFDBG support for partial runs. 731 DebugOptions debug_options; 732 RunStateArgs run_state_args(debug_options); 733 run_state_args.is_partial_run = true; 734 TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names, 735 target_nodes, &executors_and_keys, 736 &run_state_args)); 737 738 // Create the run state and save it for future PRun calls. 739 Executor::Args args; 740 args.step_id = step_id_counter_.fetch_add(1); 741 RunState* run_state = 742 new RunState(input_names, output_names, args.step_id, &devices_); 743 run_state->rendez = new IntraProcessRendezvous(device_mgr_.get()); 744 { 745 mutex_lock l(executor_lock_); 746 if (!partial_runs_ 747 .emplace(run_state_args.handle, 748 std::unique_ptr<RunState>(run_state)) 749 .second) { 750 return errors::Internal("The handle '", run_state_args.handle, 751 "' created for this partial run is not unique."); 752 } 753 } 754 755 // Start parallel Executors. 756 const size_t num_executors = executors_and_keys->items.size(); 757 ExecutorBarrier* barrier = new ExecutorBarrier( 758 num_executors, run_state->rendez, [run_state](const Status& ret) { 759 if (!ret.ok()) { 760 mutex_lock l(run_state->mu_); 761 run_state->status.Update(ret); 762 } 763 run_state->executors_done.Notify(); 764 }); 765 766 args.rendezvous = run_state->rendez; 767 args.cancellation_manager = cancellation_manager_; 768 args.runner = [this, pool](Executor::Args::Closure c) { 769 SchedClosure(pool, std::move(c)); 770 }; 771 args.session_state = &session_state_; 772 args.tensor_store = &run_state->tensor_store; 773 args.step_container = &run_state->step_container; 774 if (LogMemory::IsEnabled()) { 775 LogMemory::RecordStep(args.step_id, run_state_args.handle); 776 } 777 args.sync_on_finish = sync_on_finish_; 778 779 if (options_.config.graph_options().build_cost_model()) { 780 run_state->collector.reset(new StepStatsCollector(nullptr)); 781 args.stats_collector = run_state->collector.get(); 782 } 783 784 for (auto& item : executors_and_keys->items) { 785 item.executor->RunAsync(args, barrier->Get()); 786 } 787 788 *handle = run_state_args.handle; 789 return Status::OK(); 790 } 791 792 Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs, 793 const std::vector<string>& output_names, 794 std::vector<Tensor>* outputs) { 795 TF_RETURN_IF_ERROR(CheckNotClosed()); 796 std::vector<string> parts = str_util::Split(handle, ';'); 797 const string& key = parts[0]; 798 // Get the executors for this partial run. 799 ExecutorsAndKeys* executors_and_keys; 800 RunState* run_state; 801 { 802 mutex_lock l(executor_lock_); // could use reader lock 803 auto exc_it = executors_.find(key); 804 if (exc_it == executors_.end()) { 805 return errors::InvalidArgument( 806 "Must run 'setup' before performing partial runs!"); 807 } 808 executors_and_keys = exc_it->second.get(); 809 810 auto prun_it = partial_runs_.find(handle); 811 if (prun_it == partial_runs_.end()) { 812 return errors::InvalidArgument( 813 "Must run 'setup' before performing partial runs!"); 814 } 815 run_state = prun_it->second.get(); 816 817 // Make sure that this is a new set of feeds that are still pending. 818 for (const auto& input : inputs) { 819 auto it = run_state->pending_inputs.find(input.first); 820 if (it == run_state->pending_inputs.end()) { 821 return errors::InvalidArgument( 822 "The feed ", input.first, 823 " was not specified in partial_run_setup."); 824 } else if (it->second) { 825 return errors::InvalidArgument("The feed ", input.first, 826 " has already been fed."); 827 } 828 } 829 // Check that this is a new set of fetches that are still pending. 830 for (const auto& output : output_names) { 831 auto it = run_state->pending_outputs.find(output); 832 if (it == run_state->pending_outputs.end()) { 833 return errors::InvalidArgument( 834 "The fetch ", output, " was not specified in partial_run_setup."); 835 } else if (it->second) { 836 return errors::InvalidArgument("The fetch ", output, 837 " has already been fetched."); 838 } 839 } 840 } 841 842 // Check that this new set of fetches can be computed from all the 843 // feeds we have supplied. 844 TF_RETURN_IF_ERROR( 845 CheckFetch(inputs, output_names, executors_and_keys, run_state)); 846 847 // Send inputs. 848 Status s = SendPRunInputs(inputs, executors_and_keys, run_state->rendez); 849 850 // Receive outputs. 851 if (s.ok()) { 852 s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs); 853 } 854 855 // Save the output tensors of this run we choose to keep. 856 if (s.ok()) { 857 s = run_state->tensor_store.SaveTensors(output_names, &session_state_); 858 } 859 860 { 861 mutex_lock l(executor_lock_); 862 // Delete the run state if there is an error or all fetches are done. 863 bool done = true; 864 if (s.ok()) { 865 { 866 mutex_lock l(run_state->mu_); 867 if (!run_state->status.ok()) { 868 LOG(WARNING) << "An error unrelated to this prun has been detected. " 869 << run_state->status; 870 } 871 } 872 for (const auto& input : inputs) { 873 auto it = run_state->pending_inputs.find(input.first); 874 it->second = true; 875 } 876 for (const auto& name : output_names) { 877 auto it = run_state->pending_outputs.find(name); 878 it->second = true; 879 } 880 done = run_state->PendingDone(); 881 } 882 if (done) { 883 WaitForNotification(run_state, cancellation_manager_, 884 operation_timeout_in_ms_); 885 partial_runs_.erase(handle); 886 } 887 } 888 889 return s; 890 } 891 892 Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor, 893 Tensor* retrieved_tensor) { 894 if (resource_tensor.dtype() != DT_RESOURCE) { 895 return errors::InvalidArgument(strings::StrCat( 896 "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ", 897 resource_tensor.dtype())); 898 } 899 900 const ResourceHandle& resource_handle = 901 resource_tensor.scalar<ResourceHandle>()(); 902 903 if (resource_handle.container() == 904 SessionState::kTensorHandleResourceTypeName) { 905 return session_state_.GetTensor(resource_handle.name(), retrieved_tensor); 906 } else { 907 return errors::InvalidArgument(strings::StrCat( 908 "Invalid resource type hash code: ", resource_handle.hash_code(), 909 "(name: ", resource_handle.name(), 910 " type: ", resource_handle.maybe_type_name(), 911 "). Perhaps a resource tensor was being provided as a feed? That is " 912 "not currently allowed. Please file an issue at " 913 "https://github.com/tensorflow/tensorflow/issues/new, ideally with a " 914 "short code snippet that leads to this error message.")); 915 } 916 } 917 918 Status DirectSession::SendPRunInputs(const NamedTensorList& inputs, 919 const ExecutorsAndKeys* executors_and_keys, 920 IntraProcessRendezvous* rendez) { 921 Status s; 922 Rendezvous::ParsedKey parsed; 923 // Insert the input tensors into the local rendezvous by their 924 // rendezvous key. 925 for (const auto& input : inputs) { 926 auto it = 927 executors_and_keys->input_name_to_rendezvous_key.find(input.first); 928 if (it == executors_and_keys->input_name_to_rendezvous_key.end()) { 929 return errors::Internal("'", input.first, "' is not a pre-defined feed."); 930 } 931 const string& input_key = it->second; 932 933 s = Rendezvous::ParseKey(input_key, &parsed); 934 if (!s.ok()) { 935 rendez->StartAbort(s); 936 return s; 937 } 938 939 if (input.second.dtype() == DT_RESOURCE) { 940 Tensor tensor_from_handle; 941 s = ResourceHandleToInputTensor(input.second, &tensor_from_handle); 942 if (s.ok()) { 943 s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false); 944 } 945 } else { 946 s = rendez->Send(parsed, Rendezvous::Args(), input.second, false); 947 } 948 949 if (!s.ok()) { 950 rendez->StartAbort(s); 951 return s; 952 } 953 } 954 return Status::OK(); 955 } 956 957 Status DirectSession::RecvPRunOutputs( 958 const std::vector<string>& output_names, 959 const ExecutorsAndKeys* executors_and_keys, RunState* run_state, 960 std::vector<Tensor>* outputs) { 961 Status s; 962 if (!output_names.empty()) { 963 outputs->resize(output_names.size()); 964 } 965 966 Rendezvous::ParsedKey parsed; 967 // Get the outputs from the rendezvous 968 for (size_t output_offset = 0; output_offset < output_names.size(); 969 ++output_offset) { 970 const string& output_name = output_names[output_offset]; 971 auto it = 972 executors_and_keys->output_name_to_rendezvous_key.find(output_name); 973 if (it == executors_and_keys->output_name_to_rendezvous_key.end()) { 974 return errors::Internal("'", output_name, 975 "' is not a pre-defined fetch."); 976 } 977 const string& output_key = it->second; 978 Tensor output_tensor; 979 bool is_dead; 980 IntraProcessRendezvous* rendez = run_state->rendez; 981 982 s = Rendezvous::ParseKey(output_key, &parsed); 983 if (s.ok()) { 984 // Fetch data from the Rendezvous. 985 s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead, 986 operation_timeout_in_ms_); 987 if (is_dead && s.ok()) { 988 s = errors::InvalidArgument("The tensor returned for ", output_name, 989 " was not valid."); 990 } 991 } 992 if (!s.ok()) { 993 rendez->StartAbort(s); 994 outputs->clear(); 995 return s; 996 } 997 998 (*outputs)[output_offset] = output_tensor; 999 } 1000 return Status::OK(); 1001 } 1002 1003 Status DirectSession::CheckFetch(const NamedTensorList& feeds, 1004 const std::vector<string>& fetches, 1005 const ExecutorsAndKeys* executors_and_keys, 1006 const RunState* run_state) { 1007 const Graph* graph = executors_and_keys->graph.get(); 1008 const NameNodeMap* name_to_node = &executors_and_keys->name_to_node; 1009 1010 // Build the set of pending feeds that we haven't seen. 1011 std::unordered_set<TensorId, TensorId::Hasher> pending_feeds; 1012 { 1013 mutex_lock l(executor_lock_); 1014 for (const auto& input : run_state->pending_inputs) { 1015 // Skip if the feed has already been fed. 1016 if (input.second) continue; 1017 TensorId id(ParseTensorName(input.first)); 1018 auto it = name_to_node->find(id.first); 1019 if (it == name_to_node->end()) { 1020 return errors::NotFound("Feed ", input.first, ": not found"); 1021 } 1022 pending_feeds.insert(id); 1023 } 1024 } 1025 for (const auto& it : feeds) { 1026 TensorId id(ParseTensorName(it.first)); 1027 pending_feeds.erase(id); 1028 } 1029 1030 // Initialize the stack with the fetch nodes. 1031 std::vector<const Node*> stack; 1032 for (const string& fetch : fetches) { 1033 TensorId id(ParseTensorName(fetch)); 1034 auto it = name_to_node->find(id.first); 1035 if (it == name_to_node->end()) { 1036 return errors::NotFound("Fetch ", fetch, ": not found"); 1037 } 1038 stack.push_back(it->second); 1039 } 1040 1041 // Any tensor needed for fetches can't be in pending_feeds. 1042 std::vector<bool> visited(graph->num_node_ids(), false); 1043 while (!stack.empty()) { 1044 const Node* n = stack.back(); 1045 stack.pop_back(); 1046 1047 for (const Edge* in_edge : n->in_edges()) { 1048 const Node* in_node = in_edge->src(); 1049 if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) { 1050 return errors::InvalidArgument("Fetch ", in_node->name(), ":", 1051 in_edge->src_output(), 1052 " can't be computed from the feeds" 1053 " that have been fed so far."); 1054 } 1055 if (!visited[in_node->id()]) { 1056 visited[in_node->id()] = true; 1057 stack.push_back(in_node); 1058 } 1059 } 1060 } 1061 return Status::OK(); 1062 } 1063 1064 Status DirectSession::GetOrCreateExecutors( 1065 gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs, 1066 gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys, 1067 RunStateArgs* run_state_args) { 1068 int64 handle_name_counter_value = -1; 1069 if (LogMemory::IsEnabled() || run_state_args->is_partial_run) { 1070 handle_name_counter_value = handle_name_counter_.fetch_add(1); 1071 } 1072 1073 string debug_tensor_watches_summary; 1074 if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) { 1075 debug_tensor_watches_summary = SummarizeDebugTensorWatches( 1076 run_state_args->debug_options.debug_tensor_watch_opts()); 1077 } 1078 1079 // Fast lookup path, no sorting. 1080 const string key = strings::StrCat( 1081 str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/", 1082 str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run, 1083 "/", debug_tensor_watches_summary); 1084 // Set the handle, if it's needed to log memory or for partial run. 1085 if (handle_name_counter_value >= 0) { 1086 run_state_args->handle = 1087 strings::StrCat(key, ";", handle_name_counter_value); 1088 } 1089 1090 // See if we already have the executors for this run. 1091 { 1092 mutex_lock l(executor_lock_); // could use reader lock 1093 auto it = executors_.find(key); 1094 if (it != executors_.end()) { 1095 *executors_and_keys = it->second.get(); 1096 return Status::OK(); 1097 } 1098 } 1099 1100 // Slow lookup path, the unsorted key missed the cache. 1101 // Sort the inputs and outputs, and look up with the sorted key in case an 1102 // earlier call used a different order of inputs and outputs. 1103 // 1104 // We could consider some other signature instead of sorting that 1105 // preserves the same property to avoid the sort in the future. 1106 std::vector<string> inputs_sorted(inputs.begin(), inputs.end()); 1107 std::sort(inputs_sorted.begin(), inputs_sorted.end()); 1108 std::vector<string> outputs_sorted(outputs.begin(), outputs.end()); 1109 std::sort(outputs_sorted.begin(), outputs_sorted.end()); 1110 std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end()); 1111 std::sort(tn_sorted.begin(), tn_sorted.end()); 1112 1113 const string sorted_key = strings::StrCat( 1114 str_util::Join(inputs_sorted, ","), "->", 1115 str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","), 1116 "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary); 1117 // Set the handle, if its needed to log memory or for partial run. 1118 if (handle_name_counter_value >= 0) { 1119 run_state_args->handle = 1120 strings::StrCat(sorted_key, ";", handle_name_counter_value); 1121 } 1122 1123 // See if we already have the executors for this run. 1124 { 1125 mutex_lock l(executor_lock_); 1126 auto it = executors_.find(sorted_key); 1127 if (it != executors_.end()) { 1128 *executors_and_keys = it->second.get(); 1129 // Insert this under the original key. 1130 executors_.emplace(key, it->second); 1131 return Status::OK(); 1132 } 1133 } 1134 1135 // Nothing found, so create the executors and store in the cache. 1136 BuildGraphOptions options; 1137 options.feed_endpoints = inputs_sorted; 1138 options.fetch_endpoints = outputs_sorted; 1139 options.target_nodes = tn_sorted; 1140 options.use_function_convention = !run_state_args->is_partial_run; 1141 if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) { 1142 options.debug_options = run_state_args->debug_options; 1143 } 1144 1145 std::unique_ptr<FunctionInfo> func_info(new FunctionInfo); 1146 std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys); 1147 1148 // The executor_lock_ is intentionally released while executor is 1149 // being created. 1150 std::unordered_map<string, std::unique_ptr<Graph>> graphs; 1151 TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &func_info->flib_def, 1152 run_state_args, &ek->input_types, 1153 &ek->output_types)); 1154 1155 if (run_state_args->is_partial_run) { 1156 ek->graph = std::move(run_state_args->graph); 1157 std::unordered_set<StringPiece, StringPieceHasher> names; 1158 for (const string& input : inputs) { 1159 TensorId id(ParseTensorName(input)); 1160 names.emplace(id.first); 1161 } 1162 for (const string& output : outputs) { 1163 TensorId id(ParseTensorName(output)); 1164 names.emplace(id.first); 1165 } 1166 for (Node* n : ek->graph->nodes()) { 1167 if (names.count(n->name()) > 0) { 1168 ek->name_to_node.insert({n->name(), n}); 1169 } 1170 } 1171 } 1172 ek->items.reserve(graphs.size()); 1173 const auto& optimizer_opts = 1174 options_.config.graph_options().optimizer_options(); 1175 1176 int graph_def_version; 1177 { 1178 mutex_lock l(graph_def_lock_); 1179 graph_def_version = 1180 execution_state_->original_graph_def().versions().producer(); 1181 } 1182 func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime( 1183 device_mgr_.get(), options_.env, graph_def_version, 1184 func_info->flib_def.get(), optimizer_opts)); 1185 1186 GraphOptimizer optimizer(optimizer_opts); 1187 for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) { 1188 const string& partition_name = iter->first; 1189 std::unique_ptr<Graph>& partition_graph = iter->second; 1190 1191 Device* device; 1192 TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device)); 1193 1194 ek->items.resize(ek->items.size() + 1); 1195 auto* item = &(ek->items.back()); 1196 auto lib = func_info->proc_flr->GetFLR(partition_name); 1197 if (lib == nullptr) { 1198 return errors::Internal("Could not find device: ", partition_name); 1199 } 1200 item->flib = lib; 1201 1202 LocalExecutorParams params; 1203 params.device = device; 1204 params.function_library = lib; 1205 auto opseg = device->op_segment(); 1206 params.create_kernel = [this, lib, opseg](const NodeDef& ndef, 1207 OpKernel** kernel) { 1208 // We do not share the kernel via the OpSegment if the node is 1209 // stateless, or a function. 1210 // NOTE(mrry): We must not share function kernels (implemented 1211 // using `CallOp`) between subgraphs, because `CallOp::handle_` 1212 // is tied to a particular subgraph. Even if the function itself 1213 // is stateful, the `CallOp` that invokes it is not. 1214 if (!lib->IsStateful(ndef.op()) || 1215 lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) { 1216 return lib->CreateKernel(ndef, kernel); 1217 } 1218 auto create_fn = [lib, &ndef](OpKernel** kernel) { 1219 return lib->CreateKernel(ndef, kernel); 1220 }; 1221 // Kernels created for subgraph nodes need to be cached. On 1222 // cache miss, create_fn() is invoked to create a kernel based 1223 // on the function library here + global op registry. 1224 return opseg->FindOrCreate(session_handle_, ndef.name(), kernel, 1225 create_fn); 1226 }; 1227 params.delete_kernel = [lib](OpKernel* kernel) { 1228 // If the node is stateful, opseg owns it. Otherwise, delete it. 1229 if (kernel && !lib->IsStateful(kernel->type_string())) { 1230 delete kernel; 1231 } 1232 }; 1233 params.node_outputs_cb = node_outputs_callback_; 1234 1235 optimizer.Optimize(lib, options_.env, device, &iter->second, 1236 /*shape_map=*/nullptr); 1237 1238 // EXPERIMENTAL: tfdbg inserts debug nodes in the graph. 1239 if (!options.debug_options.debug_tensor_watch_opts().empty()) { 1240 TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug( 1241 options.debug_options, partition_graph.get(), params.device)); 1242 } 1243 1244 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()), 1245 device->name(), 1246 partition_graph.get())); 1247 // NewLocalExecutor takes ownership of partition_graph. 1248 item->graph = partition_graph.get(); 1249 item->executor = nullptr; 1250 item->device = device; 1251 Executor* executor; 1252 TF_RETURN_IF_ERROR( 1253 NewLocalExecutor(params, std::move(partition_graph), &executor)); 1254 item->executor.reset(executor); 1255 } 1256 1257 // Cache the mapping from input/output names to graph elements to 1258 // avoid recomputing it every time. 1259 if (!run_state_args->is_partial_run) { 1260 // For regular `Run()`, we use the function calling convention, and so 1261 // maintain a mapping from input/output names to 1262 // argument/return-value ordinal index. 1263 for (size_t i = 0; i < inputs_sorted.size(); ++i) { 1264 const string& input = inputs_sorted[i]; 1265 ek->input_name_to_index[input] = i; 1266 } 1267 for (size_t i = 0; i < outputs_sorted.size(); ++i) { 1268 const string& output = outputs_sorted[i]; 1269 ek->output_name_to_index[output] = i; 1270 } 1271 } else { 1272 // For `PRun()`, we use the rendezvous calling convention, and so 1273 // maintain a mapping from input/output names to rendezvous keys. 1274 // 1275 // We always use the first device as the device name portion of the 1276 // key, even if we're feeding another graph. 1277 for (size_t i = 0; i < inputs_sorted.size(); ++i) { 1278 const string& input = inputs_sorted[i]; 1279 ek->input_name_to_rendezvous_key[input] = GetRendezvousKey( 1280 input, device_set_.client_device()->attributes(), FrameAndIter(0, 0)); 1281 } 1282 for (size_t i = 0; i < outputs_sorted.size(); ++i) { 1283 const string& output = outputs_sorted[i]; 1284 ek->output_name_to_rendezvous_key[output] = 1285 GetRendezvousKey(output, device_set_.client_device()->attributes(), 1286 FrameAndIter(0, 0)); 1287 } 1288 } 1289 1290 // Reacquire the lock, try to insert into the map. 1291 mutex_lock l(executor_lock_); 1292 functions_.push_back(std::move(func_info)); 1293 1294 // Another thread may have created the entry before us, in which case we will 1295 // reuse the already created one. 1296 auto insert_result = executors_.emplace(sorted_key, ek); 1297 // Insert the value under the original key, so the fast path lookup will work 1298 // if the user uses the same order of inputs, outputs, and targets again. 1299 executors_.emplace(key, insert_result.first->second); 1300 *executors_and_keys = insert_result.first->second.get(); 1301 1302 return Status::OK(); 1303 } 1304 1305 Status DirectSession::CreateGraphs( 1306 const BuildGraphOptions& subgraph_options, 1307 std::unordered_map<string, std::unique_ptr<Graph>>* outputs, 1308 std::unique_ptr<FunctionLibraryDefinition>* flib_def, 1309 RunStateArgs* run_state_args, DataTypeVector* input_types, 1310 DataTypeVector* output_types) { 1311 mutex_lock l(graph_def_lock_); 1312 std::unique_ptr<ClientGraph> client_graph; 1313 1314 std::unique_ptr<GraphExecutionState> temp_exec_state_holder; 1315 GraphExecutionState* execution_state = nullptr; 1316 if (options_.config.graph_options().place_pruned_graph()) { 1317 // Because we are placing pruned graphs, we need to create a 1318 // new GraphExecutionState for every new unseen graph, 1319 // and then place it. 1320 GraphExecutionStateOptions prune_options; 1321 prune_options.device_set = &device_set_; 1322 prune_options.session_options = &options_; 1323 prune_options.stateful_placements = stateful_placements_; 1324 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph( 1325 execution_state_->original_graph_def().library(), prune_options, 1326 execution_state_->original_graph_def(), subgraph_options, 1327 &temp_exec_state_holder, &client_graph)); 1328 execution_state = temp_exec_state_holder.get(); 1329 } else { 1330 execution_state = execution_state_.get(); 1331 TF_RETURN_IF_ERROR( 1332 execution_state->BuildGraph(subgraph_options, &client_graph)); 1333 } 1334 1335 if (subgraph_options.feed_endpoints.size() != 1336 client_graph->feed_types.size()) { 1337 return errors::Internal( 1338 "Graph pruning failed: requested number of feed endpoints = ", 1339 subgraph_options.feed_endpoints.size(), 1340 " versus number of pruned feed endpoints = ", 1341 client_graph->feed_types.size()); 1342 } 1343 if (subgraph_options.fetch_endpoints.size() != 1344 client_graph->fetch_types.size()) { 1345 return errors::Internal( 1346 "Graph pruning failed: requested number of fetch endpoints = ", 1347 subgraph_options.fetch_endpoints.size(), 1348 " versus number of pruned fetch endpoints = ", 1349 client_graph->fetch_types.size()); 1350 } 1351 1352 auto current_stateful_placements = execution_state->GetStatefulPlacements(); 1353 // Update our current state based on the execution_state's 1354 // placements. If there are any mismatches for a node, 1355 // we should fail, as this should never happen. 1356 for (auto placement_pair : current_stateful_placements) { 1357 const string& node_name = placement_pair.first; 1358 const string& placement = placement_pair.second; 1359 auto iter = stateful_placements_.find(node_name); 1360 if (iter == stateful_placements_.end()) { 1361 stateful_placements_.insert(std::make_pair(node_name, placement)); 1362 } else if (iter->second != placement) { 1363 return errors::Internal( 1364 "Stateful placement mismatch. " 1365 "Current assignment of ", 1366 node_name, " to ", iter->second, " does not match ", placement); 1367 } 1368 } 1369 1370 stateful_placements_ = execution_state->GetStatefulPlacements(); 1371 1372 // Remember the graph in run state if this is a partial run. 1373 if (run_state_args->is_partial_run) { 1374 run_state_args->graph.reset(new Graph(flib_def_.get())); 1375 CopyGraph(*execution_state->full_graph(), run_state_args->graph.get()); 1376 } 1377 1378 // Partition the graph across devices. 1379 PartitionOptions popts; 1380 popts.node_to_loc = [](const Node* node) { 1381 return node->assigned_device_name(); 1382 }; 1383 popts.new_name = [this](const string& prefix) { 1384 return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1)); 1385 }; 1386 popts.get_incarnation = [](const string& name) { 1387 // The direct session does not have changing incarnation numbers. 1388 // Just return '1'. 1389 return 1; 1390 }; 1391 popts.flib_def = &client_graph->graph.flib_def(); 1392 popts.control_flow_added = false; 1393 1394 std::unordered_map<string, GraphDef> partitions; 1395 TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions)); 1396 1397 std::vector<string> device_names; 1398 for (auto device : devices_) { 1399 // Extract the LocalName from the device. 1400 device_names.push_back(DeviceNameUtils::LocalName(device->name())); 1401 } 1402 1403 // Check for valid partitions. 1404 for (const auto& partition : partitions) { 1405 const string local_partition_name = 1406 DeviceNameUtils::LocalName(partition.first); 1407 if (std::count(device_names.begin(), device_names.end(), 1408 local_partition_name) == 0) { 1409 return errors::InvalidArgument( 1410 "Creating a partition for ", local_partition_name, 1411 " which doesn't exist in the list of available devices. Available " 1412 "devices: ", 1413 str_util::Join(device_names, ",")); 1414 } 1415 } 1416 1417 for (const auto& partition : partitions) { 1418 std::unique_ptr<Graph> device_graph( 1419 new Graph(client_graph->flib_def.get())); 1420 GraphConstructorOptions device_opts; 1421 // There are internal operations (e.g., send/recv) that we now allow. 1422 device_opts.allow_internal_ops = true; 1423 device_opts.expect_device_spec = true; 1424 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second, 1425 device_graph.get())); 1426 outputs->emplace(partition.first, std::move(device_graph)); 1427 } 1428 1429 GraphOptimizationPassOptions optimization_options; 1430 optimization_options.session_options = &options_; 1431 optimization_options.flib_def = client_graph->flib_def.get(); 1432 optimization_options.partition_graphs = outputs; 1433 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( 1434 OptimizationPassRegistry::POST_PARTITIONING, optimization_options)); 1435 1436 Status s; 1437 for (auto& partition : *outputs) { 1438 const string& partition_name = partition.first; 1439 std::unique_ptr<Graph>* graph = &partition.second; 1440 1441 VLOG(2) << "Created " << DebugString(graph->get()) << " for " 1442 << partition_name; 1443 1444 // Give the device an opportunity to rewrite its subgraph. 1445 Device* d; 1446 s = device_mgr_->LookupDevice(partition_name, &d); 1447 if (!s.ok()) break; 1448 s = d->MaybeRewriteGraph(graph); 1449 if (!s.ok()) { 1450 break; 1451 } 1452 } 1453 *flib_def = std::move(client_graph->flib_def); 1454 std::swap(*input_types, client_graph->feed_types); 1455 std::swap(*output_types, client_graph->fetch_types); 1456 return s; 1457 } 1458 1459 ::tensorflow::Status DirectSession::ListDevices( 1460 std::vector<DeviceAttributes>* response) { 1461 response->clear(); 1462 response->reserve(devices_.size()); 1463 for (Device* d : devices_) { 1464 const DeviceAttributes& attrs = d->attributes(); 1465 response->emplace_back(attrs); 1466 } 1467 return ::tensorflow::Status::OK(); 1468 } 1469 1470 ::tensorflow::Status DirectSession::Reset( 1471 const std::vector<string>& containers) { 1472 device_mgr_->ClearContainers(containers); 1473 return ::tensorflow::Status::OK(); 1474 } 1475 1476 ::tensorflow::Status DirectSession::Close() { 1477 cancellation_manager_->StartCancel(); 1478 { 1479 mutex_lock l(closed_lock_); 1480 if (closed_) return ::tensorflow::Status::OK(); 1481 closed_ = true; 1482 } 1483 if (factory_ != nullptr) factory_->Deregister(this); 1484 return ::tensorflow::Status::OK(); 1485 } 1486 1487 DirectSession::RunState::RunState( 1488 const std::vector<string>& pending_input_names, 1489 const std::vector<string>& pending_output_names, int64 step_id, 1490 const std::vector<Device*>* devices) 1491 : step_container(step_id, [devices](const string& name) { 1492 for (auto d : *devices) { 1493 if (!d->resource_manager()->Cleanup(name).ok()) { 1494 // Do nothing... 1495 } 1496 } 1497 }) { 1498 // Initially all the feeds and fetches are pending. 1499 for (auto& name : pending_input_names) { 1500 pending_inputs[name] = false; 1501 } 1502 for (auto& name : pending_output_names) { 1503 pending_outputs[name] = false; 1504 } 1505 } 1506 1507 DirectSession::RunState::RunState(int64 step_id, 1508 const std::vector<Device*>* devices) 1509 : RunState({}, {}, step_id, devices) {} 1510 1511 DirectSession::RunState::~RunState() { 1512 if (rendez != nullptr) { 1513 if (!executors_done.HasBeenNotified()) { 1514 rendez->StartAbort(errors::Cancelled("PRun cancellation")); 1515 executors_done.WaitForNotification(); 1516 } 1517 rendez->Unref(); 1518 } 1519 } 1520 1521 bool DirectSession::RunState::PendingDone() const { 1522 for (const auto& it : pending_inputs) { 1523 if (!it.second) return false; 1524 } 1525 for (const auto& it : pending_outputs) { 1526 if (!it.second) return false; 1527 } 1528 return true; 1529 } 1530 1531 void DirectSession::WaitForNotification(RunState* run_state, 1532 CancellationManager* cm, 1533 int64 timeout_in_ms) { 1534 const Status status = 1535 WaitForNotification(&run_state->executors_done, timeout_in_ms); 1536 if (!status.ok()) { 1537 { 1538 mutex_lock l(run_state->mu_); 1539 run_state->status.Update(status); 1540 } 1541 cm->StartCancel(); 1542 // We must wait for the executors to complete, because they have borrowed 1543 // references to `cm` and other per-step state. After this notification, it 1544 // is safe to clean up the step. 1545 run_state->executors_done.WaitForNotification(); 1546 } 1547 } 1548 1549 ::tensorflow::Status DirectSession::WaitForNotification( 1550 Notification* notification, int64 timeout_in_ms) { 1551 if (timeout_in_ms > 0) { 1552 const int64 timeout_in_us = timeout_in_ms * 1000; 1553 const bool notified = 1554 WaitForNotificationWithTimeout(notification, timeout_in_us); 1555 if (!notified) { 1556 return Status(error::DEADLINE_EXCEEDED, 1557 "Timed out waiting for notification"); 1558 } 1559 } else { 1560 notification->WaitForNotification(); 1561 } 1562 return Status::OK(); 1563 } 1564 1565 } // namespace tensorflow 1566