Home | History | Annotate | Download | only in ops
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #    http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for Collective Operations."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.core.protobuf import config_pb2
     22 from tensorflow.python.framework import constant_op
     23 from tensorflow.python.framework import ops
     24 from tensorflow.python.framework import test_util
     25 from tensorflow.python.ops import collective_ops
     26 from tensorflow.python.platform import test
     27 
     28 
     29 class CollectiveOpTest(test.TestCase):
     30 
     31   def _testCollectiveReduce(self, t0, t1, expected, set_graph_key):
     32     group_key = 1
     33     instance_key = 1
     34     with self.session(
     35         config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
     36       with ops.device('/CPU:0'):
     37         in0 = constant_op.constant(t0)
     38         colred0 = collective_ops.all_reduce(in0, 2, group_key, instance_key,
     39                                             'Add', 'Div')
     40       with ops.device('/CPU:1'):
     41         in1 = constant_op.constant(t1)
     42         colred1 = collective_ops.all_reduce(in1, 2, group_key, instance_key,
     43                                             'Add', 'Div')
     44       run_options = config_pb2.RunOptions()
     45       if set_graph_key:
     46         run_options.experimental.collective_graph_key = 1
     47       results = sess.run([colred0, colred1], options=run_options)
     48     self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5)
     49     self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5)
     50 
     51   def _testMultipleConcurrentCollectiveReduce(self, t0, t1, expected):
     52     group_key = 1
     53     group_size = 2
     54     num_instances = 2
     55     all_reduces = []
     56     config = config_pb2.ConfigProto(device_count={'CPU': group_size})
     57     config.experimental.collective_deterministic_sequential_execution = True
     58     with self.session(config=config) as sess:
     59       for cpu in range(group_size):
     60         with ops.device('/CPU:%d' % cpu):
     61           in_tensor = constant_op.constant(t0 if cpu == 0 else t1)
     62           for instance in range(num_instances):
     63             all_reduces.append(collective_ops.all_reduce(
     64                 in_tensor, group_size, group_key, instance, 'Add', 'Div'))
     65       results = sess.run(all_reduces)
     66     for i in range(group_size * num_instances):
     67       self.assertAllClose(results[i], expected, rtol=1e-5, atol=1e-5)
     68 
     69   @test_util.run_deprecated_v1
     70   def testCollectiveReduce(self):
     71     self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
     72                                [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
     73                                [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], True)
     74 
     75   @test_util.run_deprecated_v1
     76   def testCollectiveAutoGraphKey(self):
     77     self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
     78                                [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
     79                                [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], False)
     80 
     81   @test_util.run_deprecated_v1
     82   def testCollectiveMultipleConcurrentReduce(self):
     83     self._testMultipleConcurrentCollectiveReduce(
     84         [0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
     85         [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
     86         [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2])
     87 
     88   @test_util.run_deprecated_v1
     89   def testCollectiveReduceScalar(self):
     90     self._testCollectiveReduce(0.1, 0.3, 0.2, True)
     91 
     92   def _testCollectiveBroadcast(self, t0):
     93     group_key = 1
     94     instance_key = 1
     95     with self.session(
     96         config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
     97       with ops.device('/CPU:0'):
     98         in0 = constant_op.constant(t0)
     99         out0 = collective_ops.broadcast_send(in0, in0.shape, in0.dtype,
    100                                              2, group_key, instance_key)
    101       with ops.device('/CPU:1'):
    102         c1 = constant_op.constant(t0)
    103         out1 = collective_ops.broadcast_recv(c1.shape, c1.dtype,
    104                                              2, group_key, instance_key)
    105       run_options = config_pb2.RunOptions()
    106       run_options.experimental.collective_graph_key = 1
    107       results = sess.run([out0, out1], options=run_options)
    108     self.assertAllClose(results[0], t0, rtol=1e-5, atol=1e-5)
    109     self.assertAllClose(results[1], t0, rtol=1e-5, atol=1e-5)
    110 
    111   @test_util.run_deprecated_v1
    112   def testCollectiveBroadcast(self):
    113     self._testCollectiveBroadcast([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1])
    114 
    115   def _testCollectiveGather(self, t0, t1, expected, set_graph_key):
    116     group_key = 1
    117     instance_key = 1
    118     with self.session(
    119         config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
    120       with ops.device('/CPU:0'):
    121         in0 = constant_op.constant(t0)
    122         colred0 = collective_ops.all_gather(in0, 2, group_key, instance_key)
    123       with ops.device('/CPU:1'):
    124         in1 = constant_op.constant(t1)
    125         colred1 = collective_ops.all_gather(in1, 2, group_key, instance_key)
    126       run_options = config_pb2.RunOptions()
    127       if set_graph_key:
    128         run_options.experimental.collective_graph_key = 1
    129       results = sess.run([colred0, colred1], options=run_options)
    130     self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5)
    131     self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5)
    132 
    133   @test_util.run_deprecated_v1
    134   def testCollectiveGather(self):
    135     self._testCollectiveGather([0, 1, 2, 3, 4, 5, 6, 7],
    136                                [10, 11, 12, 13, 14, 15, 16, 17],
    137                                [0, 1, 2, 3, 4, 5, 6, 7,
    138                                 10, 11, 12, 13, 14, 15, 16, 17],
    139                                True)
    140     self._testCollectiveGather([[0, 1, 2, 3], [4, 5, 6, 7]],
    141                                [[10, 11, 12, 13], [14, 15, 16, 17]],
    142                                [[0, 1, 2, 3], [4, 5, 6, 7],
    143                                 [10, 11, 12, 13], [14, 15, 16, 17]],
    144                                True)
    145     self._testCollectiveGather([[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
    146                                [[[10, 11], [12, 13]], [[14, 15], [16, 17]]],
    147                                [[[0, 1], [2, 3]], [[4, 5], [6, 7]],
    148                                 [[10, 11], [12, 13]], [[14, 15], [16, 17]]],
    149                                True)
    150 
    151 
    152 if __name__ == '__main__':
    153   test.main()
    154