Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
     17 
     18 #include <unordered_set>
     19 #include <vector>
     20 
     21 #include "tensorflow/core/common_runtime/copy_tensor.h"
     22 #include "tensorflow/core/common_runtime/device.h"
     23 #include "tensorflow/core/common_runtime/device_mgr.h"
     24 #include "tensorflow/core/common_runtime/dma_helper.h"
     25 #include "tensorflow/core/common_runtime/process_util.h"
     26 #include "tensorflow/core/distributed_runtime/worker_cache.h"
     27 #include "tensorflow/core/distributed_runtime/worker_interface.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/lib/core/errors.h"
     30 #include "tensorflow/core/lib/strings/numbers.h"
     31 #include "tensorflow/core/lib/strings/str_util.h"
     32 #include "tensorflow/core/platform/logging.h"
     33 #include "tensorflow/core/platform/mutex.h"
     34 #include "tensorflow/core/platform/types.h"
     35 
     36 namespace tensorflow {
     37 
     38 static void StartAbortRendevous(Rendezvous* rendez, const Status& s) {
     39   rendez->StartAbort(s);
     40   rendez->Unref();
     41 }
     42 
     43 BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
     44     : worker_env_(worker_env) {}
     45 
     46 BaseRendezvousMgr::~BaseRendezvousMgr() {
     47   for (auto& p : table_) {
     48     auto rendez = p.second;
     49     StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
     50   }
     51 }
     52 
     53 RemoteRendezvous* BaseRendezvousMgr::Find(int64 step_id) {
     54   return FindOrCreate(step_id);
     55 }
     56 
     57 BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
     58   mutex_lock l(mu_);
     59   auto iter = table_.find(step_id);
     60   if (iter == table_.end()) {
     61     auto rr = Create(step_id, worker_env_);
     62     iter = table_.insert({step_id, rr}).first;
     63   }
     64   iter->second->Ref();
     65   return iter->second;
     66 }
     67 
     68 void BaseRendezvousMgr::RecvLocalAsync(int64 step_id,
     69                                        const Rendezvous::ParsedKey& parsed,
     70                                        Rendezvous::DoneCallback done) {
     71   auto rendez = FindOrCreate(step_id);
     72   using namespace std::placeholders;
     73   Rendezvous::DoneCallback done_cb = std::bind(
     74       [rendez](Rendezvous::DoneCallback done,
     75                // Begin unbound arguments.
     76                const Status& s, const Rendezvous::Args& send_args,
     77                const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
     78         rendez->Unref();
     79         done(s, send_args, recv_args, v, dead);
     80       },
     81       std::move(done), _1, _2, _3, _4, _5);
     82   rendez->RecvLocalAsync(parsed, std::move(done_cb));
     83 }
     84 
     85 Status BaseRendezvousMgr::RecvLocal(int64 step_id,
     86                                     const Rendezvous::ParsedKey& parsed,
     87                                     Tensor* val, bool* is_dead) {
     88   Status ret;
     89   Notification n;
     90   RecvLocalAsync(step_id, parsed,
     91                  [val, is_dead, &ret, &n](const Status& s,
     92                                           const Rendezvous::Args& send_args,
     93                                           const Rendezvous::Args& recv_args,
     94                                           const Tensor& v, const bool dead) {
     95                    ret = s;
     96                    *val = v;
     97                    *is_dead = dead;
     98                    n.Notify();
     99                  });
    100   n.WaitForNotification();
    101   return ret;
    102 }
    103 
    104 void BaseRendezvousMgr::Cleanup(int64 step_id) {
    105   Rendezvous* rendez = nullptr;
    106   {
    107     mutex_lock l(mu_);
    108     auto iter = table_.find(step_id);
    109     if (iter != table_.end()) {
    110       rendez = iter->second;
    111       table_.erase(iter);
    112     }
    113   }
    114   if (rendez) {
    115     StartAbortRendevous(rendez, errors::Aborted("Cleanup ", step_id));
    116   }
    117 }
    118 
    119 void BaseRendezvousMgr::CleanupAll() {
    120   std::vector<Rendezvous*> rendezs;
    121   {
    122     mutex_lock l(mu_);
    123     for (const auto& entry : table_) {
    124       rendezs.push_back(entry.second);
    125     }
    126     table_.clear();
    127   }
    128   for (auto rendez : rendezs) {
    129     StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
    130   }
    131 }
    132 
    133 BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id)
    134     : env_(env),
    135       step_id_(step_id),
    136       local_(NewLocalRendezvous()),
    137       session_(nullptr) {}
    138 
    139 BaseRemoteRendezvous::~BaseRemoteRendezvous() {
    140   CHECK(active_.empty());
    141   local_->Unref();
    142 }
    143 
    144 // Returns true if "device_name" is a valid full name of local device
    145 // of the "worker".  This helper is purely based on the worker name
    146 // and device name and does no lookups in the worker->device_mgr.
    147 static bool IsLocalDevice(const string& worker_name,
    148                           const StringPiece device_name) {
    149   return device_name.starts_with(worker_name);
    150 }
    151 
    152 Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
    153   CHECK_NE(session, nullptr) << "session must not be null!";
    154   std::vector<DeferredCall> deferred_calls;
    155   {
    156     mutex_lock l(mu_);
    157     if (session_ != nullptr) {
    158       if (session_->worker_name == session->worker_name) {
    159         LOG(INFO) << "Skipping rendezvous re-initialization.";
    160         return Status::OK();
    161       }
    162       Status s = errors::Internal(
    163           "Double init! Worker names would have changed from: ",
    164           session_->worker_name, " -> ", session->worker_name);
    165       LOG(WARNING) << s;
    166       return s;
    167     }
    168     session_ = session;
    169     std::swap(deferred_calls, deferred_calls_);
    170   }
    171   for (auto& call : deferred_calls) {
    172     RecvLocalAsyncInternal(call.parsed, std::move(call.done));
    173   }
    174   return Status::OK();
    175 }
    176 
    177 WorkerSession* BaseRemoteRendezvous::session() {
    178   mutex_lock l(mu_);
    179   return session_;
    180 }
    181 
    182 bool BaseRemoteRendezvous::is_initialized() {
    183   mutex_lock l(mu_);
    184   return is_initialized_locked();
    185 }
    186 
    187 Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
    188                                   const Rendezvous::Args& args,
    189                                   const Tensor& val, const bool is_dead) {
    190   VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
    191   {
    192     mutex_lock l(mu_);
    193     if (!status_.ok()) return status_;
    194     DCHECK(is_initialized_locked());
    195     if (!IsLocalDevice(session_->worker_name, parsed.src_device)) {
    196       return errors::InvalidArgument(
    197           "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
    198           session_->worker_name);
    199     }
    200   }
    201   // Buffers "val" and "device_context" in local_.
    202   return local_->Send(parsed, args, val, is_dead);
    203 }
    204 
    205 Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
    206                                              bool is_src) {
    207   // Cache session pointer to avoid repeatedly taking & releasing the lock
    208   // (e.g. calling session())
    209   WorkerSession* sess = nullptr;
    210   {
    211     mutex_lock l(mu_);
    212     if (!status_.ok()) return status_;
    213     if (!is_initialized_locked()) {
    214       return errors::Internal("ValidateDevices called before initialization.");
    215     }
    216     sess = session_;
    217   }
    218   if (is_src && !IsLocalDevice(sess->worker_name, parsed.src_device)) {
    219     return errors::InvalidArgument("Invalid rendezvous key (src): ",
    220                                    parsed.FullKey(), " @ ", sess->worker_name);
    221   }
    222   if (!is_src && !IsLocalDevice(sess->worker_name, parsed.dst_device)) {
    223     return errors::InvalidArgument("Invalid rendezvous key (dst): ",
    224                                    parsed.FullKey(), " @ ", sess->worker_name);
    225   }
    226   return Status::OK();
    227 }
    228 
    229 void BaseRemoteRendezvous::SameWorkerRecvDone(
    230     const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
    231     const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
    232     StatusCallback done) {
    233   // Do a quick copy (sharing the underlying buffer) if both tensors
    234   // are on host memory.
    235   const bool src_host =
    236       (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
    237   const bool dst_host =
    238       (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
    239   if (src_host && dst_host) {
    240     *out = in;
    241     done(Status::OK());
    242     return;
    243   }
    244 
    245   // This copy must involve a GPU. Hence, "in" must support DMA
    246   // (e.g., string tensors do not work on GPU).  Variant copy DMA
    247   // checks happen inside CopyTensor::ViaDMA.
    248   if (!DMAHelper::CanUseDMA(&in) && in.dtype() != DT_VARIANT) {
    249     done(errors::InvalidArgument("Non-DMA-safe ", DataTypeString(in.dtype()),
    250                                  " tensor may not be copied from/to a GPU."));
    251     return;
    252   }
    253 
    254   WorkerSession* sess = session();
    255   Device* src_device;
    256   Status s = sess->device_mgr->LookupDevice(parsed.src_device, &src_device);
    257   if (!s.ok()) {
    258     done(s);
    259     return;
    260   }
    261   Device* dst_device;
    262   s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
    263   if (!s.ok()) {
    264     done(s);
    265     return;
    266   }
    267 
    268   AllocatorAttributes attr = recv_args.alloc_attrs;
    269   attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
    270                           recv_args.alloc_attrs.gpu_compatible());
    271   Allocator* out_allocator = dst_device->GetAllocator(attr);
    272 
    273   if (in.dtype() != DT_VARIANT) {
    274     // Variants are handled by CopyTensor::ViaDMA.
    275     Tensor copy(out_allocator, in.dtype(), in.shape());
    276     *out = copy;
    277   }
    278 
    279   // The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies,
    280   // etc.
    281   CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context,
    282                      recv_args.device_context, src_device, dst_device,
    283                      send_args.alloc_attrs, recv_args.alloc_attrs, &in, out,
    284                      std::move(done));
    285 }
    286 
    287 bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
    288                                         DeviceNameUtils::ParsedName dst) {
    289   return DeviceNameUtils::IsSameAddressSpace(src, dst);
    290 }
    291 
    292 void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
    293                                      const Rendezvous::Args& recv_args,
    294                                      DoneCallback done) {
    295   VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
    296   CHECK(is_initialized()) << "RecvAsync called when uninitialized.";
    297   Status s = ValidateDevices(parsed, false /*!is_src*/);
    298   if (!s.ok()) {
    299     done(s, Args(), recv_args, Tensor(), false);
    300     return;
    301   }
    302 
    303   // Are src and dst in the same worker?
    304   if (IsSameWorker(parsed.src, parsed.dst)) {
    305     // Recv the tensor from local_.
    306     local_->RecvAsync(
    307         parsed, recv_args,
    308         [this, parsed, done](
    309             const Status& status, const Rendezvous::Args& send_args,
    310             const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
    311           Tensor* out = new Tensor;
    312           StatusCallback final_callback = [done, send_args, recv_args, out,
    313                                            is_dead](const Status& s) {
    314             done(s, send_args, recv_args, *out, is_dead);
    315             delete out;
    316           };
    317 
    318           if (status.ok()) {
    319             SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
    320                                std::move(final_callback));
    321           } else {
    322             final_callback(status);
    323           }
    324         });
    325     return;
    326   } else {
    327     RecvFromRemoteAsync(parsed, recv_args, std::move(done));
    328   }
    329 }
    330 
    331 void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
    332                                           DoneCallback done) {
    333   {
    334     mutex_lock l(mu_);
    335     if (!is_initialized_locked()) {
    336       // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
    337       // remote worker) before the RunStep (or PartialRunStep) RPC from the
    338       // master arrives. RecvLocalAsync thus buffers the arguments until after
    339       // the RemoteRendezvous is Initialize()'d, when it completes the
    340       // rendezvous logic. At some point after Initialize() is called, a Tensor
    341       // is produced locally that will then be sent in response to the incoming
    342       // RPC.
    343       DeferredCall call(parsed, std::move(done));
    344       deferred_calls_.push_back(call);
    345       return;
    346     }
    347   }
    348   RecvLocalAsyncInternal(parsed, std::move(done));
    349 }
    350 
    351 void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
    352                                                   DoneCallback done) {
    353   Status s = ValidateDevices(parsed, true /* is_src */);
    354   if (!s.ok()) {
    355     done(s, Args(), Args(), Tensor(), false);
    356     return;
    357   }
    358   local_->RecvAsync(parsed, Args(), std::move(done));
    359 }
    360 
    361 void BaseRemoteRendezvous::StartAbort(const Status& s) {
    362   CHECK(!s.ok());
    363   local_->StartAbort(s);
    364   {
    365     // Aborts all active RecvTensor calls.
    366     mutex_lock l(mu_);
    367     if (status_.ok()) {
    368       status_ = s;
    369       for (BaseRecvTensorCall* call : active_) {
    370         call->StartAbort(s);
    371       }
    372       active_.clear();
    373     }
    374   }
    375 }
    376 
    377 void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call) {
    378   mutex_lock l(mu_);
    379   if (!status_.ok()) {
    380     call->StartAbort(status_);
    381   } else {
    382     CHECK(active_.insert(call).second);
    383   }
    384 }
    385 
    386 void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
    387   mutex_lock l(mu_);
    388   active_.erase(call);
    389 }
    390 
    391 BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
    392                                                  DoneCallback done)
    393     : parsed(parsed), done(std::move(done)) {}
    394 
    395 }  // end namespace tensorflow
    396