Home | History | Annotate | Download | only in nccl
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/nccl/nccl_manager.h"
     16 
     17 #include <utility>
     18 
     19 #ifdef GOOGLE_CUDA
     20 
     21 #include "tensorflow/core/lib/core/threadpool.h"
     22 #include "tensorflow/core/platform/cuda.h"
     23 #include "tensorflow/core/platform/env.h"
     24 
     25 namespace tensorflow {
     26 
     27 #define NCCL_RETURN_IF_ERROR(...)                               \
     28   do {                                                          \
     29     ncclResult_t nccl_status = (__VA_ARGS__);                   \
     30     if (nccl_status != ncclSuccess) {                           \
     31       return errors::Internal(ncclGetErrorString(nccl_status)); \
     32     }                                                           \
     33   } while (0)
     34 
     35 #define CUDA_RETURN_IF_ERROR(...)                               \
     36   do {                                                          \
     37     cudaError_t cuda_status = (__VA_ARGS__);                    \
     38     if (cuda_status != cudaSuccess) {                           \
     39       return errors::Internal(cudaGetErrorString(cuda_status)); \
     40     }                                                           \
     41   } while (0)
     42 
     43 using se::cuda::ScopedActivateExecutorContext;
     44 
     45 // Contains data for a single stream used for nccl communication; this includes
     46 // a background thread that calls NcclManager::LoopKernelLaunches.
     47 struct NcclManager::NcclStream {
     48  public:
     49   NcclStream() {}
     50   ~NcclStream() {
     51     mutex_lock l(mu);
     52     shutdown_requested = true;
     53     cv.notify_all();
     54   }
     55 
     56   se::StreamExecutor* executor = nullptr;
     57 
     58   // The stream on which to run the nccl collective.
     59   // This is a different stream than the tensorflow compute stream.
     60   std::unique_ptr<se::Stream> stream;
     61 
     62   // See NcclManager::LoopKernelLaunches for information on these.
     63   std::unique_ptr<Thread> thread;
     64   mutex mu;
     65   condition_variable cv;
     66   // Has collective,participant_idx pairs.
     67   std::deque<std::pair<Collective*, int>> pending_launches_ GUARDED_BY(mu);
     68   bool shutdown_requested GUARDED_BY(mu) = false;
     69 };
     70 
     71 struct NcclManager::CommunicatorMember {
     72  public:
     73   CommunicatorMember() {}
     74   ~CommunicatorMember() {
     75     if (nccl_comm != nullptr) ncclCommDestroy(nccl_comm);
     76   }
     77   ncclComm_t nccl_comm;
     78 
     79   // Owned by NcclManager::device_to_comm_streams_.
     80   NcclStream* nccl_stream = nullptr;
     81 };
     82 
     83 struct NcclManager::Communicator {
     84  public:
     85   explicit Communicator(std::vector<CommunicatorMember> members,
     86                         const string& key)
     87       : num_devices(members.size()), members(std::move(members)), key(key) {}
     88 
     89   const int num_devices;
     90   const std::vector<CommunicatorMember> members;
     91   const string key;
     92 };
     93 
     94 namespace {
     95 
     96 ncclDataType_t ToNcclType(DataType t) {
     97   switch (t) {
     98     case DT_HALF:
     99       return ncclHalf;
    100     case DT_FLOAT:
    101       return ncclFloat;
    102     case DT_DOUBLE:
    103       return ncclDouble;
    104     case DT_INT32:
    105       return ncclInt;
    106     case DT_INT64:
    107       return ncclInt64;
    108     default:
    109       return ncclFloat;
    110   }
    111 }
    112 
    113 void StringToNcclUniqueId(const string& str_id, ncclUniqueId* nccl_id) {
    114   if (str_id.size() == NCCL_UNIQUE_ID_BYTES) {
    115     memcpy(nccl_id->internal, str_id.data(), NCCL_UNIQUE_ID_BYTES);
    116   }
    117 }
    118 
    119 }  // namespace
    120 
    121 // A `Collective` encapsulates state for a collective instance at one node.
    122 // Typically, an instance in TensorFlow context would be defined by a collective
    123 // group and the (step, frame iteration) for that execution.
    124 //
    125 // For each collective instance there will be one `Collective` object per node.
    126 // For example,  a NCCL collective that runs on a single node with 4 GPUs would
    127 // have a single `Collective` per step.  However, a collective that executes on
    128 // 3 nodes with 4 GPUs each would have a `Collective` per node, each of which is
    129 // tracking the 4 GPUs local to that node.
    130 struct NcclManager::Collective {
    131   Collective(DataType data_type_in, CollectiveType type_in,
    132              ncclRedOp_t reduction_op_in, int num_local_devices_in,
    133              int num_global_devices_in, const string& communicator_key_in)
    134       : data_type(data_type_in),
    135         type(type_in),
    136         reduction_op(reduction_op_in),
    137         num_local_devices(num_local_devices_in),
    138         num_global_devices(num_global_devices_in),
    139         single_node(num_local_devices_in == num_global_devices_in),
    140         communicator_key(communicator_key_in),
    141         remaining_participants(num_local_devices_in) {
    142     participants.reserve(num_local_devices_in);
    143   }
    144 
    145   const DataType data_type;
    146   const CollectiveType type;
    147   const ncclRedOp_t reduction_op;  // applies when <type> is a reduction.
    148   const int num_local_devices;     // devices local to this node
    149   const int num_global_devices;    // devices across all nodes
    150   const bool single_node;          // true if all devices are at one node
    151   const string communicator_key;
    152 
    153   Communicator* communicator = nullptr;
    154 
    155   // All collective participants.
    156   //
    157   // Adding values in this vector is guarded by the mutex of the containing
    158   // NcclManager.
    159   std::vector<std::unique_ptr<Participant>> participants;
    160 
    161   // For collective types that have a root (e.g. the root of broadcast is the
    162   // sender), this is the rank of the root.
    163   int root_rank = -1;
    164 
    165   // How many participants have been registered so far. The Collective is
    166   // eligible for running with <available_participants> == num_local_devices.
    167   //
    168   // If this is a multi-node collective, we additionally have to synchronize
    169   // across nodes.  The caller would need to signal multi node readiness by
    170   // calling NcclManager::SignalMultiNodeReady, which sets `multi_node_ready` to
    171   // true.
    172   //
    173   // Guarded by the mutex of the containing Communicator.
    174   int available_participants = 0;
    175   bool multi_node_ready = false;
    176 
    177   mutable std::atomic_int_fast32_t remaining_participants;
    178 
    179   Status status;
    180 };
    181 
    182 NcclManager::NcclManager() {}
    183 NcclManager::~NcclManager() {}
    184 NcclManager* NcclManager::instance() {
    185   static NcclManager* instance = new NcclManager();
    186   return instance;
    187 }
    188 
    189 string NcclManager::GenerateCommunicatorKey() {
    190   ncclUniqueId nccl_id;
    191   ncclGetUniqueId(&nccl_id);
    192   return string(nccl_id.internal, NCCL_UNIQUE_ID_BYTES);
    193 }
    194 
    195 Status NcclManager::GetCommunicator(NcclManager::Collective* collective,
    196                                     NcclManager::Communicator** communicator) {
    197   // Sort by executor to make ordering of executors deterministic.
    198   std::sort(collective->participants.begin(), collective->participants.end(),
    199             [](const std::unique_ptr<Participant>& a,
    200                const std::unique_ptr<Participant>& b) {
    201               return a->executor < b->executor;
    202             });
    203 
    204   mutex_lock l(mu_);
    205 
    206   if (collective->single_node) {
    207     // For single-node collectives, we identify a communicator uniquely by the
    208     // set of devices participating in the collective.  For example, if a
    209     // collective is for GPUs 0, 1, and 2 then this will scan to find the
    210     // communicator for GPUs 0, 1, and 2.
    211     //
    212     // Note that each executor identifies a context on one device, so this is
    213     // the same as getting the communicator connecting the devices in the
    214     // collective. A device can be in different communicators as well - for
    215     // example, a communicator for GPUs 0 and 1 is separate from one for GPUs 0,
    216     // 1, and 2.
    217     //
    218     // Since it's expected that a small number of distinct communicators will
    219     // be needed, communicators_ is not garbage collected currently.
    220     //
    221     // Launching of kernels must be serialized so that, given collectives A and
    222     // B, and an order of them (e.g., A before B), then for each comm_stream
    223     // involved, the kernel for A is launched before the kernel for B. This is
    224     // guaranteed currently be a global mutex controlling additions of the
    225     // kernels to per-stream launch queues.  The launch queues are processed by
    226     // LoopKernelLaunches.
    227     for (auto& comm : communicators_) {
    228       if (comm->num_devices == collective->num_global_devices) {
    229         int i;
    230         for (i = 0; i < collective->num_local_devices; ++i) {
    231           if (comm->members[i].nccl_stream->executor !=
    232               collective->participants[i]->executor) {
    233             break;
    234           }
    235         }
    236         if (i == collective->num_local_devices) {
    237           *communicator = comm.get();
    238           return Status::OK();
    239         }
    240       }
    241     }
    242   } else {
    243 #if NCCL_MAJOR < 2
    244     return errors::Internal(
    245         "Cannot use multi-node NCCL collectives with NCCL 1.x");
    246 #endif
    247     if (collective->communicator_key.size() != NCCL_UNIQUE_ID_BYTES) {
    248       return errors::Internal("Expected communicator_key of size ",
    249                               NCCL_UNIQUE_ID_BYTES, " but found size ",
    250                               collective->communicator_key.size());
    251     }
    252     // This is an instance of multi-node collective.  We have previously
    253     // created a NCCL unique id and shared with all workers.  Now we find the
    254     // `Communicator` corresponding to this id.
    255     for (auto& comm : communicators_) {
    256       if (comm->key == collective->communicator_key) {
    257         *communicator = comm.get();
    258         return Status::OK();
    259       }
    260     }
    261   }
    262 
    263   auto* env = Env::Default();
    264   std::set<NcclStream*> used_streams;
    265 
    266   // Create and initialize a new communicator.
    267   // Note that this is done under the lock; performance is not expected to
    268   // matter as this happens a very small number of times.
    269   std::vector<CommunicatorMember> members(collective->num_local_devices);
    270   std::vector<int> devices(collective->num_local_devices);
    271   for (int i = 0; i < collective->num_local_devices; ++i) {
    272     auto* executor = collective->participants[i]->executor;
    273 
    274     // Find a communication stream to use for the device.
    275     auto& streams = device_to_comm_streams_[executor];
    276     NcclStream* nccl_stream = nullptr;
    277     for (const auto& s : streams) {
    278       if (used_streams.insert(s.get()).second) {
    279         nccl_stream = s.get();
    280         break;
    281       }
    282     }
    283     if (nccl_stream == nullptr) {
    284       nccl_stream = new NcclStream();
    285       nccl_stream->executor = executor;
    286       nccl_stream->stream.reset(new se::Stream(executor));
    287       nccl_stream->stream->Init();
    288 
    289       streams.emplace_back(nccl_stream);
    290       used_streams.insert(nccl_stream);
    291 
    292       nccl_stream->thread.reset(env->StartThread(
    293           ThreadOptions(), "nccl_kernel_launch",
    294           [this, nccl_stream] { LoopKernelLaunches(nccl_stream); }));
    295     }
    296 
    297     members[i].nccl_stream = nccl_stream;
    298     devices[i] = collective->participants[i]->gpu_device_id;
    299   }
    300 
    301   std::vector<ncclComm_t> nccl_comms(collective->num_local_devices);
    302 #if NCCL_MAJOR >= 2
    303   // For NCCL 2, we always initialize using ncclCommInitRank guarded by NCCL
    304   // group primitives.
    305   ncclUniqueId nccl_id;
    306   if (collective->single_node) {
    307     NCCL_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id));
    308   } else {
    309     StringToNcclUniqueId(collective->communicator_key, &nccl_id);
    310   }
    311   int saved_device = 0;
    312   CUDA_RETURN_IF_ERROR(cudaGetDevice(&saved_device));
    313   NCCL_RETURN_IF_ERROR(ncclGroupStart());
    314   for (int i = 0; i < collective->num_local_devices; ++i) {
    315     // Set rank to `participant->global_rank` if provided, else `i`.
    316     const int rank = collective->participants[i]->global_rank >= 0
    317                          ? collective->participants[i]->global_rank
    318                          : i;
    319     CUDA_RETURN_IF_ERROR(cudaSetDevice(devices[i]));
    320     NCCL_RETURN_IF_ERROR(ncclCommInitRank(
    321         nccl_comms.data() + i, collective->num_global_devices, nccl_id, rank));
    322   }
    323   NCCL_RETURN_IF_ERROR(ncclGroupEnd());
    324   CUDA_RETURN_IF_ERROR(cudaSetDevice(saved_device));
    325 #else
    326   // Since NCCL 1 is single node only, we use ncclCommInitAll.  We could have
    327   // used ncclCommInitRank with NCCL 1 as well, but then we would have to
    328   // issue each init call from a different thread
    329   // (https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/nccl1.html).
    330   NCCL_RETURN_IF_ERROR(ncclCommInitAll(
    331       nccl_comms.data(), collective->num_local_devices, devices.data()));
    332 #endif
    333 
    334   for (int i = 0; i < collective->num_local_devices; ++i) {
    335     members[i].nccl_comm = nccl_comms[i];
    336   }
    337   communicators_.emplace_back(
    338       new Communicator(std::move(members), collective->communicator_key));
    339   *communicator = communicators_.back().get();
    340   return Status::OK();
    341 }
    342 
    343 void NcclManager::AddToAllReduce(std::unique_ptr<Participant> participant,
    344                                  const Context& context,
    345                                  ncclRedOp_t reduction_op) {
    346   AddParticipant(std::move(participant), context, kAllReduce, reduction_op);
    347 }
    348 
    349 void NcclManager::AddToAllGather(std::unique_ptr<Participant> participant,
    350                                  const Context& context) {
    351   AddParticipant(std::move(participant), context, kAllGather,
    352                  ncclSum /* unused */);
    353 }
    354 
    355 void NcclManager::AddBroadcastSend(std::unique_ptr<Participant> participant,
    356                                    const Context& context) {
    357   participant->root = true;
    358   AddParticipant(std::move(participant), context, kBroadcast,
    359                  ncclSum /* unused */);
    360 }
    361 
    362 void NcclManager::AddBroadcastRecv(std::unique_ptr<Participant> participant,
    363                                    const Context& context) {
    364   AddParticipant(std::move(participant), context, kBroadcast,
    365                  ncclSum /* unused */);
    366 }
    367 
    368 void NcclManager::AddReduceSend(std::unique_ptr<Participant> participant,
    369                                 const Context& context,
    370                                 ncclRedOp_t reduction_op) {
    371   AddParticipant(std::move(participant), context, kReduce, reduction_op);
    372 }
    373 
    374 void NcclManager::AddReduceRecv(std::unique_ptr<Participant> participant,
    375                                 const Context& context,
    376                                 ncclRedOp_t reduction_op) {
    377   AddParticipant(std::move(participant), context, kReduce, reduction_op);
    378 }
    379 
    380 void NcclManager::SignalMultiNodeReady(const string& collective_key) {
    381   Collective* to_run = nullptr;
    382   {
    383     mutex_lock l(mu_);
    384     auto collective_it = collectives_.find(collective_key);
    385     if (collective_it != collectives_.end()) {
    386       Collective* collective = collective_it->second.get();
    387       collective->multi_node_ready = true;
    388       to_run = CheckReady(collective_key, collective);
    389     }
    390   }
    391 
    392   if (to_run != nullptr) RunCollective(to_run);
    393 }
    394 
    395 void NcclManager::AddParticipant(std::unique_ptr<Participant> participant,
    396                                  const Context& context,
    397                                  CollectiveType collective_type,
    398                                  ncclRedOp_t reduction_op) {
    399   Collective* to_run = nullptr;
    400   const DataType data_type = participant->input->dtype();
    401   {
    402     mutex_lock l(mu_);
    403     auto collective_it = collectives_.find(context.collective_key);
    404     Collective* collective = nullptr;
    405     if (collective_it == collectives_.end()) {
    406       auto collective_unique_ptr = absl::make_unique<Collective>(
    407           data_type, collective_type, reduction_op, context.num_local_devices,
    408           context.num_global_devices, context.communicator_key);
    409       collective = collective_unique_ptr.get();
    410       collectives_.emplace(context.collective_key,
    411                            std::move(collective_unique_ptr));
    412     } else {
    413       collective = collective_it->second.get();
    414     }
    415 
    416     // Check `collective` is correct and consistent.
    417     if (collective->status.ok() && collective->single_node &&
    418         !collective->communicator_key.empty()) {
    419       collective->status =
    420           errors::Internal("Collective ", reduction_op,
    421                            " is single node but has communicator_key of size ",
    422                            collective->communicator_key.size());
    423     }
    424     if (collective->status.ok() && collective->communicator_key.size() !=
    425                                        context.communicator_key.size()) {
    426       collective->status =
    427           errors::Internal("Collective ", reduction_op,
    428                            " mismatch in member communicator_key with size ",
    429                            collective->communicator_key.size(),
    430                            " and arg communicator_key with size ",
    431                            context.communicator_key.size());
    432     }
    433     if (collective->status.ok() && collective->type != collective_type) {
    434       collective->status = errors::Internal(
    435           "Collective ", reduction_op, " previously initialized with type ",
    436           collective->type, " but now got type ", collective_type);
    437     }
    438     if (collective->status.ok() &&
    439         collective->num_global_devices != context.num_global_devices) {
    440       collective->status =
    441           errors::Internal("Collective ", reduction_op,
    442                            " previously initialized with num_global_devices ",
    443                            collective->num_global_devices, " but now got ",
    444                            context.num_global_devices);
    445     }
    446     if (collective->status.ok() &&
    447         collective->num_local_devices != context.num_local_devices) {
    448       collective->status =
    449           errors::Internal("Collective ", reduction_op,
    450                            "previously initialized with num_local_devices ",
    451                            collective->num_local_devices, " but now got ",
    452                            context.num_local_devices);
    453     }
    454     if (collective->status.ok() &&
    455         collective->participants.size() >= collective->num_local_devices) {
    456       collective->status = errors::Internal(
    457           "Collective ", reduction_op, " expected ",
    458           collective->num_local_devices, " participants but now has ",
    459           collective->participants.size(),
    460           " with one more participant being added");
    461     }
    462 
    463     collective->participants.emplace_back(std::move(participant));
    464     ++collective->available_participants;
    465 
    466     to_run = CheckReady(context.collective_key, collective);
    467   }
    468 
    469   if (to_run != nullptr) RunCollective(to_run);
    470 }
    471 
    472 NcclManager::Collective* NcclManager::CheckReady(const string& collective_key,
    473                                                  Collective* collective) {
    474   Collective* to_run = nullptr;
    475   if (collective->available_participants == collective->num_local_devices) {
    476     if (collective->num_global_devices == collective->num_local_devices ||
    477         collective->multi_node_ready) {
    478       // Ownership transferred to callee.
    479       to_run = collective;
    480       auto collectives_it = collectives_.find(collective_key);
    481       collectives_it->second.release();
    482       collectives_.erase(collectives_it);
    483     }
    484   }
    485   return to_run;
    486 }
    487 
    488 void NcclManager::RunCollective(Collective* collective) {
    489   static mutex collective_mu(LINKER_INITIALIZED);
    490 
    491   Status s = collective->status;
    492   if (s.ok()) {
    493     s = GetCommunicator(collective, &collective->communicator);
    494   }
    495   if (!s.ok()) {
    496     for (int i = 0; i < collective->num_local_devices; ++i) {
    497       collective->participants[i]->done_callback(s);
    498     }
    499     delete collective;
    500     return;
    501   }
    502 
    503   for (int i = 0; i < collective->num_local_devices; ++i) {
    504     Participant* p = collective->participants[i].get();
    505     NcclStream* nccl_stream = collective->communicator->members[i].nccl_stream;
    506     CHECK(nccl_stream != nullptr);
    507     const int rank = p->global_rank >= 0 ? p->global_rank : i;
    508 
    509     if (p->input != nullptr) {
    510       // Wait to ensure that the kernel that produces the data in the input
    511       // tensor has finished running before the nccl kernel runs on the
    512       // communication stream.
    513       nccl_stream->stream->ThenWaitFor(p->tensor_stream);
    514     }
    515     if (p->root) {
    516       CHECK_EQ(collective->root_rank, -1);
    517       collective->root_rank = rank;
    518     }
    519   }
    520 
    521   if (collective->type == kBroadcast) {
    522     CHECK_NE(collective->root_rank, -1);
    523   }
    524 
    525   {
    526     // Allow only one collective at a time to queue kernels for launching. This
    527     // is to prevent collectives from deadlocking each other.
    528     // Note that it would be possible to run multiple collectives at once, if
    529     // they have non-intersecting sets of devices.
    530     mutex_lock l(collective_mu);
    531     for (int i = 0; i < collective->num_local_devices; ++i) {
    532       NcclStream* nccl_stream =
    533           collective->communicator->members[i].nccl_stream;
    534       mutex_lock l(nccl_stream->mu);
    535       nccl_stream->pending_launches_.push_front(std::make_pair(collective, i));
    536       nccl_stream->cv.notify_all();
    537     }
    538   }
    539 }
    540 
    541 void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
    542   se::Stream* comm_stream = nccl_stream->stream.get();
    543   ScopedActivateExecutorContext scoped_context(nccl_stream->executor);
    544   const cudaStream_t* cu_stream = reinterpret_cast<const cudaStream_t*>(
    545       comm_stream->implementation()->GpuStreamMemberHack());
    546 
    547   while (true) {
    548     // Find collective to run.
    549     std::pair<Collective*, int> next_launch;
    550     {
    551       mutex_lock l(nccl_stream->mu);
    552       while (nccl_stream->pending_launches_.empty()) {
    553         if (nccl_stream->shutdown_requested) {
    554           // No work and shutdown requested, exit.
    555           return;
    556         }
    557         nccl_stream->cv.wait(l);
    558       }
    559       next_launch = nccl_stream->pending_launches_.back();
    560       nccl_stream->pending_launches_.pop_back();
    561     }
    562 
    563     // Launch the nccl kernel.
    564     Collective* collective = next_launch.first;
    565     ncclDataType_t data_type = ToNcclType(collective->data_type);
    566     int p_idx = next_launch.second;
    567     Participant* p = collective->participants[p_idx].get();
    568     auto nccl_comm = collective->communicator->members[p_idx].nccl_comm;
    569     ncclResult_t nccl_result = ncclSuccess;
    570     switch (collective->type) {
    571       case kAllReduce: {
    572         const void* sendbuff = p->input->tensor_data().data();
    573         void* recvbuff = const_cast<char*>(p->output->tensor_data().data());
    574 
    575         VLOG(2) << "call NcclAllReduce participant " << p_idx << " sendbuff "
    576                 << sendbuff << " recvbuff " << recvbuff << " nccl_comm "
    577                 << nccl_comm << " comm_stream " << comm_stream
    578                 << " cuda_stream " << cu_stream;
    579         nccl_result = ncclAllReduce(sendbuff, recvbuff, p->input->NumElements(),
    580                                     data_type, collective->reduction_op,
    581                                     nccl_comm, *cu_stream);
    582         break;
    583       }
    584       case kBroadcast: {
    585         const Tensor* buf_t = p->input ? p->input : p->output;
    586         void* buf = const_cast<char*>(buf_t->tensor_data().data());
    587         nccl_result = ncclBcast(buf, buf_t->NumElements(), data_type,
    588                                 collective->root_rank, nccl_comm, *cu_stream);
    589         break;
    590       }
    591       case kReduce: {
    592         const void* sendbuff = p->input->tensor_data().data();
    593         void* recvbuff =
    594             p->output ? const_cast<char*>(p->output->tensor_data().data())
    595                       : nullptr;
    596         nccl_result = ncclReduce(sendbuff, recvbuff, p->input->NumElements(),
    597                                  data_type, collective->reduction_op,
    598                                  collective->root_rank, nccl_comm, *cu_stream);
    599         break;
    600       }
    601       case kAllGather: {
    602         const void* sendbuff = p->input->tensor_data().data();
    603         void* recvbuff = const_cast<char*>(p->output->tensor_data().data());
    604 
    605         VLOG(2) << "call NcclAllGather participant " << p_idx << " sendbuff "
    606                 << sendbuff << " sendcount " << p->input->NumElements()
    607                 << " recvbuff " << recvbuff << " recvcount "
    608                 << p->output->NumElements() << " nccl_comm " << nccl_comm
    609                 << " comm_stream " << comm_stream << " cuda_stream "
    610                 << cu_stream;
    611         nccl_result = ncclAllGather(sendbuff, recvbuff, p->input->NumElements(),
    612                                     data_type, nccl_comm, *cu_stream);
    613         break;
    614       }
    615     }
    616 
    617     // Run the done_callback when the nccl kernel finishes running.
    618     auto done_callback = [collective, p_idx, nccl_result]() {
    619       if (nccl_result == ncclSuccess) {
    620         collective->participants[p_idx]->done_callback(Status::OK());
    621       } else {
    622         // Propagate the error, but note that if other members of the collective
    623         // did launch their kernels, then they are hanging.
    624         collective->participants[p_idx]->done_callback(errors::Unknown(
    625             "Error invoking NCCL: ", ncclGetErrorString(nccl_result)));
    626       }
    627 
    628       // TODO(cwhipkey): use RefCounted after figuring out how to use in a
    629       // custom op library.
    630       // See tensorflow/core/lib/core/refcount.h for details on this locking.
    631       if (collective->remaining_participants.load(std::memory_order_acquire) ==
    632               1 ||
    633           collective->remaining_participants.fetch_sub(1) == 1) {
    634         delete collective;
    635       }
    636     };
    637     p->event_mgr->ThenExecute(comm_stream, done_callback);
    638   }
    639 }
    640 
    641 }  // namespace tensorflow
    642 
    643 #endif  // GOOGLE_CUDA
    644