1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 """Tests for TaskRunner and Experiment class.""" 15 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import json 21 import os 22 import tempfile 23 import time 24 25 from tensorflow.contrib.layers.python.layers import feature_column 26 from tensorflow.contrib.learn.python.learn import estimator as estimator_lib 27 from tensorflow.contrib.learn.python.learn import evaluable 28 from tensorflow.contrib.learn.python.learn import experiment 29 from tensorflow.contrib.learn.python.learn import run_config 30 from tensorflow.contrib.learn.python.learn import trainable 31 from tensorflow.contrib.learn.python.learn.estimators import dnn 32 from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib 33 from tensorflow.contrib.learn.python.learn.estimators import test_data 34 from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils 35 from tensorflow.core.protobuf import config_pb2 36 from tensorflow.python.client import session 37 from tensorflow.python.estimator import estimator as core_estimator 38 from tensorflow.python.ops import variables 39 from tensorflow.python.platform import test 40 from tensorflow.python.platform import tf_logging 41 from tensorflow.python.training import saver 42 from tensorflow.python.training import server_lib 43 from tensorflow.python.training import session_run_hook 44 from tensorflow.python.util import compat 45 from tensorflow.python.util import tf_inspect 46 47 48 class SheepCounter(object): 49 """To be patched in for the time module, replacing sleep() and time().""" 50 51 def __init__(self): 52 self._total_time = 0 53 self._sleeptimes = [] 54 self._time_calls = 0 55 56 def sleep(self, t): 57 self._total_time += t 58 self._sleeptimes += [t] 59 60 def time(self): 61 self._time_calls += 1 62 return self._total_time 63 64 @property 65 def sleep_times(self): 66 return self._sleeptimes 67 68 @property 69 def time_calls(self): 70 return self._time_calls 71 72 73 class TestBaseEstimator(object): 74 75 def __init__(self, config, max_evals, eval_dict): 76 self.eval_count = 0 77 self.fit_count = 0 78 self._max_evals = max_evals 79 self.export_count = 0 80 self.monitors = [] 81 self.eval_hooks = [] 82 self._config = config or run_config.RunConfig() 83 self._model_dir = tempfile.mkdtemp() 84 self._eval_dict = eval_dict 85 86 @property 87 def model_dir(self): 88 return self._model_dir 89 90 @property 91 def config(self): 92 return self._config 93 94 def evaluate(self, **kwargs): 95 tf_logging.info('evaluate called with args: %s' % kwargs) 96 if 'hooks' in kwargs: 97 self.eval_hooks = kwargs['hooks'] 98 self.eval_count += 1 99 if self.eval_count > self._max_evals: 100 tf_logging.info('Ran %d evals. Done.' % self.eval_count) 101 raise StopIteration() 102 return self._eval_dict 103 104 def fake_checkpoint(self): 105 save_path = os.path.join(self.model_dir, 'model.ckpt') 106 with session.Session() as sess: 107 var = variables.Variable(1.0, name='var0') 108 save = saver.Saver({var.op.name: var}) 109 var.initializer.run() 110 save.save(sess, save_path, global_step=0) 111 112 def train(self, **kwargs): 113 self.fake_checkpoint() 114 tf_logging.info('fit called with args: %s' % kwargs) 115 self.fit_count += 1 116 117 return [(key, kwargs[key]) for key in sorted(kwargs.keys())] 118 119 def export_savedmodel(self, export_dir_base, serving_input_fn, **kwargs): 120 tf_logging.info('export_savedmodel called with args: %s, %s, %s' % 121 (export_dir_base, serving_input_fn, kwargs)) 122 self.export_count += 1 123 return os.path.join( 124 compat.as_bytes(export_dir_base), compat.as_bytes('bogus_timestamp')) 125 126 127 def _check_method_supports_args(method, kwargs): 128 """Checks that the given method supports the given args.""" 129 supported_args = tuple(tf_inspect.getargspec(method).args) 130 for kwarg in kwargs: 131 if kwarg not in supported_args: 132 raise ValueError( 133 'Argument `{}` is not supported in method {}.'.format(kwarg, method)) 134 135 136 class TestEstimator( 137 TestBaseEstimator, evaluable.Evaluable, trainable.Trainable): 138 139 def __init__(self, config=None, max_evals=5, eval_dict=None): 140 super(TestEstimator, self).__init__(config, max_evals, eval_dict) 141 tf_logging.info('Create Estimator') 142 143 def evaluate(self, **kwargs): 144 _check_method_supports_args(evaluable.Evaluable.evaluate, kwargs) 145 return super(TestEstimator, self).evaluate(**kwargs) 146 147 def fit(self, **kwargs): 148 _check_method_supports_args(trainable.Trainable.fit, kwargs) 149 if 'monitors' in kwargs: 150 self.monitors = kwargs['monitors'] 151 return super(TestEstimator, self).train(**kwargs) 152 153 def train(self, **kwargs): 154 raise ValueError('`train` is not defined in Estimator.') 155 156 def export_savedmodel( 157 self, export_dir_base, serving_input_fn, **kwargs): 158 _check_method_supports_args( 159 estimator_lib.Estimator.export_savedmodel, kwargs) 160 return super(TestEstimator, self).export_savedmodel( 161 export_dir_base, serving_input_fn, **kwargs) 162 163 164 class TestCoreEstimator(TestBaseEstimator, core_estimator.Estimator): 165 166 def __init__(self, config=None, max_evals=5, eval_dict=None): 167 super(TestCoreEstimator, self).__init__(config, max_evals, eval_dict) 168 tf_logging.info('Create Core Estimator') 169 170 def evaluate(self, **kwargs): 171 _check_method_supports_args(core_estimator.Estimator.evaluate, kwargs) 172 return super(TestCoreEstimator, self).evaluate(**kwargs) 173 174 def train(self, **kwargs): 175 _check_method_supports_args(core_estimator.Estimator.train, kwargs) 176 if 'hooks' in kwargs: 177 self.monitors = kwargs['hooks'] 178 return super(TestCoreEstimator, self).train(**kwargs) 179 180 def export_savedmodel( 181 self, export_dir_base, serving_input_receiver_fn, **kwargs): 182 _check_method_supports_args( 183 core_estimator.Estimator.export_savedmodel, kwargs) 184 return super(TestCoreEstimator, self).export_savedmodel( 185 export_dir_base, serving_input_receiver_fn, **kwargs) 186 187 188 class _NoopHook(session_run_hook.SessionRunHook): 189 pass 190 191 192 class ExperimentTest(test.TestCase): 193 194 def _cluster_spec(self): 195 return { 196 run_config_lib.TaskType.PS: ['host1:2222', 'host2:2222'], 197 run_config_lib.TaskType.WORKER: 198 ['host3:2222', 'host4:2222', 'host5:2222'] 199 } 200 201 def _estimators_for_tests(self, config=None, eval_dict=None): 202 return [TestEstimator(config=config, eval_dict=eval_dict), 203 TestCoreEstimator(config=config, eval_dict=eval_dict)] 204 205 def test_eval_metrcis_for_core_estimator(self): 206 est = TestCoreEstimator() 207 with self.assertRaisesRegexp( 208 ValueError, '`eval_metrics` must be `None`'): 209 experiment.Experiment( 210 est, 211 train_input_fn='train_input', 212 train_steps='train_steps', 213 eval_input_fn='eval_input', 214 eval_metrics='eval_metrics') 215 216 def test_default_output_alternative_key_core_estimator(self): 217 est = TestCoreEstimator() 218 export_strategy = saved_model_export_utils.make_export_strategy( 219 est, 220 default_output_alternative_key='export_key', 221 exports_to_keep=None) 222 ex = experiment.Experiment( 223 est, 224 train_input_fn='train_input', 225 eval_input_fn='eval_input', 226 train_steps=100, 227 eval_steps=100, 228 export_strategies=export_strategy) 229 with self.assertRaisesRegexp( 230 ValueError, 'default_output_alternative_key is not supported'): 231 ex.train_and_evaluate() 232 233 def test_train(self): 234 for est in self._estimators_for_tests(): 235 if isinstance(est, core_estimator.Estimator): 236 eval_metrics = None 237 saving_listeners = 'saving_listeners' 238 else: 239 eval_metrics = 'eval_metrics' 240 saving_listeners = None 241 ex = experiment.Experiment( 242 est, 243 train_input_fn='train_input', 244 train_steps='train_steps', 245 eval_input_fn='eval_input', 246 eval_metrics=eval_metrics, 247 saving_listeners=saving_listeners) 248 fit_args = ex.train(delay_secs=0) 249 self.assertEqual(1, est.fit_count) 250 self.assertIn(('max_steps', 'train_steps'), fit_args) 251 self.assertEqual(0, est.eval_count) 252 253 def test_train_delay(self): 254 for est in self._estimators_for_tests(): 255 ex = experiment.Experiment( 256 est, train_input_fn='train_input', eval_input_fn='eval_input') 257 for delay in [0, 1, 3]: 258 sheep = SheepCounter() 259 with test.mock.patch.object(time, 'time', sheep.time): 260 with test.mock.patch.object(time, 'sleep', sheep.sleep): 261 ex.train(delay_secs=delay) 262 self.assertAlmostEqual(delay, sheep.time(), delta=1e-4) 263 264 def test_train_default_delay(self): 265 for task_id in [0, 1, 3]: 266 tf_config = {'task': {'index': task_id}} 267 with test.mock.patch.dict('os.environ', 268 {'TF_CONFIG': json.dumps(tf_config)}): 269 config = run_config.RunConfig() 270 for est in self._estimators_for_tests(config): 271 ex = experiment.Experiment( 272 est, train_input_fn='train_input', eval_input_fn='eval_input') 273 274 sheep = SheepCounter() 275 with test.mock.patch.object(time, 'time', sheep.time): 276 with test.mock.patch.object(time, 'sleep', sheep.sleep): 277 ex.train() 278 self.assertAlmostEqual(task_id * 5, sheep.time(), delta=1e-4) 279 280 @test.mock.patch.object(server_lib, 'Server') 281 def test_train_starts_server(self, mock_server): 282 # Arrange. 283 tf_config = { 284 'cluster': self._cluster_spec(), 285 'environment': run_config_lib.Environment.CLOUD, 286 'task': { 287 'type': run_config_lib.TaskType.WORKER, 288 'index': 1 289 } 290 } 291 with test.mock.patch.dict('os.environ', 292 {'TF_CONFIG': json.dumps(tf_config)}): 293 config = run_config_lib.RunConfig( 294 master='host4:2222', num_cores=15, gpu_memory_fraction=0.314) 295 296 for est in self._estimators_for_tests(config): 297 ex = experiment.Experiment( 298 est, train_input_fn='train_input', eval_input_fn='eval_input') 299 300 # Act. 301 # We want to make sure we discount the time it takes to start the server 302 # in our accounting of the delay, so we set a small delay here. 303 sheep = SheepCounter() 304 with test.mock.patch.object(time, 'time', sheep.time): 305 with test.mock.patch.object(time, 'sleep', sheep.sleep): 306 ex.train(delay_secs=1) 307 # Ensure that the delay takes into account the time to start server. 308 self.assertAlmostEqual(1, sheep.time(), delta=1e-4) 309 310 # Assert. 311 expected_config_proto = config_pb2.ConfigProto() 312 expected_config_proto.inter_op_parallelism_threads = 15 313 expected_config_proto.intra_op_parallelism_threads = 15 314 expected_config_proto.gpu_options.per_process_gpu_memory_fraction = 0.314 315 mock_server.assert_called_with( 316 config.cluster_spec, 317 job_name=run_config_lib.TaskType.WORKER, 318 task_index=1, 319 config=expected_config_proto, 320 start=False) 321 mock_server.assert_has_calls([test.mock.call().start()]) 322 323 @test.mock.patch.object(server_lib, 'Server') 324 def test_train_server_does_not_start_without_cluster_spec(self, mock_server): 325 config = run_config_lib.RunConfig(master='host4:2222') 326 for est in self._estimators_for_tests(config): 327 ex = experiment.Experiment( 328 est, 329 train_input_fn='train_input', 330 eval_input_fn='eval_input') 331 ex.train() 332 333 # The server should not have started because there was no ClusterSpec. 334 self.assertFalse(mock_server.called) 335 336 @test.mock.patch.object(server_lib, 'Server') 337 def test_train_server_does_not_start_with_empty_master(self, mock_server): 338 tf_config = {'cluster': self._cluster_spec()} 339 with test.mock.patch.dict('os.environ', 340 {'TF_CONFIG': json.dumps(tf_config)}): 341 config = run_config_lib.RunConfig(master='') 342 for est in self._estimators_for_tests(config): 343 ex = experiment.Experiment( 344 est, 345 train_input_fn='train_input', 346 eval_input_fn='eval_input') 347 ex.train() 348 # The server should not have started because master was the empty string. 349 self.assertFalse(mock_server.called) 350 351 def test_train_raises_if_job_name_is_missing(self): 352 tf_config = { 353 'cluster': self._cluster_spec(), 354 'environment': run_config_lib.Environment.CLOUD, 355 'task': { 356 'index': 1 357 } 358 } 359 with test.mock.patch.dict( 360 'os.environ', 361 {'TF_CONFIG': json.dumps(tf_config)}), self.assertRaises(ValueError): 362 config = run_config_lib.RunConfig( 363 master='host3:2222' # Normally selected by task type. 364 ) 365 for est in self._estimators_for_tests(config): 366 ex = experiment.Experiment( 367 est, 368 train_input_fn='train_input', 369 eval_input_fn='eval_input') 370 ex.train() 371 372 def test_evaluate(self): 373 for est in self._estimators_for_tests(): 374 eval_metrics = 'eval_metrics' if not isinstance( 375 est, core_estimator.Estimator) else None 376 est.fake_checkpoint() 377 noop_hook = _NoopHook() 378 ex = experiment.Experiment( 379 est, 380 train_input_fn='train_input', 381 eval_input_fn='eval_input', 382 eval_metrics=eval_metrics, 383 eval_hooks=[noop_hook], 384 eval_steps='steps', 385 eval_delay_secs=0) 386 ex.evaluate() 387 self.assertEqual(0, est.fit_count) 388 self.assertEqual(1, est.eval_count) 389 self.assertEqual([noop_hook], est.eval_hooks) 390 391 def test_evaluate_delay(self): 392 for est in self._estimators_for_tests(): 393 est.fake_checkpoint() 394 noop_hook = _NoopHook() 395 ex = experiment.Experiment( 396 est, train_input_fn='train_input', eval_input_fn='eval_input', 397 eval_hooks=[noop_hook]) 398 399 for delay in [0, 1, 3]: 400 sheep = SheepCounter() 401 with test.mock.patch.object(time, 'time', sheep.time): 402 with test.mock.patch.object(time, 'sleep', sheep.sleep): 403 ex.evaluate(delay_secs=delay) 404 self.assertAlmostEqual(delay, sheep.time(), delta=1e-4) 405 self.assertEqual([noop_hook], est.eval_hooks) 406 407 def test_continuous_eval(self): 408 for est in self._estimators_for_tests(eval_dict={'global_step': 100}): 409 eval_metrics = 'eval_metrics' if not isinstance( 410 est, core_estimator.Estimator) else None 411 est.fake_checkpoint() 412 noop_hook = _NoopHook() 413 ex = experiment.Experiment( 414 est, 415 train_input_fn='train_input', 416 eval_input_fn='eval_input', 417 eval_metrics=eval_metrics, 418 eval_hooks=[noop_hook], 419 eval_delay_secs=0, 420 continuous_eval_throttle_secs=0) 421 self.assertRaises(StopIteration, ex.continuous_eval, 422 evaluate_checkpoint_only_once=False) 423 self.assertEqual(0, est.fit_count) 424 self.assertEqual(6, est.eval_count) 425 self.assertEqual([noop_hook], est.eval_hooks) 426 427 def test_continuous_eval_ends_after_train_step(self): 428 for est in self._estimators_for_tests(eval_dict={'global_step': 100}): 429 eval_metrics = 'eval_metrics' if not isinstance( 430 est, core_estimator.Estimator) else None 431 est.fake_checkpoint() 432 noop_hook = _NoopHook() 433 ex = experiment.Experiment( 434 est, 435 train_input_fn='train_input', 436 eval_input_fn='eval_input', 437 eval_metrics=eval_metrics, 438 eval_hooks=[noop_hook], 439 eval_delay_secs=0, 440 continuous_eval_throttle_secs=0, 441 train_steps=100) 442 ex.continuous_eval() 443 self.assertEqual(0, est.fit_count) 444 self.assertEqual(1, est.eval_count) 445 self.assertEqual([noop_hook], est.eval_hooks) 446 447 def test_continuous_eval_throttle_delay(self): 448 for delay in [0, 1, 2]: 449 for est in self._estimators_for_tests(): 450 eval_metrics = 'eval_metrics' if not isinstance( 451 est, core_estimator.Estimator) else None 452 est.fake_checkpoint() 453 noop_hook = _NoopHook() 454 ex = experiment.Experiment( 455 est, 456 train_input_fn='train_input', 457 eval_input_fn='eval_input', 458 eval_metrics=eval_metrics, 459 eval_hooks=[noop_hook], 460 continuous_eval_throttle_secs=delay, 461 eval_delay_secs=0) 462 sheep = SheepCounter() 463 with test.mock.patch.object(time, 'time', sheep.time): 464 with test.mock.patch.object(time, 'sleep', sheep.sleep): 465 self.assertRaises( 466 StopIteration, 467 ex.continuous_eval, 468 evaluate_checkpoint_only_once=False) 469 self.assertAlmostEqual(5 * delay, sheep.time(), delta=1e-4) 470 471 def test_continuous_eval_predicate_fn(self): 472 for est in self._estimators_for_tests(): 473 eval_metrics = 'eval_metrics' if not isinstance( 474 est, core_estimator.Estimator) else None 475 est.fake_checkpoint() 476 noop_hook = _NoopHook() 477 478 def _predicate_fn(unused_eval_result): 479 return est.eval_count < 3 # pylint: disable=cell-var-from-loop 480 481 ex = experiment.Experiment( 482 est, 483 train_input_fn='train_input', 484 eval_input_fn='eval_input', 485 eval_metrics=eval_metrics, 486 eval_hooks=[noop_hook], 487 eval_delay_secs=0, 488 continuous_eval_throttle_secs=0) 489 ex.continuous_eval(evaluate_checkpoint_only_once=False, 490 continuous_eval_predicate_fn=_predicate_fn) 491 self.assertEqual(0, est.fit_count) 492 self.assertEqual(3, est.eval_count) 493 self.assertEqual([noop_hook], est.eval_hooks) 494 495 def test_continuous_eval_predicate_fn_with_checkpoint(self): 496 for est in self._estimators_for_tests(): 497 eval_metrics = 'eval_metrics' if not isinstance( 498 est, core_estimator.Estimator) else None 499 est.fake_checkpoint() 500 noop_hook = _NoopHook() 501 502 def _predicate_fn(eval_result, checkpoint_path): 503 self.assertEqual(eval_result is None, 504 checkpoint_path is None) 505 return est.eval_count < 3 # pylint: disable=cell-var-from-loop 506 507 ex = experiment.Experiment( 508 est, 509 train_input_fn='train_input', 510 eval_input_fn='eval_input', 511 eval_metrics=eval_metrics, 512 eval_hooks=[noop_hook], 513 eval_delay_secs=0, 514 continuous_eval_throttle_secs=0) 515 ex.continuous_eval( 516 evaluate_checkpoint_only_once=False, 517 continuous_eval_predicate_fn=_predicate_fn) 518 self.assertEqual(0, est.fit_count) 519 self.assertEqual(3, est.eval_count) 520 self.assertEqual([noop_hook], est.eval_hooks) 521 522 def test_run_local(self): 523 for est in self._estimators_for_tests(): 524 eval_metrics = 'eval_metrics' if not isinstance( 525 est, core_estimator.Estimator) else None 526 noop_hook = _NoopHook() 527 ex = experiment.Experiment( 528 est, 529 train_input_fn='train_input', 530 eval_input_fn='eval_input', 531 eval_metrics=eval_metrics, 532 eval_hooks=[noop_hook], 533 train_steps=100, 534 eval_steps=100, 535 local_eval_frequency=10) 536 ex.local_run() 537 self.assertEqual(1, est.fit_count) 538 self.assertEqual(1, est.eval_count) 539 self.assertEqual(1, len(est.monitors)) 540 self.assertEqual([noop_hook], est.eval_hooks) 541 self.assertTrue(isinstance(est.monitors[0], 542 session_run_hook.SessionRunHook)) 543 544 def test_train_hooks_extend_does_not_mutate_input_hooks(self): 545 for est in self._estimators_for_tests(): 546 eval_metrics = 'eval_metrics' if not isinstance( 547 est, core_estimator.Estimator) else None 548 noop_hook = _NoopHook() 549 input_hooks = [noop_hook] 550 551 ex = experiment.Experiment( 552 est, 553 train_input_fn='train_input', 554 eval_input_fn='eval_input', 555 eval_metrics=eval_metrics, 556 train_monitors=input_hooks) 557 self.assertAllEqual([noop_hook], ex._train_monitors) 558 559 another_noop_hook = _NoopHook() 560 # Assert that the extend API mutates the hooks, but not the input hooks 561 ex.extend_train_hooks([another_noop_hook]) 562 self.assertAllEqual([noop_hook, another_noop_hook], ex._train_monitors) 563 self.assertAllEqual([noop_hook], input_hooks) 564 565 def test_invalid_export_strategies(self): 566 for est in self._estimators_for_tests(): 567 with self.assertRaisesRegexp(ValueError, 'ExportStrategy'): 568 experiment.Experiment( 569 est, 570 train_input_fn='train_input', 571 eval_input_fn='eval_input', 572 train_steps=100, 573 eval_steps=100, 574 export_strategies='not_an_export_strategy') 575 with self.assertRaisesRegexp(ValueError, 'ExportStrategy'): 576 experiment.Experiment( 577 est, 578 train_input_fn='train_input', 579 eval_input_fn='eval_input', 580 train_steps=100, 581 eval_steps=100, 582 export_strategies=['not_an_export_srategy']) 583 584 def test_export_strategies_reset(self): 585 for est in self._estimators_for_tests(): 586 eval_metrics = 'eval_metrics' if not isinstance( 587 est, core_estimator.Estimator) else None 588 export_strategy_1 = saved_model_export_utils.make_export_strategy( 589 est, 590 None if isinstance(est, core_estimator.Estimator) else 'export_1', 591 exports_to_keep=None) 592 593 ex = experiment.Experiment( 594 est, 595 train_input_fn='train_input', 596 eval_input_fn='eval_input', 597 eval_metrics=eval_metrics, 598 train_steps=100, 599 eval_steps=100, 600 export_strategies=(export_strategy_1,)) 601 ex.train_and_evaluate() 602 self.assertEqual(1, est.export_count) 603 604 # After reset with empty list (None), the count does not change and the 605 # user provided export strategy list should remain intact. 606 old_es = ex.reset_export_strategies() 607 ex.train_and_evaluate() 608 self.assertAllEqual([export_strategy_1], old_es) 609 self.assertEqual(1, est.export_count) 610 611 # After reset with list, the count should increase with the number of 612 # items. 613 export_strategy_2 = saved_model_export_utils.make_export_strategy( 614 est, 615 None if isinstance(est, core_estimator.Estimator) else 'export_2', 616 exports_to_keep=None) 617 export_strategy_3 = saved_model_export_utils.make_export_strategy( 618 est, 619 None if isinstance(est, core_estimator.Estimator) else 'export_3', 620 exports_to_keep=None) 621 622 old_es = ex.reset_export_strategies( 623 [export_strategy_2, export_strategy_3]) 624 ex.train_and_evaluate() 625 self.assertAllEqual([], old_es) 626 self.assertEqual(3, est.export_count) 627 628 def test_train_and_evaluate(self): 629 for est in self._estimators_for_tests(): 630 eval_metrics = 'eval_metrics' if not isinstance( 631 est, core_estimator.Estimator) else None 632 noop_hook = _NoopHook() 633 export_strategy = saved_model_export_utils.make_export_strategy( 634 est, 635 None if isinstance(est, core_estimator.Estimator) else 'export_input', 636 exports_to_keep=None) 637 ex = experiment.Experiment( 638 est, 639 train_input_fn='train_input', 640 eval_input_fn='eval_input', 641 eval_metrics=eval_metrics, 642 eval_hooks=[noop_hook], 643 train_steps=100, 644 eval_steps=100, 645 export_strategies=export_strategy) 646 ex.train_and_evaluate() 647 self.assertEqual(1, est.fit_count) 648 self.assertEqual(1, est.eval_count) 649 self.assertEqual(1, est.export_count) 650 self.assertEqual(1, len(est.monitors)) 651 self.assertEqual([noop_hook], est.eval_hooks) 652 self.assertTrue(isinstance(est.monitors[0], 653 session_run_hook.SessionRunHook)) 654 655 def test_train_and_evaluate_with_no_eval_during_training(self): 656 for est in self._estimators_for_tests(): 657 eval_metrics = 'eval_metrics' if not isinstance( 658 est, core_estimator.Estimator) else None 659 noop_hook = _NoopHook() 660 ex = experiment.Experiment( 661 est, 662 train_input_fn='train_input', 663 eval_input_fn='eval_input', 664 eval_metrics=eval_metrics, 665 eval_hooks=[noop_hook], 666 train_steps=100, 667 eval_steps=100, 668 min_eval_frequency=0) 669 ex.train_and_evaluate() 670 self.assertEqual(1, est.fit_count) 671 self.assertEqual(1, est.eval_count) 672 self.assertEqual(0, len(est.monitors)) 673 674 def test_min_eval_frequency_defaults(self): 675 def dummy_model_fn(features, labels): # pylint: disable=unused-argument 676 pass 677 estimator = core_estimator.Estimator(dummy_model_fn, '/tmp/dummy') 678 ex = experiment.Experiment( 679 estimator, train_input_fn=None, eval_input_fn=None) 680 self.assertEquals(ex._min_eval_frequency, 1) 681 682 def test_continuous_train_and_eval(self): 683 for est in self._estimators_for_tests(eval_dict={'global_step': 100}): 684 if isinstance(est, core_estimator.Estimator): 685 eval_metrics = None 686 saving_listeners = 'saving_listeners' 687 else: 688 eval_metrics = 'eval_metrics' 689 saving_listeners = None 690 noop_hook = _NoopHook() 691 export_strategy = saved_model_export_utils.make_export_strategy( 692 est, 693 None if isinstance(est, core_estimator.Estimator) else 'export_input', 694 exports_to_keep=None) 695 ex = experiment.Experiment( 696 est, 697 train_input_fn='train_input', 698 eval_input_fn='eval_input', 699 eval_metrics=eval_metrics, 700 eval_hooks=[noop_hook], 701 train_steps=100, 702 eval_steps=100, 703 export_strategies=export_strategy, 704 saving_listeners=saving_listeners) 705 ex.continuous_train_and_eval() 706 self.assertEqual(1, est.fit_count) 707 self.assertEqual(1, est.eval_count) 708 self.assertEqual(1, est.export_count) 709 self.assertEqual([noop_hook], est.eval_hooks) 710 711 def test_continuous_train_and_eval_with_predicate_fn(self): 712 for est in self._estimators_for_tests(eval_dict={'global_step': 100}): 713 eval_metrics = 'eval_metrics' if not isinstance( 714 est, core_estimator.Estimator) else None 715 export_strategy = saved_model_export_utils.make_export_strategy( 716 est, 717 None if isinstance(est, core_estimator.Estimator) else 'export_input', 718 exports_to_keep=None) 719 ex = experiment.Experiment( 720 est, 721 train_input_fn='train_input', 722 eval_input_fn='eval_input', 723 eval_metrics=eval_metrics, 724 train_steps=100000000000, # a value will make `ex` never stops. 725 eval_steps=100, 726 export_strategies=export_strategy) 727 728 def predicate_fn(eval_result): 729 del eval_result # unused. for fn signature. 730 return False 731 732 ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn) 733 self.assertEqual(0, est.fit_count) 734 self.assertEqual(0, est.eval_count) 735 self.assertEqual(0, est.export_count) 736 737 def test_continuous_train_and_eval_with_adapted_steps_per_iteration(self): 738 mock_estimator = test.mock.Mock(core_estimator.Estimator) 739 type(mock_estimator).model_dir = test.mock.PropertyMock( 740 return_value='test_dir') 741 742 total_steps = 100000000000000 743 ex = experiment.Experiment( 744 mock_estimator, 745 train_input_fn='train_input', 746 eval_input_fn='eval_input', 747 train_steps=total_steps) 748 749 def predicate_fn(eval_result): 750 # Allows the first invoke only. 751 return eval_result is None 752 753 ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn) 754 mock_estimator.train.assert_called_once_with( 755 input_fn='train_input', 756 steps=int(total_steps / 10), 757 max_steps=test.mock.ANY, 758 hooks=test.mock.ANY, 759 saving_listeners=test.mock.ANY) 760 761 def test_continuous_train_and_eval_with_steps_per_iteration_from_user(self): 762 mock_estimator = test.mock.Mock(core_estimator.Estimator) 763 type(mock_estimator).model_dir = test.mock.PropertyMock( 764 return_value='test_dir') 765 766 total_steps = 100000000000000 767 ex = experiment.Experiment( 768 mock_estimator, 769 train_input_fn='train_input', 770 eval_input_fn='eval_input', 771 train_steps_per_iteration=1234, 772 train_steps=total_steps) 773 774 def predicate_fn(eval_result): 775 # Allows the first invoke only. 776 return eval_result is None 777 778 ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn) 779 mock_estimator.train.assert_called_once_with( 780 input_fn='train_input', 781 steps=1234, 782 max_steps=test.mock.ANY, 783 hooks=test.mock.ANY, 784 saving_listeners=test.mock.ANY) 785 786 def test_continuous_train_and_eval_with_default_steps_per_iteration(self): 787 mock_estimator = test.mock.Mock(core_estimator.Estimator) 788 type(mock_estimator).model_dir = test.mock.PropertyMock( 789 return_value='test_dir') 790 791 ex = experiment.Experiment( 792 mock_estimator, 793 train_input_fn='train_input', 794 eval_input_fn='eval_input', 795 train_steps_per_iteration=None, 796 train_steps=None) 797 798 def predicate_fn(eval_result): 799 # Allows the first invoke only. 800 return eval_result is None 801 802 ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn) 803 mock_estimator.train.assert_called_once_with( 804 input_fn='train_input', 805 steps=1000, 806 max_steps=test.mock.ANY, 807 hooks=test.mock.ANY, 808 saving_listeners=test.mock.ANY) 809 810 def test_continuous_train_and_eval_with_invalid_predicate_fn(self): 811 for est in self._estimators_for_tests(): 812 ex = experiment.Experiment( 813 est, 814 train_input_fn='train_input', 815 eval_input_fn='eval_input') 816 with self.assertRaisesRegexp( 817 ValueError, '`continuous_eval_predicate_fn` must be a callable'): 818 ex.continuous_train_and_eval(continuous_eval_predicate_fn='fn') 819 820 def test_continuous_train_and_eval_with_invalid_train_steps_iterations(self): 821 for est in self._estimators_for_tests(): 822 with self.assertRaisesRegexp( 823 ValueError, '`train_steps_per_iteration` must be an integer.'): 824 experiment.Experiment( 825 est, 826 train_input_fn='train_input', 827 eval_input_fn='eval_input', 828 train_steps_per_iteration='123') 829 830 @test.mock.patch.object(server_lib, 'Server') 831 def test_run_std_server(self, mock_server): 832 # Arrange. 833 tf_config = { 834 'cluster': self._cluster_spec(), 835 'task': { 836 'type': run_config_lib.TaskType.PS, 837 'index': 1 838 } 839 } 840 with test.mock.patch.dict('os.environ', 841 {'TF_CONFIG': json.dumps(tf_config)}): 842 config = run_config_lib.RunConfig( 843 master='host2:2222', 844 num_cores=15, 845 gpu_memory_fraction=0.314,) 846 for est in self._estimators_for_tests(config): 847 ex = experiment.Experiment( 848 est, train_input_fn='train_input', eval_input_fn='eval_input') 849 850 # Act. 851 ex.run_std_server() 852 853 # Assert. 854 mock_server.assert_has_calls( 855 [test.mock.call().start(), test.mock.call().join()]) 856 857 @test.mock.patch.object(server_lib, 'Server') 858 def test_run_std_server_raises_without_cluster_spec(self, mock_server): 859 config = run_config_lib.RunConfig(master='host4:2222') 860 for est in self._estimators_for_tests(config): 861 with self.assertRaises(ValueError): 862 ex = experiment.Experiment( 863 est, 864 train_input_fn='train_input', 865 eval_input_fn='eval_input') 866 ex.run_std_server() 867 868 def test_test(self): 869 for est in self._estimators_for_tests(): 870 exp_strategy = saved_model_export_utils.make_export_strategy( 871 est, 872 None if isinstance(est, core_estimator.Estimator) else 'export_input', 873 exports_to_keep=None) 874 if isinstance(est, core_estimator.Estimator): 875 eval_metrics = None 876 saving_listeners = 'saving_listeners' 877 else: 878 eval_metrics = 'eval_metrics' 879 saving_listeners = None 880 ex = experiment.Experiment( 881 est, 882 train_input_fn='train_input', 883 eval_input_fn='eval_input', 884 export_strategies=(exp_strategy,), 885 eval_metrics=eval_metrics, 886 saving_listeners=saving_listeners) 887 ex.test() 888 self.assertEqual(1, est.fit_count) 889 self.assertEqual(1, est.eval_count) 890 self.assertEqual(1, est.export_count) 891 892 def test_continuous_eval_evaluates_checkpoint_once(self): 893 for est in self._estimators_for_tests(eval_dict={'global_step': 100}): 894 eval_metrics = 'eval_metrics' if not isinstance( 895 est, core_estimator.Estimator) else None 896 est.fake_checkpoint() 897 898 result = { 899 'called': 0, 900 'called_with_eval_result': 0, 901 } 902 # pylint: disable=cell-var-from-loop 903 def _predicate_fn(eval_result): 904 result['called'] += 1 905 if eval_result: 906 # If eval_result is not empty nor None, the checkpoint has been 907 # evaluated. 908 result['called_with_eval_result'] += 1 909 # With 300 times of evaluation, this should prove something. 910 return result['called'] < 300 911 # pylint: enable=cell-var-from-loop 912 913 ex = experiment.Experiment( 914 est, 915 train_input_fn='train_input', 916 eval_input_fn='eval_input', 917 eval_metrics=eval_metrics, 918 eval_delay_secs=0, 919 continuous_eval_throttle_secs=0) 920 ex.continuous_eval(evaluate_checkpoint_only_once=True, 921 continuous_eval_predicate_fn=_predicate_fn) 922 923 self.assertEqual(0, est.fit_count) 924 self.assertEqual(1, est.eval_count) 925 self.assertEqual(300, result['called']) 926 self.assertEqual(1, result['called_with_eval_result']) 927 928 def test_checkpoint_and_export(self): 929 model_dir = tempfile.mkdtemp() 930 config = run_config_lib.RunConfig(save_checkpoints_steps=3) 931 est = dnn.DNNClassifier( 932 n_classes=3, 933 feature_columns=[ 934 feature_column.real_valued_column('feature', dimension=4) 935 ], 936 hidden_units=[3, 3], 937 model_dir=model_dir, 938 config=config) 939 940 exp_strategy = saved_model_export_utils.make_export_strategy( 941 est, 'export_input', exports_to_keep=None) 942 943 ex = experiment.Experiment( 944 est, 945 train_input_fn=test_data.iris_input_multiclass_fn, 946 eval_input_fn=test_data.iris_input_multiclass_fn, 947 export_strategies=(exp_strategy,), 948 train_steps=8, 949 checkpoint_and_export=True, 950 eval_delay_secs=0) 951 952 with test.mock.patch.object(ex, '_maybe_export'): 953 with test.mock.patch.object(ex, '_call_evaluate'): 954 ex.train_and_evaluate() 955 # Eval and export are called after steps 1, 4, 7, and 8 (after training 956 # is completed). 957 self.assertEqual(ex._maybe_export.call_count, 4) 958 self.assertEqual(ex._call_evaluate.call_count, 4) 959 960 961 if __name__ == '__main__': 962 test.main() 963