Home | History | Annotate | Download | only in mpi_collectives
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 import numpy as np
     23 import tensorflow as tf
     24 import tensorflow.contrib.mpi_collectives as mpi
     25 from tensorflow.python.platform import test
     26 
     27 
     28 average_allgather = False
     29 
     30 
     31 class AllgatherTest(test.TestCase):
     32   def checkAllgather(self, num_ranks, all_gathered, local_gathered):
     33     # Ensure that indices match.
     34     all_gat_ind = np.sort(all_gathered.indices)
     35     loc_gat_ind = np.sort(local_gathered.indices)
     36     assert(len(loc_gat_ind) == len(all_gat_ind))
     37     for i in range(len(loc_gat_ind)):
     38       assert(loc_gat_ind[i] == all_gat_ind[i])
     39 
     40     # For each index, verify same values.
     41     local_checked = []
     42     for i in range(len(local_gathered.indices)):
     43       local_checked.append(False)
     44     for i in range(len(all_gathered.indices)):
     45       all_index = all_gathered.indices[i]
     46       # TODO(jthestness): Make this lookup quicker using sorting.
     47       loc_index = -1
     48       for j in range(len(local_gathered.indices)):
     49         if local_gathered.indices[j] == all_index and not local_checked[j]:
     50           loc_index = j
     51           local_checked[j] = True
     52           break
     53       assert(loc_index >= 0)
     54       correct_output = local_gathered.values[loc_index][0]
     55       if average_allgather:
     56         correct_output = correct_output / float(num_ranks)
     57       assert(all_gathered.values[i][0] == correct_output)
     58 
     59 
     60   def test_mpi_allgather(self):
     61     # Get MPI rank
     62     my_rank = int(os.environ['PMI_RANK'])
     63     num_ranks = int(os.environ['PMI_SIZE'])
     64 
     65     indices_per_rank = 100
     66     tensor_width = 10
     67 
     68     # Create IndexedSlices for each rank, some with overlapping indices.
     69     to_gather_indices = []
     70     to_gather_values = []
     71     to_gather = []
     72     for rank_id in range(num_ranks):
     73       indices = []
     74       values = []
     75       my_multiple = rank_id + 1
     76       current_index = my_multiple
     77       for i in range(indices_per_rank):
     78         indices.append(current_index)
     79         ones_tensor = tf.ones([tensor_width])
     80         values.append(tf.multiply(ones_tensor,
     81                                   tf.fill(ones_tensor.get_shape(),
     82                                           float(current_index))))
     83         current_index += my_multiple
     84       concat_ind = tf.stack(indices)
     85       concat_vals = tf.stack(values)
     86       to_gather_indices.append(concat_ind)
     87       to_gather_values.append(concat_vals)
     88       to_gather.append(tf.IndexedSlices(concat_vals, concat_ind))
     89 
     90     # Collect the local IndexedSlices (indices and values) to create
     91     # correct IndexedSlices output.
     92     correct_gather_indices = tf.concat(to_gather_indices, 0)
     93     correct_gather_values = tf.concat(to_gather_values, 0)
     94     correct_gather = tf.IndexedSlices(correct_gather_values,
     95                                       correct_gather_indices)
     96 
     97     all_gather = mpi.allreduce(to_gather[my_rank], average_allgather)
     98 
     99     # NOTE: This assumes that device IDs are numbered the same as ranks.
    100     gpu_options = tf.GPUOptions(visible_device_list=str(my_rank))
    101     config = tf.ConfigProto(gpu_options=gpu_options)
    102 
    103     # MPI Session to test allgather.
    104     with mpi.Session(config=config) as sess:
    105       sess.run(tf.global_variables_initializer())
    106 
    107       all_gathered, local_gathered = sess.run([all_gather, correct_gather])
    108 
    109       # Compare all_gathered with local_gathered.
    110       self.checkAllgather(num_ranks, all_gathered, local_gathered)
    111 
    112 
    113 if __name__ == '__main__':
    114   test.main()
    115