Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/common_runtime/direct_session.h"
     17 
     18 #include <atomic>
     19 #include <string>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/common_runtime/constant_folding.h"
     23 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
     24 #include "tensorflow/core/common_runtime/device_factory.h"
     25 #include "tensorflow/core/common_runtime/executor.h"
     26 #include "tensorflow/core/common_runtime/function.h"
     27 #include "tensorflow/core/common_runtime/graph_optimizer.h"
     28 #include "tensorflow/core/common_runtime/memory_types.h"
     29 #include "tensorflow/core/common_runtime/optimization_registry.h"
     30 #include "tensorflow/core/common_runtime/step_stats_collector.h"
     31 #include "tensorflow/core/framework/function.h"
     32 #include "tensorflow/core/framework/graph.pb_text.h"
     33 #include "tensorflow/core/framework/graph.pb.h"
     34 #include "tensorflow/core/framework/graph_def_util.h"
     35 #include "tensorflow/core/framework/log_memory.h"
     36 #include "tensorflow/core/framework/node_def.pb.h"
     37 #include "tensorflow/core/framework/tensor.h"
     38 #include "tensorflow/core/framework/versions.pb.h"
     39 #include "tensorflow/core/graph/algorithm.h"
     40 #include "tensorflow/core/graph/graph.h"
     41 #include "tensorflow/core/graph/graph_constructor.h"
     42 #include "tensorflow/core/graph/graph_partition.h"
     43 #include "tensorflow/core/graph/subgraph.h"
     44 #include "tensorflow/core/graph/tensor_id.h"
     45 #include "tensorflow/core/lib/core/errors.h"
     46 #include "tensorflow/core/lib/core/notification.h"
     47 #include "tensorflow/core/lib/core/refcount.h"
     48 #include "tensorflow/core/lib/core/status.h"
     49 #include "tensorflow/core/lib/core/threadpool.h"
     50 #include "tensorflow/core/lib/gtl/array_slice.h"
     51 #include "tensorflow/core/lib/gtl/stl_util.h"
     52 #include "tensorflow/core/lib/monitoring/counter.h"
     53 #include "tensorflow/core/lib/strings/numbers.h"
     54 #include "tensorflow/core/lib/strings/str_util.h"
     55 #include "tensorflow/core/lib/strings/strcat.h"
     56 #include "tensorflow/core/platform/cpu_info.h"
     57 #include "tensorflow/core/platform/device_tracer.h"
     58 #include "tensorflow/core/platform/logging.h"
     59 #include "tensorflow/core/platform/mutex.h"
     60 #include "tensorflow/core/platform/types.h"
     61 #include "tensorflow/core/util/device_name_utils.h"
     62 #include "tensorflow/core/util/env_var.h"
     63 
     64 namespace tensorflow {
     65 
     66 namespace {
     67 
     68 auto* direct_session_runs = monitoring::Counter<0>::New(
     69     "/tensorflow/core/direct_session_runs",
     70     "The number of times DirectSession::Run() has been called.");
     71 
     72 int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
     73   const int32 t = options.config.inter_op_parallelism_threads();
     74   if (t != 0) return t;
     75   // Default to using the number of cores available in the process.
     76   return port::NumSchedulableCPUs();
     77 }
     78 
     79 thread::ThreadPool* NewThreadPoolFromSessionOptions(
     80     const SessionOptions& options) {
     81   const int32 num_threads = NumInterOpThreadsFromSessionOptions(options);
     82   VLOG(1) << "Direct session inter op parallelism threads: " << num_threads;
     83   return new thread::ThreadPool(options.env, "Compute", num_threads);
     84 }
     85 
     86 Status NewThreadPoolFromThreadPoolOptions(
     87     const SessionOptions& options,
     88     const ThreadPoolOptionProto& thread_pool_options, int pool_number,
     89     thread::ThreadPool** pool, bool* owned) {
     90   int32 num_threads = thread_pool_options.num_threads();
     91   if (num_threads == 0) {
     92     num_threads = NumInterOpThreadsFromSessionOptions(options);
     93   }
     94   const string& name = thread_pool_options.global_name();
     95   if (name.empty()) {
     96     // Session-local threadpool.
     97     VLOG(1) << "Direct session inter op parallelism threads for pool "
     98             << pool_number << ": " << num_threads;
     99     *pool = new thread::ThreadPool(
    100         options.env, strings::StrCat("Compute", pool_number), num_threads);
    101     *owned = true;
    102     return Status::OK();
    103   }
    104 
    105   // Global, named threadpool.
    106   typedef std::pair<int32, thread::ThreadPool*> MapValue;
    107   static std::map<string, MapValue>* global_pool_map =
    108       new std::map<string, MapValue>;
    109   static mutex* mu = new mutex();
    110   mutex_lock l(*mu);
    111   MapValue* mvalue = &(*global_pool_map)[name];
    112   if (mvalue->second == nullptr) {
    113     mvalue->first = thread_pool_options.num_threads();
    114     mvalue->second = new thread::ThreadPool(
    115         options.env, strings::StrCat("Compute", pool_number), num_threads);
    116   } else {
    117     if (mvalue->first != thread_pool_options.num_threads()) {
    118       return errors::InvalidArgument(
    119           "Pool ", name,
    120           " configured previously with num_threads=", mvalue->first,
    121           "; cannot re-configure with num_threads=",
    122           thread_pool_options.num_threads());
    123     }
    124   }
    125   *owned = false;
    126   *pool = mvalue->second;
    127   return Status::OK();
    128 }
    129 
    130 thread::ThreadPool* GlobalThreadPool(const SessionOptions& options) {
    131   static thread::ThreadPool* const thread_pool =
    132       NewThreadPoolFromSessionOptions(options);
    133   return thread_pool;
    134 }
    135 
    136 // TODO(vrv): Figure out how to unify the many different functions
    137 // that generate RendezvousKey, since many of them have to be
    138 // consistent with each other.
    139 string GetRendezvousKey(const string& tensor_name,
    140                         const DeviceAttributes& device_info,
    141                         const FrameAndIter& frame_iter) {
    142   return strings::StrCat(device_info.name(), ";",
    143                          strings::FpToString(device_info.incarnation()), ";",
    144                          device_info.name(), ";", tensor_name, ";",
    145                          frame_iter.frame_id, ":", frame_iter.iter_id);
    146 }
    147 
    148 }  // namespace
    149 
    150 class DirectSessionFactory : public SessionFactory {
    151  public:
    152   DirectSessionFactory() {}
    153 
    154   bool AcceptsOptions(const SessionOptions& options) override {
    155     return options.target.empty();
    156   }
    157 
    158   Session* NewSession(const SessionOptions& options) override {
    159     // Must do this before the CPU allocator is created.
    160     if (options.config.graph_options().build_cost_model() > 0) {
    161       EnableCPUAllocatorFullStats(true);
    162     }
    163     std::vector<Device*> devices;
    164     const Status s = DeviceFactory::AddDevices(
    165         options, "/job:localhost/replica:0/task:0", &devices);
    166     if (!s.ok()) {
    167       LOG(ERROR) << s;
    168       return nullptr;
    169     }
    170 
    171     DirectSession* session =
    172         new DirectSession(options, new DeviceMgr(devices), this);
    173     {
    174       mutex_lock l(sessions_lock_);
    175       sessions_.push_back(session);
    176     }
    177     return session;
    178   }
    179 
    180   Status Reset(const SessionOptions& options,
    181                const std::vector<string>& containers) override {
    182     std::vector<DirectSession*> sessions_to_reset;
    183     {
    184       mutex_lock l(sessions_lock_);
    185       // We create a copy to ensure that we don't have a deadlock when
    186       // session->Close calls the DirectSessionFactory.Deregister, which
    187       // acquires sessions_lock_.
    188       std::swap(sessions_to_reset, sessions_);
    189     }
    190     Status s;
    191     for (auto session : sessions_to_reset) {
    192       s.Update(session->Reset(containers));
    193     }
    194     // TODO(suharshs): Change the Reset behavior of all SessionFactories so that
    195     // it doesn't close the sessions?
    196     for (auto session : sessions_to_reset) {
    197       s.Update(session->Close());
    198     }
    199     return s;
    200   }
    201 
    202   void Deregister(const DirectSession* session) {
    203     mutex_lock l(sessions_lock_);
    204     sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
    205                     sessions_.end());
    206   }
    207 
    208  private:
    209   mutex sessions_lock_;
    210   std::vector<DirectSession*> sessions_ GUARDED_BY(sessions_lock_);
    211 };
    212 
    213 class DirectSessionRegistrar {
    214  public:
    215   DirectSessionRegistrar() {
    216     SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
    217   }
    218 };
    219 static DirectSessionRegistrar registrar;
    220 
    221 std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
    222 
    223 // NOTE: On Android with a single device, there is never
    224 // a risk of an OpKernel blocking indefinitely:
    225 //
    226 // 1) No operations do I/O that depends on other simultaneous kernels,
    227 //
    228 // 2) Recv nodes always complete immediately: The inputs are sent into
    229 //    the local rendezvous before we start the executor, so the
    230 //    corresponding recvs will not block.
    231 //
    232 // Based on these assumptions, we can use the same thread pool for
    233 // both "non-blocking" and "blocking" OpKernels on Android.
    234 //
    235 // This may change down the road when we add support for multiple
    236 // devices that run concurrently, in which case we will need to
    237 // revisit this decision.
    238 void DirectSession::SchedClosure(thread::ThreadPool* pool,
    239                                  std::function<void()> c) {
    240 // TODO(sanjay): Get rid of __ANDROID__ path
    241 #ifdef __ANDROID__
    242   // On Android, there is no implementation of ThreadPool that takes
    243   // std::function, only Closure, which we cannot easily convert.
    244   //
    245   // Instead, we just run the function in-line, which is currently
    246   // safe given the reasoning above.
    247   c();
    248 #else
    249   pool->Schedule(std::move(c));
    250 #endif  // __ANDROID__
    251 }
    252 
    253 DirectSession::DirectSession(const SessionOptions& options,
    254                              const DeviceMgr* device_mgr,
    255                              DirectSessionFactory* const factory)
    256     : options_(options),
    257       device_mgr_(device_mgr),
    258       factory_(factory),
    259       cancellation_manager_(new CancellationManager()),
    260       operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
    261   const int thread_pool_size =
    262       options_.config.session_inter_op_thread_pool_size();
    263   if (thread_pool_size > 0) {
    264     for (int i = 0; i < thread_pool_size; ++i) {
    265       thread::ThreadPool* pool = nullptr;
    266       bool owned = false;
    267       init_error_.Update(NewThreadPoolFromThreadPoolOptions(
    268           options_, options_.config.session_inter_op_thread_pool(i), i, &pool,
    269           &owned));
    270       thread_pools_.emplace_back(pool, owned);
    271     }
    272   } else if (options_.config.use_per_session_threads()) {
    273     thread_pools_.emplace_back(NewThreadPoolFromSessionOptions(options_),
    274                                true /* owned */);
    275   } else {
    276     thread_pools_.emplace_back(GlobalThreadPool(options), false /* owned */);
    277   }
    278   // The default value of sync_on_finish will be flipped soon and this
    279   // environment variable will be removed as well.
    280   const Status status =
    281       ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
    282   if (!status.ok()) {
    283     LOG(ERROR) << status.error_message();
    284   }
    285   // NOTE(mrry): We do not need to use a unique string for the session
    286   // handle, because DirectSession owns its devices. This may change
    287   // in future versions.
    288   session_handle_ = "direct";
    289   int devices_added = 0;
    290   if (options.config.log_device_placement()) {
    291     const string mapping_str = device_mgr_->DeviceMappingString();
    292     if (mapping_str.empty()) {
    293       printf("Device mapping: no known devices.\n");
    294     } else {
    295       printf("Device mapping:\n%s", mapping_str.c_str());
    296     }
    297     LOG(INFO) << "Device mapping:\n" << mapping_str;
    298   }
    299   for (auto d : device_mgr_->ListDevices()) {
    300     devices_.push_back(d);
    301     device_set_.AddDevice(d);
    302     d->op_segment()->AddHold(session_handle_);
    303 
    304     // The first device added is special: it is the 'client device' (a
    305     // CPU device) from which we feed and fetch Tensors.
    306     if (devices_added == 0) {
    307       device_set_.set_client_device(d);
    308     }
    309     ++devices_added;
    310   }
    311 }
    312 
    313 DirectSession::~DirectSession() {
    314   if (!closed_) Close().IgnoreError();
    315   for (auto& it : partial_runs_) {
    316     it.second.reset(nullptr);
    317   }
    318   for (auto& it : executors_) {
    319     it.second.reset();
    320   }
    321   for (auto d : device_mgr_->ListDevices()) {
    322     d->op_segment()->RemoveHold(session_handle_);
    323   }
    324   for (auto d : device_mgr_->ListDevices()) {
    325     d->ClearResourceMgr();
    326   }
    327   functions_.clear();
    328   delete cancellation_manager_;
    329   for (const auto& p_and_owned : thread_pools_) {
    330     if (p_and_owned.second) delete p_and_owned.first;
    331   }
    332 
    333   execution_state_.reset(nullptr);
    334   flib_def_.reset(nullptr);
    335 }
    336 
    337 Status DirectSession::MaybeInitializeExecutionState(
    338     const GraphDef& graph, bool* out_already_initialized) {
    339   // If already initialized, do nothing.
    340   if (flib_def_ && execution_state_) {
    341     *out_already_initialized = true;
    342     return Status::OK();
    343   }
    344   // Set up the per-session execution state.
    345   // NOTE(mrry): The function library created here will be used for
    346   // all subsequent extensions of the graph.
    347   flib_def_.reset(
    348       new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
    349   GraphExecutionStateOptions options;
    350   options.device_set = &device_set_;
    351   options.session_options = &options_;
    352   // TODO(mrry,suharshs): We explicitly copy `graph` so that
    353   // `MakeForBaseGraph()` can take ownership of its
    354   // contents. Previously this happened implicitly in calls to the
    355   // `GraphExecutionState`. Other sessions call
    356   // `MakeForBaseGraph` in such a way that we can destructively read
    357   // the passed-in `GraphDef`. In principle we could do the same here,
    358   // with a wider refactoring; we might revise the direct session so
    359   // that it copies the graph fewer times.
    360   GraphDef temp(graph);
    361   TF_RETURN_IF_ERROR(
    362       GraphExecutionState::MakeForBaseGraph(&temp, options, &execution_state_));
    363   graph_created_ = true;
    364   *out_already_initialized = false;
    365   return Status::OK();
    366 }
    367 
    368 Status DirectSession::Create(const GraphDef& graph) {
    369   TF_RETURN_IF_ERROR(init_error_);
    370   if (graph.node_size() > 0) {
    371     mutex_lock l(graph_def_lock_);
    372     if (graph_created_) {
    373       return errors::AlreadyExists(
    374           "A Graph has already been created for this session.");
    375     }
    376     return ExtendLocked(graph);
    377   }
    378   return Status::OK();
    379 }
    380 
    381 Status DirectSession::Extend(const GraphDef& graph) {
    382   TF_RETURN_IF_ERROR(CheckNotClosed());
    383   mutex_lock l(graph_def_lock_);
    384   return ExtendLocked(graph);
    385 }
    386 
    387 Status DirectSession::ExtendLocked(const GraphDef& graph) {
    388   bool already_initialized;
    389   // If this is the first call, we can initialize the execution state
    390   // with `graph` and do not need to call `Extend()`.
    391   TF_RETURN_IF_ERROR(
    392       MaybeInitializeExecutionState(graph, &already_initialized));
    393   if (already_initialized) {
    394     TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
    395     std::unique_ptr<GraphExecutionState> state;
    396     TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
    397     execution_state_.swap(state);
    398   }
    399   return Status::OK();
    400 }
    401 
    402 Status DirectSession::Run(const NamedTensorList& inputs,
    403                           const std::vector<string>& output_names,
    404                           const std::vector<string>& target_nodes,
    405                           std::vector<Tensor>* outputs) {
    406   RunMetadata run_metadata;
    407   return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
    408              &run_metadata);
    409 }
    410 
    411 Status DirectSession::CreateDebuggerState(
    412     const DebugOptions& debug_options, int64 session_run_index,
    413     int64 executor_step_index, const std::vector<string>& input_names,
    414     const std::vector<string>& output_names,
    415     const std::vector<string>& target_names,
    416     std::unique_ptr<DebuggerStateInterface>* debugger_state) {
    417   TF_RETURN_IF_ERROR(
    418       DebuggerStateRegistry::CreateState(debug_options, debugger_state));
    419   TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
    420       debug_options.global_step(), session_run_index, executor_step_index,
    421       input_names, output_names, target_names));
    422   return Status::OK();
    423 }
    424 
    425 Status DirectSession::DecorateAndPublishGraphForDebug(
    426     const DebugOptions& debug_options, Graph* graph, Device* device) {
    427   std::unique_ptr<DebugGraphDecoratorInterface> decorator;
    428   TF_RETURN_IF_ERROR(
    429       DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
    430 
    431   TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
    432   TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
    433   return Status::OK();
    434 }
    435 
    436 Status DirectSession::Run(const RunOptions& run_options,
    437                           const NamedTensorList& inputs,
    438                           const std::vector<string>& output_names,
    439                           const std::vector<string>& target_nodes,
    440                           std::vector<Tensor>* outputs,
    441                           RunMetadata* run_metadata) {
    442   TF_RETURN_IF_ERROR(CheckNotClosed());
    443   direct_session_runs->GetCell()->IncrementBy(1);
    444   {
    445     mutex_lock l(graph_def_lock_);
    446     if (!graph_created_) {
    447       return errors::InvalidArgument(
    448           "Session was not created with a graph before Run()!");
    449     }
    450   }
    451 
    452   // Extract the inputs names for this run of the session.
    453   std::vector<string> input_tensor_names;
    454   input_tensor_names.reserve(inputs.size());
    455   for (const auto& it : inputs) {
    456     input_tensor_names.push_back(it.first);
    457   }
    458 
    459   if (run_options.inter_op_thread_pool() < 0 ||
    460       run_options.inter_op_thread_pool() >= thread_pools_.size()) {
    461     return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
    462                                    run_options.inter_op_thread_pool());
    463   }
    464   thread::ThreadPool* pool =
    465       thread_pools_[run_options.inter_op_thread_pool()].first;
    466 
    467   // Check if we already have an executor for these arguments.
    468   ExecutorsAndKeys* executors_and_keys;
    469   RunStateArgs run_state_args(run_options.debug_options());
    470 
    471   Executor::Args args;
    472   args.step_id = step_id_counter_.fetch_add(1);
    473 
    474   TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
    475                                           target_nodes, &executors_and_keys,
    476                                           &run_state_args));
    477   const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
    478 
    479   std::unique_ptr<DebuggerStateInterface> debugger_state;
    480   if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
    481     TF_RETURN_IF_ERROR(CreateDebuggerState(
    482         run_options.debug_options(), args.step_id, executor_step_count,
    483         input_tensor_names, output_names, target_nodes, &debugger_state));
    484   }
    485 
    486   // Configure a call frame for the step, which we use to feed and
    487   // fetch values to and from the executors.
    488   FunctionCallFrame call_frame(executors_and_keys->input_types,
    489                                executors_and_keys->output_types);
    490   gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
    491   for (const auto& it : inputs) {
    492     if (it.second.dtype() == DT_RESOURCE) {
    493       Tensor tensor_from_handle;
    494       TF_RETURN_IF_ERROR(
    495           ResourceHandleToInputTensor(it.second, &tensor_from_handle));
    496       feed_args[executors_and_keys->input_name_to_index[it.first]] =
    497           tensor_from_handle;
    498     } else {
    499       feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
    500     }
    501   }
    502   const Status s = call_frame.SetArgs(feed_args);
    503   if (errors::IsInternal(s)) {
    504     return errors::InvalidArgument(s.error_message());
    505   } else if (!s.ok()) {
    506     return s;
    507   }
    508 
    509   // Create a run state and start execution.
    510   RunState run_state(args.step_id, &devices_);
    511   run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
    512   CancellationManager step_cancellation_manager;
    513   args.call_frame = &call_frame;
    514 
    515   // Start parallel Executors.
    516   const size_t num_executors = executors_and_keys->items.size();
    517   ExecutorBarrier* barrier = new ExecutorBarrier(
    518       num_executors, run_state.rendez, [&run_state](const Status& ret) {
    519         {
    520           mutex_lock l(run_state.mu_);
    521           run_state.status.Update(ret);
    522         }
    523         run_state.executors_done.Notify();
    524       });
    525 
    526   args.rendezvous = run_state.rendez;
    527   args.cancellation_manager = &step_cancellation_manager;
    528 
    529   args.session_state = &session_state_;
    530   args.tensor_store = &run_state.tensor_store;
    531   args.step_container = &run_state.step_container;
    532   if (LogMemory::IsEnabled()) {
    533     LogMemory::RecordStep(args.step_id, run_state_args.handle);
    534   }
    535   args.sync_on_finish = sync_on_finish_;
    536 
    537   const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
    538 
    539   bool update_cost_model = false;
    540   if (options_.config.graph_options().build_cost_model() > 0) {
    541     const int64 build_cost_model_every =
    542         options_.config.graph_options().build_cost_model();
    543     const int64 build_cost_model_after =
    544         options_.config.graph_options().build_cost_model_after();
    545     int64 measure_step_count = executor_step_count - build_cost_model_after;
    546     if (measure_step_count >= 0) {
    547       update_cost_model =
    548           ((measure_step_count + 1) % build_cost_model_every == 0);
    549     }
    550   }
    551   if (do_trace || update_cost_model ||
    552       run_options.report_tensor_allocations_upon_oom()) {
    553     run_state.collector.reset(
    554         new StepStatsCollector(run_metadata->mutable_step_stats()));
    555     args.stats_collector = run_state.collector.get();
    556   }
    557 
    558   std::unique_ptr<DeviceTracer> tracer;
    559   if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
    560     tracer = CreateDeviceTracer();
    561     // tracer may be NULL on platforms without accelerators.
    562     if (tracer) {
    563       Status s = tracer->Start();
    564       if (!s.ok()) {
    565         run_state.executors_done.Notify();
    566         delete barrier;
    567         return s;
    568       }
    569     }
    570   }
    571 
    572   // Register this step with session's cancellation manager, so that
    573   // `Session::Close()` will cancel the step.
    574   const CancellationToken cancellation_token =
    575       cancellation_manager_->get_cancellation_token();
    576   const bool already_cancelled = !cancellation_manager_->RegisterCallback(
    577       cancellation_token, [&step_cancellation_manager]() {
    578         step_cancellation_manager.StartCancel();
    579       });
    580   if (already_cancelled) {
    581     // NOTE(mrry): If we don't explicitly notify
    582     // `run_state.executors_done`, the RunState destructor would
    583     // block on this notification.
    584     run_state.executors_done.Notify();
    585     delete barrier;
    586     return errors::Cancelled("Run call was cancelled");
    587   }
    588 
    589   Executor::Args::Runner default_runner = [this,
    590                                            pool](Executor::Args::Closure c) {
    591     SchedClosure(pool, std::move(c));
    592   };
    593   for (const auto& item : executors_and_keys->items) {
    594     // TODO(zhengxq): support partial run.
    595     // TODO(zhengxq): if the device picks its own threadpool, we need to assign
    596     //     less threads to the main compute pool by default.
    597     thread::ThreadPool* device_thread_pool =
    598         item.device->tensorflow_device_thread_pool();
    599     if (!device_thread_pool) {
    600       args.runner = default_runner;
    601     } else {
    602       args.runner = [this, device_thread_pool](Executor::Args::Closure c) {
    603         SchedClosure(device_thread_pool, std::move(c));
    604       };
    605     }
    606     item.executor->RunAsync(args, barrier->Get());
    607   }
    608 
    609   WaitForNotification(&run_state, &step_cancellation_manager,
    610                       run_options.timeout_in_ms() > 0
    611                           ? run_options.timeout_in_ms()
    612                           : operation_timeout_in_ms_);
    613 
    614   if (!cancellation_manager_->DeregisterCallback(cancellation_token)) {
    615     // The step has been cancelled: make sure we don't attempt to receive the
    616     // outputs as this would make it block forever.
    617     mutex_lock l(run_state.mu_);
    618     run_state.status.Update(errors::Cancelled("Run call was cancelled"));
    619   }
    620 
    621   if (tracer) {
    622     TF_RETURN_IF_ERROR(tracer->Stop());
    623     TF_RETURN_IF_ERROR(tracer->Collect(args.stats_collector));
    624   }
    625 
    626   {
    627     mutex_lock l(run_state.mu_);
    628     TF_RETURN_IF_ERROR(run_state.status);
    629   }
    630 
    631   // Receive outputs.
    632   if (outputs) {
    633     std::vector<Tensor> sorted_outputs;
    634     const Status s = call_frame.ConsumeRetvals(&sorted_outputs);
    635     if (errors::IsInternal(s)) {
    636       return errors::InvalidArgument(s.error_message());
    637     } else if (!s.ok()) {
    638       return s;
    639     }
    640     const bool unique_outputs =
    641         output_names.size() == executors_and_keys->output_name_to_index.size();
    642     // first_indices[i] = j implies that j is the smallest value for which
    643     // output_names[i] == output_names[j].
    644     std::vector<int> first_indices;
    645     if (!unique_outputs) {
    646       first_indices.resize(output_names.size());
    647       for (int i = 0; i < output_names.size(); ++i) {
    648         for (int j = 0; j <= i; ++j) {
    649           if (output_names[i] == output_names[j]) {
    650             first_indices[i] = j;
    651             break;
    652           }
    653         }
    654       }
    655     }
    656     outputs->clear();
    657     outputs->reserve(sorted_outputs.size());
    658     for (int i = 0; i < output_names.size(); ++i) {
    659       const string& output_name = output_names[i];
    660       if (first_indices.empty() || first_indices[i] == i) {
    661         outputs->emplace_back(
    662             std::move(sorted_outputs[executors_and_keys
    663                                          ->output_name_to_index[output_name]]));
    664       } else {
    665         outputs->push_back((*outputs)[first_indices[i]]);
    666       }
    667     }
    668   }
    669 
    670   // Save the output tensors of this run we choose to keep.
    671   TF_RETURN_IF_ERROR(
    672       run_state.tensor_store.SaveTensors(output_names, &session_state_));
    673   if (args.stats_collector) {
    674     args.stats_collector->Finalize();
    675   }
    676 
    677   // Build and return the cost model as instructed.
    678   mutex_lock l(executor_lock_);
    679   if (update_cost_model) {
    680     // Build the cost model
    681     std::unordered_map<string, const Graph*> device_to_graph;
    682     for (const PerPartitionExecutorsAndLib& partition :
    683          executors_and_keys->items) {
    684       const Graph* graph = partition.graph;
    685       const string device = partition.flib->device()->name();
    686       device_to_graph[device] = graph;
    687     }
    688     args.stats_collector->BuildCostModel(&cost_model_manager_, device_to_graph);
    689 
    690     // annotate stats onto cost graph.
    691     CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
    692     for (const auto& item : executors_and_keys->items) {
    693       TF_RETURN_IF_ERROR(
    694           cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph));
    695     }
    696   }
    697 
    698   // If requested via RunOptions, output the partition graphs.
    699   if (run_options.output_partition_graphs()) {
    700     protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
    701         run_metadata->mutable_partition_graphs();
    702     for (const PerPartitionExecutorsAndLib& exec_and_lib :
    703          executors_and_keys->items) {
    704       GraphDef* partition_graph_def = partition_graph_defs->Add();
    705       exec_and_lib.graph->ToGraphDef(partition_graph_def);
    706     }
    707   }
    708 
    709   return Status::OK();
    710 }
    711 
    712 Status DirectSession::PRunSetup(const std::vector<string>& input_names,
    713                                 const std::vector<string>& output_names,
    714                                 const std::vector<string>& target_nodes,
    715                                 string* handle) {
    716   TF_RETURN_IF_ERROR(CheckNotClosed());
    717   {
    718     mutex_lock l(graph_def_lock_);
    719     if (!graph_created_) {
    720       return errors::InvalidArgument(
    721           "Session was not created with a graph before PRunSetup()!");
    722     }
    723   }
    724 
    725   // RunOptions is not available in PRunSetup, so use thread pool 0.
    726   thread::ThreadPool* pool = thread_pools_[0].first;
    727 
    728   // Check if we already have an executor for these arguments.
    729   ExecutorsAndKeys* executors_and_keys;
    730   // TODO(cais): TFDBG support for partial runs.
    731   DebugOptions debug_options;
    732   RunStateArgs run_state_args(debug_options);
    733   run_state_args.is_partial_run = true;
    734   TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names,
    735                                           target_nodes, &executors_and_keys,
    736                                           &run_state_args));
    737 
    738   // Create the run state and save it for future PRun calls.
    739   Executor::Args args;
    740   args.step_id = step_id_counter_.fetch_add(1);
    741   RunState* run_state =
    742       new RunState(input_names, output_names, args.step_id, &devices_);
    743   run_state->rendez = new IntraProcessRendezvous(device_mgr_.get());
    744   {
    745     mutex_lock l(executor_lock_);
    746     if (!partial_runs_
    747              .emplace(run_state_args.handle,
    748                       std::unique_ptr<RunState>(run_state))
    749              .second) {
    750       return errors::Internal("The handle '", run_state_args.handle,
    751                               "' created for this partial run is not unique.");
    752     }
    753   }
    754 
    755   // Start parallel Executors.
    756   const size_t num_executors = executors_and_keys->items.size();
    757   ExecutorBarrier* barrier = new ExecutorBarrier(
    758       num_executors, run_state->rendez, [run_state](const Status& ret) {
    759         if (!ret.ok()) {
    760           mutex_lock l(run_state->mu_);
    761           run_state->status.Update(ret);
    762         }
    763         run_state->executors_done.Notify();
    764       });
    765 
    766   args.rendezvous = run_state->rendez;
    767   args.cancellation_manager = cancellation_manager_;
    768   args.runner = [this, pool](Executor::Args::Closure c) {
    769     SchedClosure(pool, std::move(c));
    770   };
    771   args.session_state = &session_state_;
    772   args.tensor_store = &run_state->tensor_store;
    773   args.step_container = &run_state->step_container;
    774   if (LogMemory::IsEnabled()) {
    775     LogMemory::RecordStep(args.step_id, run_state_args.handle);
    776   }
    777   args.sync_on_finish = sync_on_finish_;
    778 
    779   if (options_.config.graph_options().build_cost_model()) {
    780     run_state->collector.reset(new StepStatsCollector(nullptr));
    781     args.stats_collector = run_state->collector.get();
    782   }
    783 
    784   for (auto& item : executors_and_keys->items) {
    785     item.executor->RunAsync(args, barrier->Get());
    786   }
    787 
    788   *handle = run_state_args.handle;
    789   return Status::OK();
    790 }
    791 
    792 Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
    793                            const std::vector<string>& output_names,
    794                            std::vector<Tensor>* outputs) {
    795   TF_RETURN_IF_ERROR(CheckNotClosed());
    796   std::vector<string> parts = str_util::Split(handle, ';');
    797   const string& key = parts[0];
    798   // Get the executors for this partial run.
    799   ExecutorsAndKeys* executors_and_keys;
    800   RunState* run_state;
    801   {
    802     mutex_lock l(executor_lock_);  // could use reader lock
    803     auto exc_it = executors_.find(key);
    804     if (exc_it == executors_.end()) {
    805       return errors::InvalidArgument(
    806           "Must run 'setup' before performing partial runs!");
    807     }
    808     executors_and_keys = exc_it->second.get();
    809 
    810     auto prun_it = partial_runs_.find(handle);
    811     if (prun_it == partial_runs_.end()) {
    812       return errors::InvalidArgument(
    813           "Must run 'setup' before performing partial runs!");
    814     }
    815     run_state = prun_it->second.get();
    816 
    817     // Make sure that this is a new set of feeds that are still pending.
    818     for (const auto& input : inputs) {
    819       auto it = run_state->pending_inputs.find(input.first);
    820       if (it == run_state->pending_inputs.end()) {
    821         return errors::InvalidArgument(
    822             "The feed ", input.first,
    823             " was not specified in partial_run_setup.");
    824       } else if (it->second) {
    825         return errors::InvalidArgument("The feed ", input.first,
    826                                        " has already been fed.");
    827       }
    828     }
    829     // Check that this is a new set of fetches that are still pending.
    830     for (const auto& output : output_names) {
    831       auto it = run_state->pending_outputs.find(output);
    832       if (it == run_state->pending_outputs.end()) {
    833         return errors::InvalidArgument(
    834             "The fetch ", output, " was not specified in partial_run_setup.");
    835       } else if (it->second) {
    836         return errors::InvalidArgument("The fetch ", output,
    837                                        " has already been fetched.");
    838       }
    839     }
    840   }
    841 
    842   // Check that this new set of fetches can be computed from all the
    843   // feeds we have supplied.
    844   TF_RETURN_IF_ERROR(
    845       CheckFetch(inputs, output_names, executors_and_keys, run_state));
    846 
    847   // Send inputs.
    848   Status s = SendPRunInputs(inputs, executors_and_keys, run_state->rendez);
    849 
    850   // Receive outputs.
    851   if (s.ok()) {
    852     s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
    853   }
    854 
    855   // Save the output tensors of this run we choose to keep.
    856   if (s.ok()) {
    857     s = run_state->tensor_store.SaveTensors(output_names, &session_state_);
    858   }
    859 
    860   {
    861     mutex_lock l(executor_lock_);
    862     // Delete the run state if there is an error or all fetches are done.
    863     bool done = true;
    864     if (s.ok()) {
    865       {
    866         mutex_lock l(run_state->mu_);
    867         if (!run_state->status.ok()) {
    868           LOG(WARNING) << "An error unrelated to this prun has been detected. "
    869                        << run_state->status;
    870         }
    871       }
    872       for (const auto& input : inputs) {
    873         auto it = run_state->pending_inputs.find(input.first);
    874         it->second = true;
    875       }
    876       for (const auto& name : output_names) {
    877         auto it = run_state->pending_outputs.find(name);
    878         it->second = true;
    879       }
    880       done = run_state->PendingDone();
    881     }
    882     if (done) {
    883       WaitForNotification(run_state, cancellation_manager_,
    884                           operation_timeout_in_ms_);
    885       partial_runs_.erase(handle);
    886     }
    887   }
    888 
    889   return s;
    890 }
    891 
    892 Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
    893                                                   Tensor* retrieved_tensor) {
    894   if (resource_tensor.dtype() != DT_RESOURCE) {
    895     return errors::InvalidArgument(strings::StrCat(
    896         "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
    897         resource_tensor.dtype()));
    898   }
    899 
    900   const ResourceHandle& resource_handle =
    901       resource_tensor.scalar<ResourceHandle>()();
    902 
    903   if (resource_handle.container() ==
    904       SessionState::kTensorHandleResourceTypeName) {
    905     return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
    906   } else {
    907     return errors::InvalidArgument(strings::StrCat(
    908         "Invalid resource type hash code: ", resource_handle.hash_code(),
    909         "(name: ", resource_handle.name(),
    910         " type: ", resource_handle.maybe_type_name(),
    911         "). Perhaps a resource tensor was being provided as a feed? That is "
    912         "not currently allowed. Please file an issue at "
    913         "https://github.com/tensorflow/tensorflow/issues/new, ideally with a "
    914         "short code snippet that leads to this error message."));
    915   }
    916 }
    917 
    918 Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
    919                                      const ExecutorsAndKeys* executors_and_keys,
    920                                      IntraProcessRendezvous* rendez) {
    921   Status s;
    922   Rendezvous::ParsedKey parsed;
    923   // Insert the input tensors into the local rendezvous by their
    924   // rendezvous key.
    925   for (const auto& input : inputs) {
    926     auto it =
    927         executors_and_keys->input_name_to_rendezvous_key.find(input.first);
    928     if (it == executors_and_keys->input_name_to_rendezvous_key.end()) {
    929       return errors::Internal("'", input.first, "' is not a pre-defined feed.");
    930     }
    931     const string& input_key = it->second;
    932 
    933     s = Rendezvous::ParseKey(input_key, &parsed);
    934     if (!s.ok()) {
    935       rendez->StartAbort(s);
    936       return s;
    937     }
    938 
    939     if (input.second.dtype() == DT_RESOURCE) {
    940       Tensor tensor_from_handle;
    941       s = ResourceHandleToInputTensor(input.second, &tensor_from_handle);
    942       if (s.ok()) {
    943         s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false);
    944       }
    945     } else {
    946       s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
    947     }
    948 
    949     if (!s.ok()) {
    950       rendez->StartAbort(s);
    951       return s;
    952     }
    953   }
    954   return Status::OK();
    955 }
    956 
    957 Status DirectSession::RecvPRunOutputs(
    958     const std::vector<string>& output_names,
    959     const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
    960     std::vector<Tensor>* outputs) {
    961   Status s;
    962   if (!output_names.empty()) {
    963     outputs->resize(output_names.size());
    964   }
    965 
    966   Rendezvous::ParsedKey parsed;
    967   // Get the outputs from the rendezvous
    968   for (size_t output_offset = 0; output_offset < output_names.size();
    969        ++output_offset) {
    970     const string& output_name = output_names[output_offset];
    971     auto it =
    972         executors_and_keys->output_name_to_rendezvous_key.find(output_name);
    973     if (it == executors_and_keys->output_name_to_rendezvous_key.end()) {
    974       return errors::Internal("'", output_name,
    975                               "' is not a pre-defined fetch.");
    976     }
    977     const string& output_key = it->second;
    978     Tensor output_tensor;
    979     bool is_dead;
    980     IntraProcessRendezvous* rendez = run_state->rendez;
    981 
    982     s = Rendezvous::ParseKey(output_key, &parsed);
    983     if (s.ok()) {
    984       // Fetch data from the Rendezvous.
    985       s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead,
    986                        operation_timeout_in_ms_);
    987       if (is_dead && s.ok()) {
    988         s = errors::InvalidArgument("The tensor returned for ", output_name,
    989                                     " was not valid.");
    990       }
    991     }
    992     if (!s.ok()) {
    993       rendez->StartAbort(s);
    994       outputs->clear();
    995       return s;
    996     }
    997 
    998     (*outputs)[output_offset] = output_tensor;
    999   }
   1000   return Status::OK();
   1001 }
   1002 
   1003 Status DirectSession::CheckFetch(const NamedTensorList& feeds,
   1004                                  const std::vector<string>& fetches,
   1005                                  const ExecutorsAndKeys* executors_and_keys,
   1006                                  const RunState* run_state) {
   1007   const Graph* graph = executors_and_keys->graph.get();
   1008   const NameNodeMap* name_to_node = &executors_and_keys->name_to_node;
   1009 
   1010   // Build the set of pending feeds that we haven't seen.
   1011   std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
   1012   {
   1013     mutex_lock l(executor_lock_);
   1014     for (const auto& input : run_state->pending_inputs) {
   1015       // Skip if the feed has already been fed.
   1016       if (input.second) continue;
   1017       TensorId id(ParseTensorName(input.first));
   1018       auto it = name_to_node->find(id.first);
   1019       if (it == name_to_node->end()) {
   1020         return errors::NotFound("Feed ", input.first, ": not found");
   1021       }
   1022       pending_feeds.insert(id);
   1023     }
   1024   }
   1025   for (const auto& it : feeds) {
   1026     TensorId id(ParseTensorName(it.first));
   1027     pending_feeds.erase(id);
   1028   }
   1029 
   1030   // Initialize the stack with the fetch nodes.
   1031   std::vector<const Node*> stack;
   1032   for (const string& fetch : fetches) {
   1033     TensorId id(ParseTensorName(fetch));
   1034     auto it = name_to_node->find(id.first);
   1035     if (it == name_to_node->end()) {
   1036       return errors::NotFound("Fetch ", fetch, ": not found");
   1037     }
   1038     stack.push_back(it->second);
   1039   }
   1040 
   1041   // Any tensor needed for fetches can't be in pending_feeds.
   1042   std::vector<bool> visited(graph->num_node_ids(), false);
   1043   while (!stack.empty()) {
   1044     const Node* n = stack.back();
   1045     stack.pop_back();
   1046 
   1047     for (const Edge* in_edge : n->in_edges()) {
   1048       const Node* in_node = in_edge->src();
   1049       if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
   1050         return errors::InvalidArgument("Fetch ", in_node->name(), ":",
   1051                                        in_edge->src_output(),
   1052                                        " can't be computed from the feeds"
   1053                                        " that have been fed so far.");
   1054       }
   1055       if (!visited[in_node->id()]) {
   1056         visited[in_node->id()] = true;
   1057         stack.push_back(in_node);
   1058       }
   1059     }
   1060   }
   1061   return Status::OK();
   1062 }
   1063 
   1064 Status DirectSession::GetOrCreateExecutors(
   1065     gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
   1066     gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
   1067     RunStateArgs* run_state_args) {
   1068   int64 handle_name_counter_value = -1;
   1069   if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
   1070     handle_name_counter_value = handle_name_counter_.fetch_add(1);
   1071   }
   1072 
   1073   string debug_tensor_watches_summary;
   1074   if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
   1075     debug_tensor_watches_summary = SummarizeDebugTensorWatches(
   1076         run_state_args->debug_options.debug_tensor_watch_opts());
   1077   }
   1078 
   1079   // Fast lookup path, no sorting.
   1080   const string key = strings::StrCat(
   1081       str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/",
   1082       str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run,
   1083       "/", debug_tensor_watches_summary);
   1084   // Set the handle, if it's needed to log memory or for partial run.
   1085   if (handle_name_counter_value >= 0) {
   1086     run_state_args->handle =
   1087         strings::StrCat(key, ";", handle_name_counter_value);
   1088   }
   1089 
   1090   // See if we already have the executors for this run.
   1091   {
   1092     mutex_lock l(executor_lock_);  // could use reader lock
   1093     auto it = executors_.find(key);
   1094     if (it != executors_.end()) {
   1095       *executors_and_keys = it->second.get();
   1096       return Status::OK();
   1097     }
   1098   }
   1099 
   1100   // Slow lookup path, the unsorted key missed the cache.
   1101   // Sort the inputs and outputs, and look up with the sorted key in case an
   1102   // earlier call used a different order of inputs and outputs.
   1103   //
   1104   // We could consider some other signature instead of sorting that
   1105   // preserves the same property to avoid the sort in the future.
   1106   std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
   1107   std::sort(inputs_sorted.begin(), inputs_sorted.end());
   1108   std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
   1109   std::sort(outputs_sorted.begin(), outputs_sorted.end());
   1110   std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
   1111   std::sort(tn_sorted.begin(), tn_sorted.end());
   1112 
   1113   const string sorted_key = strings::StrCat(
   1114       str_util::Join(inputs_sorted, ","), "->",
   1115       str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
   1116       "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
   1117   // Set the handle, if its needed to log memory or for partial run.
   1118   if (handle_name_counter_value >= 0) {
   1119     run_state_args->handle =
   1120         strings::StrCat(sorted_key, ";", handle_name_counter_value);
   1121   }
   1122 
   1123   // See if we already have the executors for this run.
   1124   {
   1125     mutex_lock l(executor_lock_);
   1126     auto it = executors_.find(sorted_key);
   1127     if (it != executors_.end()) {
   1128       *executors_and_keys = it->second.get();
   1129       // Insert this under the original key.
   1130       executors_.emplace(key, it->second);
   1131       return Status::OK();
   1132     }
   1133   }
   1134 
   1135   // Nothing found, so create the executors and store in the cache.
   1136   BuildGraphOptions options;
   1137   options.feed_endpoints = inputs_sorted;
   1138   options.fetch_endpoints = outputs_sorted;
   1139   options.target_nodes = tn_sorted;
   1140   options.use_function_convention = !run_state_args->is_partial_run;
   1141   if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
   1142     options.debug_options = run_state_args->debug_options;
   1143   }
   1144 
   1145   std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
   1146   std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
   1147 
   1148   // The executor_lock_ is intentionally released while executor is
   1149   // being created.
   1150   std::unordered_map<string, std::unique_ptr<Graph>> graphs;
   1151   TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &func_info->flib_def,
   1152                                   run_state_args, &ek->input_types,
   1153                                   &ek->output_types));
   1154 
   1155   if (run_state_args->is_partial_run) {
   1156     ek->graph = std::move(run_state_args->graph);
   1157     std::unordered_set<StringPiece, StringPieceHasher> names;
   1158     for (const string& input : inputs) {
   1159       TensorId id(ParseTensorName(input));
   1160       names.emplace(id.first);
   1161     }
   1162     for (const string& output : outputs) {
   1163       TensorId id(ParseTensorName(output));
   1164       names.emplace(id.first);
   1165     }
   1166     for (Node* n : ek->graph->nodes()) {
   1167       if (names.count(n->name()) > 0) {
   1168         ek->name_to_node.insert({n->name(), n});
   1169       }
   1170     }
   1171   }
   1172   ek->items.reserve(graphs.size());
   1173   const auto& optimizer_opts =
   1174       options_.config.graph_options().optimizer_options();
   1175 
   1176   int graph_def_version;
   1177   {
   1178     mutex_lock l(graph_def_lock_);
   1179     graph_def_version =
   1180         execution_state_->original_graph_def().versions().producer();
   1181   }
   1182   func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
   1183       device_mgr_.get(), options_.env, graph_def_version,
   1184       func_info->flib_def.get(), optimizer_opts));
   1185 
   1186   GraphOptimizer optimizer(optimizer_opts);
   1187   for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
   1188     const string& partition_name = iter->first;
   1189     std::unique_ptr<Graph>& partition_graph = iter->second;
   1190 
   1191     Device* device;
   1192     TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
   1193 
   1194     ek->items.resize(ek->items.size() + 1);
   1195     auto* item = &(ek->items.back());
   1196     auto lib = func_info->proc_flr->GetFLR(partition_name);
   1197     if (lib == nullptr) {
   1198       return errors::Internal("Could not find device: ", partition_name);
   1199     }
   1200     item->flib = lib;
   1201 
   1202     LocalExecutorParams params;
   1203     params.device = device;
   1204     params.function_library = lib;
   1205     auto opseg = device->op_segment();
   1206     params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
   1207                                               OpKernel** kernel) {
   1208       // We do not share the kernel via the OpSegment if the node is
   1209       // stateless, or a function.
   1210       // NOTE(mrry): We must not share function kernels (implemented
   1211       // using `CallOp`) between subgraphs, because `CallOp::handle_`
   1212       // is tied to a particular subgraph. Even if the function itself
   1213       // is stateful, the `CallOp` that invokes it is not.
   1214       if (!lib->IsStateful(ndef.op()) ||
   1215           lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
   1216         return lib->CreateKernel(ndef, kernel);
   1217       }
   1218       auto create_fn = [lib, &ndef](OpKernel** kernel) {
   1219         return lib->CreateKernel(ndef, kernel);
   1220       };
   1221       // Kernels created for subgraph nodes need to be cached.  On
   1222       // cache miss, create_fn() is invoked to create a kernel based
   1223       // on the function library here + global op registry.
   1224       return opseg->FindOrCreate(session_handle_, ndef.name(), kernel,
   1225                                  create_fn);
   1226     };
   1227     params.delete_kernel = [lib](OpKernel* kernel) {
   1228       // If the node is stateful, opseg owns it. Otherwise, delete it.
   1229       if (kernel && !lib->IsStateful(kernel->type_string())) {
   1230         delete kernel;
   1231       }
   1232     };
   1233     params.node_outputs_cb = node_outputs_callback_;
   1234 
   1235     optimizer.Optimize(lib, options_.env, device, &iter->second,
   1236                        /*shape_map=*/nullptr);
   1237 
   1238     // EXPERIMENTAL: tfdbg inserts debug nodes in the graph.
   1239     if (!options.debug_options.debug_tensor_watch_opts().empty()) {
   1240       TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
   1241           options.debug_options, partition_graph.get(), params.device));
   1242     }
   1243 
   1244     TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
   1245                                          device->name(),
   1246                                          partition_graph.get()));
   1247     // NewLocalExecutor takes ownership of partition_graph.
   1248     item->graph = partition_graph.get();
   1249     item->executor = nullptr;
   1250     item->device = device;
   1251     Executor* executor;
   1252     TF_RETURN_IF_ERROR(
   1253         NewLocalExecutor(params, std::move(partition_graph), &executor));
   1254     item->executor.reset(executor);
   1255   }
   1256 
   1257   // Cache the mapping from input/output names to graph elements to
   1258   // avoid recomputing it every time.
   1259   if (!run_state_args->is_partial_run) {
   1260     // For regular `Run()`, we use the function calling convention, and so
   1261     // maintain a mapping from input/output names to
   1262     // argument/return-value ordinal index.
   1263     for (size_t i = 0; i < inputs_sorted.size(); ++i) {
   1264       const string& input = inputs_sorted[i];
   1265       ek->input_name_to_index[input] = i;
   1266     }
   1267     for (size_t i = 0; i < outputs_sorted.size(); ++i) {
   1268       const string& output = outputs_sorted[i];
   1269       ek->output_name_to_index[output] = i;
   1270     }
   1271   } else {
   1272     // For `PRun()`, we use the rendezvous calling convention, and so
   1273     // maintain a mapping from input/output names to rendezvous keys.
   1274     //
   1275     // We always use the first device as the device name portion of the
   1276     // key, even if we're feeding another graph.
   1277     for (size_t i = 0; i < inputs_sorted.size(); ++i) {
   1278       const string& input = inputs_sorted[i];
   1279       ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
   1280           input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
   1281     }
   1282     for (size_t i = 0; i < outputs_sorted.size(); ++i) {
   1283       const string& output = outputs_sorted[i];
   1284       ek->output_name_to_rendezvous_key[output] =
   1285           GetRendezvousKey(output, device_set_.client_device()->attributes(),
   1286                            FrameAndIter(0, 0));
   1287     }
   1288   }
   1289 
   1290   // Reacquire the lock, try to insert into the map.
   1291   mutex_lock l(executor_lock_);
   1292   functions_.push_back(std::move(func_info));
   1293 
   1294   // Another thread may have created the entry before us, in which case we will
   1295   // reuse the already created one.
   1296   auto insert_result = executors_.emplace(sorted_key, ek);
   1297   // Insert the value under the original key, so the fast path lookup will work
   1298   // if the user uses the same order of inputs, outputs, and targets again.
   1299   executors_.emplace(key, insert_result.first->second);
   1300   *executors_and_keys = insert_result.first->second.get();
   1301 
   1302   return Status::OK();
   1303 }
   1304 
   1305 Status DirectSession::CreateGraphs(
   1306     const BuildGraphOptions& subgraph_options,
   1307     std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
   1308     std::unique_ptr<FunctionLibraryDefinition>* flib_def,
   1309     RunStateArgs* run_state_args, DataTypeVector* input_types,
   1310     DataTypeVector* output_types) {
   1311   mutex_lock l(graph_def_lock_);
   1312   std::unique_ptr<ClientGraph> client_graph;
   1313 
   1314   std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
   1315   GraphExecutionState* execution_state = nullptr;
   1316   if (options_.config.graph_options().place_pruned_graph()) {
   1317     // Because we are placing pruned graphs, we need to create a
   1318     // new GraphExecutionState for every new unseen graph,
   1319     // and then place it.
   1320     GraphExecutionStateOptions prune_options;
   1321     prune_options.device_set = &device_set_;
   1322     prune_options.session_options = &options_;
   1323     prune_options.stateful_placements = stateful_placements_;
   1324     TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
   1325         execution_state_->original_graph_def().library(), prune_options,
   1326         execution_state_->original_graph_def(), subgraph_options,
   1327         &temp_exec_state_holder, &client_graph));
   1328     execution_state = temp_exec_state_holder.get();
   1329   } else {
   1330     execution_state = execution_state_.get();
   1331     TF_RETURN_IF_ERROR(
   1332         execution_state->BuildGraph(subgraph_options, &client_graph));
   1333   }
   1334 
   1335   if (subgraph_options.feed_endpoints.size() !=
   1336       client_graph->feed_types.size()) {
   1337     return errors::Internal(
   1338         "Graph pruning failed: requested number of feed endpoints = ",
   1339         subgraph_options.feed_endpoints.size(),
   1340         " versus number of pruned feed endpoints = ",
   1341         client_graph->feed_types.size());
   1342   }
   1343   if (subgraph_options.fetch_endpoints.size() !=
   1344       client_graph->fetch_types.size()) {
   1345     return errors::Internal(
   1346         "Graph pruning failed: requested number of fetch endpoints = ",
   1347         subgraph_options.fetch_endpoints.size(),
   1348         " versus number of pruned fetch endpoints = ",
   1349         client_graph->fetch_types.size());
   1350   }
   1351 
   1352   auto current_stateful_placements = execution_state->GetStatefulPlacements();
   1353   // Update our current state based on the execution_state's
   1354   // placements.  If there are any mismatches for a node,
   1355   // we should fail, as this should never happen.
   1356   for (auto placement_pair : current_stateful_placements) {
   1357     const string& node_name = placement_pair.first;
   1358     const string& placement = placement_pair.second;
   1359     auto iter = stateful_placements_.find(node_name);
   1360     if (iter == stateful_placements_.end()) {
   1361       stateful_placements_.insert(std::make_pair(node_name, placement));
   1362     } else if (iter->second != placement) {
   1363       return errors::Internal(
   1364           "Stateful placement mismatch. "
   1365           "Current assignment of ",
   1366           node_name, " to ", iter->second, " does not match ", placement);
   1367     }
   1368   }
   1369 
   1370   stateful_placements_ = execution_state->GetStatefulPlacements();
   1371 
   1372   // Remember the graph in run state if this is a partial run.
   1373   if (run_state_args->is_partial_run) {
   1374     run_state_args->graph.reset(new Graph(flib_def_.get()));
   1375     CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
   1376   }
   1377 
   1378   // Partition the graph across devices.
   1379   PartitionOptions popts;
   1380   popts.node_to_loc = [](const Node* node) {
   1381     return node->assigned_device_name();
   1382   };
   1383   popts.new_name = [this](const string& prefix) {
   1384     return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1));
   1385   };
   1386   popts.get_incarnation = [](const string& name) {
   1387     // The direct session does not have changing incarnation numbers.
   1388     // Just return '1'.
   1389     return 1;
   1390   };
   1391   popts.flib_def = &client_graph->graph.flib_def();
   1392   popts.control_flow_added = false;
   1393 
   1394   std::unordered_map<string, GraphDef> partitions;
   1395   TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions));
   1396 
   1397   std::vector<string> device_names;
   1398   for (auto device : devices_) {
   1399     // Extract the LocalName from the device.
   1400     device_names.push_back(DeviceNameUtils::LocalName(device->name()));
   1401   }
   1402 
   1403   // Check for valid partitions.
   1404   for (const auto& partition : partitions) {
   1405     const string local_partition_name =
   1406         DeviceNameUtils::LocalName(partition.first);
   1407     if (std::count(device_names.begin(), device_names.end(),
   1408                    local_partition_name) == 0) {
   1409       return errors::InvalidArgument(
   1410           "Creating a partition for ", local_partition_name,
   1411           " which doesn't exist in the list of available devices. Available "
   1412           "devices: ",
   1413           str_util::Join(device_names, ","));
   1414     }
   1415   }
   1416 
   1417   for (const auto& partition : partitions) {
   1418     std::unique_ptr<Graph> device_graph(
   1419         new Graph(client_graph->flib_def.get()));
   1420     GraphConstructorOptions device_opts;
   1421     // There are internal operations (e.g., send/recv) that we now allow.
   1422     device_opts.allow_internal_ops = true;
   1423     device_opts.expect_device_spec = true;
   1424     TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
   1425                                               device_graph.get()));
   1426     outputs->emplace(partition.first, std::move(device_graph));
   1427   }
   1428 
   1429   GraphOptimizationPassOptions optimization_options;
   1430   optimization_options.session_options = &options_;
   1431   optimization_options.flib_def = client_graph->flib_def.get();
   1432   optimization_options.partition_graphs = outputs;
   1433   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
   1434       OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
   1435 
   1436   Status s;
   1437   for (auto& partition : *outputs) {
   1438     const string& partition_name = partition.first;
   1439     std::unique_ptr<Graph>* graph = &partition.second;
   1440 
   1441     VLOG(2) << "Created " << DebugString(graph->get()) << " for "
   1442             << partition_name;
   1443 
   1444     // Give the device an opportunity to rewrite its subgraph.
   1445     Device* d;
   1446     s = device_mgr_->LookupDevice(partition_name, &d);
   1447     if (!s.ok()) break;
   1448     s = d->MaybeRewriteGraph(graph);
   1449     if (!s.ok()) {
   1450       break;
   1451     }
   1452   }
   1453   *flib_def = std::move(client_graph->flib_def);
   1454   std::swap(*input_types, client_graph->feed_types);
   1455   std::swap(*output_types, client_graph->fetch_types);
   1456   return s;
   1457 }
   1458 
   1459 ::tensorflow::Status DirectSession::ListDevices(
   1460     std::vector<DeviceAttributes>* response) {
   1461   response->clear();
   1462   response->reserve(devices_.size());
   1463   for (Device* d : devices_) {
   1464     const DeviceAttributes& attrs = d->attributes();
   1465     response->emplace_back(attrs);
   1466   }
   1467   return ::tensorflow::Status::OK();
   1468 }
   1469 
   1470 ::tensorflow::Status DirectSession::Reset(
   1471     const std::vector<string>& containers) {
   1472   device_mgr_->ClearContainers(containers);
   1473   return ::tensorflow::Status::OK();
   1474 }
   1475 
   1476 ::tensorflow::Status DirectSession::Close() {
   1477   cancellation_manager_->StartCancel();
   1478   {
   1479     mutex_lock l(closed_lock_);
   1480     if (closed_) return ::tensorflow::Status::OK();
   1481     closed_ = true;
   1482   }
   1483   if (factory_ != nullptr) factory_->Deregister(this);
   1484   return ::tensorflow::Status::OK();
   1485 }
   1486 
   1487 DirectSession::RunState::RunState(
   1488     const std::vector<string>& pending_input_names,
   1489     const std::vector<string>& pending_output_names, int64 step_id,
   1490     const std::vector<Device*>* devices)
   1491     : step_container(step_id, [devices](const string& name) {
   1492         for (auto d : *devices) {
   1493           if (!d->resource_manager()->Cleanup(name).ok()) {
   1494             // Do nothing...
   1495           }
   1496         }
   1497       }) {
   1498   // Initially all the feeds and fetches are pending.
   1499   for (auto& name : pending_input_names) {
   1500     pending_inputs[name] = false;
   1501   }
   1502   for (auto& name : pending_output_names) {
   1503     pending_outputs[name] = false;
   1504   }
   1505 }
   1506 
   1507 DirectSession::RunState::RunState(int64 step_id,
   1508                                   const std::vector<Device*>* devices)
   1509     : RunState({}, {}, step_id, devices) {}
   1510 
   1511 DirectSession::RunState::~RunState() {
   1512   if (rendez != nullptr) {
   1513     if (!executors_done.HasBeenNotified()) {
   1514       rendez->StartAbort(errors::Cancelled("PRun cancellation"));
   1515       executors_done.WaitForNotification();
   1516     }
   1517     rendez->Unref();
   1518   }
   1519 }
   1520 
   1521 bool DirectSession::RunState::PendingDone() const {
   1522   for (const auto& it : pending_inputs) {
   1523     if (!it.second) return false;
   1524   }
   1525   for (const auto& it : pending_outputs) {
   1526     if (!it.second) return false;
   1527   }
   1528   return true;
   1529 }
   1530 
   1531 void DirectSession::WaitForNotification(RunState* run_state,
   1532                                         CancellationManager* cm,
   1533                                         int64 timeout_in_ms) {
   1534   const Status status =
   1535       WaitForNotification(&run_state->executors_done, timeout_in_ms);
   1536   if (!status.ok()) {
   1537     {
   1538       mutex_lock l(run_state->mu_);
   1539       run_state->status.Update(status);
   1540     }
   1541     cm->StartCancel();
   1542     // We must wait for the executors to complete, because they have borrowed
   1543     // references to `cm` and other per-step state. After this notification, it
   1544     // is safe to clean up the step.
   1545     run_state->executors_done.WaitForNotification();
   1546   }
   1547 }
   1548 
   1549 ::tensorflow::Status DirectSession::WaitForNotification(
   1550     Notification* notification, int64 timeout_in_ms) {
   1551   if (timeout_in_ms > 0) {
   1552     const int64 timeout_in_us = timeout_in_ms * 1000;
   1553     const bool notified =
   1554         WaitForNotificationWithTimeout(notification, timeout_in_us);
   1555     if (!notified) {
   1556       return Status(error::DEADLINE_EXCEEDED,
   1557                     "Timed out waiting for notification");
   1558     }
   1559   } else {
   1560     notification->WaitForNotification();
   1561   }
   1562   return Status::OK();
   1563 }
   1564 
   1565 }  // namespace tensorflow
   1566