Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifdef TENSORFLOW_USE_MPI
     17 
     18 #include <queue>
     19 #include <thread>
     20 #include <unordered_map>
     21 
     22 #include "tensorflow/core/framework/op.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/types.pb.h"
     25 #include "tensorflow/core/platform/mutex.h"
     26 
     27 #define EIGEN_USE_THREADS
     28 
     29 #if GOOGLE_CUDA
     30 #include <cuda_runtime.h>
     31 #include "tensorflow/stream_executor/stream.h"
     32 #endif
     33 
     34 #include "tensorflow/stream_executor/lib/statusor.h"
     35 
     36 #define OMPI_SKIP_MPICXX
     37 #include "third_party/mpi/mpi.h"
     38 #include "tensorflow/contrib/mpi_collectives/kernels/ring.h"
     39 #include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h"
     40 
     41 /*
     42  * MPI Allreduce and Allgather Ops for TensorFlow.
     43  *
     44  * TensorFlow natively provides inter-device communication through send and
     45  * receive ops and inter-node communication through Distributed TensorFlow,
     46  * based on the same send and receive abstractions. These end up being
     47  * insufficient for synchronous data-parallel training on HPC clusters where
     48  * Infiniband or other high-speed interconnects are available.  This module
     49  * implements MPI ops for allgather and allreduce, which do bandwidth-optimal
     50  * gathers and reductions and can take advantage of hardware-optimized
     51  * communication libraries through the MPI implementation.
     52  *
     53  * The primary logic of the allreduce and allgather are in RingAllgather() and
     54  * RingAllreduce(). The background thread which facilitates MPI operations is
     55  * run in BackgroundThreadLoop(). The provided MPI ops are:
     56  *       MPIInit:
     57  *          Initialize MPI on a given device (CPU or GPU).
     58  *          Should only be run on a single device in every process.
     59  *       MPISize:
     60  *          Get the number of MPI processes in the global communicator.
     61  *       MPIRank:
     62  *          Get the rank of the current MPI process in the global communicator.
     63  *       MPILocalRank:
     64  *          Get the local rank of the current MPI process within its node.
     65  *       MPIAllreduce:
     66  *          Perform an allreduce on a Tensor, returning the sum
     67  *          across all MPI processes in the global communicator.
     68  *       MPIAllgather:
     69  *          Perform an allgather on a Tensor, returning the concatenation of
     70  *          the tensor on the first dimension across all MPI processes in the
     71  *          global communicator.
     72  *
     73  */
     74 
     75 template <class T>
     76 using StatusOr = perftools::gputools::port::StatusOr<T>;
     77 
     78 using CPUDevice = Eigen::ThreadPoolDevice;
     79 using GPUDevice = Eigen::GpuDevice;
     80 
     81 namespace tensorflow {
     82 namespace contrib {
     83 namespace mpi_collectives {
     84 
     85 // Make sure template specializations are generated in the ring.cu.cc and the
     86 // ring.cc file, not in this file.
     87 extern template Status RingAllreduce<GPUDevice, int>(OpKernelContext*,
     88                                                      const Tensor*, Tensor*,
     89                                                      Tensor*);
     90 extern template Status RingAllreduce<GPUDevice, long long>(OpKernelContext*,
     91                                                            const Tensor*,
     92                                                            Tensor*, Tensor*);
     93 extern template Status RingAllreduce<GPUDevice, float>(OpKernelContext*,
     94                                                        const Tensor*, Tensor*,
     95                                                        Tensor*);
     96 extern template Status RingAllgather<GPUDevice, int>(OpKernelContext*,
     97                                                      const Tensor*,
     98                                                      const std::vector<size_t>&,
     99                                                      Tensor*);
    100 extern template Status RingAllgather<GPUDevice, long long>(
    101     OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
    102 extern template Status RingAllgather<GPUDevice, float>(
    103     OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
    104 extern template Status RingAllreduce<CPUDevice, int>(OpKernelContext*,
    105                                                      const Tensor*, Tensor*,
    106                                                      Tensor*);
    107 extern template Status RingAllreduce<CPUDevice, long long>(OpKernelContext*,
    108                                                            const Tensor*,
    109                                                            Tensor*, Tensor*);
    110 extern template Status RingAllreduce<CPUDevice, float>(OpKernelContext*,
    111                                                        const Tensor*, Tensor*,
    112                                                        Tensor*);
    113 extern template Status RingAllgather<CPUDevice, int>(OpKernelContext*,
    114                                                      const Tensor*,
    115                                                      const std::vector<size_t>&,
    116                                                      Tensor*);
    117 extern template Status RingAllgather<CPUDevice, long long>(
    118     OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
    119 extern template Status RingAllgather<CPUDevice, float>(
    120     OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
    121 
    122 namespace {
    123 
    124 // Return true if the templated type is GPUDevice, otherwise false.
    125 template <typename T>
    126 bool IsGPUDevice();
    127 template <>
    128 bool IsGPUDevice<GPUDevice>() {
    129   return true;
    130 };
    131 template <>
    132 bool IsGPUDevice<CPUDevice>() {
    133   return false;
    134 };
    135 
    136 // A callback to call after the MPI communication completes. Since the
    137 // allreduce and allgather ops are asynchronous, this callback is what resumes
    138 // computation after the reduction is completed.
    139 typedef std::function<void(StatusOr<Tensor>)> CommunicationDoneCallback;
    140 
    141 struct CollectiveOpRecord {
    142   // The rank performing this piece of the op
    143   int rank;
    144 
    145   // The name of the op/tensor to be reduced
    146   std::string name;
    147 
    148   // The op's kernel context
    149   OpKernelContext* context;
    150 
    151   // Data type of the op
    152   DataType dtype;
    153 
    154   // The input tensor
    155   const Tensor* in_t;
    156 
    157   // Allgather: Vector of per-rank first-dimension sizes
    158   std::vector<size_t> sizes_vec;
    159 
    160   // The temp tensor for intermediate results
    161   Tensor temp_t;
    162 
    163   // The output tensor
    164   Tensor* out_t;
    165 
    166   // Whether to run this op on the gpu
    167   bool on_gpu;
    168 
    169   // The callback to call after the op has completed
    170   CommunicationDoneCallback callback;
    171 };
    172 
    173 // Table storing Tensors to be reduced, keyed by unique name.
    174 // This table contains everything necessary to do the reduction
    175 typedef std::unordered_map<std::string, CollectiveOpRecord> TensorTable;
    176 
    177 // Table for storing Tensor metadata on rank zero. This is used for error
    178 // checking and size calculations, as well as determining when a reduction is
    179 // ready to be done (when all nodes are ready to do it).
    180 typedef std::unordered_map<std::string, std::vector<MPIRequest> > MessageTable;
    181 
    182 // The global state required for the MPI ops.
    183 //
    184 // MPI is a library that stores a lot of global per-program state and often
    185 // requires running on a single thread. As a result, we have to have a single
    186 // background thread responsible for all MPI operations, and communicate with
    187 // that background thread through global state.
    188 struct MPIGlobalState {
    189   // An atomic boolean which is set to true when MPI is initialized.
    190   // This ensures that MPI_Init is never called twice.
    191   std::atomic_flag initialized_flag = ATOMIC_FLAG_INIT;
    192 
    193   // Condition variable to wait for initialization
    194   condition_variable cv;
    195 
    196   // Whether MPI_Init has been completed on the background thread.
    197   bool initialization_done = false;
    198 
    199   // Whether MPI_Init succeeded on the background thread.
    200   Status init_status;
    201 
    202   // A mutex that needs to be used whenever MPI operations touch
    203   // shared structures.
    204   mutex mu;
    205 
    206   // Tensors waiting to be allreduced or allgathered.
    207   TensorTable tensor_table;
    208 
    209   // Queue of MPI requests waiting to be sent to the coordinator node.
    210   std::queue<MPIRequest> message_queue;
    211 
    212   // Background thread running MPI communication.
    213   std::thread background_thread;
    214 
    215   // Whether the background thread should shutdown.
    216   bool shut_down = false;
    217 
    218   // Only exists on the coordinator node (rank zero). Maintains a count of
    219   // how many nodes are ready to allreduce every tensor (keyed by tensor
    220   // name).
    221   std::unique_ptr<MessageTable> message_table;
    222 
    223   // The MPI rank, local rank, and size.
    224   int rank = 0;
    225   int local_rank = 0;
    226   int size = 1;
    227 
    228   // The device that MPI was initialized on. (-1 for no GPU)
    229   int device = -1;
    230 
    231   // The CUDA stream used for data transfers and within-allreduce operations.
    232   // A naive implementation would use the TensorFlow StreamExecutor CUDA
    233   // stream. However, the allreduce and allgather require doing memory copies
    234   // and kernel executions (for accumulation of values on the GPU). However,
    235   // the subsequent operations must wait for those operations to complete,
    236   // otherwise MPI (which uses its own stream internally) will begin the data
    237   // transfers before the CUDA calls are complete. In order to wait for those
    238   // CUDA operations, if we were using the TensorFlow stream, we would have
    239   // to synchronize that stream; however, other TensorFlow threads may be
    240   // submitting more work to that stream, so synchronizing on it can cause
    241   // the allreduce to be delayed, waiting for compute totally unrelated to it
    242   // in other parts of the graph. Overlaying memory transfers and compute
    243   // during backpropagation is crucial for good performance, so we cannot use
    244   // the TensorFlow stream, and must use our own stream.
    245 #if GOOGLE_CUDA
    246   cudaStream_t stream;
    247   std::atomic_flag stream_created_flag = ATOMIC_FLAG_INIT;
    248 #endif
    249 
    250   ~MPIGlobalState() {
    251     // Make sure that the destructor of the background thread is safe to
    252     // call. If a thread is still joinable (not detached or complete) its
    253     // destructor cannot be called.
    254     if (background_thread.joinable()) {
    255       shut_down = true;
    256       background_thread.join();
    257     }
    258   }
    259 };
    260 
    261 // All the MPI state that must be stored globally per-process.
    262 static MPIGlobalState mpi_global;
    263 
    264 // For clarify in argument lists.
    265 #define RANK_ZERO 0
    266 
    267 // A tag used for all coordinator messaging.
    268 #define TAG_NOTIFY 1
    269 
    270 // Store the MPIRequest for a name, and return whether the total count of
    271 // MPIRequests for that tensor is now equal to the MPI size (and thus we are
    272 // ready to reduce the tensor).
    273 bool IncrementTensorCount(std::unique_ptr<MessageTable>& message_table,
    274                           MPIRequest msg, int mpi_size) {
    275   auto name = msg.tensor_name();
    276   auto table_iter = message_table->find(name);
    277   if (table_iter == message_table->end()) {
    278     message_table->emplace(name, std::vector<MPIRequest>({msg}));
    279     table_iter = message_table->find(name);
    280   } else {
    281     table_iter->second.push_back(msg);
    282   }
    283 
    284   int count = table_iter->second.size();
    285   return count == mpi_size;
    286 }
    287 
    288 // Once a tensor is ready to be reduced, the coordinator sends an MPIResponse
    289 // instructing all ranks to start the reduction to all ranks. The MPIResponse
    290 // also contains error messages in case the submitted MPIRequests were not
    291 // valid (for example, contained mismatched shapes or types).
    292 //
    293 // Constructing the MPIResponse, thus, requires a whole lot of error checking.
    294 MPIResponse ConstructMPIResponse(std::unique_ptr<MessageTable>& message_table,
    295                                  std::string name) {
    296   bool error = false;
    297   auto it = message_table->find(name);
    298   assert(it != message_table->end());
    299 
    300   std::vector<MPIRequest> requests = it->second;
    301   assert(requests.size() > 0);
    302 
    303   std::ostringstream error_message_stream;
    304 
    305   // Check that all data types being reduced or gathered are identical
    306   auto data_type = requests[0].tensor_type();
    307   for (unsigned int i = 1; i < requests.size(); i++) {
    308     auto request_type = requests[i].tensor_type();
    309     if (data_type != request_type) {
    310       error = true;
    311       error_message_stream << "Mismatched data types: One rank had type "
    312                            << DataType_Name(data_type)
    313                            << ", but another rank had type "
    314                            << DataType_Name(request_type) << ".";
    315       break;
    316     }
    317   }
    318 
    319   // Check that all requested operations are the same
    320   auto message_type = requests[0].request_type();
    321   for (unsigned int i = 1; i < requests.size(); i++) {
    322     if (error) {
    323       break;
    324     }
    325 
    326     auto request_type = requests[i].request_type();
    327     if (message_type != request_type) {
    328       error = true;
    329       error_message_stream << "Mismatched MPI operations: One rank did an "
    330                            << message_type << ", but another rank did an "
    331                            << request_type << ".";
    332       break;
    333     }
    334   }
    335 
    336   // If we are doing an allreduce, check that all tensor shapes
    337   // are identical
    338   if (message_type == MPIRequest::ALLREDUCE) {
    339     TensorShape tensor_shape = requests[0].tensor_shape();
    340     for (unsigned int i = 1; i < requests.size(); i++) {
    341       if (error) {
    342         break;
    343       }
    344 
    345       TensorShape request_shape = requests[i].tensor_shape();
    346       if (tensor_shape != request_shape) {
    347         error = true;
    348         error_message_stream << "Mismatched allreduce tensor shapes: "
    349                              << "One rank reduced a tensor of shape "
    350                              << tensor_shape.DebugString()
    351                              << ", but another rank sent a tensor of shape "
    352                              << request_shape.DebugString() << ".";
    353         break;
    354       }
    355     }
    356   }
    357 
    358   // If we are doing an allgather, make sure all but the first dimension are
    359   // the same. The first dimension may be different and the output tensor is
    360   // the sum of the first dimension. Collect the sizes by rank.
    361   if (message_type == MPIRequest::ALLGATHER) {
    362     TensorShape tensor_shape = requests[0].tensor_shape();
    363 
    364     if (tensor_shape.dims() == 0) {
    365       error = true;
    366       error_message_stream << "Rank zero tried to gather a rank-zero tensor.";
    367     }
    368 
    369     for (unsigned int i = 1; i < requests.size(); i++) {
    370       if (error) {
    371         break;
    372       }
    373 
    374       TensorShape request_shape = requests[i].tensor_shape();
    375       if (tensor_shape.dims() != request_shape.dims()) {
    376         error = true;
    377         error_message_stream << "Mismatched allgather tensor shapes: "
    378                              << "One rank gathered a tensor of rank "
    379                              << tensor_shape.dims()
    380                              << ", but another rank sent a tensor of rank "
    381                              << request_shape.dims() << ".";
    382         break;
    383       }
    384 
    385       for (unsigned int dim = 1; dim < tensor_shape.dims(); dim++) {
    386         if (tensor_shape.dim_size(dim) != request_shape.dim_size(dim)) {
    387           error = true;
    388           error_message_stream
    389               << "Mismatched allgather tensor shapes: "
    390               << "One rank gathered a tensor with dimension " << dim
    391               << " equal to " << tensor_shape.dim_size(dim)
    392               << ", but another rank sent a tensor with dimension " << dim
    393               << " equal to " << request_shape.dim_size(dim) << ".";
    394           break;
    395         }
    396       }
    397     }
    398   }
    399 
    400   MPIResponse response;
    401   response.set_tensor_name(name);
    402   if (error) {
    403     std::string error_message = error_message_stream.str();
    404     response.set_response_type(MPIResponse::ERROR);
    405     response.set_error_message(error_message);
    406   } else {
    407     auto response_type = MPIResponse::ERROR;
    408     if (message_type == MPIRequest::ALLREDUCE) {
    409       response_type = MPIResponse::ALLREDUCE;
    410     } else {
    411       response_type = MPIResponse::ALLGATHER;
    412     }
    413     response.set_response_type(response_type);
    414   }
    415 
    416   // Clear all queued up requests for this name. They are now taken care of
    417   // by the constructed MPI response.
    418   message_table->erase(it);
    419 
    420   return response;
    421 }
    422 
    423 // Process an MPIResponse by doing a reduction, a gather, or raising an error.
    424 void PerformCollectiveOp(TensorTable& tensor_table, MPIResponse response) {
    425   OpKernelContext* context;
    426   const Tensor* input_tensor;
    427   std::vector<size_t> sizes_vec;
    428   Tensor temp_tensor;
    429   Tensor* output_tensor;
    430   CommunicationDoneCallback callback;
    431   bool on_gpu;
    432   {
    433     // Lock on the tensor table.
    434     mutex_lock guard(mpi_global.mu);
    435 
    436     // We should never fail at finding this key in the tensor table.
    437     auto name = response.tensor_name();
    438     auto iter = tensor_table.find(name);
    439     assert(iter != tensor_table.end());
    440 
    441     assert(response.response_type() == MPIResponse::ALLREDUCE ||
    442            response.response_type() == MPIResponse::ALLGATHER ||
    443            response.response_type() == MPIResponse::ERROR);
    444 
    445     CollectiveOpRecord record = iter->second;
    446     context = record.context;
    447     input_tensor = record.in_t;
    448     sizes_vec = record.sizes_vec;
    449     temp_tensor = record.temp_t;
    450     output_tensor = record.out_t;
    451     on_gpu = record.on_gpu;
    452     callback = record.callback;
    453 
    454     // Clear the tensor table of this tensor and its callbacks; the rest of
    455     // this function takes care of it.
    456     tensor_table.erase(iter);
    457   }
    458 
    459   // Use CPUDevice instead of GPUDevice if no CUDA, to ensure we don't
    460   // link to non-existent symbols.
    461 #if GOOGLE_CUDA
    462 #define GPU_DEVICE_IF_CUDA GPUDevice
    463 #else
    464 #define GPU_DEVICE_IF_CUDA CPUDevice
    465 #endif
    466 
    467   Status status;
    468   auto dtype = input_tensor->dtype();
    469   if (response.response_type() == MPIResponse::ALLGATHER) {
    470     if (dtype == DT_FLOAT) {
    471       status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, float>(
    472                             context, input_tensor, sizes_vec, output_tensor)
    473                       : RingAllgather<CPUDevice, float>(
    474                             context, input_tensor, sizes_vec, output_tensor);
    475     } else if (dtype == DT_INT32) {
    476       status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, int>(
    477                             context, input_tensor, sizes_vec, output_tensor)
    478                       : RingAllgather<CPUDevice, int>(context, input_tensor,
    479                                                       sizes_vec, output_tensor);
    480     } else if (dtype == DT_INT64) {
    481       status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, long long>(
    482                             context, input_tensor, sizes_vec, output_tensor)
    483                       : RingAllgather<CPUDevice, long long>(
    484                             context, input_tensor, sizes_vec, output_tensor);
    485     } else {
    486       status = errors::Unknown("Invalid tensor type for MPI allgather.");
    487     }
    488   } else if (response.response_type() == MPIResponse::ALLREDUCE) {
    489     if (dtype == DT_FLOAT) {
    490       status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, float>(
    491                             context, input_tensor, &temp_tensor, output_tensor)
    492                       : RingAllreduce<CPUDevice, float>(
    493                             context, input_tensor, &temp_tensor, output_tensor);
    494     } else if (dtype == DT_INT32) {
    495       status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, int>(
    496                             context, input_tensor, &temp_tensor, output_tensor)
    497                       : RingAllreduce<CPUDevice, int>(
    498                             context, input_tensor, &temp_tensor, output_tensor);
    499     } else if (dtype == DT_INT64) {
    500       status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, long long>(
    501                             context, input_tensor, &temp_tensor, output_tensor)
    502                       : RingAllreduce<CPUDevice, long long>(
    503                             context, input_tensor, &temp_tensor, output_tensor);
    504     } else {
    505       status = errors::Unknown("Invalid tensor type for MPI allreduce.");
    506     }
    507   } else if (response.response_type() == MPIResponse::ERROR) {
    508     status = errors::FailedPrecondition(response.error_message());
    509   }
    510 
    511   if (status.ok()) {
    512     callback(StatusOr<Tensor>(*output_tensor));
    513   } else {
    514     callback(StatusOr<Tensor>(status));
    515   }
    516 }
    517 
    518 // The MPI background thread loop coordinates all the MPI processes and the
    519 // tensor reductions. The design of the communicator mechanism is limited by a
    520 // few considerations:
    521 //
    522 //      1. Some MPI implementations require all MPI calls to happen from a
    523 //      single thread. Since TensorFlow may use several threads for graph
    524 //      processing, this means we must have our own dedicated thread for
    525 //      dealing with MPI.
    526 //      2. We want to gracefully handle errors, when MPI processes do not
    527 //      properly agree upon what should happen (such as mismatched types or
    528 //      shapes). To do so requires the MPI processes to know about the shapes
    529 //      and types of the relevant tensors on the other processes.
    530 //      3. The MPI reductions and gathers should be able to happen in parallel
    531 //      with other ongoing operations. Since MPI uses an internal
    532 //      (inaccessible) GPU stream separate from the TF GPUDevice streams, we
    533 //      cannot explicitly synchronize memcpys or kernels with it. As a result,
    534 //      MPIAllreduce and MPIAllgather must be AsyncOpKernels to ensure proper
    535 //      ordering of memcpys and kernels with respect to TF streams.
    536 //      4. NOTE: We cannot guarantee that all the MPI processes reduce their
    537 //      tensors in the same order. Thus, there must be a way to ensure the
    538 //      reduction memcpys and kernels occur for correct tensors across all
    539 //      ranks at the same time. We choose to use a coordinator (rank ID 0) to
    540 //      gather and trigger the reduction operations that are ready to execute.
    541 //
    542 // The coordinator currently follows a master-worker paradigm. Rank zero acts
    543 // as the master (the "coordinator"), whereas all other ranks are simply
    544 // workers. Each rank runs its own background thread which progresses in ticks.
    545 // In each tick, the following actions happen:
    546 //
    547 //      a) The workers send any available MPIRequests to the coordinator. These
    548 //      MPIRequests indicate what the worker would like to do (i.e. which
    549 //      tensor they would like to gather or reduce, as well as their shape and
    550 //      type). They repeat this for every tensor that they would like to
    551 //      operate on after that tensor's collective op has executed ComputeAsync.
    552 //
    553 //      b) The workers send an empty "DONE" message to the coordinator to
    554 //      indicate that there are no more tensors they wish to operate on.
    555 //
    556 //      c) The coordinator receives the MPIRequests from the workers, as well
    557 //      as from its own TensorFlow ops, and stores them in a request table. The
    558 //      coordinator continues to receive MPIRequest messages until it has
    559 //      received MPI_SIZE number of empty "DONE" messages.
    560 //
    561 //      d) The coordinator finds all tensors that are ready to be reduced,
    562 //      gathered, or all operations that result in an error. For each of those,
    563 //      it sends an MPIResponse to all the workers. When no more MPIResponses
    564 //      are available, it sends a "DONE" response to the workers. If the
    565 //      process is being shutdown, it instead sends a "SHUTDOWN" response.
    566 //
    567 //      e) The workers listen for MPIResponse messages, processing each one by
    568 //      doing the required reduce or gather, until they receive a "DONE"
    569 //      response from the coordinator. At that point, the tick ends.
    570 //      If instead of "DONE" they receive "SHUTDOWN", they exit their
    571 //      background loop.
    572 // TODO: Use the global mpi_global state variable instead of a local one
    573 void BackgroundThreadLoop() {
    574 #if GOOGLE_CUDA
    575   // Set the device, so that this thread uses the same GPU context as the
    576   // calling thread.
    577   // TODO: Ensure that this is operating correctly. The background thread
    578   // needs to be able to control all GPUs that the rank has access to, and
    579   // might be more than 1 GPU. Tensors could be resident in any of the
    580   // GPUs, so the background thread's accumulate and copy kernels might need
    581   // to correctly set the device and it might be necessary for the background
    582   // thread to manage multiple streams.
    583   cudaSetDevice(mpi_global.device);
    584   cudaStreamCreate(&mpi_global.stream);
    585 #endif
    586 
    587   // Initialize MPI. This must happen on the background thread, since not all
    588   // MPI implementations support being called from multiple threads.
    589   auto init_result = MPI_Init(NULL, NULL);
    590   if (init_result != MPI_SUCCESS) {
    591     mpi_global.init_status =
    592         errors::Unknown("Could not initialize MPI; MPI_Init() failed.");
    593     mpi_global.initialization_done = true;
    594     mpi_global.cv.notify_all();
    595     return;
    596   } else {
    597     mpi_global.init_status = Status::OK();
    598   }
    599 
    600   // Get MPI rank to determine if we are rank zero.
    601   int rank;
    602   MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    603   bool is_coordinator = rank == 0;
    604 
    605   // Get MPI size to determine how many tensors to wait for before reducing.
    606   int size;
    607   MPI_Comm_size(MPI_COMM_WORLD, &size);
    608 
    609   // Determine local rank by querying the local communicator.
    610   MPI_Comm local_comm;
    611   MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL,
    612                       &local_comm);
    613   int local_rank;
    614   MPI_Comm_rank(local_comm, &local_rank);
    615 
    616   mpi_global.rank = rank;
    617   mpi_global.local_rank = local_rank;
    618   mpi_global.size = size;
    619   mpi_global.initialization_done = true;
    620 
    621   // Notify calling thread that initialization is complete
    622   mpi_global.cv.notify_all();
    623 
    624   // TODO: MOVE MESSAGE TABLE INITIALIZATION TO LIBRARY LOAD!
    625   // Initialize the tensor count table. No tensors are available yet.
    626   if (is_coordinator) {
    627     mpi_global.message_table =
    628         std::unique_ptr<MessageTable>(new MessageTable());
    629   }
    630 
    631   // The coordinator sends a SHUTDOWN message to trigger shutdown.
    632   bool should_shut_down = false;
    633   do {
    634     // TODO: Eliminate the need for thread sleep by making all activity
    635     // depend on other activity (e.g. condition or MPI waits).
    636     std::this_thread::sleep_for(std::chrono::milliseconds(1));
    637 
    638     // Copy the data structures from global state under this lock.
    639     // However, don't keep the lock for the rest of the loop, so that
    640     // enqueued stream callbacks can continue.
    641     std::queue<MPIRequest> message_queue;
    642     {
    643       mutex_lock guard(mpi_global.mu);
    644       while (!mpi_global.message_queue.empty()) {
    645         MPIRequest message = mpi_global.message_queue.front();
    646         mpi_global.message_queue.pop();
    647         message_queue.push(message);
    648       }
    649     }
    650 
    651     // Collect all tensors that are ready to be reduced. Record them in the
    652     // tensor count table (rank zero) or send them to rank zero to be
    653     // recorded (everyone else).
    654     std::vector<std::string> ready_to_reduce;
    655     while (!message_queue.empty()) {
    656       // Pop the first available message message
    657       MPIRequest message = message_queue.front();
    658       message_queue.pop();
    659 
    660       if (is_coordinator) {
    661         bool reduce =
    662             IncrementTensorCount(mpi_global.message_table, message, size);
    663         if (reduce) {
    664           ready_to_reduce.push_back(message.tensor_name());
    665         }
    666       } else {
    667         std::string encoded_message;
    668         message.SerializeToString(&encoded_message);
    669         MPI_Send(encoded_message.c_str(), encoded_message.length() + 1,
    670                  MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
    671       }
    672     }
    673 
    674     // Rank zero has put all its own tensors in the tensor count table.
    675     // Now, it should count all the tensors that are coming from other
    676     // ranks at this tick. It should keep getting tensors until it gets a
    677     // DONE message from all the other ranks.
    678     if (is_coordinator) {
    679       // Count of DONE messages. Keep receiving messages until the number
    680       // of messages is equal to the number of processes. Initialize to
    681       // one since the coordinator is effectively done.
    682       int completed_ranks = 1;
    683       while (completed_ranks != size) {
    684         MPI_Status status;
    685         MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status);
    686 
    687         // Find number of characters in message (including zero byte).
    688         int source_rank = status.MPI_SOURCE;
    689         int msg_length;
    690         MPI_Get_count(&status, MPI_BYTE, &msg_length);
    691 
    692         // If the length is zero, this is a DONE message.
    693         if (msg_length == 0) {
    694           completed_ranks++;
    695           MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY, MPI_COMM_WORLD,
    696                    &status);
    697           continue;
    698         }
    699 
    700         // Get tensor name from MPI into an std::string.
    701         char* buffer = new char[msg_length];
    702         MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank, TAG_NOTIFY,
    703                  MPI_COMM_WORLD, &status);
    704         std::string received_data(buffer);
    705         delete[] buffer;
    706 
    707         MPIRequest received_message;
    708         received_message.ParseFromString(received_data);
    709         auto received_name = received_message.tensor_name();
    710 
    711         bool reduce = IncrementTensorCount(mpi_global.message_table,
    712                                            received_message, size);
    713         if (reduce) {
    714           ready_to_reduce.push_back(received_name);
    715         }
    716       }
    717 
    718       // At this point, rank zero should have a fully updated tensor
    719       // count table and should know all the tensors that need to be
    720       // reduced or gathered, and everyone else should have sent all
    721       // their information to rank zero. We can now do reductions and
    722       // gathers; rank zero will choose which ones and in what order,
    723       // and will notify the other ranks before doing each reduction.
    724       for (int i = 0; i < ready_to_reduce.size(); i++) {
    725         // Notify all nodes which tensor we'd like to reduce now
    726         auto name = ready_to_reduce[i];
    727         MPIResponse response =
    728             ConstructMPIResponse(mpi_global.message_table, name);
    729 
    730         std::string encoded_response;
    731         response.SerializeToString(&encoded_response);
    732         for (int r = 1; r < size; r++) {
    733           MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
    734                    MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
    735         }
    736 
    737         // Perform the reduction. All nodes should end up performing
    738         // the same reduction.
    739         PerformCollectiveOp(mpi_global.tensor_table, response);
    740       }
    741 
    742       // Notify all nodes that we are done with the reductions for this
    743       // tick.
    744       MPIResponse done_response;
    745       should_shut_down = mpi_global.shut_down;
    746       done_response.set_response_type(
    747           mpi_global.shut_down ? MPIResponse::SHUTDOWN : MPIResponse::DONE);
    748       std::string encoded_response;
    749       done_response.SerializeToString(&encoded_response);
    750       for (int r = 1; r < size; r++) {
    751         MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
    752                  MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
    753       }
    754     } else {
    755       // Notify the coordinator that this node is done sending messages.
    756       // A DONE message is encoded as a zero-length message.
    757       MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
    758 
    759       // Receive names for tensors to reduce from rank zero. Once we
    760       // receive a empty DONE message, stop waiting for more names.
    761       while (true) {
    762         MPI_Status status;
    763         MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status);
    764 
    765         // Find number of characters in message (including zero byte).
    766         int msg_length;
    767         MPI_Get_count(&status, MPI_BYTE, &msg_length);
    768 
    769         // Get tensor name from MPI into an std::string.
    770         char* buffer = new char[msg_length];
    771         MPI_Recv(buffer, msg_length, MPI_BYTE, 0, TAG_NOTIFY, MPI_COMM_WORLD,
    772                  &status);
    773         std::string received_message(buffer);
    774         delete[] buffer;
    775 
    776         MPIResponse response;
    777         response.ParseFromString(received_message);
    778         if (response.response_type() == MPIResponse::DONE) {
    779           // No more messages this tick
    780           break;
    781         } else if (response.response_type() == MPIResponse::SHUTDOWN) {
    782           // No more messages this tick, and the background thread
    783           // should shut down
    784           should_shut_down = true;
    785           break;
    786         } else {
    787           // Process the current message
    788           PerformCollectiveOp(mpi_global.tensor_table, response);
    789         }
    790       }
    791     }
    792   } while (!should_shut_down);
    793 
    794   MPI_Finalize();
    795 }
    796 
    797 // Initialize MPI and start the MPI background thread. Ensure that this is
    798 // only done once no matter how many times this function is called.
    799 Status InitializeMPIOnce(bool gpu) {
    800   // Ensure MPI is only initialized once.
    801   if (mpi_global.initialized_flag.test_and_set()) return mpi_global.init_status;
    802 
    803   mpi_global.device = -1;
    804 #if GOOGLE_CUDA
    805   if (gpu) {
    806     cudaGetDevice(&mpi_global.device);
    807   }
    808 #endif
    809 
    810   // Start the MPI background thread, which assumes MPI is initialized
    811   // TODO: Change this to a Tensorflow thread
    812   mpi_global.background_thread = std::thread(BackgroundThreadLoop);
    813 
    814   // Wait to ensure that the background thread has finished initializing MPI
    815   mutex_lock guard(mpi_global.mu);
    816   mpi_global.cv.wait(guard);
    817   if (!mpi_global.initialization_done) {
    818     mpi_global.init_status =
    819         errors::Unknown("Failed to wait for MPI initialization.");
    820   }
    821 
    822   return mpi_global.init_status;
    823 }
    824 
    825 // Check that MPI is initialized.
    826 Status IsMPIInitialized() {
    827   if (!mpi_global.initialization_done) {
    828     return errors::FailedPrecondition(
    829         "MPI has not been initialized; use tf.contrib.mpi.Session.");
    830   }
    831   return Status::OK();
    832 }
    833 
    834 // This function (called from the callback set up in MPIAll*Op::ComputeAsync)
    835 // only adds the op's record into the local op queue (to track the op's
    836 // progress), and sends a message to the coordinator indicating that this rank
    837 // is ready to begin. The MPI background thread will handle the MPI message.
    838 void EnqueueTensorCollective(CollectiveOpRecord record,
    839                              MPIRequest::RequestType rtype) {
    840   const Tensor* input_tensor = record.in_t;
    841   MPIRequest message;
    842   message.set_request_rank(record.rank);
    843   message.set_tensor_name(record.name);
    844   message.set_tensor_type(record.dtype);
    845   message.set_request_type(rtype);
    846   input_tensor->shape().AsProto(message.mutable_tensor_shape());
    847 
    848   mutex_lock guard(mpi_global.mu);
    849   mpi_global.tensor_table.emplace(record.name, record);
    850   mpi_global.message_queue.push(message);
    851 }
    852 
    853 }  // namespace
    854 
    855 #if GOOGLE_CUDA
    856 cudaStream_t CudaStreamForMPI() { return mpi_global.stream; }
    857 #endif
    858 
    859 // Op to initialize MPI in the current process. The settings used in the
    860 // configuration are the same that must be used for all future MPI ops.
    861 template <typename Device>
    862 class MPIInitOp : public OpKernel {
    863  public:
    864   explicit MPIInitOp(OpKernelConstruction* context) : OpKernel(context) {}
    865 
    866   void Compute(OpKernelContext* context) override {
    867     bool on_gpu = IsGPUDevice<Device>();
    868     OP_REQUIRES_OK(context, InitializeMPIOnce(on_gpu));
    869   }
    870 };
    871 
    872 REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_CPU),
    873                         MPIInitOp<CPUDevice>);
    874 #if GOOGLE_CUDA
    875 REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_GPU),
    876                         MPIInitOp<GPUDevice>);
    877 #endif
    878 
    879 // Op to get the current MPI Size.
    880 template <typename Device>
    881 class MPISizeOp : public OpKernel {
    882  public:
    883   explicit MPISizeOp(OpKernelConstruction* context) : OpKernel(context) {}
    884 
    885   void Compute(OpKernelContext* context) override {
    886     OP_REQUIRES_OK(context, IsMPIInitialized());
    887 
    888     // Write integer to output tensor
    889     Tensor* output;
    890     OP_REQUIRES_OK(context,
    891                    context->allocate_output(0, TensorShape({}), &output));
    892 
    893     auto flat = output->flat<int>();
    894     flat(0) = mpi_global.size;
    895   }
    896 };
    897 
    898 REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_CPU),
    899                         MPISizeOp<CPUDevice>);
    900 #if GOOGLE_CUDA
    901 REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_GPU).HostMemory("size"),
    902                         MPISizeOp<GPUDevice>);
    903 #endif
    904 
    905 // Op to get the current MPI Rank.
    906 template <typename Device>
    907 class MPIRankOp : public OpKernel {
    908  public:
    909   explicit MPIRankOp(OpKernelConstruction* context) : OpKernel(context) {}
    910 
    911   void Compute(OpKernelContext* context) override {
    912     OP_REQUIRES_OK(context, IsMPIInitialized());
    913 
    914     // Write integer to output tensor
    915     Tensor* output;
    916     OP_REQUIRES_OK(context,
    917                    context->allocate_output(0, TensorShape({}), &output));
    918 
    919     auto flat = output->flat<int>();
    920     flat(0) = mpi_global.rank;
    921   }
    922 };
    923 
    924 REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_CPU),
    925                         MPIRankOp<CPUDevice>);
    926 #if GOOGLE_CUDA
    927 REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_GPU).HostMemory("rank"),
    928                         MPIRankOp<GPUDevice>);
    929 #endif
    930 
    931 // Op to get the current local MPI Rank.
    932 template <typename Device>
    933 class MPILocalRankOp : public OpKernel {
    934  public:
    935   explicit MPILocalRankOp(OpKernelConstruction* context) : OpKernel(context) {}
    936 
    937   void Compute(OpKernelContext* context) override {
    938     OP_REQUIRES_OK(context, IsMPIInitialized());
    939 
    940     // Write integer to output tensor
    941     Tensor* output;
    942     OP_REQUIRES_OK(context,
    943                    context->allocate_output(0, TensorShape({}), &output));
    944 
    945     auto flat = output->flat<int>();
    946     flat(0) = mpi_global.local_rank;
    947   }
    948 };
    949 
    950 REGISTER_KERNEL_BUILDER(Name("MPILocalRank").Device(DEVICE_CPU),
    951                         MPILocalRankOp<CPUDevice>);
    952 #if GOOGLE_CUDA
    953 REGISTER_KERNEL_BUILDER(
    954     Name("MPILocalRank").Device(DEVICE_GPU).HostMemory("rank"),
    955     MPILocalRankOp<GPUDevice>);
    956 #endif
    957 
    958 template <typename Device>
    959 class MPIAllreduceOp : public AsyncOpKernel {
    960  public:
    961   explicit MPIAllreduceOp(OpKernelConstruction* context)
    962       : AsyncOpKernel(context) {}
    963 
    964   // Although this op is handled asynchronously, the ComputeAsync call is
    965   // very inexpensive. It only sets up a CollectiveOpRecord and places it
    966   // in the table for the background thread to handle. Thus, we do not need
    967   // a TF pool thread to perform the op.
    968   bool IsExpensive() override { return false; }
    969 
    970   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
    971     OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
    972     const Tensor* input_tensor = &context->input(0);
    973     Tensor* output_tensor;
    974     OP_REQUIRES_OK_ASYNC(
    975         context,
    976         context->allocate_output(0, input_tensor->shape(), &output_tensor),
    977         done);
    978 
    979     // Record allocated on stack so op can fail without memory leak
    980     CollectiveOpRecord record;
    981     record.name = name();
    982     record.context = context;
    983     record.in_t = input_tensor;
    984     record.out_t = output_tensor;
    985     record.on_gpu = IsGPUDevice<Device>();
    986     record.dtype = input_tensor->dtype();
    987 
    988     const size_t temp_size =
    989         (input_tensor->NumElements() + mpi_global.size - 1) / mpi_global.size;
    990     TensorShape temp_shape;
    991     temp_shape.AddDim(temp_size);
    992     OP_REQUIRES_OK_ASYNC(context,
    993                          context->allocate_temp(input_tensor->dtype(),
    994                                                 temp_shape, &record.temp_t),
    995                          done);
    996 
    997     auto allreduce_done_callback = [done, context](StatusOr<Tensor> status) {
    998       context->SetStatus(status.status());
    999       done();
   1000     };
   1001     record.callback = allreduce_done_callback;
   1002 
   1003     auto allreduce_launch_callback = [record] {
   1004       EnqueueTensorCollective(record, MPIRequest::ALLREDUCE);
   1005     };
   1006 
   1007     // If we are on a CPU, our device context will be null and we can't
   1008     // get a stream to enqueue this on. On a CPU this op is called when the
   1009     // data is already available, so we can just immediately do the
   1010     // allreduce; we don't have to wait for the data to get populated.
   1011 #if GOOGLE_CUDA
   1012     auto device_context = context->op_device_context();
   1013     if (device_context == nullptr) {
   1014       allreduce_launch_callback();
   1015     } else {
   1016       auto stream = device_context->stream();
   1017       stream->ThenDoHostCallback(allreduce_launch_callback);
   1018     }
   1019 #else
   1020     allreduce_launch_callback();
   1021 #endif
   1022   }
   1023 };
   1024 
   1025 REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_CPU),
   1026                         MPIAllreduceOp<CPUDevice>);
   1027 #if GOOGLE_CUDA
   1028 REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_GPU),
   1029                         MPIAllreduceOp<GPUDevice>);
   1030 #endif
   1031 
   1032 template <typename Device>
   1033 class MPIAllgatherOp : public AsyncOpKernel {
   1034  public:
   1035   explicit MPIAllgatherOp(OpKernelConstruction* context)
   1036       : AsyncOpKernel(context) {}
   1037 
   1038   // Although this op is handled asynchronously, the ComputeAsync call is
   1039   // very inexpensive. It only sets up a CollectiveOpRecord and places it
   1040   // in the table for the background thread to handle. Thus, we do not need
   1041   // a TF pool thread to perform the op.
   1042   bool IsExpensive() override { return false; }
   1043 
   1044   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
   1045     OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
   1046     const Tensor* input_tensor = &context->input(0);
   1047     const Tensor* sizing_tensor = &context->input(1);
   1048 
   1049     // Record allocated on stack so op can fail without memory leak
   1050     CollectiveOpRecord record;
   1051     record.name = name();
   1052     record.context = context;
   1053     record.in_t = input_tensor;
   1054     record.on_gpu = IsGPUDevice<Device>();
   1055 
   1056     // Construct the output size from the sizing tensor
   1057     size_t output_first_dim = 0;
   1058     if (sizing_tensor->shape().dims() == 0) {
   1059       // 0-dim sizing_tensor implies that the op is just gathering
   1060       // a single element from each rank
   1061       output_first_dim = mpi_global.size;
   1062       for (int i = 0; i < mpi_global.size; i++) {
   1063         record.sizes_vec.push_back(1);
   1064       }
   1065     } else {
   1066       // Collect the total output tensor sizing from the sizing tensor
   1067       // NOTE: The sizing tensor is forced to be placed on the CPU by
   1068       // declaring the input as HostMemory, so it is valid to read it here.
   1069       const int64* sizing_array =
   1070           (const int64*)sizing_tensor->tensor_data().data();
   1071       for (int i = 0; i < mpi_global.size; i++) {
   1072         record.sizes_vec.push_back(sizing_array[i]);
   1073         output_first_dim += sizing_array[i];
   1074       }
   1075     }
   1076 
   1077     TensorShape output_shape;
   1078     output_shape.AddDim(output_first_dim);
   1079     for (int i = 1; i < input_tensor->shape().dims(); i++) {
   1080       output_shape.AddDim(input_tensor->shape().dim_size(i));
   1081     }
   1082 
   1083     Tensor* output_tensor;
   1084     OP_REQUIRES_OK_ASYNC(
   1085         context, context->allocate_output(0, output_shape, &output_tensor),
   1086         done);
   1087 
   1088     record.out_t = output_tensor;
   1089     record.dtype = input_tensor->dtype();
   1090 
   1091     auto allgather_done_callback = [done, context](StatusOr<Tensor> status) {
   1092       context->SetStatus(status.status());
   1093       done();
   1094     };
   1095     record.callback = allgather_done_callback;
   1096 
   1097     auto allgather_launch_callback = [record] {
   1098       EnqueueTensorCollective(record, MPIRequest::ALLGATHER);
   1099     };
   1100 
   1101     // If we are on a CPU, our device context will be null and we can't
   1102     // get a stream to enqueue this on. On a CPU this op is called when the
   1103     // data is already available, so we can just immediately do the
   1104     // allgather; we don't have to wait for the data to get populated.
   1105 #if GOOGLE_CUDA
   1106     auto device_context = context->op_device_context();
   1107     if (device_context == nullptr) {
   1108       allgather_launch_callback();
   1109     } else {
   1110       auto stream = device_context->stream();
   1111       stream->ThenDoHostCallback(allgather_launch_callback);
   1112     }
   1113 #else
   1114     allgather_launch_callback();
   1115 #endif
   1116   }
   1117 };
   1118 
   1119 REGISTER_KERNEL_BUILDER(
   1120     Name("MPIAllgather").Device(DEVICE_CPU).HostMemory("sizes"),
   1121     MPIAllgatherOp<CPUDevice>);
   1122 #if GOOGLE_CUDA
   1123 REGISTER_KERNEL_BUILDER(
   1124     Name("MPIAllgather").Device(DEVICE_GPU).HostMemory("sizes"),
   1125     MPIAllgatherOp<GPUDevice>);
   1126 #endif
   1127 
   1128 }  // namespace mpi_collectives
   1129 }  // namespace contrib
   1130 }  // namespace tensorflow
   1131 
   1132 #endif  // TENSORFLOW_USE_MPI
   1133