1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for DropStaleGradientOptimizer.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import portpicker 21 22 from tensorflow.contrib.opt.python.training import drop_stale_gradient_optimizer 23 from tensorflow.python.client import session 24 from tensorflow.python.framework import ops 25 from tensorflow.python.ops import data_flow_ops 26 from tensorflow.python.ops import variables 27 from tensorflow.python.platform import test 28 from tensorflow.python.training import gradient_descent 29 from tensorflow.python.training import server_lib 30 from tensorflow.python.training import training_util 31 32 33 # Creates the workers and return their sessions, graphs, train_ops. 34 def _get_workers(num_workers, staleness): 35 worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] 36 cluster_dict = { 37 'worker': ['localhost:%s' % port for port in worker_ports], 38 'ps': ['localhost:%s' % portpicker.pick_unused_port()] 39 } 40 cs = server_lib.ClusterSpec(cluster_dict) 41 workers = [ 42 server_lib.Server( 43 cs, job_name='worker', task_index=ix, start=True) 44 for ix in range(num_workers) 45 ] 46 server_lib.Server(cs, job_name='ps', task_index=0, start=True) 47 48 sessions = [] 49 graphs = [] 50 train_ops = [] 51 52 # To simulate stale cases, maintaining two queues for computing and 53 # applying gradients respectively. In the phase of computing gradients, 54 # all workers except chief worker compute gradients together and chief worker 55 # computes after all other worers' computing finished. In the phase of 56 # applying gradients, chief worker will first apply gradients, then all other 57 # workers will apply gradients one by one. Therefore, the chief worker will 58 # always have 0 staleness, each of all other workers will have a unique 59 # staleness value from [1, num_workers). 60 for worker_id in range(num_workers): 61 graph = ops.Graph() 62 with graph.as_default(): 63 global_step = training_util.create_global_step() 64 var_0 = variables.Variable(0.0, name='v0') 65 var_1 = variables.Variable(1.0, name='v1') 66 compute_gradients_queue = data_flow_ops.FIFOQueue( 67 -1, global_step.dtype.base_dtype, shapes=(), 68 name='compute_gradients_queue', shared_name='compute_gradients_queue') 69 apply_gradients_queue = data_flow_ops.FIFOQueue( 70 -1, global_step.dtype.base_dtype, shapes=(), 71 name='apply_gradients_queue', shared_name='apply_gradients_queue') 72 73 # Gradients for loss on var_0 and var_1 will be 1.0. 74 loss = 0 - var_0 - var_1 75 sgd_opt = gradient_descent.GradientDescentOptimizer(1.0) 76 stale_check_opt = ( 77 drop_stale_gradient_optimizer.DropStaleGradientOptimizer( 78 sgd_opt, staleness)) 79 80 # Compute gradients. 81 if worker_id == 0: 82 with ops.control_dependencies( 83 [compute_gradients_queue.dequeue_many(num_workers - 1)]): 84 grad_and_vars = stale_check_opt.compute_gradients(loss) 85 else: 86 grad_and_vars = stale_check_opt.compute_gradients(loss) 87 with ops.control_dependencies([t[0] for t in grad_and_vars]): 88 worker_enqueue_op = compute_gradients_queue.enqueue(global_step) 89 90 # Apply gradients. 91 if worker_id == 0: 92 with ops.control_dependencies( 93 [stale_check_opt.apply_gradients(grad_and_vars, global_step)]): 94 train_op = apply_gradients_queue.enqueue(global_step) 95 else: 96 with ops.control_dependencies([worker_enqueue_op]): 97 with ops.control_dependencies([apply_gradients_queue.dequeue()]): 98 with ops.control_dependencies( 99 [stale_check_opt.apply_gradients( 100 grad_and_vars, global_step)]): 101 train_op = apply_gradients_queue.enqueue(global_step) 102 103 sess = session.Session(workers[worker_id].target) 104 105 sessions.append(sess) 106 graphs.append(graph) 107 train_ops.append(train_op) 108 109 return sessions, graphs, train_ops 110 111 112 class DropStaleGradientOptimizerTest(test.TestCase): 113 114 def _run(self, train_op, sess): 115 sess.run(train_op) 116 117 def test1Worker(self): 118 num_workers = 1 119 sessions, graphs, train_ops = _get_workers(num_workers, 0) 120 with graphs[0].as_default(): 121 sessions[0].run(variables.global_variables_initializer()) 122 global_step = training_util.get_global_step(graphs[0]) 123 var_0 = graphs[0].get_tensor_by_name('v0:0') 124 var_1 = graphs[0].get_tensor_by_name('v1:0') 125 stale_counter = graphs[0].get_tensor_by_name('stale_counter:0') 126 # Verify the initialized value. 127 self.assertAllEqual(0.0, sessions[0].run(var_0)) 128 self.assertAllEqual(1.0, sessions[0].run(var_1)) 129 self.assertAllEqual(0.0, sessions[0].run(stale_counter)) 130 self.assertAllEqual(0, sessions[0].run(global_step)) 131 132 sessions[0].run(train_ops[0]) 133 134 # Verify the updated value after 1 step. 135 self.assertAllEqual(1, sessions[0].run(global_step)) 136 self.assertAllEqual(0.0 + 1.0, sessions[0].run(var_0)) 137 self.assertAllEqual(1.0 + 1.0, sessions[0].run(var_1)) 138 self.assertAllEqual(1, sessions[0].run(global_step)) 139 140 def test1WorkerNegativeStaleness(self): 141 num_workers = 1 142 sessions, graphs, train_ops = _get_workers(num_workers, -1) 143 with graphs[0].as_default(): 144 sessions[0].run(variables.global_variables_initializer()) 145 global_step = training_util.get_global_step(graphs[0]) 146 var_0 = graphs[0].get_tensor_by_name('v0:0') 147 var_1 = graphs[0].get_tensor_by_name('v1:0') 148 stale_counter = graphs[0].get_tensor_by_name('stale_counter:0') 149 # Verify the initialized value. 150 self.assertAllEqual(0.0, sessions[0].run(var_0)) 151 self.assertAllEqual(1.0, sessions[0].run(var_1)) 152 self.assertAllEqual(0.0, sessions[0].run(stale_counter)) 153 self.assertAllEqual(0, sessions[0].run(global_step)) 154 155 sessions[0].run(train_ops[0]) 156 157 # Verify no updates because max staleness is negative. 158 self.assertAllEqual(0, sessions[0].run(global_step)) 159 self.assertAllEqual(1.0, sessions[0].run(stale_counter)) 160 self.assertAllEqual(0.0, sessions[0].run(var_0)) 161 self.assertAllEqual(1.0, sessions[0].run(var_1)) 162 163 def test2WorkersStaleness0(self): 164 num_workers = 2 165 sessions, graphs, train_ops = _get_workers(num_workers, 0) 166 with graphs[0].as_default(): 167 sessions[0].run(variables.global_variables_initializer()) 168 global_step = training_util.get_global_step(graphs[0]) 169 var_0 = graphs[0].get_tensor_by_name('v0:0') 170 var_1 = graphs[0].get_tensor_by_name('v1:0') 171 stale_counter = graphs[0].get_tensor_by_name('stale_counter:0') 172 # Verify the initialized value. 173 self.assertAllEqual(0.0, sessions[0].run(var_0)) 174 self.assertAllEqual(1.0, sessions[0].run(var_1)) 175 self.assertAllEqual(0.0, sessions[0].run(stale_counter)) 176 self.assertAllEqual(0, sessions[0].run(global_step)) 177 178 thread_0 = self.checkedThread( 179 target=self._run, args=(train_ops[0], sessions[0])) 180 thread_1 = self.checkedThread( 181 target=self._run, args=(train_ops[1], sessions[1])) 182 thread_0.start() 183 thread_1.start() 184 thread_0.join() 185 thread_1.join() 186 187 # With 2 workers and max staleness set to 0, only chief worker will update 188 # var_0 and var_1. 189 self.assertAllEqual(1, sessions[0].run(global_step)) 190 self.assertAllEqual(1.0, sessions[0].run(stale_counter)) 191 self.assertAllEqual(0.0 + 1.0, sessions[0].run(var_0)) 192 self.assertAllEqual(1.0 + 1.0, sessions[0].run(var_1)) 193 194 def test2WorkersStaleness1(self): 195 num_workers = 2 196 sessions, graphs, train_ops = _get_workers(num_workers, 1) 197 with graphs[0].as_default(): 198 sessions[0].run(variables.global_variables_initializer()) 199 global_step = training_util.get_global_step(graphs[0]) 200 var_0 = graphs[0].get_tensor_by_name('v0:0') 201 var_1 = graphs[0].get_tensor_by_name('v1:0') 202 stale_counter = graphs[0].get_tensor_by_name('stale_counter:0') 203 # Verify the initialized value. 204 self.assertAllEqual(0.0, sessions[0].run(var_0)) 205 self.assertAllEqual(1.0, sessions[0].run(var_1)) 206 self.assertAllEqual(0.0, sessions[0].run(stale_counter)) 207 self.assertAllEqual(0, sessions[0].run(global_step)) 208 209 thread_0 = self.checkedThread( 210 target=self._run, args=(train_ops[0], sessions[0])) 211 thread_1 = self.checkedThread( 212 target=self._run, args=(train_ops[1], sessions[1])) 213 thread_0.start() 214 thread_1.start() 215 thread_0.join() 216 thread_1.join() 217 218 # With 2 workers and max staleness set to 1, both workers will update 219 # var_0 and var_1. 220 self.assertAllEqual(2, sessions[0].run(global_step)) 221 self.assertAllEqual(0.0, sessions[0].run(stale_counter)) 222 self.assertAllEqual(0.0 + 2.0, sessions[0].run(var_0)) 223 self.assertAllEqual(1.0 + 2.0, sessions[0].run(var_1)) 224 225 def test3WorkersStaleness0(self): 226 num_workers = 3 227 sessions, graphs, train_ops = _get_workers(num_workers, 0) 228 with graphs[0].as_default(): 229 sessions[0].run(variables.global_variables_initializer()) 230 global_step = training_util.get_global_step(graphs[0]) 231 var_0 = graphs[0].get_tensor_by_name('v0:0') 232 var_1 = graphs[0].get_tensor_by_name('v1:0') 233 stale_counter = graphs[0].get_tensor_by_name('stale_counter:0') 234 # Verify the initialized value. 235 self.assertAllEqual(0.0, sessions[0].run(var_0)) 236 self.assertAllEqual(1.0, sessions[0].run(var_1)) 237 self.assertAllEqual(0.0, sessions[0].run(stale_counter)) 238 self.assertAllEqual(0, sessions[0].run(global_step)) 239 240 thread_0 = self.checkedThread( 241 target=self._run, args=(train_ops[0], sessions[0])) 242 thread_1 = self.checkedThread( 243 target=self._run, args=(train_ops[1], sessions[1])) 244 thread_2 = self.checkedThread( 245 target=self._run, args=(train_ops[2], sessions[2])) 246 thread_0.start() 247 thread_1.start() 248 thread_2.start() 249 thread_0.join() 250 thread_1.join() 251 thread_2.join() 252 253 # With 3 workers and max staleness set to 0, only chief worker will update 254 # var_0 and var_1. 255 self.assertAllEqual(1, sessions[0].run(global_step)) 256 self.assertAllEqual(2.0, sessions[0].run(stale_counter)) 257 self.assertAllEqual(0.0 + 1.0, sessions[0].run(var_0)) 258 self.assertAllEqual(1.0 + 1.0, sessions[0].run(var_1)) 259 260 def test3WorkersStaleness1(self): 261 num_workers = 3 262 sessions, graphs, train_ops = _get_workers(num_workers, 1) 263 with graphs[0].as_default(): 264 sessions[0].run(variables.global_variables_initializer()) 265 global_step = training_util.get_global_step(graphs[0]) 266 var_0 = graphs[0].get_tensor_by_name('v0:0') 267 var_1 = graphs[0].get_tensor_by_name('v1:0') 268 stale_counter = graphs[0].get_tensor_by_name('stale_counter:0') 269 # Verify the initialized value. 270 self.assertAllEqual(0.0, sessions[0].run(var_0)) 271 self.assertAllEqual(1.0, sessions[0].run(var_1)) 272 self.assertAllEqual(0.0, sessions[0].run(stale_counter)) 273 self.assertAllEqual(0, sessions[0].run(global_step)) 274 275 thread_0 = self.checkedThread( 276 target=self._run, args=(train_ops[0], sessions[0])) 277 thread_1 = self.checkedThread( 278 target=self._run, args=(train_ops[1], sessions[1])) 279 thread_2 = self.checkedThread( 280 target=self._run, args=(train_ops[2], sessions[2])) 281 thread_0.start() 282 thread_1.start() 283 thread_2.start() 284 thread_0.join() 285 thread_1.join() 286 thread_2.join() 287 288 # With 3 workers and max staleness set to 1, chief worker and only one of 289 # the two other workers will update var_0 and var_1. 290 self.assertAllEqual(2, sessions[0].run(global_step)) 291 self.assertAllEqual(1.0, sessions[0].run(stale_counter)) 292 self.assertAllEqual(0.0 + 2.0, sessions[0].run(var_0)) 293 self.assertAllEqual(1.0 + 2.0, sessions[0].run(var_1)) 294 295 296 if __name__ == '__main__': 297 test.main() 298