1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifdef TENSORFLOW_USE_MPI 17 18 #include <queue> 19 #include <thread> 20 #include <unordered_map> 21 22 #include "tensorflow/core/framework/op.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/types.pb.h" 25 #include "tensorflow/core/platform/mutex.h" 26 27 #define EIGEN_USE_THREADS 28 29 #if GOOGLE_CUDA 30 #include <cuda_runtime.h> 31 #include "tensorflow/stream_executor/stream.h" 32 #endif 33 34 #include "tensorflow/stream_executor/lib/statusor.h" 35 36 #define OMPI_SKIP_MPICXX 37 #include "third_party/mpi/mpi.h" 38 #include "tensorflow/contrib/mpi_collectives/kernels/ring.h" 39 #include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h" 40 41 /* 42 * MPI Allreduce and Allgather Ops for TensorFlow. 43 * 44 * TensorFlow natively provides inter-device communication through send and 45 * receive ops and inter-node communication through Distributed TensorFlow, 46 * based on the same send and receive abstractions. These end up being 47 * insufficient for synchronous data-parallel training on HPC clusters where 48 * Infiniband or other high-speed interconnects are available. This module 49 * implements MPI ops for allgather and allreduce, which do bandwidth-optimal 50 * gathers and reductions and can take advantage of hardware-optimized 51 * communication libraries through the MPI implementation. 52 * 53 * The primary logic of the allreduce and allgather are in RingAllgather() and 54 * RingAllreduce(). The background thread which facilitates MPI operations is 55 * run in BackgroundThreadLoop(). The provided MPI ops are: 56 * MPIInit: 57 * Initialize MPI on a given device (CPU or GPU). 58 * Should only be run on a single device in every process. 59 * MPISize: 60 * Get the number of MPI processes in the global communicator. 61 * MPIRank: 62 * Get the rank of the current MPI process in the global communicator. 63 * MPILocalRank: 64 * Get the local rank of the current MPI process within its node. 65 * MPIAllreduce: 66 * Perform an allreduce on a Tensor, returning the sum 67 * across all MPI processes in the global communicator. 68 * MPIAllgather: 69 * Perform an allgather on a Tensor, returning the concatenation of 70 * the tensor on the first dimension across all MPI processes in the 71 * global communicator. 72 * 73 */ 74 75 template <class T> 76 using StatusOr = perftools::gputools::port::StatusOr<T>; 77 78 using CPUDevice = Eigen::ThreadPoolDevice; 79 using GPUDevice = Eigen::GpuDevice; 80 81 namespace tensorflow { 82 namespace contrib { 83 namespace mpi_collectives { 84 85 // Make sure template specializations are generated in the ring.cu.cc and the 86 // ring.cc file, not in this file. 87 extern template Status RingAllreduce<GPUDevice, int>(OpKernelContext*, 88 const Tensor*, Tensor*, 89 Tensor*); 90 extern template Status RingAllreduce<GPUDevice, long long>(OpKernelContext*, 91 const Tensor*, 92 Tensor*, Tensor*); 93 extern template Status RingAllreduce<GPUDevice, float>(OpKernelContext*, 94 const Tensor*, Tensor*, 95 Tensor*); 96 extern template Status RingAllgather<GPUDevice, int>(OpKernelContext*, 97 const Tensor*, 98 const std::vector<size_t>&, 99 Tensor*); 100 extern template Status RingAllgather<GPUDevice, long long>( 101 OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*); 102 extern template Status RingAllgather<GPUDevice, float>( 103 OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*); 104 extern template Status RingAllreduce<CPUDevice, int>(OpKernelContext*, 105 const Tensor*, Tensor*, 106 Tensor*); 107 extern template Status RingAllreduce<CPUDevice, long long>(OpKernelContext*, 108 const Tensor*, 109 Tensor*, Tensor*); 110 extern template Status RingAllreduce<CPUDevice, float>(OpKernelContext*, 111 const Tensor*, Tensor*, 112 Tensor*); 113 extern template Status RingAllgather<CPUDevice, int>(OpKernelContext*, 114 const Tensor*, 115 const std::vector<size_t>&, 116 Tensor*); 117 extern template Status RingAllgather<CPUDevice, long long>( 118 OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*); 119 extern template Status RingAllgather<CPUDevice, float>( 120 OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*); 121 122 namespace { 123 124 // Return true if the templated type is GPUDevice, otherwise false. 125 template <typename T> 126 bool IsGPUDevice(); 127 template <> 128 bool IsGPUDevice<GPUDevice>() { 129 return true; 130 }; 131 template <> 132 bool IsGPUDevice<CPUDevice>() { 133 return false; 134 }; 135 136 // A callback to call after the MPI communication completes. Since the 137 // allreduce and allgather ops are asynchronous, this callback is what resumes 138 // computation after the reduction is completed. 139 typedef std::function<void(StatusOr<Tensor>)> CommunicationDoneCallback; 140 141 struct CollectiveOpRecord { 142 // The rank performing this piece of the op 143 int rank; 144 145 // The name of the op/tensor to be reduced 146 std::string name; 147 148 // The op's kernel context 149 OpKernelContext* context; 150 151 // Data type of the op 152 DataType dtype; 153 154 // The input tensor 155 const Tensor* in_t; 156 157 // Allgather: Vector of per-rank first-dimension sizes 158 std::vector<size_t> sizes_vec; 159 160 // The temp tensor for intermediate results 161 Tensor temp_t; 162 163 // The output tensor 164 Tensor* out_t; 165 166 // Whether to run this op on the gpu 167 bool on_gpu; 168 169 // The callback to call after the op has completed 170 CommunicationDoneCallback callback; 171 }; 172 173 // Table storing Tensors to be reduced, keyed by unique name. 174 // This table contains everything necessary to do the reduction 175 typedef std::unordered_map<std::string, CollectiveOpRecord> TensorTable; 176 177 // Table for storing Tensor metadata on rank zero. This is used for error 178 // checking and size calculations, as well as determining when a reduction is 179 // ready to be done (when all nodes are ready to do it). 180 typedef std::unordered_map<std::string, std::vector<MPIRequest> > MessageTable; 181 182 // The global state required for the MPI ops. 183 // 184 // MPI is a library that stores a lot of global per-program state and often 185 // requires running on a single thread. As a result, we have to have a single 186 // background thread responsible for all MPI operations, and communicate with 187 // that background thread through global state. 188 struct MPIGlobalState { 189 // An atomic boolean which is set to true when MPI is initialized. 190 // This ensures that MPI_Init is never called twice. 191 std::atomic_flag initialized_flag = ATOMIC_FLAG_INIT; 192 193 // Condition variable to wait for initialization 194 condition_variable cv; 195 196 // Whether MPI_Init has been completed on the background thread. 197 bool initialization_done = false; 198 199 // Whether MPI_Init succeeded on the background thread. 200 Status init_status; 201 202 // A mutex that needs to be used whenever MPI operations touch 203 // shared structures. 204 mutex mu; 205 206 // Tensors waiting to be allreduced or allgathered. 207 TensorTable tensor_table; 208 209 // Queue of MPI requests waiting to be sent to the coordinator node. 210 std::queue<MPIRequest> message_queue; 211 212 // Background thread running MPI communication. 213 std::thread background_thread; 214 215 // Whether the background thread should shutdown. 216 bool shut_down = false; 217 218 // Only exists on the coordinator node (rank zero). Maintains a count of 219 // how many nodes are ready to allreduce every tensor (keyed by tensor 220 // name). 221 std::unique_ptr<MessageTable> message_table; 222 223 // The MPI rank, local rank, and size. 224 int rank = 0; 225 int local_rank = 0; 226 int size = 1; 227 228 // The device that MPI was initialized on. (-1 for no GPU) 229 int device = -1; 230 231 // The CUDA stream used for data transfers and within-allreduce operations. 232 // A naive implementation would use the TensorFlow StreamExecutor CUDA 233 // stream. However, the allreduce and allgather require doing memory copies 234 // and kernel executions (for accumulation of values on the GPU). However, 235 // the subsequent operations must wait for those operations to complete, 236 // otherwise MPI (which uses its own stream internally) will begin the data 237 // transfers before the CUDA calls are complete. In order to wait for those 238 // CUDA operations, if we were using the TensorFlow stream, we would have 239 // to synchronize that stream; however, other TensorFlow threads may be 240 // submitting more work to that stream, so synchronizing on it can cause 241 // the allreduce to be delayed, waiting for compute totally unrelated to it 242 // in other parts of the graph. Overlaying memory transfers and compute 243 // during backpropagation is crucial for good performance, so we cannot use 244 // the TensorFlow stream, and must use our own stream. 245 #if GOOGLE_CUDA 246 cudaStream_t stream; 247 std::atomic_flag stream_created_flag = ATOMIC_FLAG_INIT; 248 #endif 249 250 ~MPIGlobalState() { 251 // Make sure that the destructor of the background thread is safe to 252 // call. If a thread is still joinable (not detached or complete) its 253 // destructor cannot be called. 254 if (background_thread.joinable()) { 255 shut_down = true; 256 background_thread.join(); 257 } 258 } 259 }; 260 261 // All the MPI state that must be stored globally per-process. 262 static MPIGlobalState mpi_global; 263 264 // For clarify in argument lists. 265 #define RANK_ZERO 0 266 267 // A tag used for all coordinator messaging. 268 #define TAG_NOTIFY 1 269 270 // Store the MPIRequest for a name, and return whether the total count of 271 // MPIRequests for that tensor is now equal to the MPI size (and thus we are 272 // ready to reduce the tensor). 273 bool IncrementTensorCount(std::unique_ptr<MessageTable>& message_table, 274 MPIRequest msg, int mpi_size) { 275 auto name = msg.tensor_name(); 276 auto table_iter = message_table->find(name); 277 if (table_iter == message_table->end()) { 278 message_table->emplace(name, std::vector<MPIRequest>({msg})); 279 table_iter = message_table->find(name); 280 } else { 281 table_iter->second.push_back(msg); 282 } 283 284 int count = table_iter->second.size(); 285 return count == mpi_size; 286 } 287 288 // Once a tensor is ready to be reduced, the coordinator sends an MPIResponse 289 // instructing all ranks to start the reduction to all ranks. The MPIResponse 290 // also contains error messages in case the submitted MPIRequests were not 291 // valid (for example, contained mismatched shapes or types). 292 // 293 // Constructing the MPIResponse, thus, requires a whole lot of error checking. 294 MPIResponse ConstructMPIResponse(std::unique_ptr<MessageTable>& message_table, 295 std::string name) { 296 bool error = false; 297 auto it = message_table->find(name); 298 assert(it != message_table->end()); 299 300 std::vector<MPIRequest> requests = it->second; 301 assert(requests.size() > 0); 302 303 std::ostringstream error_message_stream; 304 305 // Check that all data types being reduced or gathered are identical 306 auto data_type = requests[0].tensor_type(); 307 for (unsigned int i = 1; i < requests.size(); i++) { 308 auto request_type = requests[i].tensor_type(); 309 if (data_type != request_type) { 310 error = true; 311 error_message_stream << "Mismatched data types: One rank had type " 312 << DataType_Name(data_type) 313 << ", but another rank had type " 314 << DataType_Name(request_type) << "."; 315 break; 316 } 317 } 318 319 // Check that all requested operations are the same 320 auto message_type = requests[0].request_type(); 321 for (unsigned int i = 1; i < requests.size(); i++) { 322 if (error) { 323 break; 324 } 325 326 auto request_type = requests[i].request_type(); 327 if (message_type != request_type) { 328 error = true; 329 error_message_stream << "Mismatched MPI operations: One rank did an " 330 << message_type << ", but another rank did an " 331 << request_type << "."; 332 break; 333 } 334 } 335 336 // If we are doing an allreduce, check that all tensor shapes 337 // are identical 338 if (message_type == MPIRequest::ALLREDUCE) { 339 TensorShape tensor_shape = requests[0].tensor_shape(); 340 for (unsigned int i = 1; i < requests.size(); i++) { 341 if (error) { 342 break; 343 } 344 345 TensorShape request_shape = requests[i].tensor_shape(); 346 if (tensor_shape != request_shape) { 347 error = true; 348 error_message_stream << "Mismatched allreduce tensor shapes: " 349 << "One rank reduced a tensor of shape " 350 << tensor_shape.DebugString() 351 << ", but another rank sent a tensor of shape " 352 << request_shape.DebugString() << "."; 353 break; 354 } 355 } 356 } 357 358 // If we are doing an allgather, make sure all but the first dimension are 359 // the same. The first dimension may be different and the output tensor is 360 // the sum of the first dimension. Collect the sizes by rank. 361 if (message_type == MPIRequest::ALLGATHER) { 362 TensorShape tensor_shape = requests[0].tensor_shape(); 363 364 if (tensor_shape.dims() == 0) { 365 error = true; 366 error_message_stream << "Rank zero tried to gather a rank-zero tensor."; 367 } 368 369 for (unsigned int i = 1; i < requests.size(); i++) { 370 if (error) { 371 break; 372 } 373 374 TensorShape request_shape = requests[i].tensor_shape(); 375 if (tensor_shape.dims() != request_shape.dims()) { 376 error = true; 377 error_message_stream << "Mismatched allgather tensor shapes: " 378 << "One rank gathered a tensor of rank " 379 << tensor_shape.dims() 380 << ", but another rank sent a tensor of rank " 381 << request_shape.dims() << "."; 382 break; 383 } 384 385 for (unsigned int dim = 1; dim < tensor_shape.dims(); dim++) { 386 if (tensor_shape.dim_size(dim) != request_shape.dim_size(dim)) { 387 error = true; 388 error_message_stream 389 << "Mismatched allgather tensor shapes: " 390 << "One rank gathered a tensor with dimension " << dim 391 << " equal to " << tensor_shape.dim_size(dim) 392 << ", but another rank sent a tensor with dimension " << dim 393 << " equal to " << request_shape.dim_size(dim) << "."; 394 break; 395 } 396 } 397 } 398 } 399 400 MPIResponse response; 401 response.set_tensor_name(name); 402 if (error) { 403 std::string error_message = error_message_stream.str(); 404 response.set_response_type(MPIResponse::ERROR); 405 response.set_error_message(error_message); 406 } else { 407 auto response_type = MPIResponse::ERROR; 408 if (message_type == MPIRequest::ALLREDUCE) { 409 response_type = MPIResponse::ALLREDUCE; 410 } else { 411 response_type = MPIResponse::ALLGATHER; 412 } 413 response.set_response_type(response_type); 414 } 415 416 // Clear all queued up requests for this name. They are now taken care of 417 // by the constructed MPI response. 418 message_table->erase(it); 419 420 return response; 421 } 422 423 // Process an MPIResponse by doing a reduction, a gather, or raising an error. 424 void PerformCollectiveOp(TensorTable& tensor_table, MPIResponse response) { 425 OpKernelContext* context; 426 const Tensor* input_tensor; 427 std::vector<size_t> sizes_vec; 428 Tensor temp_tensor; 429 Tensor* output_tensor; 430 CommunicationDoneCallback callback; 431 bool on_gpu; 432 { 433 // Lock on the tensor table. 434 mutex_lock guard(mpi_global.mu); 435 436 // We should never fail at finding this key in the tensor table. 437 auto name = response.tensor_name(); 438 auto iter = tensor_table.find(name); 439 assert(iter != tensor_table.end()); 440 441 assert(response.response_type() == MPIResponse::ALLREDUCE || 442 response.response_type() == MPIResponse::ALLGATHER || 443 response.response_type() == MPIResponse::ERROR); 444 445 CollectiveOpRecord record = iter->second; 446 context = record.context; 447 input_tensor = record.in_t; 448 sizes_vec = record.sizes_vec; 449 temp_tensor = record.temp_t; 450 output_tensor = record.out_t; 451 on_gpu = record.on_gpu; 452 callback = record.callback; 453 454 // Clear the tensor table of this tensor and its callbacks; the rest of 455 // this function takes care of it. 456 tensor_table.erase(iter); 457 } 458 459 // Use CPUDevice instead of GPUDevice if no CUDA, to ensure we don't 460 // link to non-existent symbols. 461 #if GOOGLE_CUDA 462 #define GPU_DEVICE_IF_CUDA GPUDevice 463 #else 464 #define GPU_DEVICE_IF_CUDA CPUDevice 465 #endif 466 467 Status status; 468 auto dtype = input_tensor->dtype(); 469 if (response.response_type() == MPIResponse::ALLGATHER) { 470 if (dtype == DT_FLOAT) { 471 status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, float>( 472 context, input_tensor, sizes_vec, output_tensor) 473 : RingAllgather<CPUDevice, float>( 474 context, input_tensor, sizes_vec, output_tensor); 475 } else if (dtype == DT_INT32) { 476 status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, int>( 477 context, input_tensor, sizes_vec, output_tensor) 478 : RingAllgather<CPUDevice, int>(context, input_tensor, 479 sizes_vec, output_tensor); 480 } else if (dtype == DT_INT64) { 481 status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, long long>( 482 context, input_tensor, sizes_vec, output_tensor) 483 : RingAllgather<CPUDevice, long long>( 484 context, input_tensor, sizes_vec, output_tensor); 485 } else { 486 status = errors::Unknown("Invalid tensor type for MPI allgather."); 487 } 488 } else if (response.response_type() == MPIResponse::ALLREDUCE) { 489 if (dtype == DT_FLOAT) { 490 status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, float>( 491 context, input_tensor, &temp_tensor, output_tensor) 492 : RingAllreduce<CPUDevice, float>( 493 context, input_tensor, &temp_tensor, output_tensor); 494 } else if (dtype == DT_INT32) { 495 status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, int>( 496 context, input_tensor, &temp_tensor, output_tensor) 497 : RingAllreduce<CPUDevice, int>( 498 context, input_tensor, &temp_tensor, output_tensor); 499 } else if (dtype == DT_INT64) { 500 status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, long long>( 501 context, input_tensor, &temp_tensor, output_tensor) 502 : RingAllreduce<CPUDevice, long long>( 503 context, input_tensor, &temp_tensor, output_tensor); 504 } else { 505 status = errors::Unknown("Invalid tensor type for MPI allreduce."); 506 } 507 } else if (response.response_type() == MPIResponse::ERROR) { 508 status = errors::FailedPrecondition(response.error_message()); 509 } 510 511 if (status.ok()) { 512 callback(StatusOr<Tensor>(*output_tensor)); 513 } else { 514 callback(StatusOr<Tensor>(status)); 515 } 516 } 517 518 // The MPI background thread loop coordinates all the MPI processes and the 519 // tensor reductions. The design of the communicator mechanism is limited by a 520 // few considerations: 521 // 522 // 1. Some MPI implementations require all MPI calls to happen from a 523 // single thread. Since TensorFlow may use several threads for graph 524 // processing, this means we must have our own dedicated thread for 525 // dealing with MPI. 526 // 2. We want to gracefully handle errors, when MPI processes do not 527 // properly agree upon what should happen (such as mismatched types or 528 // shapes). To do so requires the MPI processes to know about the shapes 529 // and types of the relevant tensors on the other processes. 530 // 3. The MPI reductions and gathers should be able to happen in parallel 531 // with other ongoing operations. Since MPI uses an internal 532 // (inaccessible) GPU stream separate from the TF GPUDevice streams, we 533 // cannot explicitly synchronize memcpys or kernels with it. As a result, 534 // MPIAllreduce and MPIAllgather must be AsyncOpKernels to ensure proper 535 // ordering of memcpys and kernels with respect to TF streams. 536 // 4. NOTE: We cannot guarantee that all the MPI processes reduce their 537 // tensors in the same order. Thus, there must be a way to ensure the 538 // reduction memcpys and kernels occur for correct tensors across all 539 // ranks at the same time. We choose to use a coordinator (rank ID 0) to 540 // gather and trigger the reduction operations that are ready to execute. 541 // 542 // The coordinator currently follows a master-worker paradigm. Rank zero acts 543 // as the master (the "coordinator"), whereas all other ranks are simply 544 // workers. Each rank runs its own background thread which progresses in ticks. 545 // In each tick, the following actions happen: 546 // 547 // a) The workers send any available MPIRequests to the coordinator. These 548 // MPIRequests indicate what the worker would like to do (i.e. which 549 // tensor they would like to gather or reduce, as well as their shape and 550 // type). They repeat this for every tensor that they would like to 551 // operate on after that tensor's collective op has executed ComputeAsync. 552 // 553 // b) The workers send an empty "DONE" message to the coordinator to 554 // indicate that there are no more tensors they wish to operate on. 555 // 556 // c) The coordinator receives the MPIRequests from the workers, as well 557 // as from its own TensorFlow ops, and stores them in a request table. The 558 // coordinator continues to receive MPIRequest messages until it has 559 // received MPI_SIZE number of empty "DONE" messages. 560 // 561 // d) The coordinator finds all tensors that are ready to be reduced, 562 // gathered, or all operations that result in an error. For each of those, 563 // it sends an MPIResponse to all the workers. When no more MPIResponses 564 // are available, it sends a "DONE" response to the workers. If the 565 // process is being shutdown, it instead sends a "SHUTDOWN" response. 566 // 567 // e) The workers listen for MPIResponse messages, processing each one by 568 // doing the required reduce or gather, until they receive a "DONE" 569 // response from the coordinator. At that point, the tick ends. 570 // If instead of "DONE" they receive "SHUTDOWN", they exit their 571 // background loop. 572 // TODO: Use the global mpi_global state variable instead of a local one 573 void BackgroundThreadLoop() { 574 #if GOOGLE_CUDA 575 // Set the device, so that this thread uses the same GPU context as the 576 // calling thread. 577 // TODO: Ensure that this is operating correctly. The background thread 578 // needs to be able to control all GPUs that the rank has access to, and 579 // might be more than 1 GPU. Tensors could be resident in any of the 580 // GPUs, so the background thread's accumulate and copy kernels might need 581 // to correctly set the device and it might be necessary for the background 582 // thread to manage multiple streams. 583 cudaSetDevice(mpi_global.device); 584 cudaStreamCreate(&mpi_global.stream); 585 #endif 586 587 // Initialize MPI. This must happen on the background thread, since not all 588 // MPI implementations support being called from multiple threads. 589 auto init_result = MPI_Init(NULL, NULL); 590 if (init_result != MPI_SUCCESS) { 591 mpi_global.init_status = 592 errors::Unknown("Could not initialize MPI; MPI_Init() failed."); 593 mpi_global.initialization_done = true; 594 mpi_global.cv.notify_all(); 595 return; 596 } else { 597 mpi_global.init_status = Status::OK(); 598 } 599 600 // Get MPI rank to determine if we are rank zero. 601 int rank; 602 MPI_Comm_rank(MPI_COMM_WORLD, &rank); 603 bool is_coordinator = rank == 0; 604 605 // Get MPI size to determine how many tensors to wait for before reducing. 606 int size; 607 MPI_Comm_size(MPI_COMM_WORLD, &size); 608 609 // Determine local rank by querying the local communicator. 610 MPI_Comm local_comm; 611 MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, 612 &local_comm); 613 int local_rank; 614 MPI_Comm_rank(local_comm, &local_rank); 615 616 mpi_global.rank = rank; 617 mpi_global.local_rank = local_rank; 618 mpi_global.size = size; 619 mpi_global.initialization_done = true; 620 621 // Notify calling thread that initialization is complete 622 mpi_global.cv.notify_all(); 623 624 // TODO: MOVE MESSAGE TABLE INITIALIZATION TO LIBRARY LOAD! 625 // Initialize the tensor count table. No tensors are available yet. 626 if (is_coordinator) { 627 mpi_global.message_table = 628 std::unique_ptr<MessageTable>(new MessageTable()); 629 } 630 631 // The coordinator sends a SHUTDOWN message to trigger shutdown. 632 bool should_shut_down = false; 633 do { 634 // TODO: Eliminate the need for thread sleep by making all activity 635 // depend on other activity (e.g. condition or MPI waits). 636 std::this_thread::sleep_for(std::chrono::milliseconds(1)); 637 638 // Copy the data structures from global state under this lock. 639 // However, don't keep the lock for the rest of the loop, so that 640 // enqueued stream callbacks can continue. 641 std::queue<MPIRequest> message_queue; 642 { 643 mutex_lock guard(mpi_global.mu); 644 while (!mpi_global.message_queue.empty()) { 645 MPIRequest message = mpi_global.message_queue.front(); 646 mpi_global.message_queue.pop(); 647 message_queue.push(message); 648 } 649 } 650 651 // Collect all tensors that are ready to be reduced. Record them in the 652 // tensor count table (rank zero) or send them to rank zero to be 653 // recorded (everyone else). 654 std::vector<std::string> ready_to_reduce; 655 while (!message_queue.empty()) { 656 // Pop the first available message message 657 MPIRequest message = message_queue.front(); 658 message_queue.pop(); 659 660 if (is_coordinator) { 661 bool reduce = 662 IncrementTensorCount(mpi_global.message_table, message, size); 663 if (reduce) { 664 ready_to_reduce.push_back(message.tensor_name()); 665 } 666 } else { 667 std::string encoded_message; 668 message.SerializeToString(&encoded_message); 669 MPI_Send(encoded_message.c_str(), encoded_message.length() + 1, 670 MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD); 671 } 672 } 673 674 // Rank zero has put all its own tensors in the tensor count table. 675 // Now, it should count all the tensors that are coming from other 676 // ranks at this tick. It should keep getting tensors until it gets a 677 // DONE message from all the other ranks. 678 if (is_coordinator) { 679 // Count of DONE messages. Keep receiving messages until the number 680 // of messages is equal to the number of processes. Initialize to 681 // one since the coordinator is effectively done. 682 int completed_ranks = 1; 683 while (completed_ranks != size) { 684 MPI_Status status; 685 MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status); 686 687 // Find number of characters in message (including zero byte). 688 int source_rank = status.MPI_SOURCE; 689 int msg_length; 690 MPI_Get_count(&status, MPI_BYTE, &msg_length); 691 692 // If the length is zero, this is a DONE message. 693 if (msg_length == 0) { 694 completed_ranks++; 695 MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY, MPI_COMM_WORLD, 696 &status); 697 continue; 698 } 699 700 // Get tensor name from MPI into an std::string. 701 char* buffer = new char[msg_length]; 702 MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank, TAG_NOTIFY, 703 MPI_COMM_WORLD, &status); 704 std::string received_data(buffer); 705 delete[] buffer; 706 707 MPIRequest received_message; 708 received_message.ParseFromString(received_data); 709 auto received_name = received_message.tensor_name(); 710 711 bool reduce = IncrementTensorCount(mpi_global.message_table, 712 received_message, size); 713 if (reduce) { 714 ready_to_reduce.push_back(received_name); 715 } 716 } 717 718 // At this point, rank zero should have a fully updated tensor 719 // count table and should know all the tensors that need to be 720 // reduced or gathered, and everyone else should have sent all 721 // their information to rank zero. We can now do reductions and 722 // gathers; rank zero will choose which ones and in what order, 723 // and will notify the other ranks before doing each reduction. 724 for (int i = 0; i < ready_to_reduce.size(); i++) { 725 // Notify all nodes which tensor we'd like to reduce now 726 auto name = ready_to_reduce[i]; 727 MPIResponse response = 728 ConstructMPIResponse(mpi_global.message_table, name); 729 730 std::string encoded_response; 731 response.SerializeToString(&encoded_response); 732 for (int r = 1; r < size; r++) { 733 MPI_Send(encoded_response.c_str(), encoded_response.length() + 1, 734 MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD); 735 } 736 737 // Perform the reduction. All nodes should end up performing 738 // the same reduction. 739 PerformCollectiveOp(mpi_global.tensor_table, response); 740 } 741 742 // Notify all nodes that we are done with the reductions for this 743 // tick. 744 MPIResponse done_response; 745 should_shut_down = mpi_global.shut_down; 746 done_response.set_response_type( 747 mpi_global.shut_down ? MPIResponse::SHUTDOWN : MPIResponse::DONE); 748 std::string encoded_response; 749 done_response.SerializeToString(&encoded_response); 750 for (int r = 1; r < size; r++) { 751 MPI_Send(encoded_response.c_str(), encoded_response.length() + 1, 752 MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD); 753 } 754 } else { 755 // Notify the coordinator that this node is done sending messages. 756 // A DONE message is encoded as a zero-length message. 757 MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD); 758 759 // Receive names for tensors to reduce from rank zero. Once we 760 // receive a empty DONE message, stop waiting for more names. 761 while (true) { 762 MPI_Status status; 763 MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status); 764 765 // Find number of characters in message (including zero byte). 766 int msg_length; 767 MPI_Get_count(&status, MPI_BYTE, &msg_length); 768 769 // Get tensor name from MPI into an std::string. 770 char* buffer = new char[msg_length]; 771 MPI_Recv(buffer, msg_length, MPI_BYTE, 0, TAG_NOTIFY, MPI_COMM_WORLD, 772 &status); 773 std::string received_message(buffer); 774 delete[] buffer; 775 776 MPIResponse response; 777 response.ParseFromString(received_message); 778 if (response.response_type() == MPIResponse::DONE) { 779 // No more messages this tick 780 break; 781 } else if (response.response_type() == MPIResponse::SHUTDOWN) { 782 // No more messages this tick, and the background thread 783 // should shut down 784 should_shut_down = true; 785 break; 786 } else { 787 // Process the current message 788 PerformCollectiveOp(mpi_global.tensor_table, response); 789 } 790 } 791 } 792 } while (!should_shut_down); 793 794 MPI_Finalize(); 795 } 796 797 // Initialize MPI and start the MPI background thread. Ensure that this is 798 // only done once no matter how many times this function is called. 799 Status InitializeMPIOnce(bool gpu) { 800 // Ensure MPI is only initialized once. 801 if (mpi_global.initialized_flag.test_and_set()) return mpi_global.init_status; 802 803 mpi_global.device = -1; 804 #if GOOGLE_CUDA 805 if (gpu) { 806 cudaGetDevice(&mpi_global.device); 807 } 808 #endif 809 810 // Start the MPI background thread, which assumes MPI is initialized 811 // TODO: Change this to a Tensorflow thread 812 mpi_global.background_thread = std::thread(BackgroundThreadLoop); 813 814 // Wait to ensure that the background thread has finished initializing MPI 815 mutex_lock guard(mpi_global.mu); 816 mpi_global.cv.wait(guard); 817 if (!mpi_global.initialization_done) { 818 mpi_global.init_status = 819 errors::Unknown("Failed to wait for MPI initialization."); 820 } 821 822 return mpi_global.init_status; 823 } 824 825 // Check that MPI is initialized. 826 Status IsMPIInitialized() { 827 if (!mpi_global.initialization_done) { 828 return errors::FailedPrecondition( 829 "MPI has not been initialized; use tf.contrib.mpi.Session."); 830 } 831 return Status::OK(); 832 } 833 834 // This function (called from the callback set up in MPIAll*Op::ComputeAsync) 835 // only adds the op's record into the local op queue (to track the op's 836 // progress), and sends a message to the coordinator indicating that this rank 837 // is ready to begin. The MPI background thread will handle the MPI message. 838 void EnqueueTensorCollective(CollectiveOpRecord record, 839 MPIRequest::RequestType rtype) { 840 const Tensor* input_tensor = record.in_t; 841 MPIRequest message; 842 message.set_request_rank(record.rank); 843 message.set_tensor_name(record.name); 844 message.set_tensor_type(record.dtype); 845 message.set_request_type(rtype); 846 input_tensor->shape().AsProto(message.mutable_tensor_shape()); 847 848 mutex_lock guard(mpi_global.mu); 849 mpi_global.tensor_table.emplace(record.name, record); 850 mpi_global.message_queue.push(message); 851 } 852 853 } // namespace 854 855 #if GOOGLE_CUDA 856 cudaStream_t CudaStreamForMPI() { return mpi_global.stream; } 857 #endif 858 859 // Op to initialize MPI in the current process. The settings used in the 860 // configuration are the same that must be used for all future MPI ops. 861 template <typename Device> 862 class MPIInitOp : public OpKernel { 863 public: 864 explicit MPIInitOp(OpKernelConstruction* context) : OpKernel(context) {} 865 866 void Compute(OpKernelContext* context) override { 867 bool on_gpu = IsGPUDevice<Device>(); 868 OP_REQUIRES_OK(context, InitializeMPIOnce(on_gpu)); 869 } 870 }; 871 872 REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_CPU), 873 MPIInitOp<CPUDevice>); 874 #if GOOGLE_CUDA 875 REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_GPU), 876 MPIInitOp<GPUDevice>); 877 #endif 878 879 // Op to get the current MPI Size. 880 template <typename Device> 881 class MPISizeOp : public OpKernel { 882 public: 883 explicit MPISizeOp(OpKernelConstruction* context) : OpKernel(context) {} 884 885 void Compute(OpKernelContext* context) override { 886 OP_REQUIRES_OK(context, IsMPIInitialized()); 887 888 // Write integer to output tensor 889 Tensor* output; 890 OP_REQUIRES_OK(context, 891 context->allocate_output(0, TensorShape({}), &output)); 892 893 auto flat = output->flat<int>(); 894 flat(0) = mpi_global.size; 895 } 896 }; 897 898 REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_CPU), 899 MPISizeOp<CPUDevice>); 900 #if GOOGLE_CUDA 901 REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_GPU).HostMemory("size"), 902 MPISizeOp<GPUDevice>); 903 #endif 904 905 // Op to get the current MPI Rank. 906 template <typename Device> 907 class MPIRankOp : public OpKernel { 908 public: 909 explicit MPIRankOp(OpKernelConstruction* context) : OpKernel(context) {} 910 911 void Compute(OpKernelContext* context) override { 912 OP_REQUIRES_OK(context, IsMPIInitialized()); 913 914 // Write integer to output tensor 915 Tensor* output; 916 OP_REQUIRES_OK(context, 917 context->allocate_output(0, TensorShape({}), &output)); 918 919 auto flat = output->flat<int>(); 920 flat(0) = mpi_global.rank; 921 } 922 }; 923 924 REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_CPU), 925 MPIRankOp<CPUDevice>); 926 #if GOOGLE_CUDA 927 REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_GPU).HostMemory("rank"), 928 MPIRankOp<GPUDevice>); 929 #endif 930 931 // Op to get the current local MPI Rank. 932 template <typename Device> 933 class MPILocalRankOp : public OpKernel { 934 public: 935 explicit MPILocalRankOp(OpKernelConstruction* context) : OpKernel(context) {} 936 937 void Compute(OpKernelContext* context) override { 938 OP_REQUIRES_OK(context, IsMPIInitialized()); 939 940 // Write integer to output tensor 941 Tensor* output; 942 OP_REQUIRES_OK(context, 943 context->allocate_output(0, TensorShape({}), &output)); 944 945 auto flat = output->flat<int>(); 946 flat(0) = mpi_global.local_rank; 947 } 948 }; 949 950 REGISTER_KERNEL_BUILDER(Name("MPILocalRank").Device(DEVICE_CPU), 951 MPILocalRankOp<CPUDevice>); 952 #if GOOGLE_CUDA 953 REGISTER_KERNEL_BUILDER( 954 Name("MPILocalRank").Device(DEVICE_GPU).HostMemory("rank"), 955 MPILocalRankOp<GPUDevice>); 956 #endif 957 958 template <typename Device> 959 class MPIAllreduceOp : public AsyncOpKernel { 960 public: 961 explicit MPIAllreduceOp(OpKernelConstruction* context) 962 : AsyncOpKernel(context) {} 963 964 // Although this op is handled asynchronously, the ComputeAsync call is 965 // very inexpensive. It only sets up a CollectiveOpRecord and places it 966 // in the table for the background thread to handle. Thus, we do not need 967 // a TF pool thread to perform the op. 968 bool IsExpensive() override { return false; } 969 970 void ComputeAsync(OpKernelContext* context, DoneCallback done) override { 971 OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done); 972 const Tensor* input_tensor = &context->input(0); 973 Tensor* output_tensor; 974 OP_REQUIRES_OK_ASYNC( 975 context, 976 context->allocate_output(0, input_tensor->shape(), &output_tensor), 977 done); 978 979 // Record allocated on stack so op can fail without memory leak 980 CollectiveOpRecord record; 981 record.name = name(); 982 record.context = context; 983 record.in_t = input_tensor; 984 record.out_t = output_tensor; 985 record.on_gpu = IsGPUDevice<Device>(); 986 record.dtype = input_tensor->dtype(); 987 988 const size_t temp_size = 989 (input_tensor->NumElements() + mpi_global.size - 1) / mpi_global.size; 990 TensorShape temp_shape; 991 temp_shape.AddDim(temp_size); 992 OP_REQUIRES_OK_ASYNC(context, 993 context->allocate_temp(input_tensor->dtype(), 994 temp_shape, &record.temp_t), 995 done); 996 997 auto allreduce_done_callback = [done, context](StatusOr<Tensor> status) { 998 context->SetStatus(status.status()); 999 done(); 1000 }; 1001 record.callback = allreduce_done_callback; 1002 1003 auto allreduce_launch_callback = [record] { 1004 EnqueueTensorCollective(record, MPIRequest::ALLREDUCE); 1005 }; 1006 1007 // If we are on a CPU, our device context will be null and we can't 1008 // get a stream to enqueue this on. On a CPU this op is called when the 1009 // data is already available, so we can just immediately do the 1010 // allreduce; we don't have to wait for the data to get populated. 1011 #if GOOGLE_CUDA 1012 auto device_context = context->op_device_context(); 1013 if (device_context == nullptr) { 1014 allreduce_launch_callback(); 1015 } else { 1016 auto stream = device_context->stream(); 1017 stream->ThenDoHostCallback(allreduce_launch_callback); 1018 } 1019 #else 1020 allreduce_launch_callback(); 1021 #endif 1022 } 1023 }; 1024 1025 REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_CPU), 1026 MPIAllreduceOp<CPUDevice>); 1027 #if GOOGLE_CUDA 1028 REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_GPU), 1029 MPIAllreduceOp<GPUDevice>); 1030 #endif 1031 1032 template <typename Device> 1033 class MPIAllgatherOp : public AsyncOpKernel { 1034 public: 1035 explicit MPIAllgatherOp(OpKernelConstruction* context) 1036 : AsyncOpKernel(context) {} 1037 1038 // Although this op is handled asynchronously, the ComputeAsync call is 1039 // very inexpensive. It only sets up a CollectiveOpRecord and places it 1040 // in the table for the background thread to handle. Thus, we do not need 1041 // a TF pool thread to perform the op. 1042 bool IsExpensive() override { return false; } 1043 1044 void ComputeAsync(OpKernelContext* context, DoneCallback done) override { 1045 OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done); 1046 const Tensor* input_tensor = &context->input(0); 1047 const Tensor* sizing_tensor = &context->input(1); 1048 1049 // Record allocated on stack so op can fail without memory leak 1050 CollectiveOpRecord record; 1051 record.name = name(); 1052 record.context = context; 1053 record.in_t = input_tensor; 1054 record.on_gpu = IsGPUDevice<Device>(); 1055 1056 // Construct the output size from the sizing tensor 1057 size_t output_first_dim = 0; 1058 if (sizing_tensor->shape().dims() == 0) { 1059 // 0-dim sizing_tensor implies that the op is just gathering 1060 // a single element from each rank 1061 output_first_dim = mpi_global.size; 1062 for (int i = 0; i < mpi_global.size; i++) { 1063 record.sizes_vec.push_back(1); 1064 } 1065 } else { 1066 // Collect the total output tensor sizing from the sizing tensor 1067 // NOTE: The sizing tensor is forced to be placed on the CPU by 1068 // declaring the input as HostMemory, so it is valid to read it here. 1069 const int64* sizing_array = 1070 (const int64*)sizing_tensor->tensor_data().data(); 1071 for (int i = 0; i < mpi_global.size; i++) { 1072 record.sizes_vec.push_back(sizing_array[i]); 1073 output_first_dim += sizing_array[i]; 1074 } 1075 } 1076 1077 TensorShape output_shape; 1078 output_shape.AddDim(output_first_dim); 1079 for (int i = 1; i < input_tensor->shape().dims(); i++) { 1080 output_shape.AddDim(input_tensor->shape().dim_size(i)); 1081 } 1082 1083 Tensor* output_tensor; 1084 OP_REQUIRES_OK_ASYNC( 1085 context, context->allocate_output(0, output_shape, &output_tensor), 1086 done); 1087 1088 record.out_t = output_tensor; 1089 record.dtype = input_tensor->dtype(); 1090 1091 auto allgather_done_callback = [done, context](StatusOr<Tensor> status) { 1092 context->SetStatus(status.status()); 1093 done(); 1094 }; 1095 record.callback = allgather_done_callback; 1096 1097 auto allgather_launch_callback = [record] { 1098 EnqueueTensorCollective(record, MPIRequest::ALLGATHER); 1099 }; 1100 1101 // If we are on a CPU, our device context will be null and we can't 1102 // get a stream to enqueue this on. On a CPU this op is called when the 1103 // data is already available, so we can just immediately do the 1104 // allgather; we don't have to wait for the data to get populated. 1105 #if GOOGLE_CUDA 1106 auto device_context = context->op_device_context(); 1107 if (device_context == nullptr) { 1108 allgather_launch_callback(); 1109 } else { 1110 auto stream = device_context->stream(); 1111 stream->ThenDoHostCallback(allgather_launch_callback); 1112 } 1113 #else 1114 allgather_launch_callback(); 1115 #endif 1116 } 1117 }; 1118 1119 REGISTER_KERNEL_BUILDER( 1120 Name("MPIAllgather").Device(DEVICE_CPU).HostMemory("sizes"), 1121 MPIAllgatherOp<CPUDevice>); 1122 #if GOOGLE_CUDA 1123 REGISTER_KERNEL_BUILDER( 1124 Name("MPIAllgather").Device(DEVICE_GPU).HostMemory("sizes"), 1125 MPIAllgatherOp<GPUDevice>); 1126 #endif 1127 1128 } // namespace mpi_collectives 1129 } // namespace contrib 1130 } // namespace tensorflow 1131 1132 #endif // TENSORFLOW_USE_MPI 1133