1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for ModelAverageOptimizer.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import portpicker 21 22 from tensorflow.contrib.opt.python.training import model_average_optimizer 23 from tensorflow.python.framework import constant_op 24 from tensorflow.python.framework import ops 25 from tensorflow.python.ops import variable_scope 26 from tensorflow.python.ops import variables 27 from tensorflow.python.platform import test 28 from tensorflow.python.training import device_setter 29 from tensorflow.python.training import gradient_descent 30 from tensorflow.python.training import server_lib 31 from tensorflow.python.training import training 32 from tensorflow.python.training import training_util 33 34 35 def create_local_cluster(num_workers, num_ps, protocol="grpc"): 36 """Create local GRPC servers and return them.""" 37 worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] 38 ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] 39 cluster_dict = { 40 "worker": ["localhost:%s" % port for port in worker_ports], 41 "ps": ["localhost:%s" % port for port in ps_ports] 42 } 43 cs = server_lib.ClusterSpec(cluster_dict) 44 45 workers = [ 46 server_lib.Server( 47 cs, job_name="worker", protocol=protocol, task_index=ix, start=True) 48 for ix in range(num_workers) 49 ] 50 ps_servers = [ 51 server_lib.Server( 52 cs, job_name="ps", protocol=protocol, task_index=ix, start=True) 53 for ix in range(num_ps) 54 ] 55 56 return cluster_dict, workers, ps_servers 57 58 59 # Creates the workers and return their sessions, graphs, train_ops. 60 # Cheif worker will update at last 61 def _get_workers(num_workers, steps, workers): 62 sessions = [] 63 graphs = [] 64 train_ops = [] 65 for worker_id in range(num_workers): 66 graph = ops.Graph() 67 is_chief = (worker_id == 0) 68 with graph.as_default(): 69 worker_device = "/job:worker/task:%d/cpu:0" % (worker_id) 70 ma_coustom = model_average_optimizer.ModelAverageCustomGetter( 71 worker_device=worker_device) 72 with variable_scope.variable_scope( 73 "", custom_getter=ma_coustom), ops.device( 74 device_setter.replica_device_setter( 75 worker_device=worker_device, 76 ps_device="/job:ps/task:0/cpu:0", 77 ps_tasks=1)): 78 79 global_step = variables.Variable(0, name="global_step", trainable=False) 80 var_0 = variable_scope.get_variable(initializer=0.0, name="v0") 81 var_1 = variable_scope.get_variable(initializer=1.0, name="v1") 82 83 with ops.device("/job:worker/task:" + str(worker_id)): 84 if worker_id == 0: 85 grads_0 = constant_op.constant(-1.0) 86 grads_1 = constant_op.constant(-1.0) 87 else: 88 grads_0 = constant_op.constant(-2.0) 89 grads_1 = constant_op.constant(-2.0) 90 sgd_opt = gradient_descent.GradientDescentOptimizer(1.0) 91 opt = model_average_optimizer.ModelAverageOptimizer( 92 opt=sgd_opt, 93 num_worker=num_workers, 94 ma_custom_getter=ma_coustom, 95 is_chief=is_chief, 96 interval_steps=steps) 97 train_op = [ 98 opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]], 99 global_step) 100 ] 101 easgd_hook = opt.make_session_run_hook() 102 # Creates MonitoredSession 103 sess = training.MonitoredTrainingSession( 104 workers[worker_id].target, hooks=[easgd_hook]) 105 106 sessions.append(sess) 107 graphs.append(graph) 108 train_ops.append(train_op) 109 return sessions, graphs, train_ops 110 111 112 class ModelAverageOptimizerTest(test.TestCase): 113 def _run(self, train_op, sess): 114 sess.run(train_op) 115 116 def test1Workers2Period(self): 117 num_workers = 2 118 steps = 2 119 num_ps = 1 120 _, workers, _ = create_local_cluster( 121 num_workers=num_workers, num_ps=num_ps) 122 123 sessions, graphs, train_ops = _get_workers(num_workers, steps, workers) 124 125 var_0 = graphs[0].get_tensor_by_name("v0:0") 126 var_1 = graphs[0].get_tensor_by_name("v1:0") 127 global_step = training_util.get_global_step(graphs[0]) 128 global_var_0 = graphs[0].get_tensor_by_name( 129 model_average_optimizer.GLOBAL_VARIABLE_NAME + "/v0:0") 130 global_var_1 = graphs[0].get_tensor_by_name( 131 model_average_optimizer.GLOBAL_VARIABLE_NAME + "/v1:0") 132 133 # Verify the initialized value. 134 self.assertAllEqual(0.0, sessions[0].run(var_0)) 135 self.assertAllEqual(1.0, sessions[0].run(var_1)) 136 self.assertAllEqual(0.0, sessions[0].run(global_var_0)) 137 self.assertAllEqual(1.0, sessions[0].run(global_var_1)) 138 self.assertAllEqual(0, sessions[0].run(global_step)) 139 140 sessions[0].run(train_ops[0]) 141 sessions[1].run(train_ops[1]) 142 143 self.assertAllEqual(1.0, sessions[0].run(var_0)) 144 self.assertAllEqual(2.0, sessions[0].run(var_1)) 145 self.assertAllEqual(0.0, sessions[0].run(global_var_0)) 146 self.assertAllEqual(1.0, sessions[0].run(global_var_1)) 147 self.assertAllEqual(0, sessions[0].run(global_step)) 148 149 # iteration 2, global varibale update 150 thread_0 = self.checkedThread( 151 target=self._run, args=(train_ops[0], sessions[0])) 152 thread_1 = self.checkedThread( 153 target=self._run, args=(train_ops[1], sessions[1])) 154 thread_0.start() 155 thread_1.start() 156 thread_0.join() 157 thread_1.join() 158 159 self.assertAllEqual(3.0, sessions[0].run(var_0)) 160 self.assertAllEqual(4.0, sessions[0].run(var_1)) 161 self.assertAllEqual(3.0, sessions[0].run(global_var_0)) 162 self.assertAllEqual(4.0, sessions[0].run(global_var_1)) 163 self.assertAllEqual(1, sessions[0].run(global_step)) 164 165 # iteration 3 166 sessions[0].run(train_ops[0]) 167 168 self.assertAllEqual(4.0, sessions[0].run(var_0)) 169 self.assertAllEqual(5.0, sessions[0].run(var_1)) 170 self.assertAllEqual(3.0, sessions[0].run(global_var_0)) 171 self.assertAllEqual(4.0, sessions[0].run(global_var_1)) 172 self.assertAllEqual(1, sessions[0].run(global_step)) 173 174 def testPS2TasksWithClusterSpecClass(self): 175 cluster_spec = server_lib.ClusterSpec({ 176 "ps": ["ps0:2222", "ps1:2222"], 177 "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] 178 }) 179 worker_device = "/job:worker/task:0" 180 ma_coustom = model_average_optimizer.ModelAverageCustomGetter( 181 worker_device=worker_device) 182 from tensorflow.python.training import device_setter 183 with ops.device( 184 device_setter.replica_device_setter(cluster=cluster_spec, 185 worker_device=worker_device, 186 ps_device="/job:ps")), \ 187 variable_scope.variable_scope("", custom_getter=ma_coustom): 188 v = variable_scope.get_variable(initializer=[1, 2], name="v") 189 w = variable_scope.get_variable(initializer=[2, 1], name="w") 190 v_g, w_g = ma_coustom._local_2_global[v], ma_coustom._local_2_global[w] 191 self.assertDeviceEqual("/job:worker/task:0", v.device) 192 self.assertDeviceEqual("job:ps/task:0", v_g.device) 193 self.assertDeviceEqual("/job:worker/task:0", w.device) 194 self.assertDeviceEqual("job:ps/task:1", w_g.device) 195 196 197 if __name__ == "__main__": 198 test.main() 199