1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/contrib/gdr/gdr_rendezvous_mgr.h" 17 18 #include "google/protobuf/any.pb.h" 19 #include "tensorflow/contrib/gdr/gdr_memory_manager.h" 20 #include "tensorflow/core/common_runtime/device.h" 21 #include "tensorflow/core/common_runtime/device_mgr.h" 22 #include "tensorflow/core/common_runtime/process_util.h" 23 #include "tensorflow/core/distributed_runtime/request_id.h" 24 #include "tensorflow/core/distributed_runtime/tensor_coding.h" 25 #include "tensorflow/core/distributed_runtime/worker_cache.h" 26 #include "tensorflow/core/distributed_runtime/worker_interface.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 #include "tensorflow/core/lib/strings/numbers.h" 30 #include "tensorflow/core/lib/strings/str_util.h" 31 #include "tensorflow/core/platform/logging.h" 32 #include "tensorflow/core/platform/macros.h" 33 #include "tensorflow/core/platform/types.h" 34 35 namespace tensorflow { 36 37 namespace { 38 39 class GdrRecvTensorCall : public BaseRecvTensorCall { 40 public: 41 GdrRecvTensorCall(WorkerInterface* wi, Device* dst_device, 42 RemoteMemoryManager* remote_memory_manager, 43 const Rendezvous::Args& recv_args, int64 step_id, 44 StringPiece key) 45 : wi_(wi), 46 dst_device_(dst_device), 47 remote_memory_manager_(remote_memory_manager), 48 recv_args_(recv_args) { 49 req_.set_step_id(step_id); 50 req_.set_rendezvous_key(key.data(), key.size()); 51 req_.set_request_id(GetUniqueRequestId()); 52 } 53 54 ~GdrRecvTensorCall() override {} 55 56 void Start(std::function<void()> recv_done) override { 57 req_.set_dma_ok(true); 58 resp_.InitAlloc(dst_device_, recv_args_.alloc_attrs); 59 StatusCallback cb = [this, recv_done](const Status& s) { 60 bool dma_ok = resp_.metadata().has_transport_options(); 61 if (s.ok() && tensor().TotalBytes() > 0 && (!is_dead()) && dma_ok) { 62 auto transport_options = resp_.metadata().transport_options(); 63 const bool on_host = 64 (dst_device_->tensorflow_gpu_device_info() == nullptr) || 65 recv_args_.alloc_attrs.on_host(); 66 remote_memory_manager_->TensorFromTransportOptions( 67 const_cast<Tensor*>(&tensor()), transport_options, dst_device_, 68 recv_args_.device_context, on_host, 69 [this, recv_done](const Status& s) { 70 if (!s.ok()) { 71 mutex_lock l(mu_); 72 status_.Update(s); 73 LOG(ERROR) << "Cannot find pinned memory region from allocator " 74 << dst_device_->GetAllocator(recv_args_.alloc_attrs) 75 ->Name(); 76 } 77 recv_done(); 78 }); 79 return; 80 } 81 if (!s.ok()) { 82 mutex_lock l(mu_); 83 status_.Update(s); 84 } 85 recv_done(); 86 }; 87 wi_->RecvTensorAsync(&opts_, &req_, &resp_, std::move(cb)); 88 } 89 90 void StartAbort(const Status& s) override { 91 { 92 mutex_lock l(mu_); 93 status_.Update(s); 94 } 95 opts_.StartCancel(); 96 } 97 98 Status status() const override { 99 mutex_lock l(mu_); 100 return status_; 101 } 102 103 const Tensor& tensor() const { return resp_.tensor(); } 104 105 bool is_dead() const { return resp_.metadata().is_dead(); } 106 107 Device* dst_device() const { return dst_device_; } 108 109 const Rendezvous::Args& recv_args() const { return recv_args_; } 110 111 private: 112 WorkerInterface* wi_; 113 Device* dst_device_; 114 RemoteMemoryManager* remote_memory_manager_; 115 CallOptions opts_; 116 RecvTensorRequest req_; 117 TensorResponse resp_; 118 Rendezvous::Args recv_args_; 119 120 mutable mutex mu_; 121 Status status_ GUARDED_BY(mu_); 122 123 TF_DISALLOW_COPY_AND_ASSIGN(GdrRecvTensorCall); 124 }; 125 126 class GdrRemoteRendezvous : public BaseRemoteRendezvous { 127 public: 128 GdrRemoteRendezvous(const WorkerEnv* env, int64 step_id, 129 RemoteMemoryManager* remote_memory_manager) 130 : BaseRemoteRendezvous(env, step_id), 131 remote_memory_manager_(remote_memory_manager) {} 132 133 protected: 134 void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, 135 const Rendezvous::Args& recv_args, 136 DoneCallback done) override { 137 CHECK(is_initialized()); 138 139 string src_worker; 140 string src_rel_device; 141 if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_worker, 142 &src_rel_device)) { 143 Status s = errors::Internal(parsed.src_device, 144 " is invalid remote source device."); 145 done(s, Args(), recv_args, Tensor{}, false); 146 return; 147 } 148 149 WorkerSession* sess = session(); 150 WorkerInterface* rwi = sess->worker_cache->CreateWorker(src_worker); 151 if (rwi == nullptr) { 152 Status s = errors::Internal("No worker known as ", src_worker); 153 done(s, Args(), recv_args, Tensor{}, false); 154 return; 155 } 156 157 Device* dst_device; 158 Status s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device); 159 if (!s.ok()) { 160 sess->worker_cache->ReleaseWorker(src_worker, rwi); 161 done(s, Args(), recv_args, Tensor{}, false); 162 return; 163 } 164 165 // Prepare a RecvTensor call that can handle being aborted. 166 GdrRecvTensorCall* call = 167 new GdrRecvTensorCall(rwi, dst_device, remote_memory_manager_, 168 recv_args, step_id_, parsed.FullKey()); 169 170 // Record "call" in active_ so that it can be aborted cleanly. 171 RegisterCall(call); 172 173 // Start "call". 174 Ref(); 175 call->Start([this, call, src_worker, rwi, done]() { 176 // Removes "call" from active_. Prevent StartAbort(). 177 DeregisterCall(call); 178 // If StartAbort was called prior to DeregisterCall, then the 179 // current status should be bad. 180 Status s = call->status(); 181 done(s, Args(), call->recv_args(), call->tensor(), call->is_dead()); 182 session()->worker_cache->ReleaseWorker(src_worker, rwi); 183 delete call; 184 Unref(); 185 }); 186 } 187 188 private: 189 ~GdrRemoteRendezvous() override {} 190 191 RemoteMemoryManager* remote_memory_manager_; 192 193 TF_DISALLOW_COPY_AND_ASSIGN(GdrRemoteRendezvous); 194 }; 195 196 } // namespace 197 198 GdrRendezvousMgr::GdrRendezvousMgr(const WorkerEnv* env, 199 RemoteMemoryManager* remote_memory_manager) 200 : BaseRendezvousMgr(env), remote_memory_manager_(remote_memory_manager) {} 201 202 BaseRemoteRendezvous* GdrRendezvousMgr::Create(int64 step_id, 203 const WorkerEnv* worker_env) { 204 return new GdrRemoteRendezvous(worker_env, step_id, remote_memory_manager_); 205 } 206 207 } // end namespace tensorflow 208