Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/worker.h"
     17 
     18 #include "tensorflow/core/common_runtime/device_mgr.h"
     19 #include "tensorflow/core/common_runtime/process_util.h"
     20 #include "tensorflow/core/common_runtime/step_stats_collector.h"
     21 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
     22 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
     23 #include "tensorflow/core/distributed_runtime/worker_session.h"
     24 #include "tensorflow/core/platform/tracing.h"
     25 
     26 namespace tensorflow {
     27 
     28 Worker::Worker(WorkerEnv* env)
     29     : env_(env), cancellation_manager_(new CancellationManager) {}
     30 
     31 void Worker::GetStatusAsync(const GetStatusRequest* request,
     32                             GetStatusResponse* response, StatusCallback done) {
     33   DeviceMgr* dm = env_->device_mgr;
     34   std::vector<DeviceAttributes> devices;
     35   dm->ListDeviceAttributes(&devices);
     36   response->mutable_device_attributes()->Reserve(devices.size());
     37   for (auto& d : devices) {
     38     response->add_device_attributes()->Swap(&d);
     39   }
     40   done(Status::OK());
     41 }
     42 
     43 void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
     44                                       CreateWorkerSessionResponse* response,
     45                                       StatusCallback done) {
     46   Status s = env_->session_mgr->CreateSession(request->session_handle(),
     47                                               request->server_def(),
     48                                               request->isolate_session_state());
     49   done(s);
     50 }
     51 
     52 void Worker::DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request,
     53                                       DeleteWorkerSessionResponse* response,
     54                                       StatusCallback done) {
     55   Status s = env_->session_mgr->DeleteSession(request->session_handle());
     56   done(s);
     57 }
     58 
     59 void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
     60                                 RegisterGraphResponse* response,
     61                                 StatusCallback done) {
     62   auto session =
     63       env_->session_mgr->WorkerSessionForSession(request->session_handle());
     64   Status s = session->graph_mgr->Register(
     65       request->session_handle(), request->graph_def(), request->graph_options(),
     66       request->debug_options(), session->cluster_flr.get(),
     67       response->mutable_graph_handle());
     68   done(s);
     69 }
     70 
     71 void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
     72                                   DeregisterGraphResponse* response,
     73                                   StatusCallback done) {
     74   auto session =
     75       env_->session_mgr->WorkerSessionForSession(request->session_handle());
     76   Status s = session->graph_mgr->Deregister(request->graph_handle());
     77 
     78   done(s);
     79 }
     80 
     81 void Worker::AbortStep(int64 step_id) {
     82   Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
     83   SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
     84     // Delay a bit before aborting the step. This way, the root
     85     // cause may return first back to the client instead of this
     86     // cancellation generated abort error.
     87     rendez->StartAbort(errors::Aborted("Step ", step_id));
     88     rendez->Unref();
     89   });
     90 }
     91 
     92 Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req,
     93                                GraphMgr::NamedTensors* in,
     94                                GraphMgr::NamedTensors* out) {
     95   static Tensor empty_tensor(DT_FLOAT);
     96   if (req->num_sends() > 0) {
     97     Tensor val;
     98     for (size_t i = 0; i < req->num_sends(); ++i) {
     99       TF_RETURN_IF_ERROR(req->SendValue(i, &val));
    100       in->insert({req->send_key(i), val});
    101     }
    102   }
    103   for (size_t i = 0; i < req->num_recvs(); ++i) {
    104     out->insert({req->recv_key(i), empty_tensor});
    105   }
    106   return Status::OK();
    107 }
    108 
    109 void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
    110                            MutableRunGraphResponseWrapper* response,
    111                            StatusCallback done) {
    112   if (request->store_errors_in_response_body()) {
    113     done = [response, done](const Status& status) {
    114       response->set_status(status);
    115       done(Status::OK());
    116     };
    117   }
    118   if (request->is_partial()) {
    119     DoPartialRunGraph(opts, request, response, std::move(done));
    120   } else {
    121     DoRunGraph(opts, request, response, std::move(done));
    122   }
    123 }
    124 
    125 MutableRunGraphRequestWrapper* Worker::CreateRunGraphRequest() {
    126   return new InMemoryRunGraphRequest;
    127 }
    128 
    129 MutableRunGraphResponseWrapper* Worker::CreateRunGraphResponse() {
    130   return new InMemoryRunGraphResponse;
    131 }
    132 
    133 void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
    134                         MutableRunGraphResponseWrapper* response,
    135                         StatusCallback done) {
    136   const int64 step_id = request->step_id();
    137   TRACEPRINTF("RunGraph: %lld", step_id);
    138   auto session =
    139       env_->session_mgr->WorkerSessionForSession(request->session_handle());
    140   GraphMgr::NamedTensors in;
    141   GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
    142   Status s = PrepareRunGraph(request, &in, out);
    143   if (!s.ok()) {
    144     delete out;
    145     done(s);
    146     return;
    147   }
    148   StepStatsCollector* collector = nullptr;
    149   if (request->exec_opts().report_tensor_allocations_upon_oom() ||
    150       request->exec_opts().record_timeline() ||
    151       request->exec_opts().record_costs()) {
    152     collector = new StepStatsCollector(response->mutable_step_stats());
    153     // TODO(mrry,pbar): GPU tracing for distributed steps.
    154   }
    155   CancellationManager* cm = new CancellationManager;
    156   opts->SetCancelCallback([this, cm, step_id]() {
    157     cm->StartCancel();
    158     AbortStep(step_id);
    159   });
    160   CancellationToken token;
    161   {
    162     mutex_lock l(mu_);
    163     token = cancellation_manager_->get_cancellation_token();
    164     bool already_cancelled = !cancellation_manager_->RegisterCallback(
    165         token, [cm]() { cm->StartCancel(); });
    166     if (already_cancelled) {
    167       opts->ClearCancelCallback();
    168       delete cm;
    169       delete collector;
    170       delete out;
    171       done(errors::Aborted("Call was aborted"));
    172       return;
    173     }
    174   }
    175   session->graph_mgr->ExecuteAsync(
    176       request->graph_handle(), step_id, session.get(), request->exec_opts(),
    177       collector, response, cm, in,
    178       [this, step_id, response, session, cm, out, token, collector, opts,
    179        done](Status s) {
    180         if (s.ok()) {
    181           s = session->graph_mgr->RecvOutputs(step_id, out);
    182         }
    183         opts->ClearCancelCallback();
    184         {
    185           mutex_lock l(mu_);
    186           cancellation_manager_->DeregisterCallback(token);
    187         }
    188         delete cm;
    189 
    190         if (s.ok()) {
    191           for (const auto& p : *out) {
    192             const string& key = p.first;
    193             const Tensor& val = p.second;
    194             response->AddRecv(key, val);
    195           }
    196         }
    197         if (collector) collector->Finalize();
    198         delete collector;
    199         delete out;
    200         done(s);
    201       });
    202 }
    203 
    204 // TODO(suharshs): Add stats collection support to partial run.
    205 void Worker::DoPartialRunGraph(CallOptions* opts,
    206                                RunGraphRequestWrapper* request,
    207                                MutableRunGraphResponseWrapper* response,
    208                                StatusCallback done) {
    209   const int64 step_id = request->step_id();
    210   const string& graph_handle = request->graph_handle();
    211   TRACEPRINTF("PartialRunGraph: %lld", step_id);
    212   auto session =
    213       env_->session_mgr->WorkerSessionForSession(request->session_handle());
    214 
    215   GraphMgr::NamedTensors in;
    216   GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
    217   Status s = PrepareRunGraph(request, &in, out);
    218   auto finish = [this, done, out, opts](const Status& s) {
    219     opts->ClearCancelCallback();
    220     delete out;
    221     done(s);
    222   };
    223   if (!s.ok()) {
    224     finish(s);
    225     return;
    226   }
    227 
    228   CancellationManager* cm = nullptr;
    229   bool is_new_partial_run = partial_run_mgr_.FindOrCreate(step_id, &cm);
    230 
    231   // Before we start doing anything, we set the RPC cancellation.
    232   opts->SetCancelCallback([this, cm, step_id]() {
    233     cm->StartCancel();
    234     AbortStep(step_id);
    235   });
    236 
    237   // If this is a new partial run request, the request will need to start the
    238   // executors.
    239   if (is_new_partial_run) {
    240     CancellationToken token;
    241     {
    242       mutex_lock l(mu_);
    243       token = cancellation_manager_->get_cancellation_token();
    244       cancellation_manager_->RegisterCallback(token,
    245                                               [cm]() { cm->StartCancel(); });
    246     }
    247     session->graph_mgr->ExecuteAsync(
    248         graph_handle, step_id, session.get(), request->exec_opts(),
    249         nullptr /* collector */, nullptr /* response */, cm, in,
    250         [this, token, step_id, session, cm](Status s) {
    251           {
    252             mutex_lock l(mu_);
    253             cancellation_manager_->DeregisterCallback(token);
    254           }
    255           partial_run_mgr_.ExecutorDone(step_id, s);
    256         });
    257   } else {
    258     // Send the partial run's new inputs.
    259     s = session->graph_mgr->SendInputs(step_id, in);
    260     if (!s.ok()) {
    261       finish(s);
    262       return;
    263     }
    264   }
    265 
    266   session->graph_mgr->RecvOutputsAsync(
    267       step_id, out, [this, out, request, response, step_id, finish](Status s) {
    268         if (s.ok()) {
    269           // Construct and return the resp.
    270           for (const auto& p : *out) {
    271             const string& key = p.first;
    272             const Tensor& val = p.second;
    273             response->AddRecv(key, val);
    274           }
    275         }
    276         if (request->is_last_partial_run()) {
    277           partial_run_mgr_.PartialRunDone(step_id, finish, s);
    278         } else {
    279           finish(s);
    280         }
    281       });
    282 }
    283 
    284 void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
    285                                CleanupGraphResponse* response,
    286                                StatusCallback done) {
    287   const int64 step_id = request->step_id();
    288   env_->rendezvous_mgr->Cleanup(step_id);
    289   done(Status::OK());
    290 }
    291 
    292 void Worker::CleanupAllAsync(const CleanupAllRequest* request,
    293                              CleanupAllResponse* response,
    294                              StatusCallback done) {
    295   std::vector<string> containers;
    296   for (const auto& c : request->container()) containers.push_back(c);
    297   env_->device_mgr->ClearContainers(containers);
    298   done(Status::OK());
    299 }
    300 
    301 void Worker::LoggingAsync(const LoggingRequest* request,
    302                           LoggingResponse* response, StatusCallback done) {
    303   done(errors::Unimplemented("Logging"));
    304 }
    305 
    306 void Worker::TracingAsync(const TracingRequest* request,
    307                           TracingResponse* response, StatusCallback done) {
    308   done(errors::Unimplemented("Tracing"));
    309 }
    310 
    311 // Helper for RecvTensor. Validates "key" and returns the source
    312 // device in "*src_dev".
    313 Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
    314                                  Device** src_dev) {
    315   // Figures out which device the tensor is hosted on.
    316   string local_name = DeviceNameUtils::LocalName(parsed.src_device);
    317   TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev));
    318 
    319   // Does the device have the right incarnation number we expect?
    320   if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
    321     return errors::Aborted(
    322         "RecvTensor expects a different device incarnation: ",
    323         parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(),
    324         ". Your worker job was probably restarted. Check your "
    325         "worker job for the reason why it was restarted.");
    326   }
    327 
    328   return Status::OK();
    329 }
    330 
    331 void Worker::RecvTensorAsync(CallOptions* opts,
    332                              const RecvTensorRequest* request,
    333                              TensorResponse* response, StatusCallback done) {
    334   // The base Worker class does not implement RecvTensorAsync, because
    335   // it is not currently used for worker-to-worker communication. Use a
    336   // transport-specific implementation (such as `GrpcWorker::RecvTensorAsync()`)
    337   // instead.
    338   done(errors::Unimplemented("Worker::RecvTensorAsync()"));
    339 }
    340 
    341 }  // namespace tensorflow
    342