Home | History | Annotate | Download | only in training
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for ModelAverageOptimizer."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import portpicker
     21 
     22 from tensorflow.contrib.opt.python.training import model_average_optimizer
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.ops import variable_scope
     26 from tensorflow.python.ops import variables
     27 from tensorflow.python.platform import test
     28 from tensorflow.python.training import device_setter
     29 from tensorflow.python.training import gradient_descent
     30 from tensorflow.python.training import server_lib
     31 from tensorflow.python.training import training
     32 from tensorflow.python.training import training_util
     33 
     34 
     35 def create_local_cluster(num_workers, num_ps, protocol="grpc"):
     36   """Create local GRPC servers and return them."""
     37   worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
     38   ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
     39   cluster_dict = {
     40       "worker": ["localhost:%s" % port for port in worker_ports],
     41       "ps": ["localhost:%s" % port for port in ps_ports]
     42   }
     43   cs = server_lib.ClusterSpec(cluster_dict)
     44 
     45   workers = [
     46       server_lib.Server(
     47           cs, job_name="worker", protocol=protocol, task_index=ix, start=True)
     48       for ix in range(num_workers)
     49   ]
     50   ps_servers = [
     51       server_lib.Server(
     52           cs, job_name="ps", protocol=protocol, task_index=ix, start=True)
     53       for ix in range(num_ps)
     54   ]
     55 
     56   return cluster_dict, workers, ps_servers
     57 
     58 
     59 # Creates the workers and return their sessions, graphs, train_ops.
     60 # Cheif worker will update at last
     61 def _get_workers(num_workers, steps, workers):
     62   sessions = []
     63   graphs = []
     64   train_ops = []
     65   for worker_id in range(num_workers):
     66     graph = ops.Graph()
     67     is_chief = (worker_id == 0)
     68     with graph.as_default():
     69       worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
     70       ma_coustom = model_average_optimizer.ModelAverageCustomGetter(
     71           worker_device=worker_device)
     72       with variable_scope.variable_scope(
     73           "", custom_getter=ma_coustom), ops.device(
     74               device_setter.replica_device_setter(
     75                   worker_device=worker_device,
     76                   ps_device="/job:ps/task:0/cpu:0",
     77                   ps_tasks=1)):
     78 
     79         global_step = variables.Variable(0, name="global_step", trainable=False)
     80         var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
     81         var_1 = variable_scope.get_variable(initializer=1.0, name="v1")
     82 
     83       with ops.device("/job:worker/task:" + str(worker_id)):
     84         if worker_id == 0:
     85           grads_0 = constant_op.constant(-1.0)
     86           grads_1 = constant_op.constant(-1.0)
     87         else:
     88           grads_0 = constant_op.constant(-2.0)
     89           grads_1 = constant_op.constant(-2.0)
     90         sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
     91         opt = model_average_optimizer.ModelAverageOptimizer(
     92             opt=sgd_opt,
     93             num_worker=num_workers,
     94             ma_custom_getter=ma_coustom,
     95             is_chief=is_chief,
     96             interval_steps=steps)
     97         train_op = [
     98             opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]],
     99                                 global_step)
    100         ]
    101       easgd_hook = opt.make_session_run_hook()
    102       # Creates MonitoredSession
    103       sess = training.MonitoredTrainingSession(
    104           workers[worker_id].target, hooks=[easgd_hook])
    105 
    106     sessions.append(sess)
    107     graphs.append(graph)
    108     train_ops.append(train_op)
    109   return sessions, graphs, train_ops
    110 
    111 
    112 class ModelAverageOptimizerTest(test.TestCase):
    113   def _run(self, train_op, sess):
    114     sess.run(train_op)
    115 
    116   def test1Workers2Period(self):
    117     num_workers = 2
    118     steps = 2
    119     num_ps = 1
    120     _, workers, _ = create_local_cluster(
    121         num_workers=num_workers, num_ps=num_ps)
    122 
    123     sessions, graphs, train_ops = _get_workers(num_workers, steps, workers)
    124 
    125     var_0 = graphs[0].get_tensor_by_name("v0:0")
    126     var_1 = graphs[0].get_tensor_by_name("v1:0")
    127     global_step = training_util.get_global_step(graphs[0])
    128     global_var_0 = graphs[0].get_tensor_by_name(
    129         model_average_optimizer.GLOBAL_VARIABLE_NAME + "/v0:0")
    130     global_var_1 = graphs[0].get_tensor_by_name(
    131         model_average_optimizer.GLOBAL_VARIABLE_NAME + "/v1:0")
    132 
    133     # Verify the initialized value.
    134     self.assertAllEqual(0.0, sessions[0].run(var_0))
    135     self.assertAllEqual(1.0, sessions[0].run(var_1))
    136     self.assertAllEqual(0.0, sessions[0].run(global_var_0))
    137     self.assertAllEqual(1.0, sessions[0].run(global_var_1))
    138     self.assertAllEqual(0, sessions[0].run(global_step))
    139 
    140     sessions[0].run(train_ops[0])
    141     sessions[1].run(train_ops[1])
    142 
    143     self.assertAllEqual(1.0, sessions[0].run(var_0))
    144     self.assertAllEqual(2.0, sessions[0].run(var_1))
    145     self.assertAllEqual(0.0, sessions[0].run(global_var_0))
    146     self.assertAllEqual(1.0, sessions[0].run(global_var_1))
    147     self.assertAllEqual(0, sessions[0].run(global_step))
    148 
    149     # iteration 2, global varibale update
    150     thread_0 = self.checkedThread(
    151         target=self._run, args=(train_ops[0], sessions[0]))
    152     thread_1 = self.checkedThread(
    153         target=self._run, args=(train_ops[1], sessions[1]))
    154     thread_0.start()
    155     thread_1.start()
    156     thread_0.join()
    157     thread_1.join()
    158 
    159     self.assertAllEqual(3.0, sessions[0].run(var_0))
    160     self.assertAllEqual(4.0, sessions[0].run(var_1))
    161     self.assertAllEqual(3.0, sessions[0].run(global_var_0))
    162     self.assertAllEqual(4.0, sessions[0].run(global_var_1))
    163     self.assertAllEqual(1, sessions[0].run(global_step))
    164 
    165     # iteration 3
    166     sessions[0].run(train_ops[0])
    167 
    168     self.assertAllEqual(4.0, sessions[0].run(var_0))
    169     self.assertAllEqual(5.0, sessions[0].run(var_1))
    170     self.assertAllEqual(3.0, sessions[0].run(global_var_0))
    171     self.assertAllEqual(4.0, sessions[0].run(global_var_1))
    172     self.assertAllEqual(1, sessions[0].run(global_step))
    173 
    174   def testPS2TasksWithClusterSpecClass(self):
    175     cluster_spec = server_lib.ClusterSpec({
    176         "ps": ["ps0:2222", "ps1:2222"],
    177         "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
    178     })
    179     worker_device = "/job:worker/task:0"
    180     ma_coustom = model_average_optimizer.ModelAverageCustomGetter(
    181         worker_device=worker_device)
    182     from tensorflow.python.training import device_setter
    183     with ops.device(
    184         device_setter.replica_device_setter(cluster=cluster_spec,
    185                                             worker_device=worker_device,
    186                                             ps_device="/job:ps")), \
    187          variable_scope.variable_scope("", custom_getter=ma_coustom):
    188       v = variable_scope.get_variable(initializer=[1, 2], name="v")
    189       w = variable_scope.get_variable(initializer=[2, 1], name="w")
    190       v_g, w_g = ma_coustom._local_2_global[v], ma_coustom._local_2_global[w]
    191       self.assertDeviceEqual("/job:worker/task:0", v.device)
    192       self.assertDeviceEqual("job:ps/task:0", v_g.device)
    193       self.assertDeviceEqual("/job:worker/task:0", w.device)
    194       self.assertDeviceEqual("job:ps/task:1", w_g.device)
    195 
    196 
    197 if __name__ == "__main__":
    198   test.main()
    199