1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CONTRIB_MPI_H_ 17 #define TENSORFLOW_CONTRIB_MPI_H_ 18 19 #ifdef TENSORFLOW_USE_MPI 20 21 #include "tensorflow/core/framework/op.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/shape_inference.h" 24 25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 26 #include "tensorflow/core/framework/tensor_types.h" 27 28 #if GOOGLE_CUDA 29 #include "cuda_runtime.h" 30 #endif 31 32 // Needed to avoid header issues with C++-supporting MPI implementations 33 #define OMPI_SKIP_MPICXX 34 #include "third_party/mpi/mpi.h" 35 36 #define TAG_TENSOR 12 37 38 namespace tensorflow { 39 namespace contrib { 40 namespace mpi_collectives { 41 42 using CPUDevice = Eigen::ThreadPoolDevice; 43 using GPUDevice = Eigen::GpuDevice; 44 45 // Convert from templated types to values we can pass to MPI. 46 template <typename T> 47 MPI_Datatype MPIType(); 48 49 // Convert from templated types to TensorFlow data types. 50 template <typename T> 51 DataType TensorFlowDataType(); 52 53 #define MPI_REQUIRES_OK(MPI_STATUS) \ 54 if ((MPI_STATUS) != MPI_SUCCESS) { \ 55 return errors::Unknown("MPI operation failed unexpectedly."); \ 56 } 57 58 // Copy data from one tensor to another tensor. 59 // This uses a custom CUDA stream on GPU, which is necessary to overlay the 60 // backpropagation computations with the allreduce. 61 template <typename Device> 62 void CopyTensorData(void* destination, void* source, size_t size); 63 64 // Add a tensor into another tensor, accumulating in place. 65 // This uses a custom CUDA stream on GPU, which is necessary to overlay the 66 // backpropagation computations with the allreduce. 67 template <typename Device, typename T> 68 void AccumulateTensorData(T* destination, T* source, size_t size); 69 70 // We need to get the right stream for doing CUDA memory transfers and 71 // operations, which is possibly different from the standard TensorFlow stream. 72 #if GOOGLE_CUDA 73 cudaStream_t CudaStreamForMPI(); 74 #endif 75 76 /* Perform a ring allreduce on the data. Allocate the necessary output tensor 77 * and store it in the output parameter. 78 * 79 * Assumes that all MPI processes are doing an allreduce of the same tensor, 80 * with the same dimensions. 81 * 82 * A ring allreduce is a bandwidth-optimal way to do an allreduce. To do the 83 * allreduce, the nodes involved are arranged in a ring: 84 * 85 * .--0--. 86 * / \ 87 * 3 1 88 * \ / 89 * *--2--* 90 * 91 * Each node always sends to the next clockwise node in the ring, and receives 92 * from the previous one. 93 * 94 * The allreduce is done in two parts: a scatter-reduce and an allgather. In 95 * the scatter reduce, a reduction is done, so that each node ends up with a 96 * chunk of the final output tensor which has contributions from all other 97 * nodes. In the allgather, those chunks are distributed among all the nodes, 98 * so that all nodes have the entire output tensor. 99 * 100 * Both of these operations are done by dividing the input tensor into N 101 * evenly sized chunks (where N is the number of nodes in the ring). 102 * 103 * The scatter-reduce is done in N-1 steps. In the ith step, node j will send 104 * the (j - i)th chunk and receive the (j - i - 1)th chunk, adding it in to 105 * its existing data for that chunk. For example, in the first iteration with 106 * the ring depicted above, you will have the following transfers: 107 * 108 * Segment 0: Node 0 --> Node 1 109 * Segment 1: Node 1 --> Node 2 110 * Segment 2: Node 2 --> Node 3 111 * Segment 3: Node 3 --> Node 0 112 * 113 * In the second iteration, you'll have the following transfers: 114 * 115 * Segment 0: Node 1 --> Node 2 116 * Segment 1: Node 2 --> Node 3 117 * Segment 2: Node 3 --> Node 0 118 * Segment 3: Node 0 --> Node 1 119 * 120 * After this iteration, Node 2 has 3 of the four contributions to Segment 0. 121 * The last iteration has the following transfers: 122 * 123 * Segment 0: Node 2 --> Node 3 124 * Segment 1: Node 3 --> Node 0 125 * Segment 2: Node 0 --> Node 1 126 * Segment 3: Node 1 --> Node 2 127 * 128 * After this iteration, Node 3 has the fully accumulated Segment 0; Node 0 129 * has the fully accumulated Segment 1; and so on. The scatter-reduce is 130 * complete. 131 * 132 * Next, the allgather distributes these fully accumululated chunks across all 133 * nodes. Communication proceeds in the same ring, once again in N-1 steps. At 134 * the ith step, node j will send chunk (j - i + 1) and receive chunk (j - i). 135 * For example, at the first iteration, the following transfers will occur: 136 * 137 * Segment 0: Node 3 --> Node 0 138 * Segment 1: Node 0 --> Node 1 139 * Segment 2: Node 1 --> Node 2 140 * Segment 3: Node 2 --> Node 3 141 * 142 * After the first iteration, Node 0 will have a fully accumulated Segment 0 143 * (from Node 3) and Segment 1. In the next iteration, Node 0 will send its 144 * just-received Segment 0 onward to Node 1, and receive Segment 3 from Node 3. 145 * After this has continued for N - 1 iterations, all nodes will have a the 146 * fully accumulated tensor. 147 * 148 * Each node will do (N-1) sends for the scatter-reduce and (N-1) sends for the 149 * allgather. Each send will contain K / N bytes, if there are K bytes in the 150 * original tensor on every node. Thus, each node sends and receives 2K(N - 1)/N 151 * bytes of data, and the performance of the allreduce (assuming no latency in 152 * connections) is constrained by the slowest interconnect between the nodes. 153 * 154 */ 155 template <typename Device, typename T> 156 Status RingAllreduce(OpKernelContext* context, const Tensor* input, 157 Tensor* temp, Tensor* output) { 158 // Acquire MPI size and rank 159 int n, r; 160 MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n)); 161 MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r)); 162 163 T* buffer = (T*)output->tensor_data().data(); 164 165 CopyTensorData<Device>((void*)buffer, (void*)input->tensor_data().data(), 166 output->tensor_data().size()); 167 168 // Calculate segment sizes and segment ends 169 const size_t elements_to_reduce = input->NumElements(); 170 const size_t segment_size = elements_to_reduce / n; 171 std::vector<size_t> segment_sizes(n, segment_size); 172 173 const size_t residual = elements_to_reduce % n; 174 for (size_t i = 0; i < residual; ++i) { 175 segment_sizes[i]++; 176 } 177 178 std::vector<size_t> segment_starts(n); 179 segment_starts[0] = 0; 180 for (size_t i = 1; i < segment_starts.size(); ++i) { 181 segment_starts[i] = segment_starts[i - 1] + segment_sizes[i - 1]; 182 } 183 184 assert(segment_starts[n - 1] + segment_sizes[n - 1] == elements_to_reduce); 185 186 T* segment_recv = (T*)temp->tensor_data().data(); 187 188 // Receive from your left neighbor with wrap-around 189 const size_t recv_from = ((r - 1) + n) % n; 190 191 // Send to your right neighbor with wrap-around 192 const size_t send_to = (r + 1) % n; 193 194 MPI_Status recv_status; 195 MPI_Request recv_req; 196 197 // Now start ring. At every step, for every rank, we iterate through 198 // segments with wraparound and send and recv from our neighbors and reduce 199 // locally. At the i'th iteration, rank r, sends segment (r-i) and receives 200 // segment (r-i-1). 201 for (int i = 0; i < n - 1; i++) { 202 const size_t send_seg_id = ((r - i) + n) % n; 203 const size_t recv_seg_id = ((r - i - 1) + n) % n; 204 205 T* segment_send = &(buffer[segment_starts[send_seg_id]]); 206 207 MPI_REQUIRES_OK(MPI_Irecv(segment_recv, segment_sizes[recv_seg_id], 208 MPIType<T>(), recv_from, TAG_TENSOR, 209 MPI_COMM_WORLD, &recv_req)); 210 211 MPI_REQUIRES_OK(MPI_Send(segment_send, segment_sizes[send_seg_id], 212 MPIType<T>(), send_to, TAG_TENSOR, 213 MPI_COMM_WORLD)); 214 215 T* segment_update = &(buffer[segment_starts[recv_seg_id]]); 216 217 // Wait for recv to complete before reduction 218 MPI_REQUIRES_OK(MPI_Wait(&recv_req, &recv_status)); 219 220 const size_t recv_seg_size = segment_sizes[recv_seg_id]; 221 AccumulateTensorData<Device, T>(segment_update, segment_recv, 222 recv_seg_size); 223 } 224 225 // Now start pipelined ring allgather. At every step, for every rank, we 226 // iterate through segments with wraparound and send and recv from our 227 // neighbors. At the i'th iteration, rank r, sends segment (r-i+1) and 228 // receives segment (r-i). 229 for (size_t i = 0; i < n - 1; ++i) { 230 const size_t send_seg_id = ((r - i + 1) + n) % n; 231 const size_t recv_seg_id = ((r - i) + n) % n; 232 233 // Segment to send - at every iteration we send segment (r-i+1) 234 T* segment_send = &(buffer[segment_starts[send_seg_id]]); 235 236 // Segment to recv - at every iteration we receive segment (r-i) 237 T* segment_recv = &(buffer[segment_starts[recv_seg_id]]); 238 239 MPI_REQUIRES_OK(MPI_Sendrecv( 240 segment_send, segment_sizes[send_seg_id], MPIType<T>(), send_to, 241 TAG_TENSOR, segment_recv, segment_sizes[recv_seg_id], MPIType<T>(), 242 recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status)); 243 } 244 245 return Status::OK(); 246 } 247 248 // Perform a ring allgather on a Tensor. Other ranks may allgather with a 249 // tensor which differs in the first dimension only; all other dimensions must 250 // be the same. 251 // 252 // For more information on the ring allgather, read the documentation for the 253 // ring allreduce, which includes a ring allgather. 254 template <typename Device, typename T> 255 Status RingAllgather(OpKernelContext* context, const Tensor* input, 256 const std::vector<size_t>& sizes, Tensor* output) { 257 // Acquire MPI size and rank 258 int n, r; 259 MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n)); 260 MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r)); 261 262 assert(sizes.size() == n); 263 assert(input->dim_size(0) == sizes[r]); 264 265 // Compute number of elements in every "row". We can't compute number of 266 // elements in every chunks, because those chunks are variable length. 267 size_t elements_per_row = 1; 268 for (int i = 1; i < input->shape().dims(); i++) { 269 elements_per_row *= input->dim_size(i); 270 } 271 272 // Copy data from input tensor to correct place in output tensor. 273 std::vector<size_t> segment_starts(n); 274 segment_starts[0] = 0; 275 for (int i = 1; i < n; i++) { 276 segment_starts[i] = segment_starts[i - 1] + elements_per_row * sizes[i - 1]; 277 } 278 size_t offset = segment_starts[r]; 279 280 // Copy data to the right offset for this rank. 281 T* buffer = (T*)output->tensor_data().data(); 282 CopyTensorData<Device>((void*)(buffer + offset), 283 (void*)input->tensor_data().data(), 284 elements_per_row * sizes[r] * sizeof(T)); 285 286 // Receive from your left neighbor with wrap-around 287 const size_t recv_from = ((r - 1) + n) % n; 288 289 // Send to your right neighbor with wrap-around 290 const size_t send_to = (r + 1) % n; 291 292 // Perform a ring allgather. At every step, for every rank, we iterate 293 // through segments with wraparound and send and recv from our neighbors. 294 // At the i'th iteration, rank r, sends segment (r-i) and receives segment 295 // (r-1-i). 296 MPI_Status recv_status; 297 for (size_t i = 0; i < n - 1; ++i) { 298 const size_t send_seg_id = ((r - i) + n) % n; 299 const size_t recv_seg_id = ((r - i - 1) + n) % n; 300 301 // Segment to send - at every iteration we send segment (r-i) 302 size_t offset_send = segment_starts[send_seg_id]; 303 size_t rows_send = sizes[send_seg_id]; 304 T* segment_send = &(buffer[offset_send]); 305 306 // Segment to recv - at every iteration we receive segment (r-1-i) 307 size_t offset_recv = segment_starts[recv_seg_id]; 308 size_t rows_recv = sizes[recv_seg_id]; 309 T* segment_recv = &(buffer[offset_recv]); 310 311 MPI_REQUIRES_OK(MPI_Sendrecv( 312 segment_send, elements_per_row * rows_send, MPIType<T>(), send_to, 313 TAG_TENSOR, segment_recv, elements_per_row * rows_recv, MPIType<T>(), 314 recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status)); 315 } 316 317 return Status::OK(); 318 } 319 320 } // namespace mpi_collectives 321 } // namespace contrib 322 } // namespace tensorflow 323 324 #endif // TENSORFLOW_USE_MPI 325 326 #undef TENSORFLOW_CONTRIB_MPI_H_ 327 #endif // TENSORFLOW_CONTRIB_MPI_H_ 328