Home | History | Annotate | Download | only in training
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for DropStaleGradientOptimizer."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import portpicker
     21 
     22 from tensorflow.contrib.opt.python.training import drop_stale_gradient_optimizer
     23 from tensorflow.python.client import session
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.ops import data_flow_ops
     26 from tensorflow.python.ops import variables
     27 from tensorflow.python.platform import test
     28 from tensorflow.python.training import gradient_descent
     29 from tensorflow.python.training import server_lib
     30 from tensorflow.python.training import training_util
     31 
     32 
     33 # Creates the workers and return their sessions, graphs, train_ops.
     34 def _get_workers(num_workers, staleness):
     35   worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
     36   cluster_dict = {
     37       'worker': ['localhost:%s' % port for port in worker_ports],
     38       'ps': ['localhost:%s' % portpicker.pick_unused_port()]
     39   }
     40   cs = server_lib.ClusterSpec(cluster_dict)
     41   workers = [
     42       server_lib.Server(
     43           cs, job_name='worker', task_index=ix, start=True)
     44       for ix in range(num_workers)
     45   ]
     46   server_lib.Server(cs, job_name='ps', task_index=0, start=True)
     47 
     48   sessions = []
     49   graphs = []
     50   train_ops = []
     51 
     52   # To simulate stale cases, maintaining two queues for computing and
     53   # applying gradients respectively. In the phase of computing gradients,
     54   # all workers except chief worker compute gradients together and chief worker
     55   # computes after all other worers' computing finished. In the phase of
     56   # applying gradients, chief worker will first apply gradients, then all other
     57   # workers will apply gradients one by one. Therefore, the chief worker will
     58   # always have 0 staleness, each of all other workers will have a unique
     59   # staleness value from [1, num_workers).
     60   for worker_id in range(num_workers):
     61     graph = ops.Graph()
     62     with graph.as_default():
     63       global_step = training_util.create_global_step()
     64       var_0 = variables.Variable(0.0, name='v0')
     65       var_1 = variables.Variable(1.0, name='v1')
     66       compute_gradients_queue = data_flow_ops.FIFOQueue(
     67           -1, global_step.dtype.base_dtype, shapes=(),
     68           name='compute_gradients_queue', shared_name='compute_gradients_queue')
     69       apply_gradients_queue = data_flow_ops.FIFOQueue(
     70           -1, global_step.dtype.base_dtype, shapes=(),
     71           name='apply_gradients_queue', shared_name='apply_gradients_queue')
     72 
     73       # Gradients for loss on var_0 and var_1 will be 1.0.
     74       loss = 0 - var_0 - var_1
     75       sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
     76       stale_check_opt = (
     77           drop_stale_gradient_optimizer.DropStaleGradientOptimizer(
     78               sgd_opt, staleness))
     79 
     80       # Compute gradients.
     81       if worker_id == 0:
     82         with ops.control_dependencies(
     83             [compute_gradients_queue.dequeue_many(num_workers - 1)]):
     84           grad_and_vars = stale_check_opt.compute_gradients(loss)
     85       else:
     86         grad_and_vars = stale_check_opt.compute_gradients(loss)
     87         with ops.control_dependencies([t[0] for t in grad_and_vars]):
     88           worker_enqueue_op = compute_gradients_queue.enqueue(global_step)
     89 
     90       # Apply gradients.
     91       if worker_id == 0:
     92         with ops.control_dependencies(
     93             [stale_check_opt.apply_gradients(grad_and_vars, global_step)]):
     94           train_op = apply_gradients_queue.enqueue(global_step)
     95       else:
     96         with ops.control_dependencies([worker_enqueue_op]):
     97           with ops.control_dependencies([apply_gradients_queue.dequeue()]):
     98             with ops.control_dependencies(
     99                 [stale_check_opt.apply_gradients(
    100                     grad_and_vars, global_step)]):
    101               train_op = apply_gradients_queue.enqueue(global_step)
    102 
    103       sess = session.Session(workers[worker_id].target)
    104 
    105     sessions.append(sess)
    106     graphs.append(graph)
    107     train_ops.append(train_op)
    108 
    109   return sessions, graphs, train_ops
    110 
    111 
    112 class DropStaleGradientOptimizerTest(test.TestCase):
    113 
    114   def _run(self, train_op, sess):
    115     sess.run(train_op)
    116 
    117   def test1Worker(self):
    118     num_workers = 1
    119     sessions, graphs, train_ops = _get_workers(num_workers, 0)
    120     with graphs[0].as_default():
    121       sessions[0].run(variables.global_variables_initializer())
    122     global_step = training_util.get_global_step(graphs[0])
    123     var_0 = graphs[0].get_tensor_by_name('v0:0')
    124     var_1 = graphs[0].get_tensor_by_name('v1:0')
    125     stale_counter = graphs[0].get_tensor_by_name('stale_counter:0')
    126     # Verify the initialized value.
    127     self.assertAllEqual(0.0, sessions[0].run(var_0))
    128     self.assertAllEqual(1.0, sessions[0].run(var_1))
    129     self.assertAllEqual(0.0, sessions[0].run(stale_counter))
    130     self.assertAllEqual(0, sessions[0].run(global_step))
    131 
    132     sessions[0].run(train_ops[0])
    133 
    134     # Verify the updated value after 1 step.
    135     self.assertAllEqual(1, sessions[0].run(global_step))
    136     self.assertAllEqual(0.0 + 1.0, sessions[0].run(var_0))
    137     self.assertAllEqual(1.0 + 1.0, sessions[0].run(var_1))
    138     self.assertAllEqual(1, sessions[0].run(global_step))
    139 
    140   def test1WorkerNegativeStaleness(self):
    141     num_workers = 1
    142     sessions, graphs, train_ops = _get_workers(num_workers, -1)
    143     with graphs[0].as_default():
    144       sessions[0].run(variables.global_variables_initializer())
    145     global_step = training_util.get_global_step(graphs[0])
    146     var_0 = graphs[0].get_tensor_by_name('v0:0')
    147     var_1 = graphs[0].get_tensor_by_name('v1:0')
    148     stale_counter = graphs[0].get_tensor_by_name('stale_counter:0')
    149     # Verify the initialized value.
    150     self.assertAllEqual(0.0, sessions[0].run(var_0))
    151     self.assertAllEqual(1.0, sessions[0].run(var_1))
    152     self.assertAllEqual(0.0, sessions[0].run(stale_counter))
    153     self.assertAllEqual(0, sessions[0].run(global_step))
    154 
    155     sessions[0].run(train_ops[0])
    156 
    157     # Verify no updates because max staleness is negative.
    158     self.assertAllEqual(0, sessions[0].run(global_step))
    159     self.assertAllEqual(1.0, sessions[0].run(stale_counter))
    160     self.assertAllEqual(0.0, sessions[0].run(var_0))
    161     self.assertAllEqual(1.0, sessions[0].run(var_1))
    162 
    163   def test2WorkersStaleness0(self):
    164     num_workers = 2
    165     sessions, graphs, train_ops = _get_workers(num_workers, 0)
    166     with graphs[0].as_default():
    167       sessions[0].run(variables.global_variables_initializer())
    168     global_step = training_util.get_global_step(graphs[0])
    169     var_0 = graphs[0].get_tensor_by_name('v0:0')
    170     var_1 = graphs[0].get_tensor_by_name('v1:0')
    171     stale_counter = graphs[0].get_tensor_by_name('stale_counter:0')
    172     # Verify the initialized value.
    173     self.assertAllEqual(0.0, sessions[0].run(var_0))
    174     self.assertAllEqual(1.0, sessions[0].run(var_1))
    175     self.assertAllEqual(0.0, sessions[0].run(stale_counter))
    176     self.assertAllEqual(0, sessions[0].run(global_step))
    177 
    178     thread_0 = self.checkedThread(
    179         target=self._run, args=(train_ops[0], sessions[0]))
    180     thread_1 = self.checkedThread(
    181         target=self._run, args=(train_ops[1], sessions[1]))
    182     thread_0.start()
    183     thread_1.start()
    184     thread_0.join()
    185     thread_1.join()
    186 
    187     # With 2 workers and max staleness set to 0, only chief worker will update
    188     # var_0 and var_1.
    189     self.assertAllEqual(1, sessions[0].run(global_step))
    190     self.assertAllEqual(1.0, sessions[0].run(stale_counter))
    191     self.assertAllEqual(0.0 + 1.0, sessions[0].run(var_0))
    192     self.assertAllEqual(1.0 + 1.0, sessions[0].run(var_1))
    193 
    194   def test2WorkersStaleness1(self):
    195     num_workers = 2
    196     sessions, graphs, train_ops = _get_workers(num_workers, 1)
    197     with graphs[0].as_default():
    198       sessions[0].run(variables.global_variables_initializer())
    199     global_step = training_util.get_global_step(graphs[0])
    200     var_0 = graphs[0].get_tensor_by_name('v0:0')
    201     var_1 = graphs[0].get_tensor_by_name('v1:0')
    202     stale_counter = graphs[0].get_tensor_by_name('stale_counter:0')
    203     # Verify the initialized value.
    204     self.assertAllEqual(0.0, sessions[0].run(var_0))
    205     self.assertAllEqual(1.0, sessions[0].run(var_1))
    206     self.assertAllEqual(0.0, sessions[0].run(stale_counter))
    207     self.assertAllEqual(0, sessions[0].run(global_step))
    208 
    209     thread_0 = self.checkedThread(
    210         target=self._run, args=(train_ops[0], sessions[0]))
    211     thread_1 = self.checkedThread(
    212         target=self._run, args=(train_ops[1], sessions[1]))
    213     thread_0.start()
    214     thread_1.start()
    215     thread_0.join()
    216     thread_1.join()
    217 
    218     # With 2 workers and max staleness set to 1, both workers will update
    219     # var_0 and var_1.
    220     self.assertAllEqual(2, sessions[0].run(global_step))
    221     self.assertAllEqual(0.0, sessions[0].run(stale_counter))
    222     self.assertAllEqual(0.0 + 2.0, sessions[0].run(var_0))
    223     self.assertAllEqual(1.0 + 2.0, sessions[0].run(var_1))
    224 
    225   def test3WorkersStaleness0(self):
    226     num_workers = 3
    227     sessions, graphs, train_ops = _get_workers(num_workers, 0)
    228     with graphs[0].as_default():
    229       sessions[0].run(variables.global_variables_initializer())
    230     global_step = training_util.get_global_step(graphs[0])
    231     var_0 = graphs[0].get_tensor_by_name('v0:0')
    232     var_1 = graphs[0].get_tensor_by_name('v1:0')
    233     stale_counter = graphs[0].get_tensor_by_name('stale_counter:0')
    234     # Verify the initialized value.
    235     self.assertAllEqual(0.0, sessions[0].run(var_0))
    236     self.assertAllEqual(1.0, sessions[0].run(var_1))
    237     self.assertAllEqual(0.0, sessions[0].run(stale_counter))
    238     self.assertAllEqual(0, sessions[0].run(global_step))
    239 
    240     thread_0 = self.checkedThread(
    241         target=self._run, args=(train_ops[0], sessions[0]))
    242     thread_1 = self.checkedThread(
    243         target=self._run, args=(train_ops[1], sessions[1]))
    244     thread_2 = self.checkedThread(
    245         target=self._run, args=(train_ops[2], sessions[2]))
    246     thread_0.start()
    247     thread_1.start()
    248     thread_2.start()
    249     thread_0.join()
    250     thread_1.join()
    251     thread_2.join()
    252 
    253     # With 3 workers and max staleness set to 0, only chief worker will update
    254     # var_0 and var_1.
    255     self.assertAllEqual(1, sessions[0].run(global_step))
    256     self.assertAllEqual(2.0, sessions[0].run(stale_counter))
    257     self.assertAllEqual(0.0 + 1.0, sessions[0].run(var_0))
    258     self.assertAllEqual(1.0 + 1.0, sessions[0].run(var_1))
    259 
    260   def test3WorkersStaleness1(self):
    261     num_workers = 3
    262     sessions, graphs, train_ops = _get_workers(num_workers, 1)
    263     with graphs[0].as_default():
    264       sessions[0].run(variables.global_variables_initializer())
    265     global_step = training_util.get_global_step(graphs[0])
    266     var_0 = graphs[0].get_tensor_by_name('v0:0')
    267     var_1 = graphs[0].get_tensor_by_name('v1:0')
    268     stale_counter = graphs[0].get_tensor_by_name('stale_counter:0')
    269     # Verify the initialized value.
    270     self.assertAllEqual(0.0, sessions[0].run(var_0))
    271     self.assertAllEqual(1.0, sessions[0].run(var_1))
    272     self.assertAllEqual(0.0, sessions[0].run(stale_counter))
    273     self.assertAllEqual(0, sessions[0].run(global_step))
    274 
    275     thread_0 = self.checkedThread(
    276         target=self._run, args=(train_ops[0], sessions[0]))
    277     thread_1 = self.checkedThread(
    278         target=self._run, args=(train_ops[1], sessions[1]))
    279     thread_2 = self.checkedThread(
    280         target=self._run, args=(train_ops[2], sessions[2]))
    281     thread_0.start()
    282     thread_1.start()
    283     thread_2.start()
    284     thread_0.join()
    285     thread_1.join()
    286     thread_2.join()
    287 
    288     # With 3 workers and max staleness set to 1, chief worker and only one of
    289     # the two other workers will update var_0 and var_1.
    290     self.assertAllEqual(2, sessions[0].run(global_step))
    291     self.assertAllEqual(1.0, sessions[0].run(stale_counter))
    292     self.assertAllEqual(0.0 + 2.0, sessions[0].run(var_0))
    293     self.assertAllEqual(1.0 + 2.0, sessions[0].run(var_1))
    294 
    295 
    296 if __name__ == '__main__':
    297   test.main()
    298