Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CONTRIB_MPI_H_
     17 #define TENSORFLOW_CONTRIB_MPI_H_
     18 
     19 #ifdef TENSORFLOW_USE_MPI
     20 
     21 #include "tensorflow/core/framework/op.h"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/shape_inference.h"
     24 
     25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     26 #include "tensorflow/core/framework/tensor_types.h"
     27 
     28 #if GOOGLE_CUDA
     29 #include "cuda_runtime.h"
     30 #endif
     31 
     32 // Needed to avoid header issues with C++-supporting MPI implementations
     33 #define OMPI_SKIP_MPICXX
     34 #include "third_party/mpi/mpi.h"
     35 
     36 #define TAG_TENSOR 12
     37 
     38 namespace tensorflow {
     39 namespace contrib {
     40 namespace mpi_collectives {
     41 
     42 using CPUDevice = Eigen::ThreadPoolDevice;
     43 using GPUDevice = Eigen::GpuDevice;
     44 
     45 // Convert from templated types to values we can pass to MPI.
     46 template <typename T>
     47 MPI_Datatype MPIType();
     48 
     49 // Convert from templated types to TensorFlow data types.
     50 template <typename T>
     51 DataType TensorFlowDataType();
     52 
     53 #define MPI_REQUIRES_OK(MPI_STATUS)                               \
     54   if ((MPI_STATUS) != MPI_SUCCESS) {                              \
     55     return errors::Unknown("MPI operation failed unexpectedly."); \
     56   }
     57 
     58 // Copy data from one tensor to another tensor.
     59 // This uses a custom CUDA stream on GPU, which is necessary to overlay the
     60 // backpropagation computations with the allreduce.
     61 template <typename Device>
     62 void CopyTensorData(void* destination, void* source, size_t size);
     63 
     64 // Add a tensor into another tensor, accumulating in place.
     65 // This uses a custom CUDA stream on GPU, which is necessary to overlay the
     66 // backpropagation computations with the allreduce.
     67 template <typename Device, typename T>
     68 void AccumulateTensorData(T* destination, T* source, size_t size);
     69 
     70 // We need to get the right stream for doing CUDA memory transfers and
     71 // operations, which is possibly different from the standard TensorFlow stream.
     72 #if GOOGLE_CUDA
     73 cudaStream_t CudaStreamForMPI();
     74 #endif
     75 
     76 /* Perform a ring allreduce on the data. Allocate the necessary output tensor
     77  * and store it in the output parameter.
     78  *
     79  * Assumes that all MPI processes are doing an allreduce of the same tensor,
     80  * with the same dimensions.
     81  *
     82  * A ring allreduce is a bandwidth-optimal way to do an allreduce. To do the
     83  * allreduce, the nodes involved are arranged in a ring:
     84  *
     85  *                   .--0--.
     86  *                  /       \
     87  *                 3         1
     88  *                  \       /
     89  *                   *--2--*
     90  *
     91  *  Each node always sends to the next clockwise node in the ring, and receives
     92  *  from the previous one.
     93  *
     94  *  The allreduce is done in two parts: a scatter-reduce and an allgather. In
     95  *  the scatter reduce, a reduction is done, so that each node ends up with a
     96  *  chunk of the final output tensor which has contributions from all other
     97  *  nodes.  In the allgather, those chunks are distributed among all the nodes,
     98  *  so that all nodes have the entire output tensor.
     99  *
    100  *  Both of these operations are done by dividing the input tensor into N
    101  *  evenly sized chunks (where N is the number of nodes in the ring).
    102  *
    103  *  The scatter-reduce is done in N-1 steps. In the ith step, node j will send
    104  *  the (j - i)th chunk and receive the (j - i - 1)th chunk, adding it in to
    105  *  its existing data for that chunk. For example, in the first iteration with
    106  *  the ring depicted above, you will have the following transfers:
    107  *
    108  *      Segment 0:  Node 0 --> Node 1
    109  *      Segment 1:  Node 1 --> Node 2
    110  *      Segment 2:  Node 2 --> Node 3
    111  *      Segment 3:  Node 3 --> Node 0
    112  *
    113  *  In the second iteration, you'll have the following transfers:
    114  *
    115  *      Segment 0:  Node 1 --> Node 2
    116  *      Segment 1:  Node 2 --> Node 3
    117  *      Segment 2:  Node 3 --> Node 0
    118  *      Segment 3:  Node 0 --> Node 1
    119  *
    120  *  After this iteration, Node 2 has 3 of the four contributions to Segment 0.
    121  *  The last iteration has the following transfers:
    122  *
    123  *      Segment 0:  Node 2 --> Node 3
    124  *      Segment 1:  Node 3 --> Node 0
    125  *      Segment 2:  Node 0 --> Node 1
    126  *      Segment 3:  Node 1 --> Node 2
    127  *
    128  *  After this iteration, Node 3 has the fully accumulated Segment 0; Node 0
    129  *  has the fully accumulated Segment 1; and so on. The scatter-reduce is
    130  * complete.
    131  *
    132  *  Next, the allgather distributes these fully accumululated chunks across all
    133  * nodes. Communication proceeds in the same ring, once again in N-1 steps. At
    134  * the ith step, node j will send chunk (j - i + 1) and receive chunk (j - i).
    135  * For example, at the first iteration, the following transfers will occur:
    136  *
    137  *      Segment 0:  Node 3 --> Node 0
    138  *      Segment 1:  Node 0 --> Node 1
    139  *      Segment 2:  Node 1 --> Node 2
    140  *      Segment 3:  Node 2 --> Node 3
    141  *
    142  * After the first iteration, Node 0 will have a fully accumulated Segment 0
    143  * (from Node 3) and Segment 1. In the next iteration, Node 0 will send its
    144  * just-received Segment 0 onward to Node 1, and receive Segment 3 from Node 3.
    145  * After this has continued for N - 1 iterations, all nodes will have a the
    146  * fully accumulated tensor.
    147  *
    148  * Each node will do (N-1) sends for the scatter-reduce and (N-1) sends for the
    149  * allgather. Each send will contain K / N bytes, if there are K bytes in the
    150  * original tensor on every node. Thus, each node sends and receives 2K(N - 1)/N
    151  * bytes of data, and the performance of the allreduce (assuming no latency in
    152  * connections) is constrained by the slowest interconnect between the nodes.
    153  *
    154  */
    155 template <typename Device, typename T>
    156 Status RingAllreduce(OpKernelContext* context, const Tensor* input,
    157                      Tensor* temp, Tensor* output) {
    158   // Acquire MPI size and rank
    159   int n, r;
    160   MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n));
    161   MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r));
    162 
    163   T* buffer = (T*)output->tensor_data().data();
    164 
    165   CopyTensorData<Device>((void*)buffer, (void*)input->tensor_data().data(),
    166                          output->tensor_data().size());
    167 
    168   // Calculate segment sizes and segment ends
    169   const size_t elements_to_reduce = input->NumElements();
    170   const size_t segment_size = elements_to_reduce / n;
    171   std::vector<size_t> segment_sizes(n, segment_size);
    172 
    173   const size_t residual = elements_to_reduce % n;
    174   for (size_t i = 0; i < residual; ++i) {
    175     segment_sizes[i]++;
    176   }
    177 
    178   std::vector<size_t> segment_starts(n);
    179   segment_starts[0] = 0;
    180   for (size_t i = 1; i < segment_starts.size(); ++i) {
    181     segment_starts[i] = segment_starts[i - 1] + segment_sizes[i - 1];
    182   }
    183 
    184   assert(segment_starts[n - 1] + segment_sizes[n - 1] == elements_to_reduce);
    185 
    186   T* segment_recv = (T*)temp->tensor_data().data();
    187 
    188   // Receive from your left neighbor with wrap-around
    189   const size_t recv_from = ((r - 1) + n) % n;
    190 
    191   // Send to your right neighbor with wrap-around
    192   const size_t send_to = (r + 1) % n;
    193 
    194   MPI_Status recv_status;
    195   MPI_Request recv_req;
    196 
    197   // Now start ring. At every step, for every rank, we iterate through
    198   // segments with wraparound and send and recv from our neighbors and reduce
    199   // locally. At the i'th iteration, rank r, sends segment (r-i) and receives
    200   // segment (r-i-1).
    201   for (int i = 0; i < n - 1; i++) {
    202     const size_t send_seg_id = ((r - i) + n) % n;
    203     const size_t recv_seg_id = ((r - i - 1) + n) % n;
    204 
    205     T* segment_send = &(buffer[segment_starts[send_seg_id]]);
    206 
    207     MPI_REQUIRES_OK(MPI_Irecv(segment_recv, segment_sizes[recv_seg_id],
    208                               MPIType<T>(), recv_from, TAG_TENSOR,
    209                               MPI_COMM_WORLD, &recv_req));
    210 
    211     MPI_REQUIRES_OK(MPI_Send(segment_send, segment_sizes[send_seg_id],
    212                              MPIType<T>(), send_to, TAG_TENSOR,
    213                              MPI_COMM_WORLD));
    214 
    215     T* segment_update = &(buffer[segment_starts[recv_seg_id]]);
    216 
    217     // Wait for recv to complete before reduction
    218     MPI_REQUIRES_OK(MPI_Wait(&recv_req, &recv_status));
    219 
    220     const size_t recv_seg_size = segment_sizes[recv_seg_id];
    221     AccumulateTensorData<Device, T>(segment_update, segment_recv,
    222                                     recv_seg_size);
    223   }
    224 
    225   // Now start pipelined ring allgather. At every step, for every rank, we
    226   // iterate through segments with wraparound and send and recv from our
    227   // neighbors. At the i'th iteration, rank r, sends segment (r-i+1) and
    228   // receives segment (r-i).
    229   for (size_t i = 0; i < n - 1; ++i) {
    230     const size_t send_seg_id = ((r - i + 1) + n) % n;
    231     const size_t recv_seg_id = ((r - i) + n) % n;
    232 
    233     // Segment to send - at every iteration we send segment (r-i+1)
    234     T* segment_send = &(buffer[segment_starts[send_seg_id]]);
    235 
    236     // Segment to recv - at every iteration we receive segment (r-i)
    237     T* segment_recv = &(buffer[segment_starts[recv_seg_id]]);
    238 
    239     MPI_REQUIRES_OK(MPI_Sendrecv(
    240         segment_send, segment_sizes[send_seg_id], MPIType<T>(), send_to,
    241         TAG_TENSOR, segment_recv, segment_sizes[recv_seg_id], MPIType<T>(),
    242         recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status));
    243   }
    244 
    245   return Status::OK();
    246 }
    247 
    248 // Perform a ring allgather on a Tensor. Other ranks may allgather with a
    249 // tensor which differs in the first dimension only; all other dimensions must
    250 // be the same.
    251 //
    252 // For more information on the ring allgather, read the documentation for the
    253 // ring allreduce, which includes a ring allgather.
    254 template <typename Device, typename T>
    255 Status RingAllgather(OpKernelContext* context, const Tensor* input,
    256                      const std::vector<size_t>& sizes, Tensor* output) {
    257   // Acquire MPI size and rank
    258   int n, r;
    259   MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n));
    260   MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r));
    261 
    262   assert(sizes.size() == n);
    263   assert(input->dim_size(0) == sizes[r]);
    264 
    265   // Compute number of elements in every "row". We can't compute number of
    266   // elements in every chunks, because those chunks are variable length.
    267   size_t elements_per_row = 1;
    268   for (int i = 1; i < input->shape().dims(); i++) {
    269     elements_per_row *= input->dim_size(i);
    270   }
    271 
    272   // Copy data from input tensor to correct place in output tensor.
    273   std::vector<size_t> segment_starts(n);
    274   segment_starts[0] = 0;
    275   for (int i = 1; i < n; i++) {
    276     segment_starts[i] = segment_starts[i - 1] + elements_per_row * sizes[i - 1];
    277   }
    278   size_t offset = segment_starts[r];
    279 
    280   // Copy data to the right offset for this rank.
    281   T* buffer = (T*)output->tensor_data().data();
    282   CopyTensorData<Device>((void*)(buffer + offset),
    283                          (void*)input->tensor_data().data(),
    284                          elements_per_row * sizes[r] * sizeof(T));
    285 
    286   // Receive from your left neighbor with wrap-around
    287   const size_t recv_from = ((r - 1) + n) % n;
    288 
    289   // Send to your right neighbor with wrap-around
    290   const size_t send_to = (r + 1) % n;
    291 
    292   // Perform a ring allgather. At every step, for every rank, we iterate
    293   // through segments with wraparound and send and recv from our neighbors.
    294   // At the i'th iteration, rank r, sends segment (r-i) and receives segment
    295   // (r-1-i).
    296   MPI_Status recv_status;
    297   for (size_t i = 0; i < n - 1; ++i) {
    298     const size_t send_seg_id = ((r - i) + n) % n;
    299     const size_t recv_seg_id = ((r - i - 1) + n) % n;
    300 
    301     // Segment to send - at every iteration we send segment (r-i)
    302     size_t offset_send = segment_starts[send_seg_id];
    303     size_t rows_send = sizes[send_seg_id];
    304     T* segment_send = &(buffer[offset_send]);
    305 
    306     // Segment to recv - at every iteration we receive segment (r-1-i)
    307     size_t offset_recv = segment_starts[recv_seg_id];
    308     size_t rows_recv = sizes[recv_seg_id];
    309     T* segment_recv = &(buffer[offset_recv]);
    310 
    311     MPI_REQUIRES_OK(MPI_Sendrecv(
    312         segment_send, elements_per_row * rows_send, MPIType<T>(), send_to,
    313         TAG_TENSOR, segment_recv, elements_per_row * rows_recv, MPIType<T>(),
    314         recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status));
    315   }
    316 
    317   return Status::OK();
    318 }
    319 
    320 }  // namespace mpi_collectives
    321 }  // namespace contrib
    322 }  // namespace tensorflow
    323 
    324 #endif  // TENSORFLOW_USE_MPI
    325 
    326 #undef TENSORFLOW_CONTRIB_MPI_H_
    327 #endif  // TENSORFLOW_CONTRIB_MPI_H_
    328