1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tf.contrib.training.device_setter.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import collections 21 from tensorflow.contrib.training.python.training import device_setter as device_setter_lib 22 from tensorflow.python.framework import ops 23 from tensorflow.python.ops import array_ops 24 from tensorflow.python.ops import variables 25 from tensorflow.python.platform import test 26 from tensorflow.python.training import device_setter 27 from tensorflow.python.training import server_lib 28 29 _CLUSTER_SPEC = server_lib.ClusterSpec({ 30 "ps": ["ps0:2222", "ps1:2222"], 31 "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] 32 }) 33 34 MockOperation = collections.namedtuple("MockOperation", "name") 35 36 37 class RandomStrategyTest(test.TestCase): 38 39 def testBasic(self): 40 ps_strategy = device_setter_lib.RandomStrategy(2, seed=0) 41 with ops.device( 42 device_setter.replica_device_setter( 43 cluster=_CLUSTER_SPEC, 44 ps_strategy=ps_strategy)): 45 u = variables.Variable(array_ops.zeros([2, 2])) 46 v = variables.Variable(array_ops.zeros([2, 1])) 47 w = variables.Variable(array_ops.zeros([2, 2])) 48 x = variables.Variable(array_ops.zeros([1, 3])) 49 a = v + w 50 # Randomly distributed with seed 0. 51 self.assertDeviceEqual("/job:ps/task:1", u.device) 52 self.assertDeviceEqual("/job:ps/task:1", u.initializer.device) 53 self.assertDeviceEqual("/job:ps/task:0", v.device) 54 self.assertDeviceEqual("/job:ps/task:0", v.initializer.device) 55 self.assertDeviceEqual("/job:ps/task:1", w.device) 56 self.assertDeviceEqual("/job:ps/task:1", w.initializer.device) 57 self.assertDeviceEqual("/job:ps/task:1", x.device) 58 self.assertDeviceEqual("/job:ps/task:1", x.initializer.device) 59 self.assertDeviceEqual("/job:worker", a.device) 60 61 def testHandlesUnicode(self): 62 op = MockOperation(u"A unicode \u018e string \xf1") 63 ps_strategy = device_setter_lib.RandomStrategy(2, seed=0) 64 ps_task = ps_strategy(op) 65 self.assertEqual(ps_task, 1) 66 67 68 class GreedyLoadBalancingStrategyTest(test.TestCase): 69 70 def testUniformLoadEqualsRoundRobin(self): 71 72 def _load_fn(unused_op): 73 return 1 74 75 with ops.device( 76 device_setter.replica_device_setter( 77 cluster=_CLUSTER_SPEC, 78 ps_strategy=device_setter_lib.GreedyLoadBalancingStrategy( 79 2, _load_fn))): 80 u = variables.Variable(array_ops.zeros([2, 2])) 81 v = variables.Variable(array_ops.zeros([2, 1])) 82 w = variables.Variable(array_ops.zeros([2, 2])) 83 x = variables.Variable(array_ops.zeros([1, 3])) 84 a = v + w 85 self.assertDeviceEqual("/job:ps/task:0", u.device) 86 self.assertDeviceEqual("/job:ps/task:0", u.initializer.device) 87 self.assertDeviceEqual("/job:ps/task:1", v.device) 88 self.assertDeviceEqual("/job:ps/task:1", v.initializer.device) 89 self.assertDeviceEqual("/job:ps/task:0", w.device) 90 self.assertDeviceEqual("/job:ps/task:0", w.initializer.device) 91 self.assertDeviceEqual("/job:ps/task:1", x.device) 92 self.assertDeviceEqual("/job:ps/task:1", x.initializer.device) 93 self.assertDeviceEqual("/job:worker", a.device) 94 95 def testByteSizeLoadFn(self): 96 with ops.device( 97 device_setter.replica_device_setter( 98 cluster=_CLUSTER_SPEC, 99 ps_strategy=device_setter_lib.GreedyLoadBalancingStrategy( 100 2, device_setter_lib.byte_size_load_fn))): 101 u = variables.Variable(array_ops.zeros([2, 2])) 102 v = variables.Variable(array_ops.zeros([2, 1])) 103 w = variables.Variable(array_ops.zeros([2, 2])) 104 x = variables.Variable(array_ops.zeros([1, 3])) 105 a = v + w 106 self.assertDeviceEqual("/job:ps/task:0", u.device) 107 self.assertDeviceEqual("/job:ps/task:0", u.initializer.device) 108 self.assertDeviceEqual("/job:ps/task:1", v.device) 109 self.assertDeviceEqual("/job:ps/task:1", v.initializer.device) 110 self.assertDeviceEqual("/job:ps/task:1", w.device) 111 self.assertDeviceEqual("/job:ps/task:1", w.initializer.device) 112 self.assertDeviceEqual("/job:ps/task:0", x.device) 113 self.assertDeviceEqual("/job:ps/task:0", x.initializer.device) 114 self.assertDeviceEqual("/job:worker", a.device) 115 116 def testByteSizeLoadFnWithScalar(self): 117 with ops.device( 118 device_setter.replica_device_setter( 119 cluster=_CLUSTER_SPEC, 120 ps_strategy=device_setter_lib.GreedyLoadBalancingStrategy( 121 2, device_setter_lib.byte_size_load_fn))): 122 # Note: we must test the load function as part of the device function 123 # instead of passing u.op to the function directly, because the only 124 # time that the output Tensor has unknown shape for scalars is during 125 # Variable construction. 126 u = variables.Variable(0) 127 self.assertDeviceEqual("/job:ps/task:0", u.device) 128 self.assertDeviceEqual("/job:ps/task:0", u.initializer.device) 129 130 131 if __name__ == "__main__": 132 test.main() 133