Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/common_runtime/ring_alg.h"
     16 
     17 #include <stdlib.h>
     18 #include <atomic>
     19 #include <functional>
     20 #include <utility>
     21 
     22 #include "tensorflow/core/common_runtime/collective_rma_local.h"
     23 #include "tensorflow/core/common_runtime/collective_util.h"
     24 #include "tensorflow/core/common_runtime/copy_tensor.h"
     25 #include "tensorflow/core/common_runtime/device.h"
     26 #include "tensorflow/core/common_runtime/device_mgr.h"
     27 #include "tensorflow/core/common_runtime/dma_helper.h"
     28 #include "tensorflow/core/common_runtime/process_util.h"
     29 #include "tensorflow/core/framework/allocator.h"
     30 #include "tensorflow/core/framework/device_base.h"
     31 #include "tensorflow/core/framework/op_kernel.h"
     32 #include "tensorflow/core/framework/tensor.h"
     33 #include "tensorflow/core/framework/types.h"
     34 #include "tensorflow/core/lib/core/errors.h"
     35 #include "tensorflow/core/lib/core/notification.h"
     36 #include "tensorflow/core/lib/core/status.h"
     37 #include "tensorflow/core/lib/strings/str_util.h"
     38 #include "tensorflow/core/lib/strings/strcat.h"
     39 #include "tensorflow/core/platform/env.h"
     40 #include "tensorflow/core/platform/types.h"
     41 
     42 // Set true for greater intelligibility of debug mode log messages.
     43 #define READABLE_KEYS false
     44 // A ring algorithm exchanges chunks of tensor between devices.  The chunk size
     45 // depends on the number of subdivisions specified in the algorithm.  If the
     46 // user does not specify the number of subdivisions we may infer the number
     47 // dynamically so that the resulting chunk size does not exceed
     48 // kMaxChunkSizeBytes, empirically set at 4 MiB.
     49 constexpr size_t kMaxChunkSizeBytes = (4 * 1024 * 1024);
     50 // kMaxSubdivsPerDev is used to give an upper bound on the number of
     51 // subdivisions dynamically generated.  A reasonable value would be a small
     52 // multiple of the number of NICs adjacent to each device.
     53 constexpr int kMaxSubdivsPerDevice = 2;
     54 
     55 namespace tensorflow {
     56 namespace {
     57 // Each CollectiveOp implementation is free to define its own
     58 // BufRendezvous key format.  This function produces the key used by
     59 // RingAlg instances.  Note that the exec_key will differentiate between
     60 // different instances consequently we don't need to further differentiate
     61 // between subclasses of RingAlg.
     62 string RingAlgBufKey(const string& name, const string& exec_key, int pass,
     63                      int section, int source_rank) {
     64   if (READABLE_KEYS) {
     65     return strings::StrCat(name, "(", exec_key, "):pass(", pass, "):section(",
     66                            section, "):srcrank(", source_rank, ")");
     67   } else {
     68     // TODO(b/78352018): Try out some kind of denser encoding, e.g. 128 bit
     69     // hash.
     70     return strings::StrCat(exec_key, ":", pass, ":", section, ":", source_rank);
     71   }
     72 }
     73 
     74 }  // namespace
     75 
     76 void RingAlg::PCQueue::Enqueue(RingField* rf) {
     77   mutex_lock l(pcq_mu_);
     78   deque_.push_back(rf);
     79   if (waiter_count_ > 0) {
     80     cv_.notify_one();
     81   }
     82 }
     83 
     84 RingAlg::RingField* RingAlg::PCQueue::Dequeue() {
     85   mutex_lock l(pcq_mu_);
     86   if (deque_.empty()) {
     87     ++waiter_count_;
     88     while (deque_.empty()) {
     89       cv_.wait(l);
     90     }
     91     --waiter_count_;
     92   }
     93   RingField* rf = deque_.front();
     94   deque_.pop_front();
     95   return rf;
     96 }
     97 
     98 RingAlg::RingAlg(CollectiveType type, const string& name)
     99     : type_(type),
    100       name_(name),
    101       col_ctx_(nullptr),
    102       col_params_(nullptr),
    103       done_(nullptr),
    104       group_size_(-1),
    105       num_subdivs_(-1) {}
    106 
    107 namespace {
    108 Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) {
    109   if (col_params->instance.shape.num_elements() == 0) {
    110     return errors::Internal("shape in CollectiveParams should be non-empty");
    111   }
    112   const int kAvgDevPerTask =
    113       col_params->group.group_size / col_params->group.num_tasks;
    114   const int kMaxNumSubdivs = kMaxSubdivsPerDevice * kAvgDevPerTask;
    115   if (kMaxNumSubdivs <= 0) {
    116     return errors::Internal("Unexpected kMaxNumSubdivs ", kMaxNumSubdivs,
    117                             " in ",
    118                             col_params->instance.impl_details.collective_name);
    119   }
    120   // NOTE(ayushd): If no subdiv_offsets have been specified, dynamically add
    121   // as many offsets as needed so that the size of tensor chunks <=
    122   // kMaxChunkSizeBytes.  Empirically, chunks that are too small or too large
    123   // lead to worse performance.
    124   int num_subdivs = 0;
    125   const size_t tensor_size = col_params->instance.shape.num_elements() *
    126                              DataTypeSize(col_params->instance.data_type);
    127   size_t chunk_size;
    128   do {
    129     ++num_subdivs;
    130     int num_chunks = col_params->group.group_size * num_subdivs;
    131     chunk_size = tensor_size / num_chunks;
    132     VLOG(2) << "num_subdivs " << num_subdivs << " num_chunks " << num_chunks
    133             << " chunk_size " << chunk_size;
    134   } while (chunk_size > kMaxChunkSizeBytes && num_subdivs < kMaxNumSubdivs);
    135   if (num_subdivs <= 0) {
    136     return errors::Internal("Unexpected num_subdivs ", num_subdivs, " in ",
    137                             col_params->instance.impl_details.collective_name);
    138   }
    139 
    140   int subdiv_stride = kAvgDevPerTask / num_subdivs;
    141   if (subdiv_stride == 0) subdiv_stride = 1;
    142   col_params->instance.impl_details.subdiv_offsets.reserve(num_subdivs);
    143   for (int sdi = 0; sdi < num_subdivs; ++sdi) {
    144     int subdiv_offset = subdiv_stride * sdi;
    145     if (sdi % 2 == 1) subdiv_offset *= -1;
    146     col_params->instance.impl_details.subdiv_offsets.push_back(subdiv_offset);
    147   }
    148 
    149   if (VLOG_IS_ON(2)) {
    150     string subdiv_buf;
    151     for (const int subdiv_offset :
    152          col_params->instance.impl_details.subdiv_offsets) {
    153       strings::StrAppend(&subdiv_buf, " ", subdiv_offset);
    154     }
    155     VLOG(2) << "Dynamically generated " << num_subdivs
    156             << " subdiv_offsets:" << subdiv_buf << " tensor_size "
    157             << tensor_size << " chunk_size " << chunk_size;
    158   }
    159 
    160   return Status::OK();
    161 }
    162 }  // namespace
    163 
    164 Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) {
    165   const string& device_name =
    166       col_params->instance.device_names[col_params->default_rank];
    167   // Each subdiv permutation is a ring formed by rotating each
    168   // single-task subsequence of devices by an offset.  This makes most
    169   // sense when each task has the same number of devices but we can't
    170   // depend on that being the case so we'll compute something that
    171   // works in any case.
    172 
    173   // Start by counting the devices in each task.
    174   // Precondition: device_names must be sorted so that all devices in
    175   // the same task are adjacent.
    176   VLOG(2) << "Sorted task names: "
    177           << str_util::Join(col_params->instance.task_names, ", ");
    178   std::vector<int> dev_per_task;
    179   const string* prior_task_name = &col_params->instance.task_names[0];
    180   int dev_count = 1;
    181   for (int di = 1; di < col_params->group.group_size; ++di) {
    182     if (col_params->instance.task_names[di] != *prior_task_name) {
    183       dev_per_task.push_back(dev_count);
    184       dev_count = 1;
    185       prior_task_name = &col_params->instance.task_names[di];
    186     } else {
    187       ++dev_count;
    188     }
    189   }
    190   dev_per_task.push_back(dev_count);
    191   DCHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
    192 
    193   if (col_params->instance.impl_details.subdiv_offsets.empty()) {
    194     TF_RETURN_IF_ERROR(GenerateSubdivsInCollectiveParams(col_params));
    195   }
    196 
    197   // Generate a ring permutation for requested offset.
    198   VLOG(2) << "Setting up perms for col_params " << col_params
    199           << " subdiv_permutations "
    200           << &col_params->instance.impl_details.subdiv_permutations;
    201   col_params->instance.impl_details.subdiv_permutations.resize(
    202       col_params->instance.impl_details.subdiv_offsets.size());
    203   col_params->subdiv_rank.resize(
    204       col_params->instance.impl_details.subdiv_offsets.size(), -1);
    205   for (int sdi = 0;
    206        sdi < col_params->instance.impl_details.subdiv_offsets.size(); ++sdi) {
    207     std::vector<int>& perm =
    208         col_params->instance.impl_details.subdiv_permutations[sdi];
    209     DCHECK_EQ(perm.size(), 0);
    210     int offset = col_params->instance.impl_details.subdiv_offsets[sdi];
    211     // A negative subdivision offset is interpreted as follows:
    212     //  1. Reverse the local device ordering.
    213     //  2. Begin the subdivision at abs(offset) in the reversed ordering.
    214     bool reverse = false;
    215     if (offset < 0) {
    216       offset = abs(offset);
    217       reverse = true;
    218     }
    219     int prior_dev_count = 0;  // sum over prior worker device counts
    220     for (int ti = 0; ti < col_params->group.num_tasks; ++ti) {
    221       for (int di = 0; di < dev_per_task[ti]; ++di) {
    222         int di_offset = (di + offset) % dev_per_task[ti];
    223         int offset_di =
    224             reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset;
    225         // Device index in global subdivision permutation.
    226         int permuted_di = prior_dev_count + offset_di;
    227         int rank = static_cast<int>(perm.size());
    228         perm.push_back(permuted_di);
    229         if (col_params->instance.device_names[permuted_di] == device_name) {
    230           DCHECK_EQ(permuted_di, col_params->default_rank);
    231           col_params->subdiv_rank[sdi] = rank;
    232         }
    233       }
    234       prior_dev_count += dev_per_task[ti];
    235     }
    236     DCHECK_EQ(col_params->group.group_size, perm.size());
    237   }
    238 
    239   VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
    240   return Status::OK();
    241 }
    242 
    243 Status RingAlg::InitializeCollectiveContext(CollectiveContext* col_ctx) {
    244   DCHECK(col_ctx->dev_mgr);
    245   col_ctx_ = col_ctx;
    246   col_params_ = &col_ctx->col_params;
    247   return collective_util::InitializeDeviceAndLocality(
    248       col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
    249       &col_ctx->device_locality);
    250 }
    251 
    252 string RingAlg::TensorDebugString(const Tensor& tensor) {
    253   const DeviceBase::GpuDeviceInfo* gpu_device_info =
    254       col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
    255   if (gpu_device_info) {
    256     Tensor cpu_tensor(tensor.dtype(), tensor.shape());
    257     Notification note;
    258     gpu_device_info->default_context->CopyDeviceTensorToCPU(
    259         &tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor,
    260         [&note](const Status& s) {
    261           DCHECK(s.ok());
    262           note.Notify();
    263         });
    264     note.WaitForNotification();
    265     return cpu_tensor.SummarizeValue(64);
    266   } else {
    267     return tensor.SummarizeValue(64);
    268   }
    269 }
    270 
    271 void RingAlg::StartAbort(const Status& s) {
    272   // In abort mode we stop issuing additional ProvideBuf
    273   // and ConsumeBuf calls, but we need to wait for all of the
    274   // outstanding callbacks to be invoked before quitting.
    275   bool abort_started = false;
    276   {
    277     mutex_lock l(status_mu_);
    278     if (status_.ok()) {
    279       LOG(ERROR) << "Aborting Ring" << name_ << " with " << s;
    280       abort_started = true;
    281       status_.Update(s);
    282     }
    283   }
    284   // If this is the initial entry to abort mode then invoke StartAbort
    285   // on the CollectiveExecutor that invoked us.  That should start
    286   // cancellation on all of the outstanding CollectiveRemoteAccess
    287   // actions.
    288   if (abort_started) {
    289     col_ctx_->col_exec->StartAbort(s);
    290   }
    291 }
    292 
    293 void RingAlg::Finish(bool ok) {
    294   if (ok) {
    295     // Recover the output from the adaptor.
    296     ca_->ConsumeFinalValue(col_ctx_->output);
    297   }
    298   Status s;
    299   {
    300     mutex_lock l(status_mu_);
    301     s = status_;
    302   }
    303   rfv_.clear();  // Give up Refs on output tensor.
    304   done_(s);
    305 }
    306 
    307 // At the beginning of the algorithm initialize a RingField struct for
    308 // every independent field of the tensor.
    309 void RingAlg::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx,
    310                             int field_idx) {
    311   // Note on field indexing: There are group_size_ devices in the
    312   // instance, implying the same number of chunks per tensor, where a
    313   // chunk is the unit of data transferred in a time step.  However, if
    314   // a device can simultaneously send data by 2 or more independent
    315   // channels we can speed up the transfer by subdividing chunks and
    316   // processing multiple subdivisions at once.  So the actual number
    317   // of RingFields is group_size_ * num_subdivs_.
    318   DCHECK_EQ(field_idx, (chunk_idx * num_subdivs_) + subdiv_idx);
    319   rf->chunk_idx = chunk_idx;
    320   rf->subdiv_idx = subdiv_idx;
    321   rf->sc_idx = field_idx;
    322   rf->rank = col_params_->subdiv_rank[subdiv_idx];
    323   rf->second_pass = false;
    324   rf->action = RF_INIT;
    325   // Recv from the device with preceding rank within the subdivision.
    326   int recv_from_rank = (rf->rank + (group_size_ - 1)) % group_size_;
    327   int send_to_rank = (rf->rank + 1) % group_size_;
    328   rf->recv_dev_idx = col_params_->instance.impl_details
    329                          .subdiv_permutations[subdiv_idx][recv_from_rank];
    330   int send_dev_idx = col_params_->instance.impl_details
    331                          .subdiv_permutations[subdiv_idx][send_to_rank];
    332   rf->recv_is_remote = !col_params_->task.is_local[rf->recv_dev_idx];
    333   rf->send_is_remote = !col_params_->task.is_local[send_dev_idx];
    334   if (ca_->ChunkBytes(rf->sc_idx) > 0) {
    335     // In pass 0 we skip Recv when rank = chunk_idx
    336     rf->do_recv = (rf->chunk_idx != rf->rank);
    337     // In pass 0 we skip Send when rank = chunk_idx-1
    338     rf->do_send =
    339         (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
    340   }
    341   rf->is_final =
    342       (rf->rank == ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
    343   if (rf->do_send || rf->do_recv) {
    344     rf->chunk = ca_->ChunkAlias(rf->sc_idx);
    345   }
    346   VLOG(2) << this << " InitRingField " << rf->DebugString() << " chunk "
    347           << ca_->TBounds(rf->chunk);
    348 }
    349 
    350 // When a RingField transitions from first to second recompute the
    351 // do_send and do_recv values.
    352 void RingAlg::AdvanceToSecondPass(RingField* rf) {
    353   VLOG(3) << "IncrRingField old value " << rf->DebugString();
    354   DCHECK(!rf->second_pass);
    355   rf->second_pass = true;
    356   rf->action = RF_INIT;
    357   if (ca_->ChunkBytes(rf->sc_idx) > 0) {
    358     // In pass 1 the send/no-send boundary moves down 1 place.
    359     rf->do_recv =
    360         (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
    361     rf->do_send =
    362         (rf->rank != ((rf->chunk_idx + (group_size_ - 2)) % group_size_));
    363   }
    364   rf->is_final =
    365       (rf->rank == ((rf->chunk_idx + (group_size_ - 2)) % group_size_));
    366   VLOG(3) << "IncrRingField new value " << rf->DebugString();
    367 }
    368 
    369 string RingAlg::RingField::DebugString() const {
    370   string rv = strings::StrCat("RingField rank=", rank, " chunk_idx=", chunk_idx,
    371                               " subdiv=", subdiv_idx, " sc_idx=", sc_idx,
    372                               " action=", action);
    373   strings::StrAppend(&rv, " pass=", second_pass);
    374   strings::StrAppend(&rv, " do_send=", do_send, " do_recv=", do_recv,
    375                      " is_final=", is_final, " recv_is_remote=", recv_is_remote,
    376                      " recv_dev_idx=", recv_dev_idx, " sc_idx=", sc_idx);
    377   return rv;
    378 }
    379 
    380 void RingAlg::DispatchSend(RingField* rf, const StatusCallback& done) {
    381   DCHECK(rf->do_send);
    382   string send_buf_key = RingAlgBufKey(name_, col_ctx_->exec_key,
    383                                       rf->second_pass, rf->sc_idx, rf->rank);
    384   VLOG(3) << "DispatchSend rank=" << col_params_->default_rank << " send key "
    385           << send_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " sc_idx "
    386           << rf->sc_idx;
    387   int send_to_rank = (rf->rank + 1) % group_size_;
    388   int send_to_dev_idx = col_params_->instance.impl_details
    389                             .subdiv_permutations[rf->subdiv_idx][send_to_rank];
    390   col_ctx_->col_exec->PostToPeer(
    391       col_params_->instance.device_names[send_to_dev_idx],
    392       col_params_->instance.task_names[send_to_dev_idx], send_buf_key,
    393       col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
    394       col_ctx_->op_ctx->output_alloc_attr(0), &rf->chunk,
    395       col_ctx_->device_locality, done);
    396 }
    397 
    398 void RingAlg::DispatchRecv(RingField* rf, const StatusCallback& done) {
    399   DCHECK(rf->do_recv);
    400   string recv_buf_key =
    401       RingAlgBufKey(name_, col_ctx_->exec_key, rf->second_pass, rf->sc_idx,
    402                     (rf->rank + (group_size_ - 1)) % group_size_);
    403   VLOG(3) << "DispatchRecv rank=" << col_params_->default_rank << " recv key "
    404           << recv_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " into "
    405           << ((col_params_->merge_op != nullptr) ? "tmp_chunk" : "chunk");
    406   Tensor* dst_tensor = (!rf->second_pass && (col_params_->merge_op != nullptr))
    407                            ? &rf->tmp_chunk
    408                            : &rf->chunk;
    409   col_ctx_->col_exec->RecvFromPeer(
    410       col_params_->instance.device_names[rf->recv_dev_idx],
    411       col_params_->instance.task_names[rf->recv_dev_idx],
    412       col_params_->task.is_local[rf->recv_dev_idx], recv_buf_key,
    413       col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
    414       col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
    415       col_ctx_->device_locality, rf->subdiv_idx, done);
    416 }
    417 
    418 string RingAlg::FieldState() {
    419   string s = strings::StrCat(
    420       "Ring", name_, " ", strings::Hex(reinterpret_cast<uint64>(this)),
    421       " exec ", col_ctx_->exec_key, " step_id=", col_ctx_->step_id,
    422       " state of all ", rfv_.size(), " fields:");
    423   for (int i = 0; i < rfv_.size(); ++i) {
    424     s.append("\n");
    425     s.append(rfv_[i].DebugString());
    426   }
    427   return s;
    428 }
    429 
    430 }  // namespace tensorflow
    431