Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/master_session.h"
     17 
     18 #include <unordered_map>
     19 #include <unordered_set>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/common_runtime/process_util.h"
     23 #include "tensorflow/core/common_runtime/profile_handler.h"
     24 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
     25 #include "tensorflow/core/debug/debug_graph_utils.h"
     26 #include "tensorflow/core/distributed_runtime/scheduler.h"
     27 #include "tensorflow/core/distributed_runtime/worker_cache.h"
     28 #include "tensorflow/core/distributed_runtime/worker_interface.h"
     29 #include "tensorflow/core/framework/allocation_description.pb.h"
     30 #include "tensorflow/core/framework/cost_graph.pb.h"
     31 #include "tensorflow/core/framework/node_def.pb.h"
     32 #include "tensorflow/core/framework/node_def_util.h"
     33 #include "tensorflow/core/framework/tensor.h"
     34 #include "tensorflow/core/framework/tensor_description.pb.h"
     35 #include "tensorflow/core/graph/graph_partition.h"
     36 #include "tensorflow/core/graph/tensor_id.h"
     37 #include "tensorflow/core/lib/core/blocking_counter.h"
     38 #include "tensorflow/core/lib/core/notification.h"
     39 #include "tensorflow/core/lib/core/refcount.h"
     40 #include "tensorflow/core/lib/core/status.h"
     41 #include "tensorflow/core/lib/gtl/cleanup.h"
     42 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     43 #include "tensorflow/core/lib/gtl/map_util.h"
     44 #include "tensorflow/core/lib/random/random.h"
     45 #include "tensorflow/core/lib/strings/numbers.h"
     46 #include "tensorflow/core/lib/strings/str_util.h"
     47 #include "tensorflow/core/lib/strings/strcat.h"
     48 #include "tensorflow/core/lib/strings/stringprintf.h"
     49 #include "tensorflow/core/platform/env.h"
     50 #include "tensorflow/core/platform/logging.h"
     51 #include "tensorflow/core/platform/macros.h"
     52 #include "tensorflow/core/platform/mutex.h"
     53 #include "tensorflow/core/platform/tracing.h"
     54 #include "tensorflow/core/public/session_options.h"
     55 
     56 namespace tensorflow {
     57 
     58 // MasterSession wraps ClientGraph in a reference counted object.
     59 // This way, MasterSession can clear up the cache mapping Run requests to
     60 // compiled graphs while the compiled graph is still being used.
     61 //
     62 // TODO(zhifengc): Cleanup this class. It's becoming messy.
     63 class MasterSession::ReffedClientGraph : public core::RefCounted {
     64  public:
     65   ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
     66                     std::unique_ptr<ClientGraph> cg,
     67                     const SessionOptions& session_opts,
     68                     const StatsPublisherFactory& stats_publisher_factory,
     69                     GraphExecutionState* execution_state, bool is_partial,
     70                     WorkerCacheInterface* worker_cache, bool should_deregister)
     71       : session_handle_(handle),
     72         client_graph_(std::move(cg)),
     73         session_opts_(session_opts),
     74         is_partial_(is_partial),
     75         debug_opts_(bopts.debug_options),
     76         worker_cache_(worker_cache),
     77         should_deregister_(should_deregister) {
     78     VLOG(1) << "Created ReffedClientGraph for node with "
     79             << client_graph()->graph.num_node_ids();
     80 
     81     stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
     82 
     83     // Initialize a name to node map for testing that fetches are reachable.
     84     for (Node* n : execution_state->full_graph()->nodes()) {
     85       name_to_node_.insert({n->name(), n});
     86     }
     87   }
     88 
     89   ~ReffedClientGraph() override {
     90     if (should_deregister_) {
     91       DeregisterPartitions();
     92     }
     93   }
     94 
     95   const ClientGraph* client_graph() { return client_graph_.get(); }
     96 
     97   std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
     98                                                     int64 execution_count,
     99                                                     const RunOptions& ropts) {
    100     return stats_publisher_->GetProfileHandler(step, execution_count, ropts);
    101   }
    102 
    103   // Turn RPC logging on or off, both at the WorkerCache used by this
    104   // master process, and at each remote worker in use for the current
    105   // partitions.
    106   void SetRPCLogging(bool active) {
    107     worker_cache_->SetLogging(active);
    108     // Logging is a best-effort activity, so we make async calls to turn
    109     // it on/off and don't make use of the responses.
    110     for (auto& p : partitions_) {
    111       LoggingRequest* req = new LoggingRequest;
    112       req->set_rpc_logging(active);
    113       LoggingResponse* resp = new LoggingResponse;
    114       Ref();
    115       p.worker->LoggingAsync(req, resp, [this, req, resp](const Status& s) {
    116         delete req;
    117         delete resp;
    118         // ReffedClientGraph owns p.worker so we need to hold a ref to
    119         // ensure that the method doesn't attempt to access p.worker after
    120         // ReffedClient graph has deleted it.
    121         // TODO(suharshs): Simplify this ownership model.
    122         Unref();
    123       });
    124     }
    125   }
    126 
    127   // Retrieve all RPC logs data accumulated for the current step, both
    128   // from the local WorkerCache in use by this master process and from
    129   // all the remote workers executing the remote partitions.
    130   void RetrieveLogs(int64 step_id, StepStats* ss) {
    131     // Get the local data first, because it sets *ss without merging.
    132     worker_cache_->RetrieveLogs(step_id, ss);
    133 
    134     // Then merge in data from all the remote workers.
    135     LoggingRequest req;
    136     req.add_fetch_step_id(step_id);
    137     int waiting_for = partitions_.size();
    138     if (waiting_for > 0) {
    139       mutex scoped_mu;
    140       BlockingCounter all_done(waiting_for);
    141       for (auto& p : partitions_) {
    142         LoggingResponse* resp = new LoggingResponse;
    143         p.worker->LoggingAsync(
    144             &req, resp,
    145             [step_id, ss, resp, &scoped_mu, &waiting_for,
    146              &all_done](const Status& s) {
    147               {
    148                 mutex_lock l(scoped_mu);
    149                 if (s.ok()) {
    150                   for (auto& lss : resp->step()) {
    151                     if (step_id != lss.step_id()) {
    152                       LOG(ERROR) << "Wrong step_id in LoggingResponse";
    153                       continue;
    154                     }
    155                     ss->MergeFrom(lss.step_stats());
    156                   }
    157                 }
    158                 delete resp;
    159               }
    160               // Must not decrement all_done until out of critical section where
    161               // *ss is updated.
    162               all_done.DecrementCount();
    163             });
    164       }
    165       all_done.Wait();
    166     }
    167   }
    168 
    169   // Local execution methods.
    170 
    171   // Partitions the graph into subgraphs and registers them on
    172   // workers.
    173   Status RegisterPartitions(const PartitionOptions& popts);
    174 
    175   // Runs one step of all partitions.
    176   Status RunPartitions(const MasterEnv* env, int64 step_id,
    177                        int64 execution_count, PerStepState* pss,
    178                        CallOptions* opts, const RunStepRequestWrapper& req,
    179                        MutableRunStepResponseWrapper* resp,
    180                        CancellationManager* cm, const bool is_last_partial_run);
    181 
    182   // Calls workers to cleanup states for the step "step_id".  Calls
    183   // `done` when all cleanup RPCs have completed.
    184   void CleanupPartitionsAsync(int64 step_id, StatusCallback done);
    185 
    186   // Post-processing of any runtime statistics gathered during execution.
    187   void ProcessStats(int64 step_id, PerStepState* pss, ProfileHandler* ph,
    188                     const RunOptions& options, RunMetadata* resp);
    189   void ProcessDeviceStats(ProfileHandler* ph, const DeviceStepStats& ds,
    190                           bool is_rpc);
    191   // Checks that the requested fetches can be computed from the provided feeds.
    192   Status CheckFetches(const RunStepRequestWrapper& req,
    193                       const RunState* run_state,
    194                       GraphExecutionState* execution_state);
    195 
    196   string DetailText(const Node& node, const NodeExecStats& ns) {
    197     int64 tot = 0;
    198     for (auto& no : ns.output()) {
    199       tot += no.tensor_description().allocation_description().requested_bytes();
    200     }
    201     string bytes;
    202     if (tot >= 0.1 * 1048576.0) {
    203       bytes = strings::Printf("[%.1fMB] ", tot / 1048576.0);
    204     }
    205     return strings::StrCat(bytes, node.name(), " = ", node.type_string(), "(",
    206                            str_util::Join(node.requested_inputs(), ", "), ")");
    207   }
    208 
    209  private:
    210   const string session_handle_;
    211   const std::unique_ptr<ClientGraph> client_graph_;
    212   const SessionOptions session_opts_;
    213   const bool is_partial_;
    214   const DebugOptions& debug_opts_;
    215   WorkerCacheInterface* const worker_cache_;  // Not owned.
    216   std::unordered_map<StringPiece, Node*, StringPieceHasher> name_to_node_;
    217   const bool should_deregister_;
    218 
    219   // Graph partitioned into per-location subgraphs.
    220   struct Part {
    221     // Worker name.
    222     string name;
    223 
    224     // Maps feed names to rendezvous keys. Empty most of the time.
    225     std::unordered_map<string, string> feed_key;
    226 
    227     // Maps rendezvous keys to fetch names. Empty most of the time.
    228     std::unordered_map<string, string> key_fetch;
    229 
    230     // The interface to the worker. Owned.
    231     WorkerInterface* worker = nullptr;
    232 
    233     // After registeration with the worker, graph_handle identifies
    234     // this partition on the worker.
    235     string graph_handle;
    236 
    237     Part() : feed_key(3), key_fetch(3) {}
    238   };
    239 
    240   // partitions_ is immutable after RegisterPartitions() call
    241   // finishes.  RunPartitions() can access partitions_ safely without
    242   // acquiring locks.
    243   std::vector<Part> partitions_;
    244 
    245   mutable mutex mu_;
    246 
    247   // Partition initialization and registration only needs to happen
    248   // once. init_started_ && !init_done_ indicates the initialization
    249   // is on going.
    250   bool init_started_ GUARDED_BY(mu_) = false;
    251   Notification init_done_;
    252 
    253   // init_result_ remembers the initialization error if any.
    254   Status init_result_ GUARDED_BY(mu_);
    255 
    256   std::unique_ptr<StatsPublisherInterface> stats_publisher_;
    257 
    258   // Send/Recv nodes that are the result of client-added
    259   // feeds and fetches must be tracked so that the tensors
    260   // can be added to the local rendezvous.
    261   static void TrackFeedsAndFetches(Part* part, const GraphDef& graph_def,
    262                                    const PartitionOptions& popts);
    263 
    264   // The actual graph partitioning and registration implementation.
    265   Status DoBuildPartitions(
    266       PartitionOptions pots,
    267       std::unordered_map<string, GraphDef>* out_partitions);
    268   Status DoRegisterPartitions(
    269       const PartitionOptions& popts,
    270       std::unordered_map<string, GraphDef> graph_partitions);
    271 
    272   // Deregisters the partitions on the workers.  Called in the
    273   // destructor and does not wait for the rpc completion.
    274   void DeregisterPartitions();
    275 
    276   TF_DISALLOW_COPY_AND_ASSIGN(ReffedClientGraph);
    277 };
    278 
    279 Status MasterSession::ReffedClientGraph::RegisterPartitions(
    280     const PartitionOptions& popts) {
    281   {  // Ensure register once.
    282     mu_.lock();
    283     if (!init_started_) {
    284       init_started_ = true;
    285       mu_.unlock();
    286       std::unordered_map<string, GraphDef> graph_defs;
    287       Status s = DoBuildPartitions(popts, &graph_defs);
    288       if (s.ok()) {
    289         // NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain
    290         // valid after the call to DoRegisterPartitions begins, so
    291         // `stats_publisher_` must make a copy if it wants to retain the
    292         // GraphDef objects.
    293         std::vector<const GraphDef*> graph_defs_for_publishing;
    294         graph_defs_for_publishing.reserve(partitions_.size());
    295         for (const auto& name_def : graph_defs) {
    296           graph_defs_for_publishing.push_back(&name_def.second);
    297         }
    298         stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
    299         s = DoRegisterPartitions(popts, std::move(graph_defs));
    300       }
    301       mu_.lock();
    302       init_result_ = s;
    303       init_done_.Notify();
    304     } else {
    305       mu_.unlock();
    306       init_done_.WaitForNotification();
    307       mu_.lock();
    308     }
    309     const Status result = init_result_;
    310     mu_.unlock();
    311     return result;
    312   }
    313 }
    314 
    315 static string SplitByWorker(const Node* node) {
    316   string task;
    317   string device;
    318   CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task,
    319                                          &device))
    320       << "node: " << node->name() << " dev: " << node->assigned_device_name();
    321   return task;
    322 }
    323 
    324 void MasterSession::ReffedClientGraph::TrackFeedsAndFetches(
    325     Part* part, const GraphDef& graph_def, const PartitionOptions& popts) {
    326   for (int i = 0; i < graph_def.node_size(); ++i) {
    327     const NodeDef& ndef = graph_def.node(i);
    328     const bool is_recv = ndef.op() == "_Recv";
    329     const bool is_send = ndef.op() == "_Send";
    330 
    331     if (is_recv || is_send) {
    332       // Only send/recv nodes that were added as feeds and fetches
    333       // (client-terminated) should be tracked.  Other send/recv nodes
    334       // are for transferring data between partitions / memory spaces.
    335       bool client_terminated;
    336       TF_CHECK_OK(GetNodeAttr(ndef, "client_terminated", &client_terminated));
    337       if (client_terminated) {
    338         string name;
    339         TF_CHECK_OK(GetNodeAttr(ndef, "tensor_name", &name));
    340         string send_device;
    341         TF_CHECK_OK(GetNodeAttr(ndef, "send_device", &send_device));
    342         string recv_device;
    343         TF_CHECK_OK(GetNodeAttr(ndef, "recv_device", &recv_device));
    344         uint64 send_device_incarnation;
    345         TF_CHECK_OK(
    346             GetNodeAttr(ndef, "send_device_incarnation",
    347                         reinterpret_cast<int64*>(&send_device_incarnation)));
    348         const string& key =
    349             Rendezvous::CreateKey(send_device, send_device_incarnation,
    350                                   recv_device, name, FrameAndIter(0, 0));
    351 
    352         if (is_recv) {
    353           part->feed_key.insert({name, key});
    354         } else {
    355           part->key_fetch.insert({key, name});
    356         }
    357       }
    358     }
    359   }
    360 }
    361 
    362 Status MasterSession::ReffedClientGraph::DoBuildPartitions(
    363     PartitionOptions popts,
    364     std::unordered_map<string, GraphDef>* out_partitions) {
    365   if (popts.need_to_record_start_times) {
    366     CostModel cost_model(true);
    367     cost_model.InitFromGraph(client_graph()->graph);
    368     // TODO(yuanbyu): Use the real cost model.
    369     // execution_state_->MergeFromGlobal(&cost_model);
    370     SlackAnalysis sa(&client_graph()->graph, &cost_model);
    371     sa.ComputeAsap(&popts.start_times);
    372   }
    373 
    374   // Partition the graph.
    375   return Partition(popts, &client_graph_->graph, out_partitions);
    376 }
    377 
    378 Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
    379     const PartitionOptions& popts,
    380     std::unordered_map<string, GraphDef> graph_partitions) {
    381   partitions_.reserve(graph_partitions.size());
    382   Status s;
    383   for (auto& name_def : graph_partitions) {
    384     partitions_.resize(partitions_.size() + 1);
    385     Part* part = &partitions_.back();
    386     part->name = name_def.first;
    387     TrackFeedsAndFetches(part, name_def.second, popts);
    388     part->worker = worker_cache_->CreateWorker(part->name);
    389     if (part->worker == nullptr) {
    390       s = errors::NotFound("worker ", part->name);
    391       break;
    392     }
    393   }
    394   if (!s.ok()) {
    395     for (Part& part : partitions_) {
    396       worker_cache_->ReleaseWorker(part.name, part.worker);
    397     }
    398     return s;
    399   }
    400   struct Call {
    401     RegisterGraphRequest req;
    402     RegisterGraphResponse resp;
    403     Status status;
    404   };
    405   const int num = partitions_.size();
    406   gtl::InlinedVector<Call, 4> calls(num);
    407   BlockingCounter done(num);
    408   for (int i = 0; i < num; ++i) {
    409     const Part& part = partitions_[i];
    410     Call* c = &calls[i];
    411     c->req.set_session_handle(session_handle_);
    412     c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
    413     *c->req.mutable_graph_options() = session_opts_.config.graph_options();
    414     *c->req.mutable_debug_options() = debug_opts_;
    415     VLOG(2) << "Register " << c->req.graph_def().DebugString();
    416     auto cb = [c, &done](const Status& s) {
    417       c->status = s;
    418       done.DecrementCount();
    419     };
    420     part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
    421   }
    422   done.Wait();
    423   for (int i = 0; i < num; ++i) {
    424     Call* c = &calls[i];
    425     s.Update(c->status);
    426     partitions_[i].graph_handle = c->resp.graph_handle();
    427   }
    428   return s;
    429 }
    430 
    431 // Helper class to manage "num" parallel RunGraph calls.
    432 class RunManyGraphs {
    433  public:
    434   explicit RunManyGraphs(int num) : calls_(num), pending_(num) {}
    435 
    436   ~RunManyGraphs() {}
    437 
    438   // Returns the index-th call.
    439   struct Call {
    440     CallOptions opts;
    441     std::unique_ptr<MutableRunGraphRequestWrapper> req;
    442     std::unique_ptr<MutableRunGraphResponseWrapper> resp;
    443   };
    444   Call* get(int index) { return &calls_[index]; }
    445 
    446   // When the index-th call is done, updates the overall status.
    447   void WhenDone(int index, const Status& s) {
    448     TRACEPRINTF("Partition %d %s", index, s.ToString().c_str());
    449     auto resp = get(index)->resp.get();
    450     if (resp->status_code() != error::Code::OK) {
    451       // resp->status_code will only be non-OK if s.ok().
    452       mutex_lock l(mu_);
    453       UpdateStatusLocked(
    454           Status(resp->status_code(), resp->status_error_message()));
    455     } else if (!s.ok()) {
    456       mutex_lock l(mu_);
    457       UpdateStatusLocked(s);
    458     }
    459     pending_.DecrementCount();
    460   }
    461 
    462   void StartCancel() {
    463     mutex_lock l(mu_);
    464     UpdateStatusLocked(errors::Cancelled("RunManyGraphs"));
    465   }
    466 
    467   void Wait() { pending_.Wait(); }
    468 
    469   Status status() const {
    470     mutex_lock l(mu_);
    471     return status_;
    472   }
    473 
    474  private:
    475   gtl::InlinedVector<Call, 4> calls_;
    476 
    477   BlockingCounter pending_;
    478   mutable mutex mu_;
    479   Status status_ GUARDED_BY(mu_);
    480 
    481   void UpdateStatusLocked(const Status& s) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    482     if (status_.ok()) {
    483       status_ = s;
    484       for (Call& call : calls_) {
    485         call.opts.StartCancel();
    486       }
    487     }
    488   }
    489 
    490   TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
    491 };
    492 
    493 Status MasterSession::ReffedClientGraph::RunPartitions(
    494     const MasterEnv* env, int64 step_id, int64 execution_count,
    495     PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req,
    496     MutableRunStepResponseWrapper* resp, CancellationManager* cm,
    497     const bool is_last_partial_run) {
    498   VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
    499           << execution_count;
    500   // Maps the names of fed tensors to their index in `req`.
    501   std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
    502 
    503   for (size_t i = 0; i < req.num_feeds(); ++i) {
    504     if (!feeds.insert({req.feed_name(i), i}).second) {
    505       return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i));
    506     }
    507   }
    508 
    509   // Prepares a number of calls to workers. One call per partition.
    510 
    511   // Collect execution cost stats on a smoothly decreasing frequency.
    512   ExecutorOpts exec_opts;
    513   if (pss->report_tensor_allocations_upon_oom) {
    514     exec_opts.set_report_tensor_allocations_upon_oom(true);
    515   }
    516   if (pss->collect_costs) {
    517     exec_opts.set_record_costs(true);
    518   }
    519   if (pss->collect_timeline) {
    520     exec_opts.set_record_timeline(true);
    521   }
    522   if (pss->collect_rpcs) {
    523     SetRPCLogging(true);
    524   }
    525   if (pss->collect_partition_graphs) {
    526     exec_opts.set_record_partition_graphs(true);
    527   }
    528   if (pss->collect_costs || pss->collect_timeline) {
    529     pss->step_stats.resize(partitions_.size());
    530   }
    531 
    532   const int num = partitions_.size();
    533   RunManyGraphs calls(num);
    534 
    535   for (int i = 0; i < num; ++i) {
    536     const Part& part = partitions_[i];
    537     RunManyGraphs::Call* c = calls.get(i);
    538     c->req.reset(part.worker->CreateRunGraphRequest());
    539     c->resp.reset(part.worker->CreateRunGraphResponse());
    540     if (is_partial_) {
    541       c->req->set_is_partial(is_partial_);
    542       c->req->set_is_last_partial_run(is_last_partial_run);
    543     }
    544     c->req->set_session_handle(session_handle_);
    545     c->req->set_graph_handle(part.graph_handle);
    546     c->req->set_step_id(step_id);
    547     *c->req->mutable_exec_opts() = exec_opts;
    548     c->req->set_store_errors_in_response_body(true);
    549     // If any feeds are provided, send the feed values together
    550     // in the RunGraph request.
    551     // In the partial case, we only want to include feeds provided in the req.
    552     // In the non-partial case, all feeds in the request are in the part.
    553     // We keep these as separate paths for now, to ensure we aren't
    554     // inadvertently slowing down the normal run path.
    555     if (is_partial_) {
    556       for (size_t i = 0; i < req.num_feeds(); ++i) {
    557         const string& name = req.feed_name(i);
    558         const auto iter = part.feed_key.find(name);
    559         if (iter == part.feed_key.end()) {
    560           // The provided feed must be for a different partition.
    561           continue;
    562         }
    563         const string& key = iter->second;
    564         auto feeds_iter = feeds.find(name);
    565         if (feeds_iter == feeds.end()) {
    566           return errors::InvalidArgument("No feed is provided for feed=", name,
    567                                          ", key=", key);
    568         } else if (feeds_iter->second != static_cast<size_t>(i)) {
    569           return errors::Internal("Cannot find feed named \"", name,
    570                                   " in request.");
    571         }
    572         TF_RETURN_IF_ERROR(c->req->AddSendFromRunStepRequest(req, i, key));
    573       }
    574       // TODO(suharshs): Make a map from feed to fetch_key to make this faster.
    575       // For now, we just iterate through partitions to find the matching key.
    576       for (int i = 0; static_cast<size_t>(i) < req.num_fetches(); ++i) {
    577         const string& req_fetch = req.fetch_name(i);
    578         for (const auto& key_fetch : part.key_fetch) {
    579           if (key_fetch.second == req_fetch) {
    580             c->req->add_recv_key(key_fetch.first);
    581             break;
    582           }
    583         }
    584       }
    585     } else {
    586       for (const auto& feed_key : part.feed_key) {
    587         const string& feed = feed_key.first;
    588         const string& key = feed_key.second;
    589         const int64 feed_index = feeds[feed];
    590         TF_RETURN_IF_ERROR(
    591             c->req->AddSendFromRunStepRequest(req, feed_index, key));
    592       }
    593       for (const auto& key_fetch : part.key_fetch) {
    594         const string& key = key_fetch.first;
    595         c->req->add_recv_key(key);
    596       }
    597     }
    598   }
    599 
    600   // Issues RunGraph calls.
    601   for (int i = 0; i < num; ++i) {
    602     const Part& part = partitions_[i];
    603     RunManyGraphs::Call* call = calls.get(i);
    604     TRACEPRINTF("Partition %d %s", i, part.name.c_str());
    605     part.worker->RunGraphAsync(
    606         &call->opts, call->req.get(), call->resp.get(),
    607         std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
    608   }
    609 
    610   // Waits for the RunGraph calls.
    611   call_opts->SetCancelCallback([&calls]() { calls.StartCancel(); });
    612   auto token = cm->get_cancellation_token();
    613   const bool success =
    614       cm->RegisterCallback(token, [&calls]() { calls.StartCancel(); });
    615   if (!success) {
    616     calls.StartCancel();
    617   }
    618   calls.Wait();
    619   call_opts->ClearCancelCallback();
    620   if (success) {
    621     cm->DeregisterCallback(token);
    622   } else {
    623     return errors::Cancelled("Step was cancelled");
    624   }
    625 
    626   // Collects fetches.
    627   Status status = calls.status();
    628   if (status.ok()) {
    629     for (int i = 0; i < num; ++i) {
    630       const Part& part = partitions_[i];
    631       MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
    632       for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
    633         auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
    634         if (iter == part.key_fetch.end()) {
    635           status.Update(errors::Internal("Unexpected fetch key: ",
    636                                          run_graph_resp->recv_key(j)));
    637           break;
    638         }
    639         const string& fetch = iter->second;
    640         status.Update(
    641             resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
    642         if (!status.ok()) {
    643           break;
    644         }
    645       }
    646       if (pss->collect_timeline) {
    647         pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
    648       }
    649       if (pss->collect_costs) {
    650         CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
    651         for (int j = 0; j < cost_graph->node_size(); ++j) {
    652           resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
    653               cost_graph->mutable_node(j));
    654         }
    655       }
    656       if (pss->collect_partition_graphs) {
    657         protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
    658             resp->mutable_metadata()->mutable_partition_graphs();
    659         for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
    660           partition_graph_defs->Add()->Swap(
    661               run_graph_resp->mutable_partition_graph(i));
    662         }
    663       }
    664     }
    665   }
    666   return status;
    667 }
    668 
    669 namespace {
    670 
    671 class CleanupBroadcastHelper {
    672  public:
    673   CleanupBroadcastHelper(int64 step_id, int num_calls, StatusCallback done)
    674       : resps_(num_calls), num_pending_(num_calls), done_(std::move(done)) {
    675     req_.set_step_id(step_id);
    676   }
    677 
    678   // Returns a non-owned pointer to a request buffer for all calls.
    679   CleanupGraphRequest* request() { return &req_; }
    680 
    681   // Returns a non-owned pointer to a response buffer for the ith call.
    682   CleanupGraphResponse* response(int i) { return &resps_[i]; }
    683 
    684   // Called when the ith response is received.
    685   void call_done(int i, const Status& s) {
    686     bool run_callback = false;
    687     Status status_copy;
    688     {
    689       mutex_lock l(mu_);
    690       status_.Update(s);
    691       if (--num_pending_ == 0) {
    692         run_callback = true;
    693         status_copy = status_;
    694       }
    695     }
    696     if (run_callback) {
    697       done_(status_copy);
    698       // This is the last call, so delete the helper object.
    699       delete this;
    700     }
    701   }
    702 
    703  private:
    704   // A single request shared between all workers.
    705   CleanupGraphRequest req_;
    706   // One response buffer for each worker.
    707   gtl::InlinedVector<CleanupGraphResponse, 4> resps_;
    708 
    709   mutex mu_;
    710   // Number of requests remaining to be collected.
    711   int num_pending_ GUARDED_BY(mu_);
    712   // Aggregate status of the operation.
    713   Status status_ GUARDED_BY(mu_);
    714   // Callback to be called when all operations complete.
    715   StatusCallback done_;
    716 
    717   TF_DISALLOW_COPY_AND_ASSIGN(CleanupBroadcastHelper);
    718 };
    719 
    720 }  // namespace
    721 
    722 void MasterSession::ReffedClientGraph::CleanupPartitionsAsync(
    723     int64 step_id, StatusCallback done) {
    724   const int num = partitions_.size();
    725   // Helper object will be deleted when the final call completes.
    726   CleanupBroadcastHelper* helper =
    727       new CleanupBroadcastHelper(step_id, num, std::move(done));
    728   for (int i = 0; i < num; ++i) {
    729     const Part& part = partitions_[i];
    730     part.worker->CleanupGraphAsync(
    731         helper->request(), helper->response(i),
    732         [helper, i](const Status& s) { helper->call_done(i, s); });
    733   }
    734 }
    735 
    736 void MasterSession::ReffedClientGraph::ProcessStats(int64 step_id,
    737                                                     PerStepState* pss,
    738                                                     ProfileHandler* ph,
    739                                                     const RunOptions& options,
    740                                                     RunMetadata* resp) {
    741   if (!pss->collect_costs && !pss->collect_timeline) return;
    742 
    743   // Out-of-band logging data is collected now, during post-processing.
    744   if (pss->collect_timeline) {
    745     SetRPCLogging(false);
    746     RetrieveLogs(step_id, &pss->rpc_stats);
    747   }
    748   for (size_t i = 0; i < partitions_.size(); ++i) {
    749     const StepStats& ss = pss->step_stats[i];
    750     if (ph) {
    751       for (const auto& ds : ss.dev_stats()) {
    752         ProcessDeviceStats(ph, ds, false /*is_rpc*/);
    753       }
    754     }
    755   }
    756   if (ph) {
    757     for (const auto& ds : pss->rpc_stats.dev_stats()) {
    758       ProcessDeviceStats(ph, ds, true /*is_rpc*/);
    759     }
    760     ph->StepDone(pss->start_micros, pss->end_micros,
    761                  Microseconds(0) /*cleanup_time*/, 0 /*total_runops*/,
    762                  Status::OK());
    763   }
    764   // Assemble all stats for this timeline into a merged StepStats.
    765   if (pss->collect_timeline) {
    766     StepStats step_stats_proto;
    767     step_stats_proto.Swap(&pss->rpc_stats);
    768     for (size_t i = 0; i < partitions_.size(); ++i) {
    769       step_stats_proto.MergeFrom(pss->step_stats[i]);
    770       pss->step_stats[i].Clear();
    771     }
    772     pss->step_stats.clear();
    773     // Copy the stats back, but only for on-demand profiling to avoid slowing
    774     // down calls that trigger the automatic profiling.
    775     if (options.trace_level() == RunOptions::FULL_TRACE) {
    776       resp->mutable_step_stats()->Swap(&step_stats_proto);
    777     } else {
    778       // If FULL_TRACE, it can be fetched from Session API, no need for
    779       // duplicated publishing.
    780       stats_publisher_->PublishStatsProto(step_stats_proto);
    781     }
    782   }
    783 }
    784 
    785 void MasterSession::ReffedClientGraph::ProcessDeviceStats(
    786     ProfileHandler* ph, const DeviceStepStats& ds, bool is_rpc) {
    787   const string& dev_name = ds.device();
    788   VLOG(1) << "Device " << dev_name << " reports stats for "
    789           << ds.node_stats_size() << " nodes";
    790   for (const auto& ns : ds.node_stats()) {
    791     if (is_rpc) {
    792       // We don't have access to a good Node pointer, so we rely on
    793       // sufficient data being present in the NodeExecStats.
    794       ph->RecordOneOp(dev_name, ns, true /*is_copy*/, "", ns.node_name(),
    795                       ns.timeline_label());
    796     } else {
    797       const Node* node = name_to_node_[ns.node_name()];
    798       const bool found_node_in_graph = node != nullptr;
    799       if (!found_node_in_graph && ns.timeline_label().empty()) {
    800         // The counter incrementing is not thread-safe. But we don't really
    801         // care.
    802         // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N for
    803         // more general usage.
    804         static int log_counter = 0;
    805         if (log_counter < 10) {
    806           log_counter++;
    807           LOG(WARNING) << "Failed to find node " << ns.node_name()
    808                        << " for dev " << dev_name;
    809         }
    810         continue;
    811       }
    812       string optype =
    813           found_node_in_graph ? node->type_string() : ns.node_name();
    814       string details;
    815       if (!ns.timeline_label().empty()) {
    816         details = ns.timeline_label();
    817       } else if (found_node_in_graph) {
    818         details = DetailText(*node, ns);
    819       } else {
    820         // Leave details string empty
    821       }
    822       ph->RecordOneOp(dev_name, ns, false /*is_copy*/, ns.node_name(), optype,
    823                       details);
    824     }
    825   }
    826 }
    827 
    828 // TODO(suharshs): Merge with CheckFetches in DirectSession.
    829 // TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends
    830 // on once at setup time to prevent us from computing the dependencies
    831 // everytime.
    832 // TODO(suharshs,mrry): Consider removing the need for execution_state to reduce
    833 // contention.
    834 Status MasterSession::ReffedClientGraph::CheckFetches(
    835     const RunStepRequestWrapper& req, const RunState* run_state,
    836     GraphExecutionState* execution_state) {
    837   // Build the set of pending feeds that we haven't seen.
    838   std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
    839   for (const auto& input : run_state->pending_inputs) {
    840     // Skip if already fed.
    841     if (input.second) continue;
    842     TensorId id(ParseTensorName(input.first));
    843     const auto it = name_to_node_.find(id.first);
    844     if (it == name_to_node_.end()) {
    845       return errors::NotFound("Feed ", input.first, ": not found");
    846     }
    847     pending_feeds.insert(id);
    848   }
    849   for (size_t i = 0; i < req.num_feeds(); ++i) {
    850     const TensorId id(ParseTensorName(req.feed_name(i)));
    851     pending_feeds.erase(id);
    852   }
    853 
    854   // Initialize the stack with the fetch nodes.
    855   std::vector<const Node*> stack;
    856   for (size_t i = 0; i < req.num_fetches(); ++i) {
    857     const string& fetch = req.fetch_name(i);
    858     const TensorId id(ParseTensorName(fetch));
    859     auto it = name_to_node_.find(id.first);
    860     if (it == name_to_node_.end()) {
    861       return errors::NotFound("Fetch ", fetch, ": not found");
    862     }
    863     stack.push_back(it->second);
    864   }
    865 
    866   // Any tensor needed for fetches can't be in pending_feeds.
    867   // We need to use the original full graph from execution state.
    868   const Graph* graph = execution_state->full_graph();
    869   std::vector<bool> visited(graph->num_node_ids(), false);
    870   while (!stack.empty()) {
    871     const Node* n = stack.back();
    872     stack.pop_back();
    873 
    874     for (const Edge* in_edge : n->in_edges()) {
    875       const Node* in_node = in_edge->src();
    876       if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
    877         return errors::InvalidArgument("Fetch ", in_node->name(), ":",
    878                                        in_edge->src_output(),
    879                                        " can't be computed from the feeds"
    880                                        " that have been fed so far.");
    881       }
    882       if (!visited[in_node->id()]) {
    883         visited[in_node->id()] = true;
    884         stack.push_back(in_node);
    885       }
    886     }
    887   }
    888   return Status::OK();
    889 }
    890 
    891 // Asynchronously deregisters subgraphs on the workers, without waiting for the
    892 // result.
    893 void MasterSession::ReffedClientGraph::DeregisterPartitions() {
    894   struct Call {
    895     DeregisterGraphRequest req;
    896     DeregisterGraphResponse resp;
    897   };
    898   for (Part& part : partitions_) {
    899     // The graph handle may be empty if we failed during partition registration.
    900     if (!part.graph_handle.empty()) {
    901       Call* c = new Call;
    902       c->req.set_session_handle(session_handle_);
    903       c->req.set_graph_handle(part.graph_handle);
    904       // NOTE(mrry): We must capture `worker_cache_` since `this`
    905       // could be deleted before the callback is called.
    906       WorkerCacheInterface* worker_cache = worker_cache_;
    907       const string name = part.name;
    908       WorkerInterface* w = part.worker;
    909       CHECK_NOTNULL(w);
    910       auto cb = [worker_cache, c, name, w](const Status& s) {
    911         if (!s.ok()) {
    912           // This error is potentially benign, so we don't log at the
    913           // error level.
    914           LOG(INFO) << "DeregisterGraph error: " << s;
    915         }
    916         delete c;
    917         worker_cache->ReleaseWorker(name, w);
    918       };
    919       w->DeregisterGraphAsync(&c->req, &c->resp, cb);
    920     }
    921   }
    922 }
    923 
    924 void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
    925                             BuildGraphOptions* opts) {
    926   for (size_t i = 0; i < req.num_feeds(); ++i) {
    927     opts->feed_endpoints.push_back(req.feed_name(i));
    928   }
    929   for (size_t i = 0; i < req.num_fetches(); ++i) {
    930     opts->fetch_endpoints.push_back(req.fetch_name(i));
    931   }
    932   for (size_t i = 0; i < req.num_targets(); ++i) {
    933     opts->target_nodes.push_back(req.target_name(i));
    934   }
    935 
    936   if (!req.options().debug_options().debug_tensor_watch_opts().empty()) {
    937     opts->debug_options = req.options().debug_options();
    938   }
    939 
    940   std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end());
    941   std::sort(opts->target_nodes.begin(), opts->target_nodes.end());
    942   std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
    943 }
    944 
    945 void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
    946                             BuildGraphOptions* opts) {
    947   for (const auto& feed : req.feed()) {
    948     opts->feed_endpoints.push_back(feed);
    949   }
    950   for (const auto& fetch : req.fetch()) {
    951     opts->fetch_endpoints.push_back(fetch);
    952   }
    953   for (const auto& target : req.target()) {
    954     opts->target_nodes.push_back(target);
    955   }
    956 
    957   // TODO(cais): Add TFDBG support to partial runs.
    958 
    959   std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end());
    960   std::sort(opts->target_nodes.begin(), opts->target_nodes.end());
    961   std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
    962 }
    963 
    964 uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
    965   uint64 h = 0x2b992ddfa23249d6ull;
    966   for (const string& name : opts.feed_endpoints) {
    967     h = Hash64(name.c_str(), name.size(), h);
    968   }
    969   for (const string& name : opts.target_nodes) {
    970     h = Hash64(name.c_str(), name.size(), h);
    971   }
    972   for (const string& name : opts.fetch_endpoints) {
    973     h = Hash64(name.c_str(), name.size(), h);
    974   }
    975 
    976   if (!opts.debug_options.debug_tensor_watch_opts().empty()) {
    977     const string watch_summary = SummarizeDebugTensorWatches(
    978         opts.debug_options.debug_tensor_watch_opts());
    979     h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
    980   }
    981 
    982   return h;
    983 }
    984 
    985 string BuildGraphOptionsString(const BuildGraphOptions& opts) {
    986   string buf;
    987   for (const string& name : opts.feed_endpoints) {
    988     strings::StrAppend(&buf, " FdE: ", name);
    989   }
    990   strings::StrAppend(&buf, "\n");
    991   for (const string& name : opts.target_nodes) {
    992     strings::StrAppend(&buf, " TN: ", name);
    993   }
    994   strings::StrAppend(&buf, "\n");
    995   for (const string& name : opts.fetch_endpoints) {
    996     strings::StrAppend(&buf, " FeE: ", name);
    997   }
    998   strings::StrAppend(&buf, "\n");
    999   return buf;
   1000 }
   1001 
   1002 MasterSession::MasterSession(
   1003     const SessionOptions& opt, const MasterEnv* env,
   1004     std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
   1005     std::unique_ptr<WorkerCacheInterface> worker_cache,
   1006     std::unique_ptr<DeviceSet> device_set,
   1007     StatsPublisherFactory stats_publisher_factory)
   1008     : session_opts_(opt),
   1009       env_(env),
   1010       handle_(strings::FpToString(random::New64())),
   1011       remote_devs_(std::move(remote_devs)),
   1012       worker_cache_(std::move(worker_cache)),
   1013       devices_(std::move(device_set)),
   1014       stats_publisher_factory_(std::move(stats_publisher_factory)),
   1015       graph_version_(0),
   1016       run_graphs_(5),
   1017       partial_run_graphs_(5) {
   1018   UpdateLastAccessTime();
   1019   CHECK(devices_) << "device_set was null!";
   1020 
   1021   VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
   1022           << " #remote " << remote_devs_->size();
   1023 
   1024   LOG(INFO) << "Start master session " << handle_
   1025             << " with config: " << session_opts_.config.ShortDebugString();
   1026 }
   1027 
   1028 MasterSession::~MasterSession() {
   1029   for (const auto& iter : run_graphs_) iter.second->Unref();
   1030   for (const auto& iter : partial_run_graphs_) iter.second->Unref();
   1031 }
   1032 
   1033 void MasterSession::UpdateLastAccessTime() {
   1034   last_access_time_usec_.store(Env::Default()->NowMicros());
   1035 }
   1036 
   1037 Status MasterSession::Create(GraphDef* graph_def,
   1038                              const WorkerCacheFactoryOptions& options) {
   1039   if (session_opts_.config.use_per_session_threads() ||
   1040       session_opts_.config.session_inter_op_thread_pool_size() > 0) {
   1041     return errors::InvalidArgument(
   1042         "Distributed session does not support session thread pool options.");
   1043   }
   1044   if (session_opts_.config.graph_options().place_pruned_graph()) {
   1045     // TODO(b/29900832): Fix this or remove the option.
   1046     LOG(WARNING) << "Distributed session does not support the "
   1047                     "place_pruned_graph option.";
   1048     session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
   1049   }
   1050 
   1051   GraphExecutionStateOptions execution_options;
   1052   execution_options.device_set = devices_.get();
   1053   execution_options.session_options = &session_opts_;
   1054   {
   1055     mutex_lock l(mu_);
   1056     TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
   1057         graph_def, execution_options, &execution_state_));
   1058   }
   1059   // TODO(b/36574172): Remove these conditions when ClusterSpec
   1060   // propagation is supported in all servers.
   1061   if (options.cluster_def != nullptr ||
   1062       session_opts_.config.isolate_session_state()) {
   1063     should_delete_worker_sessions_ = true;
   1064     return CreateWorkerSessions(options);
   1065   }
   1066   return Status::OK();
   1067 }
   1068 
   1069 Status MasterSession::CreateWorkerSessions(
   1070     const WorkerCacheFactoryOptions& options) {
   1071   std::vector<string> worker_names;
   1072   WorkerCacheInterface* worker_cache = get_worker_cache();
   1073   worker_cache->ListWorkers(&worker_names);
   1074 
   1075   struct WorkerGroup {
   1076     // The worker name. (Not owned.)
   1077     const string* name;
   1078 
   1079     // The worker referenced by name. (Not owned.)
   1080     WorkerInterface* worker = nullptr;
   1081 
   1082     // Request and responses used for a given worker.
   1083     CreateWorkerSessionRequest request;
   1084     CreateWorkerSessionResponse response;
   1085     Status status = Status::OK();
   1086   };
   1087   BlockingCounter done(worker_names.size());
   1088   std::vector<WorkerGroup> workers(worker_names.size());
   1089 
   1090   // Release the workers.
   1091   auto cleanup = gtl::MakeCleanup([this, &workers, worker_cache] {
   1092     for (auto&& worker_group : workers) {
   1093       if (worker_group.worker != nullptr) {
   1094         worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
   1095       }
   1096     }
   1097   });
   1098 
   1099   Status status = Status::OK();
   1100   // Create all the workers & kick off the computations.
   1101   for (size_t i = 0; i < worker_names.size(); ++i) {
   1102     workers[i].name = &worker_names[i];
   1103     workers[i].worker = worker_cache->CreateWorker(worker_names[i]);
   1104     workers[i].request.set_session_handle(handle_);
   1105     if (options.cluster_def) {
   1106       *workers[i].request.mutable_server_def()->mutable_cluster() =
   1107           *options.cluster_def;
   1108       workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
   1109       // Session state is always isolated when ClusterSpec propagation
   1110       // is in use.
   1111       workers[i].request.set_isolate_session_state(true);
   1112     } else {
   1113       workers[i].request.set_isolate_session_state(
   1114           session_opts_.config.isolate_session_state());
   1115     }
   1116 
   1117     DeviceNameUtils::ParsedName name;
   1118     if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
   1119       status = errors::Internal("Could not parse name ", worker_names[i]);
   1120       LOG(WARNING) << status;
   1121       return status;
   1122     }
   1123     if (!name.has_job || !name.has_task) {
   1124       status = errors::Internal("Incomplete worker name ", worker_names[i]);
   1125       LOG(WARNING) << status;
   1126       return status;
   1127     }
   1128 
   1129     workers[i].request.mutable_server_def()->set_job_name(name.job);
   1130     workers[i].request.mutable_server_def()->set_task_index(name.task);
   1131   }
   1132 
   1133   for (size_t i = 0; i < worker_names.size(); ++i) {
   1134     auto cb = [i, &workers, &done](const Status& s) {
   1135       workers[i].status = s;
   1136       done.DecrementCount();
   1137     };
   1138     workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
   1139                                                 &workers[i].response, cb);
   1140   }
   1141 
   1142   done.Wait();
   1143   for (size_t i = 0; i < workers.size(); ++i) {
   1144     status.Update(workers[i].status);
   1145   }
   1146   return status;
   1147 }
   1148 
   1149 Status MasterSession::DeleteWorkerSessions() {
   1150   WorkerCacheInterface* worker_cache = get_worker_cache();
   1151   std::vector<string> worker_names;
   1152   worker_cache->ListWorkers(&worker_names);
   1153 
   1154   struct WorkerGroup {
   1155     // The worker name. (Not owned.)
   1156     const string* name;
   1157 
   1158     // The worker referenced by name. (Not owned.)
   1159     WorkerInterface* worker = nullptr;
   1160 
   1161     // Request and responses used for a given worker.
   1162     DeleteWorkerSessionRequest request;
   1163     DeleteWorkerSessionResponse response;
   1164     Status status = Status::OK();
   1165   };
   1166   BlockingCounter done(worker_names.size());
   1167   std::vector<WorkerGroup> workers(worker_names.size());
   1168 
   1169   // Release the workers.
   1170   auto cleanup = gtl::MakeCleanup([this, &workers, worker_cache] {
   1171     for (auto&& worker_group : workers) {
   1172       if (worker_group.worker != nullptr) {
   1173         worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
   1174       }
   1175     }
   1176   });
   1177 
   1178   Status status = Status::OK();
   1179   // Create all the workers & kick off the computations.
   1180   for (size_t i = 0; i < worker_names.size(); ++i) {
   1181     workers[i].name = &worker_names[i];
   1182     workers[i].worker = worker_cache->CreateWorker(worker_names[i]);
   1183     workers[i].request.set_session_handle(handle_);
   1184   }
   1185 
   1186   for (size_t i = 0; i < worker_names.size(); ++i) {
   1187     auto cb = [i, &workers, &done](const Status& s) {
   1188       workers[i].status = s;
   1189       done.DecrementCount();
   1190     };
   1191     workers[i].worker->DeleteWorkerSessionAsync(&workers[i].request,
   1192                                                 &workers[i].response, cb);
   1193   }
   1194 
   1195   done.Wait();
   1196   for (size_t i = 0; i < workers.size(); ++i) {
   1197     status.Update(workers[i].status);
   1198   }
   1199   return status;
   1200 }
   1201 
   1202 Status MasterSession::ListDevices(ListDevicesResponse* resp) const {
   1203   if (worker_cache_) {
   1204     // This is a ClusterSpec-propagated session, and thus env_->local_devices
   1205     // are invalid.
   1206 
   1207     // Mark the "client_device" as the sole local device.
   1208     const Device* client_device = devices_->client_device();
   1209     for (const Device* dev : devices_->devices()) {
   1210       if (dev != client_device) {
   1211         *(resp->add_remote_device()) = dev->attributes();
   1212       }
   1213     }
   1214     *(resp->add_local_device()) = client_device->attributes();
   1215   } else {
   1216     for (Device* dev : env_->local_devices) {
   1217       *(resp->add_local_device()) = dev->attributes();
   1218     }
   1219     for (auto&& dev : *remote_devs_) {
   1220       *(resp->add_local_device()) = dev->attributes();
   1221     }
   1222   }
   1223   return Status::OK();
   1224 }
   1225 
   1226 Status MasterSession::Extend(const ExtendSessionRequest* req,
   1227                              ExtendSessionResponse* resp) {
   1228   UpdateLastAccessTime();
   1229   std::unique_ptr<GraphExecutionState> extended_execution_state;
   1230   {
   1231     mutex_lock l(mu_);
   1232     if (closed_) {
   1233       return errors::FailedPrecondition("Session is closed.");
   1234     }
   1235 
   1236     if (graph_version_ != req->current_graph_version()) {
   1237       return errors::Aborted("Current version is ", graph_version_,
   1238                              " but caller expected ",
   1239                              req->current_graph_version(), ".");
   1240     }
   1241 
   1242     CHECK(execution_state_);
   1243     TF_RETURN_IF_ERROR(
   1244         execution_state_->Extend(req->graph_def(), &extended_execution_state));
   1245 
   1246     CHECK(extended_execution_state);
   1247     // The old execution state will be released outside the lock.
   1248     execution_state_.swap(extended_execution_state);
   1249     ++graph_version_;
   1250     resp->set_new_graph_version(graph_version_);
   1251   }
   1252   return Status::OK();
   1253 }
   1254 
   1255 WorkerCacheInterface* MasterSession::get_worker_cache() const {
   1256   if (worker_cache_) {
   1257     return worker_cache_.get();
   1258   }
   1259   return env_->worker_cache;
   1260 }
   1261 
   1262 Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
   1263                                 ReffedClientGraph** rcg, bool is_partial) {
   1264   const uint64 hash = HashBuildGraphOptions(opts);
   1265   {
   1266     mutex_lock l(mu_);
   1267     // Keep track of how many times this subgraph has been executed in
   1268     // this session.
   1269     int64* c = &subgraph_execution_counts_[hash];
   1270     *count = (*c)++;
   1271     // TODO(suharshs): We cache partial run graphs and run graphs separately
   1272     // because there is preprocessing that needs to only be run for partial
   1273     // run calls.
   1274     RCGMap* m = is_partial ? &partial_run_graphs_ : &run_graphs_;
   1275     auto iter = m->find(hash);
   1276     if (iter == m->end()) {
   1277       // We have not seen this subgraph before. Build the subgraph and
   1278       // cache it.
   1279       VLOG(1) << "Unseen hash " << hash << " for "
   1280               << BuildGraphOptionsString(opts) << " is_partial = " << is_partial
   1281               << "\n";
   1282       std::unique_ptr<ClientGraph> client_graph;
   1283       TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
   1284       WorkerCacheInterface* worker_cache = get_worker_cache();
   1285       auto entry = new ReffedClientGraph(
   1286           handle_, opts, std::move(client_graph), session_opts_,
   1287           stats_publisher_factory_, execution_state_.get(), is_partial,
   1288           worker_cache, !should_delete_worker_sessions_);
   1289       iter = m->insert({hash, entry}).first;
   1290       VLOG(1) << "Preparing to execute new graph";
   1291     }
   1292     *rcg = iter->second;
   1293     (*rcg)->Ref();
   1294   }
   1295   return Status::OK();
   1296 }
   1297 
   1298 void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
   1299                                    RCGMap* rcg_map) {
   1300   VLOG(1) << "Discarding all reffed graphs";
   1301   for (auto p : *rcg_map) {
   1302     ReffedClientGraph* rcg = p.second;
   1303     if (to_unref) {
   1304       to_unref->push_back(rcg);
   1305     } else {
   1306       rcg->Unref();
   1307     }
   1308   }
   1309   rcg_map->clear();
   1310 }
   1311 
   1312 Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
   1313                                       PartialRunSetupResponse* resp) {
   1314   std::vector<string> inputs, outputs, targets;
   1315   for (const auto& feed : req->feed()) {
   1316     inputs.push_back(feed);
   1317   }
   1318   for (const auto& fetch : req->fetch()) {
   1319     outputs.push_back(fetch);
   1320   }
   1321   for (const auto& target : req->target()) {
   1322     targets.push_back(target);
   1323   }
   1324 
   1325   string handle = std::to_string(partial_run_handle_counter_.fetch_add(1));
   1326 
   1327   ReffedClientGraph* rcg = nullptr;
   1328   int64 count = 0;
   1329 
   1330   // Prepare.
   1331   BuildGraphOptions opts;
   1332   BuildBuildGraphOptions(*req, &opts);
   1333   TF_RETURN_IF_ERROR(StartStep(opts, &count, &rcg, true));
   1334   // Keeps the highest 8 bits 0x01: we reserve some bits of the
   1335   // step_id for future use.
   1336   uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
   1337   TRACEPRINTF("stepid %llu", step_id);
   1338 
   1339   rcg->Ref();
   1340   RunState* run_state = new RunState(inputs, outputs, rcg, step_id, count);
   1341   {
   1342     mutex_lock l(mu_);
   1343     partial_runs_.emplace(
   1344         std::make_pair(handle, std::unique_ptr<RunState>(run_state)));
   1345   }
   1346 
   1347   TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
   1348 
   1349   resp->set_partial_run_handle(handle);
   1350   return Status::OK();
   1351 }
   1352 
   1353 Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
   1354                           MutableRunStepResponseWrapper* resp) {
   1355   UpdateLastAccessTime();
   1356   {
   1357     mutex_lock l(mu_);
   1358     if (closed_) {
   1359       return errors::FailedPrecondition("Session is closed.");
   1360     }
   1361     ++num_running_;
   1362     // Note: all code paths must eventually call MarkRunCompletion()
   1363     // in order to appropriate decrement the num_running_ counter.
   1364   }
   1365   Status status;
   1366   if (!req.partial_run_handle().empty()) {
   1367     status = DoPartialRun(opts, req, resp);
   1368   } else {
   1369     status = DoRunWithLocalExecution(opts, req, resp);
   1370   }
   1371   return status;
   1372 }
   1373 
   1374 // Decrements num_running_ and broadcasts if num_running_ is zero.
   1375 void MasterSession::MarkRunCompletion() {
   1376   mutex_lock l(mu_);
   1377   --num_running_;
   1378   if (num_running_ == 0) {
   1379     num_running_is_zero_.notify_all();
   1380   }
   1381 }
   1382 
   1383 Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
   1384   // Registers subgraphs if haven't done so.
   1385   PartitionOptions popts;
   1386   popts.node_to_loc = SplitByWorker;
   1387   // The closures potps.{new_name,get_incarnation} are called synchronously in
   1388   // RegisterPartitions() below, so do not need a Ref()/Unref() pair to keep
   1389   // "this" alive during the closure.
   1390   popts.new_name = [this](const string& prefix) {
   1391     mutex_lock l(mu_);
   1392     return strings::StrCat(prefix, "_S", next_node_id_++);
   1393   };
   1394   popts.flib_def = rcg->client_graph()->flib_def.get();
   1395   popts.get_incarnation = [this](const string& name) -> int64 {
   1396     Device* d = devices_->FindDeviceByName(name);
   1397     if (d == nullptr) {
   1398       return PartitionOptions::kIllegalIncarnation;
   1399     } else {
   1400       return d->attributes().incarnation();
   1401     }
   1402   };
   1403   popts.control_flow_added = false;
   1404   const bool enable_bfloat16_sendrecv =
   1405       session_opts_.config.graph_options().enable_bfloat16_sendrecv();
   1406   popts.should_cast = [enable_bfloat16_sendrecv](const Edge* e) {
   1407     if (e->IsControlEdge()) {
   1408       return DT_FLOAT;
   1409     }
   1410     DataType dtype = BaseType(e->src()->output_type(e->src_output()));
   1411     if (enable_bfloat16_sendrecv && dtype == DT_FLOAT) {
   1412       return DT_BFLOAT16;
   1413     } else {
   1414       return dtype;
   1415     }
   1416   };
   1417   if (session_opts_.config.graph_options().enable_recv_scheduling()) {
   1418     popts.scheduling_for_recvs = true;
   1419     popts.need_to_record_start_times = true;
   1420   }
   1421 
   1422   TF_RETURN_IF_ERROR(rcg->RegisterPartitions(popts));
   1423 
   1424   return Status::OK();
   1425 }
   1426 
   1427 Status MasterSession::DoPartialRun(CallOptions* opts,
   1428                                    const RunStepRequestWrapper& req,
   1429                                    MutableRunStepResponseWrapper* resp) {
   1430   auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
   1431   const string& prun_handle = req.partial_run_handle();
   1432   RunState* run_state = nullptr;
   1433   {
   1434     mutex_lock l(mu_);
   1435     auto it = partial_runs_.find(prun_handle);
   1436     if (it == partial_runs_.end()) {
   1437       return errors::InvalidArgument(
   1438           "Must run PartialRunSetup before performing partial runs");
   1439     }
   1440     run_state = it->second.get();
   1441   }
   1442 
   1443   // If this is the first partial run, initialize the PerStepState.
   1444   if (!run_state->step_started) {
   1445     run_state->step_started = true;
   1446     PerStepState pss;
   1447 
   1448     const auto count = run_state->count;
   1449     pss.collect_timeline =
   1450         req.options().trace_level() == RunOptions::FULL_TRACE;
   1451     pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE;
   1452     pss.report_tensor_allocations_upon_oom =
   1453         req.options().report_tensor_allocations_upon_oom();
   1454 
   1455     // Build the cost model every 'build_cost_model_every' steps after skipping
   1456     // an
   1457     // initial 'build_cost_model_after' steps.
   1458     const int64 build_cost_model_after =
   1459         session_opts_.config.graph_options().build_cost_model_after();
   1460     const int64 build_cost_model_every =
   1461         session_opts_.config.graph_options().build_cost_model();
   1462     pss.collect_costs =
   1463         build_cost_model_every > 0 &&
   1464         ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
   1465     pss.collect_partition_graphs = req.options().output_partition_graphs();
   1466 
   1467     std::unique_ptr<ProfileHandler> ph = run_state->rcg->GetProfileHandler(
   1468         run_state->step_id, count, req.options());
   1469     if (ph) {
   1470       pss.collect_timeline = true;
   1471       pss.collect_rpcs = ph->should_collect_rpcs();
   1472     }
   1473 
   1474     run_state->pss = std::move(pss);
   1475     run_state->ph = std::move(ph);
   1476   }
   1477 
   1478   // Make sure that this is a new set of feeds that are still pending.
   1479   for (size_t i = 0; i < req.num_feeds(); ++i) {
   1480     const string& feed = req.feed_name(i);
   1481     auto it = run_state->pending_inputs.find(feed);
   1482     if (it == run_state->pending_inputs.end()) {
   1483       return errors::InvalidArgument(
   1484           "The feed ", feed, " was not specified in partial_run_setup.");
   1485     } else if (it->second) {
   1486       return errors::InvalidArgument("The feed ", feed,
   1487                                      " has already been fed.");
   1488     }
   1489   }
   1490   // Check that this is a new set of fetches that are still pending.
   1491   for (size_t i = 0; i < req.num_fetches(); ++i) {
   1492     const string& fetch = req.fetch_name(i);
   1493     auto it = run_state->pending_outputs.find(fetch);
   1494     if (it == run_state->pending_outputs.end()) {
   1495       return errors::InvalidArgument(
   1496           "The fetch ", fetch, " was not specified in partial_run_setup.");
   1497     } else if (it->second) {
   1498       return errors::InvalidArgument("The fetch ", fetch,
   1499                                      " has already been fetched.");
   1500     }
   1501   }
   1502 
   1503   // Ensure that the requested fetches can be computed from the provided feeds.
   1504   {
   1505     mutex_lock l(mu_);
   1506     TF_RETURN_IF_ERROR(
   1507         run_state->rcg->CheckFetches(req, run_state, execution_state_.get()));
   1508   }
   1509 
   1510   // Determine if this partial run satisfies all the pending inputs and outputs.
   1511   for (size_t i = 0; i < req.num_feeds(); ++i) {
   1512     auto it = run_state->pending_inputs.find(req.feed_name(i));
   1513     it->second = true;
   1514   }
   1515   for (size_t i = 0; i < req.num_fetches(); ++i) {
   1516     auto it = run_state->pending_outputs.find(req.fetch_name(i));
   1517     it->second = true;
   1518   }
   1519   bool is_last_partial_run = run_state->PendingDone();
   1520 
   1521   Status s = run_state->rcg->RunPartitions(
   1522       env_, run_state->step_id, run_state->count, &run_state->pss, opts, req,
   1523       resp, &cancellation_manager_, is_last_partial_run);
   1524 
   1525   // Delete the run state if there is an error or all fetches are done.
   1526   if (!s.ok() || is_last_partial_run) {
   1527     ReffedClientGraph* rcg = run_state->rcg;
   1528     run_state->pss.end_micros = Env::Default()->NowMicros();
   1529     // Schedule post-processing and cleanup to be done asynchronously.
   1530     Ref();
   1531     rcg->Ref();
   1532     rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(),
   1533                       req.options(), resp->mutable_metadata());
   1534     cleanup.release();  // MarkRunCompletion called in done closure.
   1535     rcg->CleanupPartitionsAsync(
   1536         run_state->step_id, [this, rcg, prun_handle](const Status& s) {
   1537           if (!s.ok()) {
   1538             LOG(ERROR) << "Cleanup partition error: " << s;
   1539           }
   1540           rcg->Unref();
   1541           MarkRunCompletion();
   1542           Unref();
   1543         });
   1544     mutex_lock l(mu_);
   1545     partial_runs_.erase(prun_handle);
   1546   }
   1547   return s;
   1548 }
   1549 
   1550 Status MasterSession::CreateDebuggerState(
   1551     const DebugOptions& debug_options, const RunStepRequestWrapper& req,
   1552     int64 rcg_execution_count,
   1553     std::unique_ptr<DebuggerStateInterface>* debugger_state) {
   1554   TF_RETURN_IF_ERROR(
   1555       DebuggerStateRegistry::CreateState(debug_options, debugger_state));
   1556 
   1557   std::vector<string> input_names;
   1558   for (size_t i = 0; i < req.num_feeds(); ++i) {
   1559     input_names.push_back(req.feed_name(i));
   1560   }
   1561   std::vector<string> output_names;
   1562   for (size_t i = 0; i < req.num_fetches(); ++i) {
   1563     output_names.push_back(req.fetch_name(i));
   1564   }
   1565   std::vector<string> target_names;
   1566   for (size_t i = 0; i < req.num_targets(); ++i) {
   1567     target_names.push_back(req.target_name(i));
   1568   }
   1569 
   1570   // TODO(cais): We currently use -1 as a dummy value for session run count.
   1571   // While this counter value is straightforward to define and obtain for
   1572   // DirectSessions, it is less so for non-direct Sessions. Devise a better
   1573   // way to get its value when the need arises.
   1574   TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
   1575       debug_options.global_step(), rcg_execution_count, rcg_execution_count,
   1576       input_names, output_names, target_names));
   1577 
   1578   return Status::OK();
   1579 }
   1580 
   1581 Status MasterSession::DoRunWithLocalExecution(
   1582     CallOptions* opts, const RunStepRequestWrapper& req,
   1583     MutableRunStepResponseWrapper* resp) {
   1584   VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString();
   1585   PerStepState pss;
   1586   pss.start_micros = Env::Default()->NowMicros();
   1587   auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
   1588 
   1589   // Prepare.
   1590   BuildGraphOptions bgopts;
   1591   BuildBuildGraphOptions(req, &bgopts);
   1592   ReffedClientGraph* rcg = nullptr;
   1593   int64 count = 0;
   1594   TF_RETURN_IF_ERROR(StartStep(bgopts, &count, &rcg, false));
   1595 
   1596   // Unref "rcg" when out of scope.
   1597   core::ScopedUnref unref(rcg);
   1598 
   1599   std::unique_ptr<DebuggerStateInterface> debugger_state;
   1600   const DebugOptions& debug_options = req.options().debug_options();
   1601 
   1602   if (!debug_options.debug_tensor_watch_opts().empty()) {
   1603     TF_RETURN_IF_ERROR(
   1604         CreateDebuggerState(debug_options, req, count, &debugger_state));
   1605   }
   1606   TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
   1607 
   1608   // Keeps the highest 8 bits 0x01: we reserve some bits of the
   1609   // step_id for future use.
   1610   const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
   1611   TRACEPRINTF("stepid %llu", step_id);
   1612 
   1613   pss.collect_timeline = req.options().trace_level() == RunOptions::FULL_TRACE;
   1614   pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE;
   1615   pss.report_tensor_allocations_upon_oom =
   1616       req.options().report_tensor_allocations_upon_oom();
   1617   // Build the cost model every 'build_cost_model_every' steps after skipping an
   1618   // initial 'build_cost_model_after' steps.
   1619   const int64 build_cost_model_after =
   1620       session_opts_.config.graph_options().build_cost_model_after();
   1621   const int64 build_cost_model_every =
   1622       session_opts_.config.graph_options().build_cost_model();
   1623   pss.collect_costs =
   1624       build_cost_model_every > 0 &&
   1625       ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
   1626   pss.collect_partition_graphs = req.options().output_partition_graphs();
   1627 
   1628   std::unique_ptr<ProfileHandler> ph =
   1629       rcg->GetProfileHandler(step_id, count, req.options());
   1630   if (ph) {
   1631     pss.collect_timeline = true;
   1632     pss.collect_rpcs = ph->should_collect_rpcs();
   1633   }
   1634 
   1635   Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
   1636                                 &cancellation_manager_, false);
   1637   if (s.ok()) {
   1638     pss.end_micros = Env::Default()->NowMicros();
   1639 
   1640     // Schedule post-processing and cleanup to be done asynchronously.
   1641     rcg->ProcessStats(step_id, &pss, ph.get(), req.options(),
   1642                       resp->mutable_metadata());
   1643   } else if (errors::IsCancelled(s)) {
   1644     mutex_lock l(mu_);
   1645     if (closed_) {
   1646       if (garbage_collected_) {
   1647         s = errors::Cancelled(
   1648             "Step was cancelled because the session was garbage collected due "
   1649             "to inactivity.");
   1650       } else {
   1651         s = errors::Cancelled(
   1652             "Step was cancelled by an explicit call to `Session::Close()`.");
   1653       }
   1654     }
   1655   }
   1656   Ref();
   1657   rcg->Ref();
   1658   cleanup.release();  // MarkRunCompletion called in done closure.
   1659   rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
   1660     if (!s.ok()) {
   1661       LOG(ERROR) << "Cleanup partition error: " << s;
   1662     }
   1663     rcg->Unref();
   1664     MarkRunCompletion();
   1665     Unref();
   1666   });
   1667   return s;
   1668 }
   1669 
   1670 Status MasterSession::Close() {
   1671   {
   1672     mutex_lock l(mu_);
   1673     closed_ = true;  // All subsequent calls to Run() or Extend() will fail.
   1674   }
   1675   cancellation_manager_.StartCancel();
   1676   std::vector<ReffedClientGraph*> to_unref;
   1677   {
   1678     mutex_lock l(mu_);
   1679     while (num_running_ != 0) {
   1680       num_running_is_zero_.wait(l);
   1681     }
   1682     ClearRunsTable(&to_unref, &run_graphs_);
   1683     ClearRunsTable(&to_unref, &partial_run_graphs_);
   1684   }
   1685   for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
   1686   if (should_delete_worker_sessions_) {
   1687     Status s = DeleteWorkerSessions();
   1688     if (!s.ok()) {
   1689       LOG(WARNING) << s;
   1690     }
   1691   }
   1692   return Status::OK();
   1693 }
   1694 
   1695 void MasterSession::GarbageCollect() {
   1696   {
   1697     mutex_lock l(mu_);
   1698     closed_ = true;
   1699     garbage_collected_ = true;
   1700   }
   1701   cancellation_manager_.StartCancel();
   1702   Unref();
   1703 }
   1704 
   1705 MasterSession::RunState::RunState(const std::vector<string>& input_names,
   1706                                   const std::vector<string>& output_names,
   1707                                   ReffedClientGraph* rcg, const uint64 step_id,
   1708                                   const int64 count)
   1709     : rcg(rcg), step_id(step_id), count(count) {
   1710   // Initially all the feeds and fetches are pending.
   1711   for (auto& name : input_names) {
   1712     pending_inputs[name] = false;
   1713   }
   1714   for (auto& name : output_names) {
   1715     pending_outputs[name] = false;
   1716   }
   1717 }
   1718 
   1719 MasterSession::RunState::~RunState() {
   1720   if (rcg) rcg->Unref();
   1721 }
   1722 
   1723 bool MasterSession::RunState::PendingDone() const {
   1724   for (const auto& it : pending_inputs) {
   1725     if (!it.second) return false;
   1726   }
   1727   for (const auto& it : pending_outputs) {
   1728     if (!it.second) return false;
   1729   }
   1730   return true;
   1731 }
   1732 
   1733 }  // end namespace tensorflow
   1734