Home | History | Annotate | Download | only in training
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for sync_replicas_optimizer.py."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import time
     22 
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.framework.test_util import create_local_cluster
     26 from tensorflow.python.ops import variables
     27 from tensorflow.python.platform import test
     28 from tensorflow.python.training import adam
     29 from tensorflow.python.training import gradient_descent
     30 from tensorflow.python.training import training
     31 
     32 
     33 # Creates the workers and return their sessions, graphs, train_ops.
     34 def get_workers(num_workers, replicas_to_aggregate, workers):
     35   sessions = []
     36   graphs = []
     37   train_ops = []
     38   for worker_id in range(num_workers):
     39     graph = ops.Graph()
     40     is_chief = (worker_id == 0)
     41     with graph.as_default():
     42       with ops.device("/job:ps/task:0"):
     43         global_step = variables.Variable(0, name="global_step", trainable=False)
     44         var_0 = variables.Variable(0.0, name="v0")
     45       with ops.device("/job:ps/task:1"):
     46         var_1 = variables.Variable(1.0, name="v1")
     47         var_sparse = variables.Variable([[3.0], [4.0]], name="v_sparse")
     48 
     49       with ops.device("/job:worker/task:" + str(worker_id)):
     50         grads_0 = constant_op.constant(0.1 + worker_id * 0.2)
     51         grads_1 = constant_op.constant(0.9 + worker_id * 0.2)
     52         # This is to test against sparse gradients.
     53         grads_sparse = ops.IndexedSlices(
     54             constant_op.constant(
     55                 [0.1 + worker_id * 0.2], shape=[1, 1]),
     56             constant_op.constant([1]),
     57             constant_op.constant([2, 1]))
     58         sgd_opt = gradient_descent.GradientDescentOptimizer(2.0)
     59         sync_rep_opt = training.SyncReplicasOptimizer(
     60             sgd_opt,
     61             replicas_to_aggregate=replicas_to_aggregate,
     62             total_num_replicas=num_workers)
     63         train_op = [
     64             sync_rep_opt.apply_gradients(
     65                 zip([grads_0, grads_1, grads_sparse],
     66                     [var_0, var_1, var_sparse]),
     67                 global_step=global_step)
     68         ]
     69         sync_replicas_hook = sync_rep_opt.make_session_run_hook(
     70             is_chief, num_tokens=num_workers)
     71 
     72       # Creates MonitoredSession
     73       session = training.MonitoredTrainingSession(
     74           master=workers[worker_id].target,
     75           is_chief=is_chief,
     76           hooks=[sync_replicas_hook])
     77 
     78     sessions.append(session)
     79     graphs.append(graph)
     80     train_ops.append(train_op)
     81 
     82   return sessions, graphs, train_ops
     83 
     84 
     85 class SyncReplicasOptimizerTest(test.TestCase):
     86 
     87   def _run(self, train_op, sess):
     88     sess.run(train_op)
     89 
     90   def test2Workers(self):
     91     num_workers = 2
     92     replicas_to_aggregate = 2
     93     num_ps = 2
     94     workers, _ = create_local_cluster(num_workers=num_workers, num_ps=num_ps)
     95 
     96     # Creates and returns all the workers.
     97     sessions, graphs, train_ops = get_workers(num_workers,
     98                                               replicas_to_aggregate, workers)
     99 
    100     # Chief should have already initialized all the variables.
    101     var_0_g_0 = graphs[0].get_tensor_by_name("v0:0")
    102     var_1_g_0 = graphs[0].get_tensor_by_name("v1:0")
    103     local_step_0 = graphs[0].get_tensor_by_name("sync_rep_local_step:0")
    104     self.assertAllEqual(0.0, sessions[0].run(var_0_g_0))
    105     self.assertAllEqual(1.0, sessions[0].run(var_1_g_0))
    106     self.assertAllEqual(0, sessions[0].run(local_step_0))
    107 
    108     # Will just use session 1 to verify all the variables later.
    109     var_0_g_1 = graphs[1].get_tensor_by_name("v0:0")
    110     var_1_g_1 = graphs[1].get_tensor_by_name("v1:0")
    111     var_sparse_g_1 = graphs[1].get_tensor_by_name("v_sparse:0")
    112     local_step_1 = graphs[1].get_tensor_by_name("sync_rep_local_step:0")
    113     global_step = graphs[1].get_tensor_by_name("global_step:0")
    114 
    115     # The steps should also be initialized.
    116     self.assertAllEqual(0, sessions[1].run(global_step))
    117     self.assertAllEqual(0, sessions[1].run(local_step_1))
    118     self.assertAllClose([[3.0], [4.0]], sessions[1].run(var_sparse_g_1))
    119 
    120     # We have initial tokens in the queue so we can call this one by one. After
    121     # the first step, this will no longer work as there will be no more extra
    122     # tokens in the queue.
    123     sessions[0].run(train_ops[0])
    124     sessions[1].run(train_ops[1])
    125 
    126     # The global step should have been updated and the variables should now have
    127     # the new values after the average of the gradients are applied.
    128     while sessions[1].run(global_step) != 1:
    129       time.sleep(0.01)
    130 
    131     self.assertAllClose(0 - (0.1 + 0.3) / 2 * 2.0, sessions[1].run(var_0_g_1))
    132     self.assertAllClose(1 - (0.9 + 1.1) / 2 * 2.0, sessions[1].run(var_1_g_1))
    133     self.assertAllClose([[3.0], [4.0 - (0.1 + 0.3) / 2 * 2.0]],
    134                         sessions[1].run(var_sparse_g_1))
    135 
    136     # The local step for both workers should still be 0 because the initial
    137     # tokens in the token queue are 0s. This means that the following
    138     # computation of the gradients will be wasted as local_step is smaller than
    139     # the current global step. However, this only happens once when the system
    140     # just starts and this is necessary to make the system robust for the case
    141     # when chief gets restarted by errors/preemption/...
    142     self.assertAllEqual(0, sessions[0].run(local_step_0))
    143     self.assertAllEqual(0, sessions[1].run(local_step_1))
    144 
    145     sessions[0].run(train_ops[0])
    146     sessions[1].run(train_ops[1])
    147     # Although the global step should still be 1 as explained above, the local
    148     # step should now be updated to 1. The variables are still the same.
    149     self.assertAllEqual(1, sessions[1].run(global_step))
    150     self.assertAllEqual(1, sessions[0].run(local_step_0))
    151     self.assertAllEqual(1, sessions[1].run(local_step_1))
    152     self.assertAllClose(0 - (0.1 + 0.3) / 2 * 2.0, sessions[1].run(var_0_g_1))
    153     self.assertAllClose(1 - (0.9 + 1.1) / 2 * 2.0, sessions[1].run(var_1_g_1))
    154 
    155     # At this step, the token queue is empty. So the 2 workers need to work
    156     # together to proceed.
    157     threads = []
    158     threads.append(
    159         self.checkedThread(
    160             target=self._run, args=(train_ops[0], sessions[0])))
    161     threads.append(
    162         self.checkedThread(
    163             target=self._run, args=(train_ops[1], sessions[1])))
    164 
    165     # The two workers starts to execute the train op.
    166     for thread in threads:
    167       thread.start()
    168     for thread in threads:
    169       thread.join()
    170 
    171     # The global step should now be 2 and the gradients should have been
    172     # applied twice.
    173     self.assertAllEqual(2, sessions[1].run(global_step))
    174     self.assertAllClose(0 - 2 * (0.1 + 0.3) / 2 * 2.0,
    175                         sessions[1].run(var_0_g_1))
    176     self.assertAllClose(1 - 2 * (0.9 + 1.1) / 2 * 2.0,
    177                         sessions[1].run(var_1_g_1))
    178 
    179   # 3 workers and one of them is backup.
    180   def test3Workers1Backup(self):
    181     num_workers = 3
    182     replicas_to_aggregate = 2
    183     num_ps = 2
    184     workers, _ = create_local_cluster(num_workers=num_workers, num_ps=num_ps)
    185 
    186     # Creates and returns all the workers.
    187     sessions, graphs, train_ops = get_workers(num_workers,
    188                                               replicas_to_aggregate, workers)
    189 
    190     # Chief should have already initialized all the variables.
    191     var_0_g_1 = graphs[1].get_tensor_by_name("v0:0")
    192     var_1_g_1 = graphs[1].get_tensor_by_name("v1:0")
    193     local_step_1 = graphs[1].get_tensor_by_name("sync_rep_local_step:0")
    194     global_step = graphs[1].get_tensor_by_name("global_step:0")
    195 
    196     # The steps should also be initilized.
    197     self.assertAllEqual(0, sessions[1].run(global_step))
    198     self.assertAllEqual(0, sessions[1].run(local_step_1))
    199 
    200     # We have initial tokens in the queue so we can call this one by one. After
    201     # the token queue becomes empty, they should be called concurrently.
    202     # Here worker 0 and worker 2 finished first.
    203     sessions[0].run(train_ops[0])
    204     sessions[2].run(train_ops[2])
    205 
    206     # The global step should have been updated since we only need to collect 2
    207     # gradients. The variables should now have the new values after the average
    208     # of the gradients from worker 0/2 are applied.
    209     while sessions[1].run(global_step) != 1:
    210       time.sleep(0.01)
    211 
    212     self.assertAllEqual(1, sessions[1].run(global_step))
    213     self.assertAllClose(0 - (0.1 + 0.5) / 2 * 2.0, sessions[1].run(var_0_g_1))
    214     self.assertAllClose(1 - (0.9 + 1.3) / 2 * 2.0, sessions[1].run(var_1_g_1))
    215 
    216     # Worker 1 finished later and its gradients will now be dropped as it is
    217     # stale.
    218     sessions[1].run(train_ops[1])
    219 
    220     # As shown in the previous test, the local_step for all workers should be
    221     # still 0 so their next computation will also be dropped.
    222     sessions[0].run(train_ops[0])
    223     sessions[1].run(train_ops[1])
    224     sessions[2].run(train_ops[2])
    225 
    226     # Although the global step should still be 1 as explained above, the local
    227     # step should now be updated to 1. Just check worker 1 as an example.
    228     self.assertAllEqual(1, sessions[1].run(global_step))
    229     self.assertAllEqual(1, sessions[1].run(local_step_1))
    230 
    231     thread_0 = self.checkedThread(
    232         target=self._run, args=(train_ops[0], sessions[0]))
    233     thread_1 = self.checkedThread(
    234         target=self._run, args=(train_ops[1], sessions[1]))
    235 
    236     # Lets worker 0 execute first.
    237     # It will wait as we need 2 workers to finish this step and the global step
    238     # should be still 1.
    239     thread_0.start()
    240     self.assertAllEqual(1, sessions[1].run(global_step))
    241 
    242     # Starts worker 1.
    243     thread_1.start()
    244     thread_1.join()
    245     thread_0.join()
    246 
    247     # The global step should now be 2 and the gradients should have been
    248     # applied again.
    249     self.assertAllEqual(2, sessions[1].run(global_step))
    250     self.assertAllClose(-0.6 - (0.1 + 0.3) / 2 * 2.0,
    251                         sessions[1].run(var_0_g_1))
    252     self.assertAllClose(-1.2 - (0.9 + 1.1) / 2 * 2.0,
    253                         sessions[1].run(var_1_g_1))
    254 
    255 
    256 class SyncReplicasOptimizerHookTest(test.TestCase):
    257 
    258   def testErrorIfUsedBeforeMinimizeCalled(self):
    259     opt = training.SyncReplicasOptimizer(
    260         opt=gradient_descent.GradientDescentOptimizer(1.0),
    261         replicas_to_aggregate=1,
    262         total_num_replicas=1)
    263     hook = opt.make_session_run_hook(True)
    264     with self.assertRaisesRegexp(ValueError,
    265                                  "apply_gradient should be called"):
    266       hook.begin()
    267 
    268   def testCanCreatedBeforeMinimizeCalled(self):
    269     """This behavior is required to be integrated with Estimators."""
    270     opt = training.SyncReplicasOptimizer(
    271         opt=gradient_descent.GradientDescentOptimizer(1.0),
    272         replicas_to_aggregate=1,
    273         total_num_replicas=1)
    274     hook = opt.make_session_run_hook(True)
    275     v = variables.Variable([0.])
    276     global_step = variables.Variable(0, name="global_step", trainable=False)
    277     opt.minimize(v, global_step=global_step)
    278     hook.begin()
    279 
    280   def testFetchVariableList(self):
    281     opt = training.SyncReplicasOptimizer(
    282         opt=adam.AdamOptimizer(0.01),
    283         replicas_to_aggregate=1,
    284         total_num_replicas=1)
    285     v = variables.Variable([0.], name="fetch_variable_test")
    286     global_step = variables.Variable(0, name="global_step", trainable=False)
    287     opt.minimize(v, global_step=global_step)
    288     opt_variables = opt.variables()
    289     beta1_power, beta2_power = opt._opt._get_beta_accumulators()
    290     self.assertIn(beta1_power, opt_variables)
    291     self.assertIn(beta2_power, opt_variables)
    292 
    293 
    294 if __name__ == "__main__":
    295   test.main()
    296