Home | History | Annotate | Download | only in training
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tf.contrib.training.device_setter."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import collections
     21 from tensorflow.contrib.training.python.training import device_setter as device_setter_lib
     22 from tensorflow.python.framework import ops
     23 from tensorflow.python.ops import array_ops
     24 from tensorflow.python.ops import variables
     25 from tensorflow.python.platform import test
     26 from tensorflow.python.training import device_setter
     27 from tensorflow.python.training import server_lib
     28 
     29 _CLUSTER_SPEC = server_lib.ClusterSpec({
     30     "ps": ["ps0:2222", "ps1:2222"],
     31     "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
     32 })
     33 
     34 MockOperation = collections.namedtuple("MockOperation", "name")
     35 
     36 
     37 class RandomStrategyTest(test.TestCase):
     38 
     39   def testBasic(self):
     40     ps_strategy = device_setter_lib.RandomStrategy(2, seed=0)
     41     with ops.device(
     42         device_setter.replica_device_setter(
     43             cluster=_CLUSTER_SPEC,
     44             ps_strategy=ps_strategy)):
     45       u = variables.Variable(array_ops.zeros([2, 2]))
     46       v = variables.Variable(array_ops.zeros([2, 1]))
     47       w = variables.Variable(array_ops.zeros([2, 2]))
     48       x = variables.Variable(array_ops.zeros([1, 3]))
     49       a = v + w
     50       # Randomly distributed with seed 0.
     51       self.assertDeviceEqual("/job:ps/task:1", u.device)
     52       self.assertDeviceEqual("/job:ps/task:1", u.initializer.device)
     53       self.assertDeviceEqual("/job:ps/task:0", v.device)
     54       self.assertDeviceEqual("/job:ps/task:0", v.initializer.device)
     55       self.assertDeviceEqual("/job:ps/task:1", w.device)
     56       self.assertDeviceEqual("/job:ps/task:1", w.initializer.device)
     57       self.assertDeviceEqual("/job:ps/task:1", x.device)
     58       self.assertDeviceEqual("/job:ps/task:1", x.initializer.device)
     59       self.assertDeviceEqual("/job:worker", a.device)
     60 
     61   def testHandlesUnicode(self):
     62     op = MockOperation(u"A unicode \u018e string \xf1")
     63     ps_strategy = device_setter_lib.RandomStrategy(2, seed=0)
     64     ps_task = ps_strategy(op)
     65     self.assertEqual(ps_task, 1)
     66 
     67 
     68 class GreedyLoadBalancingStrategyTest(test.TestCase):
     69 
     70   def testUniformLoadEqualsRoundRobin(self):
     71 
     72     def _load_fn(unused_op):
     73       return 1
     74 
     75     with ops.device(
     76         device_setter.replica_device_setter(
     77             cluster=_CLUSTER_SPEC,
     78             ps_strategy=device_setter_lib.GreedyLoadBalancingStrategy(
     79                 2, _load_fn))):
     80       u = variables.Variable(array_ops.zeros([2, 2]))
     81       v = variables.Variable(array_ops.zeros([2, 1]))
     82       w = variables.Variable(array_ops.zeros([2, 2]))
     83       x = variables.Variable(array_ops.zeros([1, 3]))
     84       a = v + w
     85       self.assertDeviceEqual("/job:ps/task:0", u.device)
     86       self.assertDeviceEqual("/job:ps/task:0", u.initializer.device)
     87       self.assertDeviceEqual("/job:ps/task:1", v.device)
     88       self.assertDeviceEqual("/job:ps/task:1", v.initializer.device)
     89       self.assertDeviceEqual("/job:ps/task:0", w.device)
     90       self.assertDeviceEqual("/job:ps/task:0", w.initializer.device)
     91       self.assertDeviceEqual("/job:ps/task:1", x.device)
     92       self.assertDeviceEqual("/job:ps/task:1", x.initializer.device)
     93       self.assertDeviceEqual("/job:worker", a.device)
     94 
     95   def testByteSizeLoadFn(self):
     96     with ops.device(
     97         device_setter.replica_device_setter(
     98             cluster=_CLUSTER_SPEC,
     99             ps_strategy=device_setter_lib.GreedyLoadBalancingStrategy(
    100                 2, device_setter_lib.byte_size_load_fn))):
    101       u = variables.Variable(array_ops.zeros([2, 2]))
    102       v = variables.Variable(array_ops.zeros([2, 1]))
    103       w = variables.Variable(array_ops.zeros([2, 2]))
    104       x = variables.Variable(array_ops.zeros([1, 3]))
    105       a = v + w
    106       self.assertDeviceEqual("/job:ps/task:0", u.device)
    107       self.assertDeviceEqual("/job:ps/task:0", u.initializer.device)
    108       self.assertDeviceEqual("/job:ps/task:1", v.device)
    109       self.assertDeviceEqual("/job:ps/task:1", v.initializer.device)
    110       self.assertDeviceEqual("/job:ps/task:1", w.device)
    111       self.assertDeviceEqual("/job:ps/task:1", w.initializer.device)
    112       self.assertDeviceEqual("/job:ps/task:0", x.device)
    113       self.assertDeviceEqual("/job:ps/task:0", x.initializer.device)
    114       self.assertDeviceEqual("/job:worker", a.device)
    115 
    116   def testByteSizeLoadFnWithScalar(self):
    117     with ops.device(
    118         device_setter.replica_device_setter(
    119             cluster=_CLUSTER_SPEC,
    120             ps_strategy=device_setter_lib.GreedyLoadBalancingStrategy(
    121                 2, device_setter_lib.byte_size_load_fn))):
    122       # Note: we must test the load function as part of the device function
    123       # instead of passing u.op to the function directly, because the only
    124       # time that the output Tensor has unknown shape for scalars is during
    125       # Variable construction.
    126       u = variables.Variable(0)
    127       self.assertDeviceEqual("/job:ps/task:0", u.device)
    128       self.assertDeviceEqual("/job:ps/task:0", u.initializer.device)
    129 
    130 
    131 if __name__ == "__main__":
    132   test.main()
    133