Home | History | Annotate | Download | only in estimators
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """run_config.py tests."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import copy
     22 import json
     23 
     24 from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
     25 from tensorflow.core.protobuf import config_pb2
     26 from tensorflow.python.estimator import run_config as core_run_config
     27 from tensorflow.python.platform import test
     28 from tensorflow.python.training import server_lib
     29 
     30 TEST_DIR = "test_dir"
     31 ANOTHER_TEST_DIR = "another_test_dir"
     32 MASTER = "master_"
     33 RANDOM_SEED = 123
     34 
     35 patch = test.mock.patch
     36 
     37 
     38 def _create_run_config_with_cluster_spec(tf_config_str):
     39   with patch.dict("os.environ", {"TF_CONFIG": tf_config_str}):
     40     return run_config_lib.RunConfig(
     41         tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR)
     42 
     43 
     44 class RunConfigTest(test.TestCase):
     45 
     46   def test_instance_of_core_run_config(self):
     47     config = run_config_lib.RunConfig()
     48     self.assertTrue(isinstance(config, core_run_config.RunConfig))
     49 
     50   def test_defaults_with_no_tf_config(self):
     51     config = run_config_lib.RunConfig()
     52     self.assertEqual(config.master, "")
     53     self.assertEqual(config.task_id, 0)
     54     self.assertEqual(config.num_ps_replicas, 0)
     55     self.assertEqual(config.cluster_spec, {})
     56     self.assertIsNone(config.task_type)
     57     self.assertTrue(config.is_chief)
     58     self.assertEqual(config.evaluation_master, "")
     59 
     60   def test_values_from_tf_config(self):
     61     tf_config = {
     62         "cluster": {
     63             run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
     64             run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"]
     65         },
     66         "task": {
     67             "type": run_config_lib.TaskType.WORKER,
     68             "index": 1
     69         }
     70     }
     71     with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
     72       config = run_config_lib.RunConfig()
     73 
     74     self.assertEqual(config.master, "grpc://host4:4")
     75     self.assertEqual(config.task_id, 1)
     76     self.assertEqual(config.num_ps_replicas, 2)
     77     self.assertEqual(config.num_worker_replicas, 3)
     78     self.assertEqual(config.cluster_spec.as_dict(), tf_config["cluster"])
     79     self.assertEqual(config.task_type, run_config_lib.TaskType.WORKER)
     80     self.assertFalse(config.is_chief)
     81     self.assertEqual(config.evaluation_master, "")
     82 
     83   def test_explicitly_specified_values(self):
     84     cluster_spec = {
     85         run_config_lib.TaskType.PS: ["localhost:9990"],
     86         "my_job_name": ["localhost:9991", "localhost:9992", "localhost:0"]
     87     }
     88     tf_config = {
     89         "cluster": cluster_spec,
     90         "task": {
     91             "type": run_config_lib.TaskType.WORKER,
     92             "index": 2
     93         }
     94     }
     95     with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
     96       config = run_config_lib.RunConfig(
     97           master="localhost:0", evaluation_master="localhost:9991")
     98 
     99     self.assertEqual(config.master, "localhost:0")
    100     self.assertEqual(config.task_id, 2)
    101     self.assertEqual(config.num_ps_replicas, 1)
    102     self.assertEqual(config.num_worker_replicas, 0)
    103     self.assertEqual(config.cluster_spec, server_lib.ClusterSpec(cluster_spec))
    104     self.assertEqual(config.task_type, run_config_lib.TaskType.WORKER)
    105     self.assertFalse(config.is_chief)
    106     self.assertEqual(config.evaluation_master, "localhost:9991")
    107 
    108   def test_single_node_in_cluster_spec_produces_empty_master(self):
    109     tf_config = {"cluster": {run_config_lib.TaskType.WORKER: ["host1:1"]}}
    110     with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
    111       config = run_config_lib.RunConfig()
    112       self.assertEqual(config.master, "")
    113 
    114   def test_no_task_type_produces_empty_master(self):
    115     tf_config = {
    116         "cluster": {
    117             run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
    118             run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"]
    119         },
    120         # Omits "task": {"type": "worker}
    121     }
    122     with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
    123       config = run_config_lib.RunConfig()
    124       self.assertEqual(config.master, "")
    125 
    126   def test_invalid_job_name_raises(self):
    127     tf_config = {
    128         "cluster": {
    129             run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
    130             run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"]
    131         },
    132         "task": {
    133             "type": "not_in_cluster_spec"
    134         }
    135     }
    136     expected_msg_regexp = "not_in_cluster_spec is not a valid task"
    137     with patch.dict(
    138         "os.environ",
    139         {"TF_CONFIG": json.dumps(tf_config)}), self.assertRaisesRegexp(
    140             ValueError, expected_msg_regexp):
    141       run_config_lib.RunConfig()
    142 
    143   def test_illegal_task_index_raises(self):
    144     tf_config = {
    145         "cluster": {
    146             run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
    147             run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"]
    148         },
    149         "task": {
    150             "type": run_config_lib.TaskType.WORKER,
    151             "index": 3
    152         }
    153     }
    154     expected_msg_regexp = "3 is not a valid task_id"
    155     with patch.dict(
    156         "os.environ",
    157         {"TF_CONFIG": json.dumps(tf_config)}), self.assertRaisesRegexp(
    158             ValueError, expected_msg_regexp):
    159       run_config_lib.RunConfig()
    160 
    161   def test_is_chief_from_cloud_tf_config(self):
    162     # is_chief should be true when ["task"]["type"] == "master" and
    163     # index == 0 and ["task"]["environment"] == "cloud". Note that
    164     # test_values_from_tf_config covers the non-master case.
    165     tf_config = {
    166         "cluster": {
    167             run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
    168             run_config_lib.TaskType.MASTER: ["host3:3"],
    169             run_config_lib.TaskType.WORKER: ["host4:4", "host5:5", "host6:6"]
    170         },
    171         "task": {
    172             "type": run_config_lib.TaskType.MASTER,
    173             "index": 0
    174         },
    175         "environment": "cloud"
    176     }
    177     with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
    178       config = run_config_lib.RunConfig()
    179 
    180     self.assertTrue(config.is_chief)
    181 
    182   def test_is_chief_from_noncloud_tf_config(self):
    183     # is_chief should be true when ["task"]["type"] == "worker" and
    184     # index == 0 if ["task"]["environment"] != "cloud".
    185     tf_config = {
    186         "cluster": {
    187             run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
    188             run_config_lib.TaskType.MASTER: ["host3:3"],
    189             run_config_lib.TaskType.WORKER: ["host4:4", "host5:5", "host6:6"]
    190         },
    191         "task": {
    192             "type": run_config_lib.TaskType.WORKER,
    193             "index": 0
    194         },
    195         "environment": "random"
    196     }
    197     with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
    198       config = run_config_lib.RunConfig()
    199 
    200     self.assertTrue(config.is_chief)
    201 
    202     # But task 0 for a job named "master" should not be.
    203     tf_config = {
    204         "cluster": {
    205             run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
    206             run_config_lib.TaskType.MASTER: ["host3:3"],
    207             run_config_lib.TaskType.WORKER: ["host4:4", "host5:5", "host6:6"]
    208         },
    209         "task": {
    210             "type": run_config_lib.TaskType.MASTER,
    211             "index": 0
    212         },
    213         "environment": "random"
    214     }
    215     with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
    216       config = run_config_lib.RunConfig()
    217 
    218     self.assertFalse(config.is_chief)
    219 
    220   def test_default_is_chief_from_tf_config_without_job_name(self):
    221     tf_config = {"cluster": {}, "task": {}}
    222     with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
    223       config = run_config_lib.RunConfig()
    224 
    225     self.assertTrue(config.is_chief)
    226 
    227   def test_model_dir(self):
    228     empty_config = run_config_lib.RunConfig()
    229     self.assertIsNone(empty_config.model_dir)
    230 
    231     config = run_config_lib.RunConfig(model_dir=TEST_DIR)
    232     self.assertEqual(TEST_DIR, config.model_dir)
    233 
    234   def test_model_dir_in_tf_config(self):
    235     tf_config = {"model_dir": TEST_DIR}
    236     with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
    237       run_config = run_config_lib.RunConfig()
    238     self.assertEqual(TEST_DIR, run_config.model_dir)
    239 
    240   def test_model_dir_both_in_tf_config_and_constructor(self):
    241     tf_config = {"model_dir": TEST_DIR}
    242     with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
    243       run_config = run_config_lib.RunConfig(model_dir=TEST_DIR)
    244     self.assertEqual(TEST_DIR, run_config.model_dir)
    245 
    246   def test_model_dir_fail_if_constructor_value_mismatch_tf_config(self):
    247     tf_config = {"model_dir": TEST_DIR}
    248     with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
    249       with self.assertRaisesRegexp(
    250           ValueError,
    251           "`model_dir` provided in RunConfig .* must have "
    252           "the same value .* in TF_CONFIG"):
    253         run_config_lib.RunConfig(model_dir=TEST_DIR + "/sub_dir")
    254 
    255   def test_replace(self):
    256     config = run_config_lib.RunConfig(
    257         tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR)
    258     self.assertEqual(TEST_DIR, config.model_dir)
    259     self.assertEqual(RANDOM_SEED, config.tf_random_seed)
    260 
    261     new_config = config.replace(model_dir=ANOTHER_TEST_DIR)
    262     self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir)
    263     self.assertEqual(RANDOM_SEED, new_config.tf_random_seed)
    264     self.assertEqual(RANDOM_SEED, config.tf_random_seed)
    265 
    266   def test_uid_for_different_configs(self):
    267     config = run_config_lib.RunConfig(
    268         tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR)
    269 
    270     expected_uid = config.uid()
    271     # Check for 10 times, which should prove something.
    272     for _ in range(10):
    273       self.assertEqual(expected_uid, config.uid())
    274 
    275     new_config = config.replace(model_dir=ANOTHER_TEST_DIR)
    276     self.assertEqual(TEST_DIR, config.model_dir)
    277     self.assertNotEqual(expected_uid, new_config.uid())
    278     self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir)
    279 
    280   def test_uid_for_whitelist(self):
    281     whitelist = ["model_dir"]
    282     config = run_config_lib.RunConfig(
    283         tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR)
    284 
    285     expected_uid = config.uid(whitelist)
    286     self.assertEqual(expected_uid, config.uid(whitelist))
    287 
    288     new_config = config.replace(model_dir=ANOTHER_TEST_DIR)
    289     self.assertEqual(TEST_DIR, config.model_dir)
    290     self.assertEqual(expected_uid, new_config.uid(whitelist))
    291     self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir)
    292 
    293   def test_uid_for_default_whitelist(self):
    294     config = run_config_lib.RunConfig(
    295         tf_random_seed=11,
    296         save_summary_steps=12,
    297         save_checkpoints_steps=13,
    298         save_checkpoints_secs=14,
    299         session_config=config_pb2.ConfigProto(allow_soft_placement=True),
    300         keep_checkpoint_max=16,
    301         keep_checkpoint_every_n_hours=17)
    302     self.assertEqual(11, config.tf_random_seed)
    303     self.assertEqual(12, config.save_summary_steps)
    304     self.assertEqual(13, config.save_checkpoints_steps)
    305     self.assertEqual(14, config.save_checkpoints_secs)
    306     self.assertEqual(config_pb2.ConfigProto(allow_soft_placement=True),
    307                      config.session_config)
    308     self.assertEqual(16, config.keep_checkpoint_max)
    309     self.assertEqual(17, config.keep_checkpoint_every_n_hours)
    310 
    311     new_config = run_config_lib.RunConfig(
    312         tf_random_seed=21,
    313         save_summary_steps=22,
    314         save_checkpoints_steps=23,
    315         save_checkpoints_secs=24,
    316         session_config=config_pb2.ConfigProto(allow_soft_placement=False),
    317         keep_checkpoint_max=26,
    318         keep_checkpoint_every_n_hours=27)
    319     self.assertEqual(config.uid(), new_config.uid())
    320     # model_dir is not on the default whitelist.
    321     self.assertNotEqual(config.uid(whitelist=[]),
    322                         new_config.uid(whitelist=[]))
    323     new_config = new_config.replace(model_dir=ANOTHER_TEST_DIR)
    324     self.assertNotEqual(config.uid(), new_config.uid())
    325 
    326   def test_uid_for_deepcopy(self):
    327     tf_config = {
    328         "cluster": {
    329             run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
    330             run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"]
    331         },
    332         "task": {
    333             "type": run_config_lib.TaskType.WORKER,
    334             "index": 1
    335         }
    336     }
    337 
    338     config = _create_run_config_with_cluster_spec(json.dumps(tf_config))
    339     expected_uid = config.uid()
    340     self.assertEqual(tf_config["cluster"], config.cluster_spec.as_dict())
    341 
    342     new_config = copy.deepcopy(config)
    343     self.assertEqual(tf_config["cluster"], new_config.cluster_spec.as_dict())
    344     self.assertEqual(expected_uid, new_config.uid())
    345 
    346   def test_uid_for_different_cluster_spec_order(self):
    347     tf_config_1_str = (
    348         "{\"cluster\": {\"ps\": [\"host1:1\", \"host2:2\"], "
    349         "\"worker\": [\"host3:3\", \"host4:4\", \"host5:5\"]}}")
    350 
    351     tf_config_2_str = (
    352         "{\"cluster\": {\"worker\": [\"host3:3\", \"host4:4\", \"host5:5\"],"
    353         "\"ps\": [\"host1:1\", \"host2:2\"]}}")
    354 
    355     # Wraps in a loop to check flakiness.
    356     for _ in range(100):
    357       uid_1 = _create_run_config_with_cluster_spec(tf_config_1_str).uid()
    358       uid_2 = _create_run_config_with_cluster_spec(tf_config_2_str).uid()
    359       self.assertEqual(uid_1, uid_2)
    360 
    361   def test_uid_for_different_cluster_specs(self):
    362     tf_config_1 = {
    363         "cluster": {
    364             run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
    365             run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"]
    366         },
    367     }
    368 
    369     tf_config_2 = {
    370         "cluster": {
    371             run_config_lib.TaskType.PS: ["host1:1"],
    372             run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"]
    373         },
    374     }
    375 
    376     uid_1 = _create_run_config_with_cluster_spec(json.dumps(tf_config_1)).uid()
    377     uid_2 = _create_run_config_with_cluster_spec(json.dumps(tf_config_2)).uid()
    378     self.assertNotEqual(uid_1, uid_2)
    379 
    380   def test_num_worker_replicas_counts_in_master_too(self):
    381     tf_config = {
    382         "cluster": {
    383             run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
    384             run_config_lib.TaskType.MASTER: ["host6:6"],
    385             run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"],
    386         },
    387         "task": {
    388             "type": run_config_lib.TaskType.WORKER,
    389             "index": 1
    390         }
    391     }
    392 
    393     config = _create_run_config_with_cluster_spec(json.dumps(tf_config))
    394     self.assertEqual(config.num_worker_replicas, 4)
    395 
    396 
    397 if __name__ == "__main__":
    398   test.main()
    399