1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================= 15 """Inter-process communication using MPI.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import tensorflow as tf 22 23 from tensorflow.contrib.mpi_collectives.ops import gen_mpi_ops 24 from tensorflow.contrib.util import loader 25 from tensorflow.python.framework import ops 26 from tensorflow.python.platform import resource_loader 27 28 _mpi_ops_so = loader.load_op_library( 29 resource_loader.get_path_to_datafile('_mpi_ops.so')) 30 31 32 def size(name=None): 33 """An op which returns the number of MPI processes. 34 35 This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the 36 size of the global communicator. 37 38 Returns: 39 An integer scalar containing the number of MPI processes. 40 """ 41 return gen_mpi_ops.mpi_size(name=name) 42 43 44 ops.NotDifferentiable('MPISize') 45 46 47 def rank(name=None): 48 """An op which returns the MPI rank of the calling process. 49 50 This is equivalent to running `MPI_Comm_rank(MPI_COMM_WORLD, ...)` to get the 51 rank of the current process in the global communicator. 52 53 Returns: 54 An integer scalar with the MPI rank of the calling process. 55 """ 56 return gen_mpi_ops.mpi_rank(name=name) 57 58 59 ops.NotDifferentiable('MPIRank') 60 61 62 def init(name=None): 63 """An op which initializes MPI on the device on which it is run. 64 65 All future MPI ops must be run on the same device that the `init` op was run 66 on. 67 """ 68 return gen_mpi_ops.mpi_init(name=name) 69 70 71 ops.NotDifferentiable('MPIInit') 72 73 74 def local_rank(name=None): 75 """An op which returns the local MPI rank of the calling process, within the 76 node that it is running on. For example, if there are seven processes running 77 on a node, their local ranks will be zero through six, inclusive. 78 79 This is equivalent to running `MPI_Comm_rank(...)` on a new communicator 80 which only includes processes on the same node. 81 82 Returns: 83 An integer scalar with the local MPI rank of the calling process. 84 """ 85 return gen_mpi_ops.mpi_local_rank(name=name) 86 87 88 ops.NotDifferentiable('MPILocalRank') 89 90 91 def _allreduce(tensor, name=None): 92 """An op which sums an input tensor over all the MPI processes. 93 94 The reduction operation is keyed by the name of the op. The tensor type and 95 shape must be the same on all MPI processes for a given name. The reduction 96 will not start until all processes are ready to send and receive the tensor. 97 98 Returns: 99 A tensor of the same shape and type as `tensor`, summed across all 100 processes. 101 """ 102 return gen_mpi_ops.mpi_allreduce(tensor, name=name) 103 104 105 ops.NotDifferentiable('MPIAllreduce') 106 107 108 def allgather(tensor, name=None): 109 """An op which concatenates the input tensor with the same input tensor on 110 all other MPI processes. 111 112 The concatenation is done on the first dimension, so the input tensors on the 113 different processes must have the same rank and shape, except for the first 114 dimension, which is allowed to be different. 115 116 Returns: 117 A tensor of the same type as `tensor`, concatenated on dimension zero 118 across all processes. The shape is identical to the input shape, except for 119 the first dimension, which may be greater and is the sum of all first 120 dimensions of the tensors in different MPI processes. 121 """ 122 # Specify that first allgather is to collect the tensor gather sizes, 123 # indicated by passing in a scalar (0-D tensor) of value 0 124 sizes_flag = tf.constant(0, dtype=tf.int64, name='size_flag_const') 125 my_size = tf.slice( 126 tf.shape(tensor, out_type=tf.int64), [0], [1], name='size_slice') 127 if name is None: 128 name = 'allgather' 129 sizing_name = '{}_sizing'.format(name) 130 sizes = gen_mpi_ops.mpi_allgather(my_size, sizes_flag, name=sizing_name) 131 return gen_mpi_ops.mpi_allgather(tensor, sizes, name=name) 132 133 134 ops.NotDifferentiable('MPIAllgather') 135