1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" 17 18 #include <unordered_set> 19 #include <vector> 20 21 #include "tensorflow/core/common_runtime/copy_tensor.h" 22 #include "tensorflow/core/common_runtime/device.h" 23 #include "tensorflow/core/common_runtime/device_mgr.h" 24 #include "tensorflow/core/common_runtime/dma_helper.h" 25 #include "tensorflow/core/common_runtime/process_util.h" 26 #include "tensorflow/core/distributed_runtime/worker_cache.h" 27 #include "tensorflow/core/distributed_runtime/worker_interface.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/lib/core/errors.h" 30 #include "tensorflow/core/lib/strings/numbers.h" 31 #include "tensorflow/core/lib/strings/str_util.h" 32 #include "tensorflow/core/platform/logging.h" 33 #include "tensorflow/core/platform/mutex.h" 34 #include "tensorflow/core/platform/types.h" 35 36 namespace tensorflow { 37 38 static void StartAbortRendevous(Rendezvous* rendez, const Status& s) { 39 rendez->StartAbort(s); 40 rendez->Unref(); 41 } 42 43 BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env) 44 : worker_env_(worker_env) {} 45 46 BaseRendezvousMgr::~BaseRendezvousMgr() { 47 for (auto& p : table_) { 48 auto rendez = p.second; 49 StartAbortRendevous(rendez, errors::Aborted("Shutdown")); 50 } 51 } 52 53 RemoteRendezvous* BaseRendezvousMgr::Find(int64 step_id) { 54 return FindOrCreate(step_id); 55 } 56 57 BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) { 58 mutex_lock l(mu_); 59 auto iter = table_.find(step_id); 60 if (iter == table_.end()) { 61 auto rr = Create(step_id, worker_env_); 62 iter = table_.insert({step_id, rr}).first; 63 } 64 iter->second->Ref(); 65 return iter->second; 66 } 67 68 void BaseRendezvousMgr::RecvLocalAsync(int64 step_id, 69 const Rendezvous::ParsedKey& parsed, 70 Rendezvous::DoneCallback done) { 71 auto rendez = FindOrCreate(step_id); 72 using namespace std::placeholders; 73 Rendezvous::DoneCallback done_cb = std::bind( 74 [rendez](Rendezvous::DoneCallback done, 75 // Begin unbound arguments. 76 const Status& s, const Rendezvous::Args& send_args, 77 const Rendezvous::Args& recv_args, const Tensor& v, bool dead) { 78 rendez->Unref(); 79 done(s, send_args, recv_args, v, dead); 80 }, 81 std::move(done), _1, _2, _3, _4, _5); 82 rendez->RecvLocalAsync(parsed, std::move(done_cb)); 83 } 84 85 Status BaseRendezvousMgr::RecvLocal(int64 step_id, 86 const Rendezvous::ParsedKey& parsed, 87 Tensor* val, bool* is_dead) { 88 Status ret; 89 Notification n; 90 RecvLocalAsync(step_id, parsed, 91 [val, is_dead, &ret, &n](const Status& s, 92 const Rendezvous::Args& send_args, 93 const Rendezvous::Args& recv_args, 94 const Tensor& v, const bool dead) { 95 ret = s; 96 *val = v; 97 *is_dead = dead; 98 n.Notify(); 99 }); 100 n.WaitForNotification(); 101 return ret; 102 } 103 104 void BaseRendezvousMgr::Cleanup(int64 step_id) { 105 Rendezvous* rendez = nullptr; 106 { 107 mutex_lock l(mu_); 108 auto iter = table_.find(step_id); 109 if (iter != table_.end()) { 110 rendez = iter->second; 111 table_.erase(iter); 112 } 113 } 114 if (rendez) { 115 StartAbortRendevous(rendez, errors::Aborted("Cleanup ", step_id)); 116 } 117 } 118 119 void BaseRendezvousMgr::CleanupAll() { 120 std::vector<Rendezvous*> rendezs; 121 { 122 mutex_lock l(mu_); 123 for (const auto& entry : table_) { 124 rendezs.push_back(entry.second); 125 } 126 table_.clear(); 127 } 128 for (auto rendez : rendezs) { 129 StartAbortRendevous(rendez, errors::Aborted("Shutdown")); 130 } 131 } 132 133 BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id) 134 : env_(env), 135 step_id_(step_id), 136 local_(NewLocalRendezvous()), 137 session_(nullptr) {} 138 139 BaseRemoteRendezvous::~BaseRemoteRendezvous() { 140 CHECK(active_.empty()); 141 local_->Unref(); 142 } 143 144 // Returns true if "device_name" is a valid full name of local device 145 // of the "worker". This helper is purely based on the worker name 146 // and device name and does no lookups in the worker->device_mgr. 147 static bool IsLocalDevice(const string& worker_name, 148 const StringPiece device_name) { 149 return device_name.starts_with(worker_name); 150 } 151 152 Status BaseRemoteRendezvous::Initialize(WorkerSession* session) { 153 CHECK_NE(session, nullptr) << "session must not be null!"; 154 std::vector<DeferredCall> deferred_calls; 155 { 156 mutex_lock l(mu_); 157 if (session_ != nullptr) { 158 if (session_->worker_name == session->worker_name) { 159 LOG(INFO) << "Skipping rendezvous re-initialization."; 160 return Status::OK(); 161 } 162 Status s = errors::Internal( 163 "Double init! Worker names would have changed from: ", 164 session_->worker_name, " -> ", session->worker_name); 165 LOG(WARNING) << s; 166 return s; 167 } 168 session_ = session; 169 std::swap(deferred_calls, deferred_calls_); 170 } 171 for (auto& call : deferred_calls) { 172 RecvLocalAsyncInternal(call.parsed, std::move(call.done)); 173 } 174 return Status::OK(); 175 } 176 177 WorkerSession* BaseRemoteRendezvous::session() { 178 mutex_lock l(mu_); 179 return session_; 180 } 181 182 bool BaseRemoteRendezvous::is_initialized() { 183 mutex_lock l(mu_); 184 return is_initialized_locked(); 185 } 186 187 Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, 188 const Rendezvous::Args& args, 189 const Tensor& val, const bool is_dead) { 190 VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey(); 191 { 192 mutex_lock l(mu_); 193 if (!status_.ok()) return status_; 194 DCHECK(is_initialized_locked()); 195 if (!IsLocalDevice(session_->worker_name, parsed.src_device)) { 196 return errors::InvalidArgument( 197 "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ", 198 session_->worker_name); 199 } 200 } 201 // Buffers "val" and "device_context" in local_. 202 return local_->Send(parsed, args, val, is_dead); 203 } 204 205 Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed, 206 bool is_src) { 207 // Cache session pointer to avoid repeatedly taking & releasing the lock 208 // (e.g. calling session()) 209 WorkerSession* sess = nullptr; 210 { 211 mutex_lock l(mu_); 212 if (!status_.ok()) return status_; 213 if (!is_initialized_locked()) { 214 return errors::Internal("ValidateDevices called before initialization."); 215 } 216 sess = session_; 217 } 218 if (is_src && !IsLocalDevice(sess->worker_name, parsed.src_device)) { 219 return errors::InvalidArgument("Invalid rendezvous key (src): ", 220 parsed.FullKey(), " @ ", sess->worker_name); 221 } 222 if (!is_src && !IsLocalDevice(sess->worker_name, parsed.dst_device)) { 223 return errors::InvalidArgument("Invalid rendezvous key (dst): ", 224 parsed.FullKey(), " @ ", sess->worker_name); 225 } 226 return Status::OK(); 227 } 228 229 void BaseRemoteRendezvous::SameWorkerRecvDone( 230 const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args, 231 const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out, 232 StatusCallback done) { 233 // Do a quick copy (sharing the underlying buffer) if both tensors 234 // are on host memory. 235 const bool src_host = 236 (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU"); 237 const bool dst_host = 238 (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU"); 239 if (src_host && dst_host) { 240 *out = in; 241 done(Status::OK()); 242 return; 243 } 244 245 // This copy must involve a GPU. Hence, "in" must support DMA 246 // (e.g., string tensors do not work on GPU). Variant copy DMA 247 // checks happen inside CopyTensor::ViaDMA. 248 if (!DMAHelper::CanUseDMA(&in) && in.dtype() != DT_VARIANT) { 249 done(errors::InvalidArgument("Non-DMA-safe ", DataTypeString(in.dtype()), 250 " tensor may not be copied from/to a GPU.")); 251 return; 252 } 253 254 WorkerSession* sess = session(); 255 Device* src_device; 256 Status s = sess->device_mgr->LookupDevice(parsed.src_device, &src_device); 257 if (!s.ok()) { 258 done(s); 259 return; 260 } 261 Device* dst_device; 262 s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device); 263 if (!s.ok()) { 264 done(s); 265 return; 266 } 267 268 AllocatorAttributes attr = recv_args.alloc_attrs; 269 attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() || 270 recv_args.alloc_attrs.gpu_compatible()); 271 Allocator* out_allocator = dst_device->GetAllocator(attr); 272 273 if (in.dtype() != DT_VARIANT) { 274 // Variants are handled by CopyTensor::ViaDMA. 275 Tensor copy(out_allocator, in.dtype(), in.shape()); 276 *out = copy; 277 } 278 279 // The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies, 280 // etc. 281 CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context, 282 recv_args.device_context, src_device, dst_device, 283 send_args.alloc_attrs, recv_args.alloc_attrs, &in, out, 284 std::move(done)); 285 } 286 287 bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src, 288 DeviceNameUtils::ParsedName dst) { 289 return DeviceNameUtils::IsSameAddressSpace(src, dst); 290 } 291 292 void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, 293 const Rendezvous::Args& recv_args, 294 DoneCallback done) { 295 VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey(); 296 CHECK(is_initialized()) << "RecvAsync called when uninitialized."; 297 Status s = ValidateDevices(parsed, false /*!is_src*/); 298 if (!s.ok()) { 299 done(s, Args(), recv_args, Tensor(), false); 300 return; 301 } 302 303 // Are src and dst in the same worker? 304 if (IsSameWorker(parsed.src, parsed.dst)) { 305 // Recv the tensor from local_. 306 local_->RecvAsync( 307 parsed, recv_args, 308 [this, parsed, done]( 309 const Status& status, const Rendezvous::Args& send_args, 310 const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) { 311 Tensor* out = new Tensor; 312 StatusCallback final_callback = [done, send_args, recv_args, out, 313 is_dead](const Status& s) { 314 done(s, send_args, recv_args, *out, is_dead); 315 delete out; 316 }; 317 318 if (status.ok()) { 319 SameWorkerRecvDone(parsed, send_args, recv_args, in, out, 320 std::move(final_callback)); 321 } else { 322 final_callback(status); 323 } 324 }); 325 return; 326 } else { 327 RecvFromRemoteAsync(parsed, recv_args, std::move(done)); 328 } 329 } 330 331 void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed, 332 DoneCallback done) { 333 { 334 mutex_lock l(mu_); 335 if (!is_initialized_locked()) { 336 // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a 337 // remote worker) before the RunStep (or PartialRunStep) RPC from the 338 // master arrives. RecvLocalAsync thus buffers the arguments until after 339 // the RemoteRendezvous is Initialize()'d, when it completes the 340 // rendezvous logic. At some point after Initialize() is called, a Tensor 341 // is produced locally that will then be sent in response to the incoming 342 // RPC. 343 DeferredCall call(parsed, std::move(done)); 344 deferred_calls_.push_back(call); 345 return; 346 } 347 } 348 RecvLocalAsyncInternal(parsed, std::move(done)); 349 } 350 351 void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed, 352 DoneCallback done) { 353 Status s = ValidateDevices(parsed, true /* is_src */); 354 if (!s.ok()) { 355 done(s, Args(), Args(), Tensor(), false); 356 return; 357 } 358 local_->RecvAsync(parsed, Args(), std::move(done)); 359 } 360 361 void BaseRemoteRendezvous::StartAbort(const Status& s) { 362 CHECK(!s.ok()); 363 local_->StartAbort(s); 364 { 365 // Aborts all active RecvTensor calls. 366 mutex_lock l(mu_); 367 if (status_.ok()) { 368 status_ = s; 369 for (BaseRecvTensorCall* call : active_) { 370 call->StartAbort(s); 371 } 372 active_.clear(); 373 } 374 } 375 } 376 377 void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call) { 378 mutex_lock l(mu_); 379 if (!status_.ok()) { 380 call->StartAbort(status_); 381 } else { 382 CHECK(active_.insert(call).second); 383 } 384 } 385 386 void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) { 387 mutex_lock l(mu_); 388 active_.erase(call); 389 } 390 391 BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed, 392 DoneCallback done) 393 : parsed(parsed), done(std::move(done)) {} 394 395 } // end namespace tensorflow 396