Home | History | Annotate | Download | only in gdr
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/gdr/gdr_rendezvous_mgr.h"
     17 
     18 #include "google/protobuf/any.pb.h"
     19 #include "tensorflow/contrib/gdr/gdr_memory_manager.h"
     20 #include "tensorflow/core/common_runtime/device.h"
     21 #include "tensorflow/core/common_runtime/device_mgr.h"
     22 #include "tensorflow/core/common_runtime/process_util.h"
     23 #include "tensorflow/core/distributed_runtime/request_id.h"
     24 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
     25 #include "tensorflow/core/distributed_runtime/worker_cache.h"
     26 #include "tensorflow/core/distributed_runtime/worker_interface.h"
     27 #include "tensorflow/core/framework/types.h"
     28 #include "tensorflow/core/lib/core/errors.h"
     29 #include "tensorflow/core/lib/strings/numbers.h"
     30 #include "tensorflow/core/lib/strings/str_util.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 #include "tensorflow/core/platform/macros.h"
     33 #include "tensorflow/core/platform/types.h"
     34 
     35 namespace tensorflow {
     36 
     37 namespace {
     38 
     39 class GdrRecvTensorCall : public BaseRecvTensorCall {
     40  public:
     41   GdrRecvTensorCall(WorkerInterface* wi, Device* dst_device,
     42                     RemoteMemoryManager* remote_memory_manager,
     43                     const Rendezvous::Args& recv_args, int64 step_id,
     44                     StringPiece key)
     45       : wi_(wi),
     46         dst_device_(dst_device),
     47         remote_memory_manager_(remote_memory_manager),
     48         recv_args_(recv_args) {
     49     req_.set_step_id(step_id);
     50     req_.set_rendezvous_key(key.data(), key.size());
     51     req_.set_request_id(GetUniqueRequestId());
     52   }
     53 
     54   ~GdrRecvTensorCall() override {}
     55 
     56   void Start(std::function<void()> recv_done) override {
     57     req_.set_dma_ok(true);
     58     resp_.InitAlloc(dst_device_, recv_args_.alloc_attrs);
     59     StatusCallback cb = [this, recv_done](const Status& s) {
     60       bool dma_ok = resp_.metadata().has_transport_options();
     61       if (s.ok() && tensor().TotalBytes() > 0 && (!is_dead()) && dma_ok) {
     62         auto transport_options = resp_.metadata().transport_options();
     63         const bool on_host =
     64             (dst_device_->tensorflow_gpu_device_info() == nullptr) ||
     65             recv_args_.alloc_attrs.on_host();
     66         remote_memory_manager_->TensorFromTransportOptions(
     67             const_cast<Tensor*>(&tensor()), transport_options, dst_device_,
     68             recv_args_.device_context, on_host,
     69             [this, recv_done](const Status& s) {
     70               if (!s.ok()) {
     71                 mutex_lock l(mu_);
     72                 status_.Update(s);
     73                 LOG(ERROR) << "Cannot find pinned memory region from allocator "
     74                            << dst_device_->GetAllocator(recv_args_.alloc_attrs)
     75                                   ->Name();
     76               }
     77               recv_done();
     78             });
     79         return;
     80       }
     81       if (!s.ok()) {
     82         mutex_lock l(mu_);
     83         status_.Update(s);
     84       }
     85       recv_done();
     86     };
     87     wi_->RecvTensorAsync(&opts_, &req_, &resp_, std::move(cb));
     88   }
     89 
     90   void StartAbort(const Status& s) override {
     91     {
     92       mutex_lock l(mu_);
     93       status_.Update(s);
     94     }
     95     opts_.StartCancel();
     96   }
     97 
     98   Status status() const override {
     99     mutex_lock l(mu_);
    100     return status_;
    101   }
    102 
    103   const Tensor& tensor() const { return resp_.tensor(); }
    104 
    105   bool is_dead() const { return resp_.metadata().is_dead(); }
    106 
    107   Device* dst_device() const { return dst_device_; }
    108 
    109   const Rendezvous::Args& recv_args() const { return recv_args_; }
    110 
    111  private:
    112   WorkerInterface* wi_;
    113   Device* dst_device_;
    114   RemoteMemoryManager* remote_memory_manager_;
    115   CallOptions opts_;
    116   RecvTensorRequest req_;
    117   TensorResponse resp_;
    118   Rendezvous::Args recv_args_;
    119 
    120   mutable mutex mu_;
    121   Status status_ GUARDED_BY(mu_);
    122 
    123   TF_DISALLOW_COPY_AND_ASSIGN(GdrRecvTensorCall);
    124 };
    125 
    126 class GdrRemoteRendezvous : public BaseRemoteRendezvous {
    127  public:
    128   GdrRemoteRendezvous(const WorkerEnv* env, int64 step_id,
    129                       RemoteMemoryManager* remote_memory_manager)
    130       : BaseRemoteRendezvous(env, step_id),
    131         remote_memory_manager_(remote_memory_manager) {}
    132 
    133  protected:
    134   void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
    135                            const Rendezvous::Args& recv_args,
    136                            DoneCallback done) override {
    137     CHECK(is_initialized());
    138 
    139     string src_worker;
    140     string src_rel_device;
    141     if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_worker,
    142                                           &src_rel_device)) {
    143       Status s = errors::Internal(parsed.src_device,
    144                                   " is invalid remote source device.");
    145       done(s, Args(), recv_args, Tensor{}, false);
    146       return;
    147     }
    148 
    149     WorkerSession* sess = session();
    150     WorkerInterface* rwi = sess->worker_cache->CreateWorker(src_worker);
    151     if (rwi == nullptr) {
    152       Status s = errors::Internal("No worker known as ", src_worker);
    153       done(s, Args(), recv_args, Tensor{}, false);
    154       return;
    155     }
    156 
    157     Device* dst_device;
    158     Status s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
    159     if (!s.ok()) {
    160       sess->worker_cache->ReleaseWorker(src_worker, rwi);
    161       done(s, Args(), recv_args, Tensor{}, false);
    162       return;
    163     }
    164 
    165     // Prepare a RecvTensor call that can handle being aborted.
    166     GdrRecvTensorCall* call =
    167         new GdrRecvTensorCall(rwi, dst_device, remote_memory_manager_,
    168                               recv_args, step_id_, parsed.FullKey());
    169 
    170     // Record "call" in active_ so that it can be aborted cleanly.
    171     RegisterCall(call);
    172 
    173     // Start "call".
    174     Ref();
    175     call->Start([this, call, src_worker, rwi, done]() {
    176       // Removes "call" from active_. Prevent StartAbort().
    177       DeregisterCall(call);
    178       // If StartAbort was called prior to DeregisterCall, then the
    179       // current status should be bad.
    180       Status s = call->status();
    181       done(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
    182       session()->worker_cache->ReleaseWorker(src_worker, rwi);
    183       delete call;
    184       Unref();
    185     });
    186   }
    187 
    188  private:
    189   ~GdrRemoteRendezvous() override {}
    190 
    191   RemoteMemoryManager* remote_memory_manager_;
    192 
    193   TF_DISALLOW_COPY_AND_ASSIGN(GdrRemoteRendezvous);
    194 };
    195 
    196 }  // namespace
    197 
    198 GdrRendezvousMgr::GdrRendezvousMgr(const WorkerEnv* env,
    199                                    RemoteMemoryManager* remote_memory_manager)
    200     : BaseRendezvousMgr(env), remote_memory_manager_(remote_memory_manager) {}
    201 
    202 BaseRemoteRendezvous* GdrRendezvousMgr::Create(int64 step_id,
    203                                                const WorkerEnv* worker_env) {
    204   return new GdrRemoteRendezvous(worker_env, step_id, remote_memory_manager_);
    205 }
    206 
    207 }  // end namespace tensorflow
    208