1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for Collective Operations.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.core.protobuf import config_pb2 22 from tensorflow.python.framework import constant_op 23 from tensorflow.python.framework import ops 24 from tensorflow.python.framework import test_util 25 from tensorflow.python.ops import collective_ops 26 from tensorflow.python.platform import test 27 28 29 class CollectiveOpTest(test.TestCase): 30 31 def _testCollectiveReduce(self, t0, t1, expected, set_graph_key): 32 group_key = 1 33 instance_key = 1 34 with self.session( 35 config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess: 36 with ops.device('/CPU:0'): 37 in0 = constant_op.constant(t0) 38 colred0 = collective_ops.all_reduce(in0, 2, group_key, instance_key, 39 'Add', 'Div') 40 with ops.device('/CPU:1'): 41 in1 = constant_op.constant(t1) 42 colred1 = collective_ops.all_reduce(in1, 2, group_key, instance_key, 43 'Add', 'Div') 44 run_options = config_pb2.RunOptions() 45 if set_graph_key: 46 run_options.experimental.collective_graph_key = 1 47 results = sess.run([colred0, colred1], options=run_options) 48 self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5) 49 self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5) 50 51 def _testMultipleConcurrentCollectiveReduce(self, t0, t1, expected): 52 group_key = 1 53 group_size = 2 54 num_instances = 2 55 all_reduces = [] 56 config = config_pb2.ConfigProto(device_count={'CPU': group_size}) 57 config.experimental.collective_deterministic_sequential_execution = True 58 with self.session(config=config) as sess: 59 for cpu in range(group_size): 60 with ops.device('/CPU:%d' % cpu): 61 in_tensor = constant_op.constant(t0 if cpu == 0 else t1) 62 for instance in range(num_instances): 63 all_reduces.append(collective_ops.all_reduce( 64 in_tensor, group_size, group_key, instance, 'Add', 'Div')) 65 results = sess.run(all_reduces) 66 for i in range(group_size * num_instances): 67 self.assertAllClose(results[i], expected, rtol=1e-5, atol=1e-5) 68 69 @test_util.run_deprecated_v1 70 def testCollectiveReduce(self): 71 self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1], 72 [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3], 73 [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], True) 74 75 @test_util.run_deprecated_v1 76 def testCollectiveAutoGraphKey(self): 77 self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1], 78 [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3], 79 [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], False) 80 81 @test_util.run_deprecated_v1 82 def testCollectiveMultipleConcurrentReduce(self): 83 self._testMultipleConcurrentCollectiveReduce( 84 [0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1], 85 [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3], 86 [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2]) 87 88 @test_util.run_deprecated_v1 89 def testCollectiveReduceScalar(self): 90 self._testCollectiveReduce(0.1, 0.3, 0.2, True) 91 92 def _testCollectiveBroadcast(self, t0): 93 group_key = 1 94 instance_key = 1 95 with self.session( 96 config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess: 97 with ops.device('/CPU:0'): 98 in0 = constant_op.constant(t0) 99 out0 = collective_ops.broadcast_send(in0, in0.shape, in0.dtype, 100 2, group_key, instance_key) 101 with ops.device('/CPU:1'): 102 c1 = constant_op.constant(t0) 103 out1 = collective_ops.broadcast_recv(c1.shape, c1.dtype, 104 2, group_key, instance_key) 105 run_options = config_pb2.RunOptions() 106 run_options.experimental.collective_graph_key = 1 107 results = sess.run([out0, out1], options=run_options) 108 self.assertAllClose(results[0], t0, rtol=1e-5, atol=1e-5) 109 self.assertAllClose(results[1], t0, rtol=1e-5, atol=1e-5) 110 111 @test_util.run_deprecated_v1 112 def testCollectiveBroadcast(self): 113 self._testCollectiveBroadcast([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1]) 114 115 def _testCollectiveGather(self, t0, t1, expected, set_graph_key): 116 group_key = 1 117 instance_key = 1 118 with self.session( 119 config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess: 120 with ops.device('/CPU:0'): 121 in0 = constant_op.constant(t0) 122 colred0 = collective_ops.all_gather(in0, 2, group_key, instance_key) 123 with ops.device('/CPU:1'): 124 in1 = constant_op.constant(t1) 125 colred1 = collective_ops.all_gather(in1, 2, group_key, instance_key) 126 run_options = config_pb2.RunOptions() 127 if set_graph_key: 128 run_options.experimental.collective_graph_key = 1 129 results = sess.run([colred0, colred1], options=run_options) 130 self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5) 131 self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5) 132 133 @test_util.run_deprecated_v1 134 def testCollectiveGather(self): 135 self._testCollectiveGather([0, 1, 2, 3, 4, 5, 6, 7], 136 [10, 11, 12, 13, 14, 15, 16, 17], 137 [0, 1, 2, 3, 4, 5, 6, 7, 138 10, 11, 12, 13, 14, 15, 16, 17], 139 True) 140 self._testCollectiveGather([[0, 1, 2, 3], [4, 5, 6, 7]], 141 [[10, 11, 12, 13], [14, 15, 16, 17]], 142 [[0, 1, 2, 3], [4, 5, 6, 7], 143 [10, 11, 12, 13], [14, 15, 16, 17]], 144 True) 145 self._testCollectiveGather([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], 146 [[[10, 11], [12, 13]], [[14, 15], [16, 17]]], 147 [[[0, 1], [2, 3]], [[4, 5], [6, 7]], 148 [[10, 11], [12, 13]], [[14, 15], [16, 17]]], 149 True) 150 151 152 if __name__ == '__main__': 153 test.main() 154