Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # =============================================================================
     15 """Inter-process communication using MPI."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import tensorflow as tf
     22 
     23 from tensorflow.contrib.mpi_collectives.ops import gen_mpi_ops
     24 from tensorflow.contrib.util import loader
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.platform import resource_loader
     27 
     28 _mpi_ops_so = loader.load_op_library(
     29     resource_loader.get_path_to_datafile('_mpi_ops.so'))
     30 
     31 
     32 def size(name=None):
     33   """An op which returns the number of MPI processes.
     34 
     35   This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the
     36   size of the global communicator.
     37 
     38   Returns:
     39     An integer scalar containing the number of MPI processes.
     40   """
     41   return gen_mpi_ops.mpi_size(name=name)
     42 
     43 
     44 ops.NotDifferentiable('MPISize')
     45 
     46 
     47 def rank(name=None):
     48   """An op which returns the MPI rank of the calling process.
     49 
     50   This is equivalent to running `MPI_Comm_rank(MPI_COMM_WORLD, ...)` to get the
     51   rank of the current process in the global communicator.
     52 
     53   Returns:
     54     An integer scalar with the MPI rank of the calling process.
     55   """
     56   return gen_mpi_ops.mpi_rank(name=name)
     57 
     58 
     59 ops.NotDifferentiable('MPIRank')
     60 
     61 
     62 def init(name=None):
     63   """An op which initializes MPI on the device on which it is run.
     64 
     65   All future MPI ops must be run on the same device that the `init` op was run
     66   on.
     67   """
     68   return gen_mpi_ops.mpi_init(name=name)
     69 
     70 
     71 ops.NotDifferentiable('MPIInit')
     72 
     73 
     74 def local_rank(name=None):
     75   """An op which returns the local MPI rank of the calling process, within the
     76   node that it is running on. For example, if there are seven processes running
     77   on a node, their local ranks will be zero through six, inclusive.
     78 
     79   This is equivalent to running `MPI_Comm_rank(...)` on a new communicator
     80   which only includes processes on the same node.
     81 
     82   Returns:
     83     An integer scalar with the local MPI rank of the calling process.
     84   """
     85   return gen_mpi_ops.mpi_local_rank(name=name)
     86 
     87 
     88 ops.NotDifferentiable('MPILocalRank')
     89 
     90 
     91 def _allreduce(tensor, name=None):
     92   """An op which sums an input tensor over all the MPI processes.
     93 
     94   The reduction operation is keyed by the name of the op. The tensor type and
     95   shape must be the same on all MPI processes for a given name. The reduction
     96   will not start until all processes are ready to send and receive the tensor.
     97 
     98   Returns:
     99     A tensor of the same shape and type as `tensor`, summed across all
    100     processes.
    101   """
    102   return gen_mpi_ops.mpi_allreduce(tensor, name=name)
    103 
    104 
    105 ops.NotDifferentiable('MPIAllreduce')
    106 
    107 
    108 def allgather(tensor, name=None):
    109   """An op which concatenates the input tensor with the same input tensor on
    110   all other MPI processes.
    111 
    112   The concatenation is done on the first dimension, so the input tensors on the
    113   different processes must have the same rank and shape, except for the first
    114   dimension, which is allowed to be different.
    115 
    116   Returns:
    117     A tensor of the same type as `tensor`, concatenated on dimension zero
    118     across all processes. The shape is identical to the input shape, except for
    119     the first dimension, which may be greater and is the sum of all first
    120     dimensions of the tensors in different MPI processes.
    121   """
    122   # Specify that first allgather is to collect the tensor gather sizes,
    123   # indicated by passing in a scalar (0-D tensor) of value 0
    124   sizes_flag = tf.constant(0, dtype=tf.int64, name='size_flag_const')
    125   my_size = tf.slice(
    126       tf.shape(tensor, out_type=tf.int64), [0], [1], name='size_slice')
    127   if name is None:
    128     name = 'allgather'
    129   sizing_name = '{}_sizing'.format(name)
    130   sizes = gen_mpi_ops.mpi_allgather(my_size, sizes_flag, name=sizing_name)
    131   return gen_mpi_ops.mpi_allgather(tensor, sizes, name=name)
    132 
    133 
    134 ops.NotDifferentiable('MPIAllgather')
    135