1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/common_runtime/ring_alg.h" 16 17 #include <stdlib.h> 18 #include <atomic> 19 #include <functional> 20 #include <utility> 21 22 #include "tensorflow/core/common_runtime/collective_rma_local.h" 23 #include "tensorflow/core/common_runtime/collective_util.h" 24 #include "tensorflow/core/common_runtime/copy_tensor.h" 25 #include "tensorflow/core/common_runtime/device.h" 26 #include "tensorflow/core/common_runtime/device_mgr.h" 27 #include "tensorflow/core/common_runtime/dma_helper.h" 28 #include "tensorflow/core/common_runtime/process_util.h" 29 #include "tensorflow/core/framework/allocator.h" 30 #include "tensorflow/core/framework/device_base.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/tensor.h" 33 #include "tensorflow/core/framework/types.h" 34 #include "tensorflow/core/lib/core/errors.h" 35 #include "tensorflow/core/lib/core/notification.h" 36 #include "tensorflow/core/lib/core/status.h" 37 #include "tensorflow/core/lib/strings/str_util.h" 38 #include "tensorflow/core/lib/strings/strcat.h" 39 #include "tensorflow/core/platform/env.h" 40 #include "tensorflow/core/platform/types.h" 41 42 // Set true for greater intelligibility of debug mode log messages. 43 #define READABLE_KEYS false 44 // A ring algorithm exchanges chunks of tensor between devices. The chunk size 45 // depends on the number of subdivisions specified in the algorithm. If the 46 // user does not specify the number of subdivisions we may infer the number 47 // dynamically so that the resulting chunk size does not exceed 48 // kMaxChunkSizeBytes, empirically set at 4 MiB. 49 constexpr size_t kMaxChunkSizeBytes = (4 * 1024 * 1024); 50 // kMaxSubdivsPerDev is used to give an upper bound on the number of 51 // subdivisions dynamically generated. A reasonable value would be a small 52 // multiple of the number of NICs adjacent to each device. 53 constexpr int kMaxSubdivsPerDevice = 2; 54 55 namespace tensorflow { 56 namespace { 57 // Each CollectiveOp implementation is free to define its own 58 // BufRendezvous key format. This function produces the key used by 59 // RingAlg instances. Note that the exec_key will differentiate between 60 // different instances consequently we don't need to further differentiate 61 // between subclasses of RingAlg. 62 string RingAlgBufKey(const string& name, const string& exec_key, int pass, 63 int section, int source_rank) { 64 if (READABLE_KEYS) { 65 return strings::StrCat(name, "(", exec_key, "):pass(", pass, "):section(", 66 section, "):srcrank(", source_rank, ")"); 67 } else { 68 // TODO(b/78352018): Try out some kind of denser encoding, e.g. 128 bit 69 // hash. 70 return strings::StrCat(exec_key, ":", pass, ":", section, ":", source_rank); 71 } 72 } 73 74 } // namespace 75 76 void RingAlg::PCQueue::Enqueue(RingField* rf) { 77 mutex_lock l(pcq_mu_); 78 deque_.push_back(rf); 79 if (waiter_count_ > 0) { 80 cv_.notify_one(); 81 } 82 } 83 84 RingAlg::RingField* RingAlg::PCQueue::Dequeue() { 85 mutex_lock l(pcq_mu_); 86 if (deque_.empty()) { 87 ++waiter_count_; 88 while (deque_.empty()) { 89 cv_.wait(l); 90 } 91 --waiter_count_; 92 } 93 RingField* rf = deque_.front(); 94 deque_.pop_front(); 95 return rf; 96 } 97 98 RingAlg::RingAlg(CollectiveType type, const string& name) 99 : type_(type), 100 name_(name), 101 col_ctx_(nullptr), 102 col_params_(nullptr), 103 done_(nullptr), 104 group_size_(-1), 105 num_subdivs_(-1) {} 106 107 namespace { 108 Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) { 109 if (col_params->instance.shape.num_elements() == 0) { 110 return errors::Internal("shape in CollectiveParams should be non-empty"); 111 } 112 const int kAvgDevPerTask = 113 col_params->group.group_size / col_params->group.num_tasks; 114 const int kMaxNumSubdivs = kMaxSubdivsPerDevice * kAvgDevPerTask; 115 if (kMaxNumSubdivs <= 0) { 116 return errors::Internal("Unexpected kMaxNumSubdivs ", kMaxNumSubdivs, 117 " in ", 118 col_params->instance.impl_details.collective_name); 119 } 120 // NOTE(ayushd): If no subdiv_offsets have been specified, dynamically add 121 // as many offsets as needed so that the size of tensor chunks <= 122 // kMaxChunkSizeBytes. Empirically, chunks that are too small or too large 123 // lead to worse performance. 124 int num_subdivs = 0; 125 const size_t tensor_size = col_params->instance.shape.num_elements() * 126 DataTypeSize(col_params->instance.data_type); 127 size_t chunk_size; 128 do { 129 ++num_subdivs; 130 int num_chunks = col_params->group.group_size * num_subdivs; 131 chunk_size = tensor_size / num_chunks; 132 VLOG(2) << "num_subdivs " << num_subdivs << " num_chunks " << num_chunks 133 << " chunk_size " << chunk_size; 134 } while (chunk_size > kMaxChunkSizeBytes && num_subdivs < kMaxNumSubdivs); 135 if (num_subdivs <= 0) { 136 return errors::Internal("Unexpected num_subdivs ", num_subdivs, " in ", 137 col_params->instance.impl_details.collective_name); 138 } 139 140 int subdiv_stride = kAvgDevPerTask / num_subdivs; 141 if (subdiv_stride == 0) subdiv_stride = 1; 142 col_params->instance.impl_details.subdiv_offsets.reserve(num_subdivs); 143 for (int sdi = 0; sdi < num_subdivs; ++sdi) { 144 int subdiv_offset = subdiv_stride * sdi; 145 if (sdi % 2 == 1) subdiv_offset *= -1; 146 col_params->instance.impl_details.subdiv_offsets.push_back(subdiv_offset); 147 } 148 149 if (VLOG_IS_ON(2)) { 150 string subdiv_buf; 151 for (const int subdiv_offset : 152 col_params->instance.impl_details.subdiv_offsets) { 153 strings::StrAppend(&subdiv_buf, " ", subdiv_offset); 154 } 155 VLOG(2) << "Dynamically generated " << num_subdivs 156 << " subdiv_offsets:" << subdiv_buf << " tensor_size " 157 << tensor_size << " chunk_size " << chunk_size; 158 } 159 160 return Status::OK(); 161 } 162 } // namespace 163 164 Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) { 165 const string& device_name = 166 col_params->instance.device_names[col_params->default_rank]; 167 // Each subdiv permutation is a ring formed by rotating each 168 // single-task subsequence of devices by an offset. This makes most 169 // sense when each task has the same number of devices but we can't 170 // depend on that being the case so we'll compute something that 171 // works in any case. 172 173 // Start by counting the devices in each task. 174 // Precondition: device_names must be sorted so that all devices in 175 // the same task are adjacent. 176 VLOG(2) << "Sorted task names: " 177 << str_util::Join(col_params->instance.task_names, ", "); 178 std::vector<int> dev_per_task; 179 const string* prior_task_name = &col_params->instance.task_names[0]; 180 int dev_count = 1; 181 for (int di = 1; di < col_params->group.group_size; ++di) { 182 if (col_params->instance.task_names[di] != *prior_task_name) { 183 dev_per_task.push_back(dev_count); 184 dev_count = 1; 185 prior_task_name = &col_params->instance.task_names[di]; 186 } else { 187 ++dev_count; 188 } 189 } 190 dev_per_task.push_back(dev_count); 191 DCHECK_EQ(col_params->group.num_tasks, dev_per_task.size()); 192 193 if (col_params->instance.impl_details.subdiv_offsets.empty()) { 194 TF_RETURN_IF_ERROR(GenerateSubdivsInCollectiveParams(col_params)); 195 } 196 197 // Generate a ring permutation for requested offset. 198 VLOG(2) << "Setting up perms for col_params " << col_params 199 << " subdiv_permutations " 200 << &col_params->instance.impl_details.subdiv_permutations; 201 col_params->instance.impl_details.subdiv_permutations.resize( 202 col_params->instance.impl_details.subdiv_offsets.size()); 203 col_params->subdiv_rank.resize( 204 col_params->instance.impl_details.subdiv_offsets.size(), -1); 205 for (int sdi = 0; 206 sdi < col_params->instance.impl_details.subdiv_offsets.size(); ++sdi) { 207 std::vector<int>& perm = 208 col_params->instance.impl_details.subdiv_permutations[sdi]; 209 DCHECK_EQ(perm.size(), 0); 210 int offset = col_params->instance.impl_details.subdiv_offsets[sdi]; 211 // A negative subdivision offset is interpreted as follows: 212 // 1. Reverse the local device ordering. 213 // 2. Begin the subdivision at abs(offset) in the reversed ordering. 214 bool reverse = false; 215 if (offset < 0) { 216 offset = abs(offset); 217 reverse = true; 218 } 219 int prior_dev_count = 0; // sum over prior worker device counts 220 for (int ti = 0; ti < col_params->group.num_tasks; ++ti) { 221 for (int di = 0; di < dev_per_task[ti]; ++di) { 222 int di_offset = (di + offset) % dev_per_task[ti]; 223 int offset_di = 224 reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset; 225 // Device index in global subdivision permutation. 226 int permuted_di = prior_dev_count + offset_di; 227 int rank = static_cast<int>(perm.size()); 228 perm.push_back(permuted_di); 229 if (col_params->instance.device_names[permuted_di] == device_name) { 230 DCHECK_EQ(permuted_di, col_params->default_rank); 231 col_params->subdiv_rank[sdi] = rank; 232 } 233 } 234 prior_dev_count += dev_per_task[ti]; 235 } 236 DCHECK_EQ(col_params->group.group_size, perm.size()); 237 } 238 239 VLOG(2) << collective_util::SubdivPermDebugString(*col_params); 240 return Status::OK(); 241 } 242 243 Status RingAlg::InitializeCollectiveContext(CollectiveContext* col_ctx) { 244 DCHECK(col_ctx->dev_mgr); 245 col_ctx_ = col_ctx; 246 col_params_ = &col_ctx->col_params; 247 return collective_util::InitializeDeviceAndLocality( 248 col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device, 249 &col_ctx->device_locality); 250 } 251 252 string RingAlg::TensorDebugString(const Tensor& tensor) { 253 const DeviceBase::GpuDeviceInfo* gpu_device_info = 254 col_ctx_->op_ctx->device()->tensorflow_gpu_device_info(); 255 if (gpu_device_info) { 256 Tensor cpu_tensor(tensor.dtype(), tensor.shape()); 257 Notification note; 258 gpu_device_info->default_context->CopyDeviceTensorToCPU( 259 &tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor, 260 [¬e](const Status& s) { 261 DCHECK(s.ok()); 262 note.Notify(); 263 }); 264 note.WaitForNotification(); 265 return cpu_tensor.SummarizeValue(64); 266 } else { 267 return tensor.SummarizeValue(64); 268 } 269 } 270 271 void RingAlg::StartAbort(const Status& s) { 272 // In abort mode we stop issuing additional ProvideBuf 273 // and ConsumeBuf calls, but we need to wait for all of the 274 // outstanding callbacks to be invoked before quitting. 275 bool abort_started = false; 276 { 277 mutex_lock l(status_mu_); 278 if (status_.ok()) { 279 LOG(ERROR) << "Aborting Ring" << name_ << " with " << s; 280 abort_started = true; 281 status_.Update(s); 282 } 283 } 284 // If this is the initial entry to abort mode then invoke StartAbort 285 // on the CollectiveExecutor that invoked us. That should start 286 // cancellation on all of the outstanding CollectiveRemoteAccess 287 // actions. 288 if (abort_started) { 289 col_ctx_->col_exec->StartAbort(s); 290 } 291 } 292 293 void RingAlg::Finish(bool ok) { 294 if (ok) { 295 // Recover the output from the adaptor. 296 ca_->ConsumeFinalValue(col_ctx_->output); 297 } 298 Status s; 299 { 300 mutex_lock l(status_mu_); 301 s = status_; 302 } 303 rfv_.clear(); // Give up Refs on output tensor. 304 done_(s); 305 } 306 307 // At the beginning of the algorithm initialize a RingField struct for 308 // every independent field of the tensor. 309 void RingAlg::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx, 310 int field_idx) { 311 // Note on field indexing: There are group_size_ devices in the 312 // instance, implying the same number of chunks per tensor, where a 313 // chunk is the unit of data transferred in a time step. However, if 314 // a device can simultaneously send data by 2 or more independent 315 // channels we can speed up the transfer by subdividing chunks and 316 // processing multiple subdivisions at once. So the actual number 317 // of RingFields is group_size_ * num_subdivs_. 318 DCHECK_EQ(field_idx, (chunk_idx * num_subdivs_) + subdiv_idx); 319 rf->chunk_idx = chunk_idx; 320 rf->subdiv_idx = subdiv_idx; 321 rf->sc_idx = field_idx; 322 rf->rank = col_params_->subdiv_rank[subdiv_idx]; 323 rf->second_pass = false; 324 rf->action = RF_INIT; 325 // Recv from the device with preceding rank within the subdivision. 326 int recv_from_rank = (rf->rank + (group_size_ - 1)) % group_size_; 327 int send_to_rank = (rf->rank + 1) % group_size_; 328 rf->recv_dev_idx = col_params_->instance.impl_details 329 .subdiv_permutations[subdiv_idx][recv_from_rank]; 330 int send_dev_idx = col_params_->instance.impl_details 331 .subdiv_permutations[subdiv_idx][send_to_rank]; 332 rf->recv_is_remote = !col_params_->task.is_local[rf->recv_dev_idx]; 333 rf->send_is_remote = !col_params_->task.is_local[send_dev_idx]; 334 if (ca_->ChunkBytes(rf->sc_idx) > 0) { 335 // In pass 0 we skip Recv when rank = chunk_idx 336 rf->do_recv = (rf->chunk_idx != rf->rank); 337 // In pass 0 we skip Send when rank = chunk_idx-1 338 rf->do_send = 339 (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_)); 340 } 341 rf->is_final = 342 (rf->rank == ((rf->chunk_idx + (group_size_ - 1)) % group_size_)); 343 if (rf->do_send || rf->do_recv) { 344 rf->chunk = ca_->ChunkAlias(rf->sc_idx); 345 } 346 VLOG(2) << this << " InitRingField " << rf->DebugString() << " chunk " 347 << ca_->TBounds(rf->chunk); 348 } 349 350 // When a RingField transitions from first to second recompute the 351 // do_send and do_recv values. 352 void RingAlg::AdvanceToSecondPass(RingField* rf) { 353 VLOG(3) << "IncrRingField old value " << rf->DebugString(); 354 DCHECK(!rf->second_pass); 355 rf->second_pass = true; 356 rf->action = RF_INIT; 357 if (ca_->ChunkBytes(rf->sc_idx) > 0) { 358 // In pass 1 the send/no-send boundary moves down 1 place. 359 rf->do_recv = 360 (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_)); 361 rf->do_send = 362 (rf->rank != ((rf->chunk_idx + (group_size_ - 2)) % group_size_)); 363 } 364 rf->is_final = 365 (rf->rank == ((rf->chunk_idx + (group_size_ - 2)) % group_size_)); 366 VLOG(3) << "IncrRingField new value " << rf->DebugString(); 367 } 368 369 string RingAlg::RingField::DebugString() const { 370 string rv = strings::StrCat("RingField rank=", rank, " chunk_idx=", chunk_idx, 371 " subdiv=", subdiv_idx, " sc_idx=", sc_idx, 372 " action=", action); 373 strings::StrAppend(&rv, " pass=", second_pass); 374 strings::StrAppend(&rv, " do_send=", do_send, " do_recv=", do_recv, 375 " is_final=", is_final, " recv_is_remote=", recv_is_remote, 376 " recv_dev_idx=", recv_dev_idx, " sc_idx=", sc_idx); 377 return rv; 378 } 379 380 void RingAlg::DispatchSend(RingField* rf, const StatusCallback& done) { 381 DCHECK(rf->do_send); 382 string send_buf_key = RingAlgBufKey(name_, col_ctx_->exec_key, 383 rf->second_pass, rf->sc_idx, rf->rank); 384 VLOG(3) << "DispatchSend rank=" << col_params_->default_rank << " send key " 385 << send_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " sc_idx " 386 << rf->sc_idx; 387 int send_to_rank = (rf->rank + 1) % group_size_; 388 int send_to_dev_idx = col_params_->instance.impl_details 389 .subdiv_permutations[rf->subdiv_idx][send_to_rank]; 390 col_ctx_->col_exec->PostToPeer( 391 col_params_->instance.device_names[send_to_dev_idx], 392 col_params_->instance.task_names[send_to_dev_idx], send_buf_key, 393 col_ctx_->device, col_ctx_->op_ctx->op_device_context(), 394 col_ctx_->op_ctx->output_alloc_attr(0), &rf->chunk, 395 col_ctx_->device_locality, done); 396 } 397 398 void RingAlg::DispatchRecv(RingField* rf, const StatusCallback& done) { 399 DCHECK(rf->do_recv); 400 string recv_buf_key = 401 RingAlgBufKey(name_, col_ctx_->exec_key, rf->second_pass, rf->sc_idx, 402 (rf->rank + (group_size_ - 1)) % group_size_); 403 VLOG(3) << "DispatchRecv rank=" << col_params_->default_rank << " recv key " 404 << recv_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " into " 405 << ((col_params_->merge_op != nullptr) ? "tmp_chunk" : "chunk"); 406 Tensor* dst_tensor = (!rf->second_pass && (col_params_->merge_op != nullptr)) 407 ? &rf->tmp_chunk 408 : &rf->chunk; 409 col_ctx_->col_exec->RecvFromPeer( 410 col_params_->instance.device_names[rf->recv_dev_idx], 411 col_params_->instance.task_names[rf->recv_dev_idx], 412 col_params_->task.is_local[rf->recv_dev_idx], recv_buf_key, 413 col_ctx_->device, col_ctx_->op_ctx->op_device_context(), 414 col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor, 415 col_ctx_->device_locality, rf->subdiv_idx, done); 416 } 417 418 string RingAlg::FieldState() { 419 string s = strings::StrCat( 420 "Ring", name_, " ", strings::Hex(reinterpret_cast<uint64>(this)), 421 " exec ", col_ctx_->exec_key, " step_id=", col_ctx_->step_id, 422 " state of all ", rfv_.size(), " fields:"); 423 for (int i = 0; i < rfv_.size(); ++i) { 424 s.append("\n"); 425 s.append(rfv_[i].DebugString()); 426 } 427 return s; 428 } 429 430 } // namespace tensorflow 431