1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """run_config.py tests.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import copy 22 import json 23 24 from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib 25 from tensorflow.core.protobuf import config_pb2 26 from tensorflow.python.estimator import run_config as core_run_config 27 from tensorflow.python.platform import test 28 from tensorflow.python.training import server_lib 29 30 TEST_DIR = "test_dir" 31 ANOTHER_TEST_DIR = "another_test_dir" 32 MASTER = "master_" 33 RANDOM_SEED = 123 34 35 patch = test.mock.patch 36 37 38 def _create_run_config_with_cluster_spec(tf_config_str): 39 with patch.dict("os.environ", {"TF_CONFIG": tf_config_str}): 40 return run_config_lib.RunConfig( 41 tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR) 42 43 44 class RunConfigTest(test.TestCase): 45 46 def test_instance_of_core_run_config(self): 47 config = run_config_lib.RunConfig() 48 self.assertTrue(isinstance(config, core_run_config.RunConfig)) 49 50 def test_defaults_with_no_tf_config(self): 51 config = run_config_lib.RunConfig() 52 self.assertEqual(config.master, "") 53 self.assertEqual(config.task_id, 0) 54 self.assertEqual(config.num_ps_replicas, 0) 55 self.assertEqual(config.cluster_spec, {}) 56 self.assertIsNone(config.task_type) 57 self.assertTrue(config.is_chief) 58 self.assertEqual(config.evaluation_master, "") 59 60 def test_values_from_tf_config(self): 61 tf_config = { 62 "cluster": { 63 run_config_lib.TaskType.PS: ["host1:1", "host2:2"], 64 run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"] 65 }, 66 "task": { 67 "type": run_config_lib.TaskType.WORKER, 68 "index": 1 69 } 70 } 71 with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): 72 config = run_config_lib.RunConfig() 73 74 self.assertEqual(config.master, "grpc://host4:4") 75 self.assertEqual(config.task_id, 1) 76 self.assertEqual(config.num_ps_replicas, 2) 77 self.assertEqual(config.num_worker_replicas, 3) 78 self.assertEqual(config.cluster_spec.as_dict(), tf_config["cluster"]) 79 self.assertEqual(config.task_type, run_config_lib.TaskType.WORKER) 80 self.assertFalse(config.is_chief) 81 self.assertEqual(config.evaluation_master, "") 82 83 def test_explicitly_specified_values(self): 84 cluster_spec = { 85 run_config_lib.TaskType.PS: ["localhost:9990"], 86 "my_job_name": ["localhost:9991", "localhost:9992", "localhost:0"] 87 } 88 tf_config = { 89 "cluster": cluster_spec, 90 "task": { 91 "type": run_config_lib.TaskType.WORKER, 92 "index": 2 93 } 94 } 95 with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): 96 config = run_config_lib.RunConfig( 97 master="localhost:0", evaluation_master="localhost:9991") 98 99 self.assertEqual(config.master, "localhost:0") 100 self.assertEqual(config.task_id, 2) 101 self.assertEqual(config.num_ps_replicas, 1) 102 self.assertEqual(config.num_worker_replicas, 0) 103 self.assertEqual(config.cluster_spec, server_lib.ClusterSpec(cluster_spec)) 104 self.assertEqual(config.task_type, run_config_lib.TaskType.WORKER) 105 self.assertFalse(config.is_chief) 106 self.assertEqual(config.evaluation_master, "localhost:9991") 107 108 def test_single_node_in_cluster_spec_produces_empty_master(self): 109 tf_config = {"cluster": {run_config_lib.TaskType.WORKER: ["host1:1"]}} 110 with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): 111 config = run_config_lib.RunConfig() 112 self.assertEqual(config.master, "") 113 114 def test_no_task_type_produces_empty_master(self): 115 tf_config = { 116 "cluster": { 117 run_config_lib.TaskType.PS: ["host1:1", "host2:2"], 118 run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"] 119 }, 120 # Omits "task": {"type": "worker} 121 } 122 with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): 123 config = run_config_lib.RunConfig() 124 self.assertEqual(config.master, "") 125 126 def test_invalid_job_name_raises(self): 127 tf_config = { 128 "cluster": { 129 run_config_lib.TaskType.PS: ["host1:1", "host2:2"], 130 run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"] 131 }, 132 "task": { 133 "type": "not_in_cluster_spec" 134 } 135 } 136 expected_msg_regexp = "not_in_cluster_spec is not a valid task" 137 with patch.dict( 138 "os.environ", 139 {"TF_CONFIG": json.dumps(tf_config)}), self.assertRaisesRegexp( 140 ValueError, expected_msg_regexp): 141 run_config_lib.RunConfig() 142 143 def test_illegal_task_index_raises(self): 144 tf_config = { 145 "cluster": { 146 run_config_lib.TaskType.PS: ["host1:1", "host2:2"], 147 run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"] 148 }, 149 "task": { 150 "type": run_config_lib.TaskType.WORKER, 151 "index": 3 152 } 153 } 154 expected_msg_regexp = "3 is not a valid task_id" 155 with patch.dict( 156 "os.environ", 157 {"TF_CONFIG": json.dumps(tf_config)}), self.assertRaisesRegexp( 158 ValueError, expected_msg_regexp): 159 run_config_lib.RunConfig() 160 161 def test_is_chief_from_cloud_tf_config(self): 162 # is_chief should be true when ["task"]["type"] == "master" and 163 # index == 0 and ["task"]["environment"] == "cloud". Note that 164 # test_values_from_tf_config covers the non-master case. 165 tf_config = { 166 "cluster": { 167 run_config_lib.TaskType.PS: ["host1:1", "host2:2"], 168 run_config_lib.TaskType.MASTER: ["host3:3"], 169 run_config_lib.TaskType.WORKER: ["host4:4", "host5:5", "host6:6"] 170 }, 171 "task": { 172 "type": run_config_lib.TaskType.MASTER, 173 "index": 0 174 }, 175 "environment": "cloud" 176 } 177 with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): 178 config = run_config_lib.RunConfig() 179 180 self.assertTrue(config.is_chief) 181 182 def test_is_chief_from_noncloud_tf_config(self): 183 # is_chief should be true when ["task"]["type"] == "worker" and 184 # index == 0 if ["task"]["environment"] != "cloud". 185 tf_config = { 186 "cluster": { 187 run_config_lib.TaskType.PS: ["host1:1", "host2:2"], 188 run_config_lib.TaskType.MASTER: ["host3:3"], 189 run_config_lib.TaskType.WORKER: ["host4:4", "host5:5", "host6:6"] 190 }, 191 "task": { 192 "type": run_config_lib.TaskType.WORKER, 193 "index": 0 194 }, 195 "environment": "random" 196 } 197 with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): 198 config = run_config_lib.RunConfig() 199 200 self.assertTrue(config.is_chief) 201 202 # But task 0 for a job named "master" should not be. 203 tf_config = { 204 "cluster": { 205 run_config_lib.TaskType.PS: ["host1:1", "host2:2"], 206 run_config_lib.TaskType.MASTER: ["host3:3"], 207 run_config_lib.TaskType.WORKER: ["host4:4", "host5:5", "host6:6"] 208 }, 209 "task": { 210 "type": run_config_lib.TaskType.MASTER, 211 "index": 0 212 }, 213 "environment": "random" 214 } 215 with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): 216 config = run_config_lib.RunConfig() 217 218 self.assertFalse(config.is_chief) 219 220 def test_default_is_chief_from_tf_config_without_job_name(self): 221 tf_config = {"cluster": {}, "task": {}} 222 with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): 223 config = run_config_lib.RunConfig() 224 225 self.assertTrue(config.is_chief) 226 227 def test_model_dir(self): 228 empty_config = run_config_lib.RunConfig() 229 self.assertIsNone(empty_config.model_dir) 230 231 config = run_config_lib.RunConfig(model_dir=TEST_DIR) 232 self.assertEqual(TEST_DIR, config.model_dir) 233 234 def test_model_dir_in_tf_config(self): 235 tf_config = {"model_dir": TEST_DIR} 236 with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): 237 run_config = run_config_lib.RunConfig() 238 self.assertEqual(TEST_DIR, run_config.model_dir) 239 240 def test_model_dir_both_in_tf_config_and_constructor(self): 241 tf_config = {"model_dir": TEST_DIR} 242 with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): 243 run_config = run_config_lib.RunConfig(model_dir=TEST_DIR) 244 self.assertEqual(TEST_DIR, run_config.model_dir) 245 246 def test_model_dir_fail_if_constructor_value_mismatch_tf_config(self): 247 tf_config = {"model_dir": TEST_DIR} 248 with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): 249 with self.assertRaisesRegexp( 250 ValueError, 251 "`model_dir` provided in RunConfig .* must have " 252 "the same value .* in TF_CONFIG"): 253 run_config_lib.RunConfig(model_dir=TEST_DIR + "/sub_dir") 254 255 def test_replace(self): 256 config = run_config_lib.RunConfig( 257 tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR) 258 self.assertEqual(TEST_DIR, config.model_dir) 259 self.assertEqual(RANDOM_SEED, config.tf_random_seed) 260 261 new_config = config.replace(model_dir=ANOTHER_TEST_DIR) 262 self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir) 263 self.assertEqual(RANDOM_SEED, new_config.tf_random_seed) 264 self.assertEqual(RANDOM_SEED, config.tf_random_seed) 265 266 def test_uid_for_different_configs(self): 267 config = run_config_lib.RunConfig( 268 tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR) 269 270 expected_uid = config.uid() 271 # Check for 10 times, which should prove something. 272 for _ in range(10): 273 self.assertEqual(expected_uid, config.uid()) 274 275 new_config = config.replace(model_dir=ANOTHER_TEST_DIR) 276 self.assertEqual(TEST_DIR, config.model_dir) 277 self.assertNotEqual(expected_uid, new_config.uid()) 278 self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir) 279 280 def test_uid_for_whitelist(self): 281 whitelist = ["model_dir"] 282 config = run_config_lib.RunConfig( 283 tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR) 284 285 expected_uid = config.uid(whitelist) 286 self.assertEqual(expected_uid, config.uid(whitelist)) 287 288 new_config = config.replace(model_dir=ANOTHER_TEST_DIR) 289 self.assertEqual(TEST_DIR, config.model_dir) 290 self.assertEqual(expected_uid, new_config.uid(whitelist)) 291 self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir) 292 293 def test_uid_for_default_whitelist(self): 294 config = run_config_lib.RunConfig( 295 tf_random_seed=11, 296 save_summary_steps=12, 297 save_checkpoints_steps=13, 298 save_checkpoints_secs=14, 299 session_config=config_pb2.ConfigProto(allow_soft_placement=True), 300 keep_checkpoint_max=16, 301 keep_checkpoint_every_n_hours=17) 302 self.assertEqual(11, config.tf_random_seed) 303 self.assertEqual(12, config.save_summary_steps) 304 self.assertEqual(13, config.save_checkpoints_steps) 305 self.assertEqual(14, config.save_checkpoints_secs) 306 self.assertEqual(config_pb2.ConfigProto(allow_soft_placement=True), 307 config.session_config) 308 self.assertEqual(16, config.keep_checkpoint_max) 309 self.assertEqual(17, config.keep_checkpoint_every_n_hours) 310 311 new_config = run_config_lib.RunConfig( 312 tf_random_seed=21, 313 save_summary_steps=22, 314 save_checkpoints_steps=23, 315 save_checkpoints_secs=24, 316 session_config=config_pb2.ConfigProto(allow_soft_placement=False), 317 keep_checkpoint_max=26, 318 keep_checkpoint_every_n_hours=27) 319 self.assertEqual(config.uid(), new_config.uid()) 320 # model_dir is not on the default whitelist. 321 self.assertNotEqual(config.uid(whitelist=[]), 322 new_config.uid(whitelist=[])) 323 new_config = new_config.replace(model_dir=ANOTHER_TEST_DIR) 324 self.assertNotEqual(config.uid(), new_config.uid()) 325 326 def test_uid_for_deepcopy(self): 327 tf_config = { 328 "cluster": { 329 run_config_lib.TaskType.PS: ["host1:1", "host2:2"], 330 run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"] 331 }, 332 "task": { 333 "type": run_config_lib.TaskType.WORKER, 334 "index": 1 335 } 336 } 337 338 config = _create_run_config_with_cluster_spec(json.dumps(tf_config)) 339 expected_uid = config.uid() 340 self.assertEqual(tf_config["cluster"], config.cluster_spec.as_dict()) 341 342 new_config = copy.deepcopy(config) 343 self.assertEqual(tf_config["cluster"], new_config.cluster_spec.as_dict()) 344 self.assertEqual(expected_uid, new_config.uid()) 345 346 def test_uid_for_different_cluster_spec_order(self): 347 tf_config_1_str = ( 348 "{\"cluster\": {\"ps\": [\"host1:1\", \"host2:2\"], " 349 "\"worker\": [\"host3:3\", \"host4:4\", \"host5:5\"]}}") 350 351 tf_config_2_str = ( 352 "{\"cluster\": {\"worker\": [\"host3:3\", \"host4:4\", \"host5:5\"]," 353 "\"ps\": [\"host1:1\", \"host2:2\"]}}") 354 355 # Wraps in a loop to check flakiness. 356 for _ in range(100): 357 uid_1 = _create_run_config_with_cluster_spec(tf_config_1_str).uid() 358 uid_2 = _create_run_config_with_cluster_spec(tf_config_2_str).uid() 359 self.assertEqual(uid_1, uid_2) 360 361 def test_uid_for_different_cluster_specs(self): 362 tf_config_1 = { 363 "cluster": { 364 run_config_lib.TaskType.PS: ["host1:1", "host2:2"], 365 run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"] 366 }, 367 } 368 369 tf_config_2 = { 370 "cluster": { 371 run_config_lib.TaskType.PS: ["host1:1"], 372 run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"] 373 }, 374 } 375 376 uid_1 = _create_run_config_with_cluster_spec(json.dumps(tf_config_1)).uid() 377 uid_2 = _create_run_config_with_cluster_spec(json.dumps(tf_config_2)).uid() 378 self.assertNotEqual(uid_1, uid_2) 379 380 def test_num_worker_replicas_counts_in_master_too(self): 381 tf_config = { 382 "cluster": { 383 run_config_lib.TaskType.PS: ["host1:1", "host2:2"], 384 run_config_lib.TaskType.MASTER: ["host6:6"], 385 run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"], 386 }, 387 "task": { 388 "type": run_config_lib.TaskType.WORKER, 389 "index": 1 390 } 391 } 392 393 config = _create_run_config_with_cluster_spec(json.dumps(tf_config)) 394 self.assertEqual(config.num_worker_replicas, 4) 395 396 397 if __name__ == "__main__": 398 test.main() 399