1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Tests for training.py.""" 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import glob 23 import json 24 import os 25 import random 26 import shutil 27 import tempfile 28 import time 29 30 import numpy as np 31 32 from tensorflow.python.estimator import estimator as estimator_lib 33 from tensorflow.python.estimator import exporter as exporter_lib 34 from tensorflow.python.estimator import run_config as run_config_lib 35 from tensorflow.python.estimator import training 36 from tensorflow.python.estimator.canned import dnn 37 from tensorflow.python.estimator.canned import prediction_keys 38 from tensorflow.python.estimator.export import export as export_lib 39 from tensorflow.python.estimator.inputs import numpy_io 40 from tensorflow.python.feature_column import feature_column 41 from tensorflow.python.framework import ops 42 from tensorflow.python.ops import control_flow_ops 43 from tensorflow.python.platform import gfile 44 from tensorflow.python.platform import test 45 from tensorflow.python.platform import tf_logging as logging 46 from tensorflow.python.summary import summary_iterator 47 from tensorflow.python.summary.writer import writer_cache 48 from tensorflow.python.training import basic_session_run_hooks 49 from tensorflow.python.training import monitored_session 50 from tensorflow.python.training import server_lib 51 from tensorflow.python.training import session_run_hook 52 from tensorflow.python.util import compat 53 54 _DEFAULT_EVAL_STEPS = 100 55 _DEFAULT_EVAL_DELAY_SECS = 120 56 _DEFAULT_EVAL_THROTTLE_SECS = 600 57 _DELAY_SECS_PER_WORKER = 5 58 _GLOBAL_STEP_KEY = ops.GraphKeys.GLOBAL_STEP 59 _INVALID_INPUT_FN_MSG = '`input_fn` must be callable' 60 _INVALID_HOOK_MSG = 'All hooks must be `SessionRunHook` instances' 61 _INVALID_MAX_STEPS_MSG = 'Must specify max_steps > 0' 62 _INVALID_STEPS_MSG = 'Must specify steps > 0' 63 _INVALID_NAME_MSG = '`name` must be string' 64 _INVALID_EVAL_DELAY_SECS_MSG = 'Must specify start_delay_secs >= 0' 65 _INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0' 66 _INVALID_ESTIMATOR_MSG = '`estimator` must have type `tf.estimator.Estimator`' 67 _STALE_CHECKPOINT_MSG = 'There was no new checkpoint after the training.' 68 _INVALID_EXPORTER_MSG = '`exporters` must be an Exporter' 69 _INVALID_EXPORTER_NAME_TYPE_MSG = 'An Exporter must have a string name' 70 _DUPLICATE_EXPORTER_NAMES_MSG = '`exporters` must have unique names.' 71 _NONE_EXPORTER_NAME_MSG = ( 72 'An Exporter cannot have a name that is `None` or empty.') 73 _INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`' 74 _INVALID_EVAL_SPEC_MSG = '`eval_spec` must have type `tf.estimator.EvalSpec`' 75 _INVALID_EVAL_LISTENER_MSG = 'must have type `_ContinuousEvalListener`' 76 _INVALID_CONFIG_FOR_STD_SERVER_MSG = 'Could not start server; .*TF_CONFIG' 77 _INVALID_LOCAL_TASK_WITH_CLUSTER = '`task.type` in TF_CONFIG cannot be `local`' 78 _INVALID_TASK_TYPE = '`estimator.config` must have task_type set.' 79 # The message should NOT have 'local' word as part of it. As (?!word) is looking 80 # ahead, so, the $ (ending) check is required; otherwise, it will match 81 # partially and return successuful. 82 _INVALID_TASK_TO_RUN = ( 83 'Task type .* is not supported. Supported task types are ((?!local).)*$') 84 _INVALID_EMPTY_EVAL_RESULT_ERR = ( 85 'Internal error: `Estimator.evaluate` should never return empty metrics') 86 _INVALID_EVAL_RESULT_TYPE_ERR = '`Estimator.evaluate` should return dict.' 87 _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR = ( 88 'Internal error: `Estimator.evaluate` result should have `global_step`') 89 _INVALID_EVAL_TASK_ID_ERR = ( 90 'there can only be one `evaluator` task .*with task id 0') 91 92 _TF_CONFIG_FOR_CHIEF = { 93 'cluster': { 94 run_config_lib.TaskType.CHIEF: ['host0:0'], 95 run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], 96 run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4'] 97 }, 98 'task': { 99 'type': run_config_lib.TaskType.CHIEF, 100 'index': 0 101 } 102 } 103 104 _TF_CONFIG_FOR_MASTER = { 105 'cluster': { 106 run_config_lib.TaskType.MASTER: ['host0:0'], 107 run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], 108 run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4'] 109 }, 110 'task': { 111 'type': run_config_lib.TaskType.MASTER, 112 'index': 0 113 } 114 } 115 116 _TF_CONFIG_FOR_WORKER = { 117 'cluster': { 118 run_config_lib.TaskType.CHIEF: ['host0:0'], 119 run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], 120 run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4'] 121 }, 122 'task': { 123 'type': run_config_lib.TaskType.WORKER, 124 'index': 1 125 } 126 } 127 128 _TF_CONFIG_FOR_PS = { 129 'cluster': { 130 run_config_lib.TaskType.CHIEF: ['host0:0'], 131 run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], 132 run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4'] 133 }, 134 'task': { 135 'type': run_config_lib.TaskType.PS, 136 'index': 1 137 } 138 } 139 140 _TF_CONFIG_FOR_EVALUATOR = { 141 'cluster': { 142 run_config_lib.TaskType.CHIEF: ['host0:0'], 143 run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], 144 run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4'] 145 }, 146 'task': { 147 'type': run_config_lib.TaskType.EVALUATOR, 148 'index': 0 149 } 150 } 151 152 _TF_CONFIG_FOR_GOOGLE = {'environment': 'google'} 153 154 155 class _FakeHook(session_run_hook.SessionRunHook): 156 """Fake implementation of `SessionRunHook`.""" 157 158 159 class _InvalidHook(object): 160 """Invalid hook (not a subclass of `SessionRunHook`).""" 161 162 163 def _create_exporter(name): 164 class FakeExporter(exporter_lib.Exporter): 165 166 def __init__(self, name): 167 self._name = name 168 169 @property 170 def name(self): 171 return self._name 172 173 def export(self, *args, **kwargs): 174 del args, kwargs 175 176 return FakeExporter(name=name) 177 178 179 def _create_run_config_with_cluster_spec(tf_config): 180 with test.mock.patch.dict('os.environ', {'TF_CONFIG': json.dumps(tf_config)}): 181 return run_config_lib.RunConfig() 182 183 184 class TrainSpecTest(test.TestCase): 185 """Tests TrainSpec.""" 186 187 def testRequiredArgumentsSet(self): 188 """Tests that no errors are raised when all required arguments are set.""" 189 spec = training.TrainSpec(input_fn=lambda: 1) 190 self.assertEqual(1, spec.input_fn()) 191 self.assertIsNone(spec.max_steps) 192 self.assertEqual(0, len(spec.hooks)) 193 194 def testAllArgumentsSet(self): 195 """Tests that no errors are raised when all arguments are set.""" 196 hooks = [_FakeHook()] 197 spec = training.TrainSpec(input_fn=lambda: 1, max_steps=2, hooks=hooks) 198 self.assertEqual(1, spec.input_fn()) 199 self.assertEqual(2, spec.max_steps) 200 self.assertEqual(tuple(hooks), spec.hooks) 201 202 def testInvalidInputFn(self): 203 with self.assertRaisesRegexp(TypeError, _INVALID_INPUT_FN_MSG): 204 training.TrainSpec(input_fn='invalid') 205 206 def testInvalidMaxStep(self): 207 with self.assertRaisesRegexp(ValueError, _INVALID_MAX_STEPS_MSG): 208 training.TrainSpec(input_fn=lambda: 1, max_steps=0) 209 210 def testInvalidHook(self): 211 with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG): 212 training.TrainSpec(input_fn=lambda: 1, hooks=[_InvalidHook()]) 213 214 215 class EvalSpecTest(test.TestCase): 216 """Tests EvalSpec.""" 217 218 def testRequiredArgumentsSet(self): 219 """Tests that no errors are raised when all required arguments are set.""" 220 spec = training.EvalSpec(input_fn=lambda: 1) 221 self.assertEqual(1, spec.input_fn()) 222 self.assertEqual(_DEFAULT_EVAL_STEPS, spec.steps) 223 self.assertIsNone(spec.name) 224 self.assertEqual(0, len(spec.hooks)) 225 self.assertEqual(0, len(spec.exporters)) 226 self.assertEqual(_DEFAULT_EVAL_DELAY_SECS, spec.start_delay_secs) 227 self.assertEqual(_DEFAULT_EVAL_THROTTLE_SECS, spec.throttle_secs) 228 229 def testAllArgumentsSet(self): 230 """Tests that no errors are raised when all arguments are set.""" 231 hooks = [_FakeHook()] 232 exporter = _create_exporter('a') 233 234 spec = training.EvalSpec( 235 input_fn=lambda: 1, 236 steps=2, 237 name='name', 238 hooks=hooks, 239 exporters=exporter, 240 start_delay_secs=3, 241 throttle_secs=4) 242 self.assertEqual(1, spec.input_fn()) 243 self.assertEqual(2, spec.steps) 244 self.assertEqual('name', spec.name) 245 self.assertEqual(tuple(hooks), spec.hooks) 246 self.assertEqual((exporter,), spec.exporters) 247 self.assertEqual(3, spec.start_delay_secs) 248 self.assertEqual(4, spec.throttle_secs) 249 250 def testListOfExporters(self): 251 """Tests that no errors are raised with multiple exporters.""" 252 exporters = [_create_exporter('a'), _create_exporter('b')] 253 254 spec = training.EvalSpec(input_fn=lambda: 1, exporters=exporters) 255 self.assertEqual(1, spec.input_fn()) 256 self.assertEqual(tuple(exporters), spec.exporters) 257 258 def testInvalidInputFn(self): 259 with self.assertRaisesRegexp(TypeError, _INVALID_INPUT_FN_MSG): 260 training.EvalSpec(input_fn='invalid') 261 262 def testInvalidMaxStep(self): 263 with self.assertRaisesRegexp(ValueError, _INVALID_STEPS_MSG): 264 training.EvalSpec(input_fn=lambda: 1, steps=0) 265 266 def testInvalidName(self): 267 with self.assertRaisesRegexp(TypeError, _INVALID_NAME_MSG): 268 training.EvalSpec(input_fn=lambda: 1, name=123) 269 270 def testInvalidHook(self): 271 with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG): 272 training.EvalSpec(input_fn=lambda: 1, hooks=[_InvalidHook()]) 273 274 def testInvalidDelaySecs(self): 275 with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_DELAY_SECS_MSG): 276 training.EvalSpec(input_fn=lambda: 1, start_delay_secs=-1) 277 278 def testInvalidThrottleSecs(self): 279 with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_THROTTLE_SECS_MSG): 280 training.EvalSpec(input_fn=lambda: 1, throttle_secs=-1) 281 282 def testInvalidTypeOfListOfExporters(self): 283 with self.assertRaisesRegexp(TypeError, _INVALID_EXPORTER_MSG): 284 training.EvalSpec( 285 input_fn=lambda: 1, exporters=[_create_exporter('a'), 286 _FakeHook()]) 287 288 def testInvalidTypeOfIndividualExporter(self): 289 with self.assertRaisesRegexp(TypeError, _INVALID_EXPORTER_MSG): 290 training.EvalSpec(input_fn=lambda: 1, exporters=_FakeHook()) 291 292 def testInvalidTypeOfExporterName(self): 293 with self.assertRaisesRegexp(ValueError, _INVALID_EXPORTER_NAME_TYPE_MSG): 294 training.EvalSpec(input_fn=lambda: 1, 295 exporters=_create_exporter(name=123)) 296 297 def testMultipleExportersWithTheSameName(self): 298 with self.assertRaisesRegexp(ValueError, _DUPLICATE_EXPORTER_NAMES_MSG): 299 training.EvalSpec( 300 input_fn=lambda: 1, 301 exporters=[_create_exporter('a'), _create_exporter('a')]) 302 303 def testMultipleExportersAndOneWithoutAName(self): 304 with self.assertRaisesRegexp(ValueError, _NONE_EXPORTER_NAME_MSG): 305 training.EvalSpec( 306 input_fn=lambda: 1, 307 exporters=[_create_exporter('a'), 308 _create_exporter(None)]) 309 310 def testSingleExporterWithoutAName(self): 311 with self.assertRaisesRegexp(ValueError, _NONE_EXPORTER_NAME_MSG): 312 training.EvalSpec(input_fn=lambda: 1, exporters=_create_exporter(None)) 313 314 315 class TrainAndEvaluateTest(test.TestCase): 316 317 def test_run_task(self): 318 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 319 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 320 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 321 322 with test.mock.patch.object(training, '_TrainingExecutor') as mock_executor: 323 mock_executor_instance = test.mock.Mock() 324 mock_executor.return_value = mock_executor_instance 325 training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec) 326 mock_executor.assert_called_with(estimator=mock_est, 327 train_spec=mock_train_spec, 328 eval_spec=mock_eval_spec) 329 self.assertTrue(mock_executor_instance.run.called) 330 331 def test_error_out_if_evaluator_task_id_is_non_zero(self): 332 tf_config = { 333 'cluster': { 334 run_config_lib.TaskType.CHIEF: ['host0:0'], 335 }, 336 'task': { 337 'type': run_config_lib.TaskType.EVALUATOR, 338 'index': 1 339 } 340 } 341 342 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 343 mock_est.config = _create_run_config_with_cluster_spec(tf_config) 344 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 345 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 346 347 with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_TASK_ID_ERR): 348 training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec) 349 350 def test_invalid_estimator(self): 351 invalid_estimator = object() 352 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 353 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 354 355 with self.assertRaisesRegexp(TypeError, _INVALID_ESTIMATOR_MSG): 356 training.train_and_evaluate(invalid_estimator, mock_train_spec, 357 mock_eval_spec) 358 359 360 class TrainingExecutorConstructorTest(test.TestCase): 361 """Tests constructor of _TrainingExecutor.""" 362 363 def testRequiredArgumentsSet(self): 364 estimator = estimator_lib.Estimator(model_fn=lambda features: features) 365 train_spec = training.TrainSpec(input_fn=lambda: 1) 366 eval_spec = training.EvalSpec(input_fn=lambda: 1) 367 368 executor = training._TrainingExecutor(estimator, train_spec, eval_spec) 369 self.assertEqual(estimator, executor.estimator) 370 371 def test_invalid_estimator(self): 372 invalid_estimator = object() 373 train_spec = training.TrainSpec(input_fn=lambda: 1) 374 eval_spec = training.EvalSpec(input_fn=lambda: 1) 375 376 with self.assertRaisesRegexp(TypeError, _INVALID_ESTIMATOR_MSG): 377 training._TrainingExecutor(invalid_estimator, train_spec, eval_spec) 378 379 def test_invalid_train_spec(self): 380 estimator = estimator_lib.Estimator(model_fn=lambda features: features) 381 invalid_train_spec = object() 382 eval_spec = training.EvalSpec(input_fn=lambda: 1) 383 384 with self.assertRaisesRegexp(TypeError, _INVALID_TRAIN_SPEC_MSG): 385 training._TrainingExecutor(estimator, invalid_train_spec, eval_spec) 386 387 def test_invalid_eval_spec(self): 388 estimator = estimator_lib.Estimator(model_fn=lambda features: features) 389 train_spec = training.TrainSpec(input_fn=lambda: 1) 390 invalid_eval_spec = object() 391 392 with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG): 393 training._TrainingExecutor(estimator, train_spec, invalid_eval_spec) 394 395 def test_invalid_train_hooks(self): 396 estimator = estimator_lib.Estimator(model_fn=lambda features: features) 397 train_spec = training.TrainSpec(input_fn=lambda: 1) 398 eval_spec = training.EvalSpec(input_fn=lambda: 1) 399 invalid_train_hooks = [object()] 400 401 with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG): 402 training._TrainingExecutor( 403 estimator, train_spec, eval_spec, train_hooks=invalid_train_hooks) 404 405 def test_invalid_continuous_eval_listener(self): 406 estimator = estimator_lib.Estimator(model_fn=lambda features: features) 407 train_spec = training.TrainSpec(input_fn=lambda: 1) 408 eval_spec = training.EvalSpec(input_fn=lambda: 1) 409 invalid_continuous_eval_listener = object() 410 411 with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_LISTENER_MSG): 412 training._TrainingExecutor( 413 estimator, 414 train_spec, 415 eval_spec, 416 continuous_eval_listener=invalid_continuous_eval_listener) 417 418 419 class _TrainingExecutorTrainingTest(object): 420 """Tests training of _TrainingExecutor.""" 421 422 def __init__(self, run_config): 423 self._run_config = run_config 424 425 def _run_task(self, executor): 426 # We should not call executor.run as the test here is intended to test 427 # run_foo explicitly (foo is the task type). 428 return getattr(executor, 'run_' + self._run_config.task_type)() 429 430 @test.mock.patch.object(time, 'sleep') 431 @test.mock.patch.object(server_lib, 'Server') 432 def test_train_with_train_spec(self, mock_server, unused_mock_sleep): 433 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 434 mock_est.config = self._run_config 435 train_spec = training.TrainSpec( 436 input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) 437 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 438 mock_server_instance = mock_server.return_value 439 440 executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec) 441 self._run_task(executor) 442 443 mock_server.assert_called_with( 444 mock_est.config.cluster_spec, 445 job_name=mock_est.config.task_type, 446 task_index=mock_est.config.task_id, 447 config=test.mock.ANY, 448 start=False) 449 450 self.assertTrue(mock_server_instance.start.called) 451 452 mock_est.train.assert_called_with( 453 input_fn=train_spec.input_fn, 454 max_steps=train_spec.max_steps, 455 hooks=list(train_spec.hooks), 456 saving_listeners=test.mock.ANY) 457 mock_est.evaluate.assert_not_called() 458 mock_est.export_savedmodel.assert_not_called() 459 460 @test.mock.patch.object(time, 'sleep') 461 @test.mock.patch.object(server_lib, 'Server') 462 def test_train_with_train_hooks(self, unused_mock_server, unused_mock_sleep): 463 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 464 mock_est.config = self._run_config 465 train_spec = training.TrainSpec( 466 input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) 467 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 468 extra_hooks = [_FakeHook()] 469 470 executor = training._TrainingExecutor( 471 mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks) 472 self._run_task(executor) 473 474 mock_est.train.assert_called_with( 475 input_fn=train_spec.input_fn, 476 max_steps=train_spec.max_steps, 477 hooks=list(train_spec.hooks) + extra_hooks, 478 saving_listeners=test.mock.ANY) 479 480 @test.mock.patch.object(time, 'sleep') 481 @test.mock.patch.object(server_lib, 'Server') 482 def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep): 483 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 484 mock_est.config = self._run_config 485 mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[]) 486 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 487 488 executor = training._TrainingExecutor(mock_est, mock_train_spec, 489 mock_eval_spec) 490 tf_config = {'TF_CONFIG': json.dumps(_TF_CONFIG_FOR_GOOGLE)} 491 with test.mock.patch.dict('os.environ', tf_config): 492 self._run_task(executor) 493 mock_server.assert_not_called() 494 495 def test_fail_with_empty_cluster_spec(self): 496 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 497 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 498 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 499 500 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 501 mock_est.config.cluster_spec = None 502 mock_est.config.master = 'grpc://...' 503 mock_est.config.task_type = 'worker' 504 mock_est.config.task_id = 2 505 506 with self.assertRaisesRegexp(RuntimeError, 507 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 508 self._run_task(training._TrainingExecutor(mock_est, mock_train_spec, 509 mock_eval_spec)) 510 511 def test_fail_with_empty_master(self): 512 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 513 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 514 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 515 516 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 517 mock_est.config.cluster_spec = server_lib.ClusterSpec( 518 {'worker': ['dummy', 'dummy1']}) 519 mock_est.config.master = '' 520 mock_est.config.task_type = 'worker' 521 mock_est.config.task_id = 2 522 523 with self.assertRaisesRegexp(RuntimeError, 524 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 525 self._run_task(training._TrainingExecutor(mock_est, mock_train_spec, 526 mock_eval_spec)) 527 528 @test.mock.patch.object(time, 'sleep') 529 @test.mock.patch.object(server_lib, 'Server') 530 def test_single_worker_node_with_empty_tf_master( 531 self, mock_server, unused_mock_sleep): 532 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 533 mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[]) 534 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 535 536 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 537 # Single node cluster. 538 mock_est.config.cluster_spec = server_lib.ClusterSpec({'worker': ['dummy']}) 539 mock_est.config.master = '' 540 mock_est.config.task_type = 'worker' 541 mock_est.config.task_id = 2 542 543 self._run_task(training._TrainingExecutor(mock_est, mock_train_spec, 544 mock_eval_spec)) 545 self.assertTrue(mock_est.train.called) 546 mock_server.assert_not_called() 547 548 def test_fail_with_empty_task_type(self): 549 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 550 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 551 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 552 553 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 554 mock_est.config.cluster_spec = server_lib.ClusterSpec({'worker': ['dummy']}) 555 mock_est.config.master = 'grpc://...' 556 mock_est.config.task_type = '' 557 mock_est.config.task_id = 2 558 559 with self.assertRaisesRegexp(RuntimeError, 560 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 561 self._run_task(training._TrainingExecutor(mock_est, mock_train_spec, 562 mock_eval_spec)) 563 564 def test_fail_with_none_task_id(self): 565 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 566 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 567 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 568 569 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 570 mock_est.config.cluster_spec = server_lib.ClusterSpec({'worker': ['dummy']}) 571 mock_est.config.master = 'grpc://...' 572 mock_est.config.task_type = 'worker' 573 mock_est.config.task_id = None 574 575 with self.assertRaisesRegexp(RuntimeError, 576 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 577 self._run_task(training._TrainingExecutor(mock_est, mock_train_spec, 578 mock_eval_spec)) 579 580 581 class TrainingExecutorRunWorkerTest(_TrainingExecutorTrainingTest, 582 test.TestCase): 583 """Tests run_worker of _TrainingExecutor.""" 584 585 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 586 test.TestCase.__init__(self, methodName) 587 _TrainingExecutorTrainingTest.__init__( 588 self, 589 run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_WORKER)) 590 591 @test.mock.patch.object(server_lib, 'Server') 592 def test_delay_for_worker(self, _): 593 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 594 mock_est.config = self._run_config 595 mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[]) 596 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 597 598 executor = training._TrainingExecutor(mock_est, mock_train_spec, 599 mock_eval_spec) 600 601 expected_secs = (self._run_config.task_id + 1) * _DELAY_SECS_PER_WORKER 602 with test.mock.patch.object(time, 'sleep') as mock_sleep: 603 mock_sleep.side_effect = lambda s: self.assertEqual(expected_secs, s) 604 self._run_task(executor) 605 self.assertTrue(mock_sleep.called) 606 607 608 class TrainingExecutorRunChiefTest(_TrainingExecutorTrainingTest, 609 test.TestCase): 610 """Tests run_chief of _TrainingExecutor.""" 611 612 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 613 test.TestCase.__init__(self, methodName) 614 _TrainingExecutorTrainingTest.__init__( 615 self, 616 run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_CHIEF)) 617 618 @test.mock.patch.object(server_lib, 'Server') 619 def test_no_delay_for_chief(self, _): 620 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 621 mock_est.config = self._run_config 622 mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[]) 623 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 624 625 executor = training._TrainingExecutor(mock_est, mock_train_spec, 626 mock_eval_spec) 627 628 with test.mock.patch.object(time, 'sleep') as mock_sleep: 629 self._run_task(executor) 630 mock_sleep.assert_not_called() 631 632 633 class TrainingExecutorRunMasterTest(test.TestCase): 634 """Tests run_chief of _TrainingExecutor.""" 635 636 def setUp(self): 637 self._run_config = _create_run_config_with_cluster_spec( 638 _TF_CONFIG_FOR_MASTER) 639 640 @test.mock.patch.object(server_lib, 'Server') 641 def test_no_delay_for_master(self, _): 642 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 643 mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123} 644 mock_est.config = self._run_config 645 mock_train_spec = test.mock.Mock( 646 spec=training.TrainSpec, max_steps=123, hooks=[]) 647 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[]) 648 649 executor = training._TrainingExecutor(mock_est, mock_train_spec, 650 mock_eval_spec) 651 652 with test.mock.patch.object(time, 'sleep') as mock_sleep: 653 executor.run_master() 654 mock_sleep.assert_not_called() 655 656 @test.mock.patch.object(time, 'sleep') 657 @test.mock.patch.object(server_lib, 'Server') 658 def test_train_with_train_spec(self, mock_server, unused_mock_sleep): 659 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 660 mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123} 661 mock_est.config = self._run_config 662 train_spec = training.TrainSpec( 663 input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) 664 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[]) 665 mock_server_instance = mock_server.return_value 666 667 executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec) 668 executor.run_master() 669 670 mock_server.assert_called_with( 671 mock_est.config.cluster_spec, 672 job_name=mock_est.config.task_type, 673 task_index=mock_est.config.task_id, 674 config=test.mock.ANY, 675 start=False) 676 677 self.assertTrue(mock_server_instance.start.called) 678 679 mock_est.train.assert_called_with( 680 input_fn=train_spec.input_fn, 681 max_steps=train_spec.max_steps, 682 hooks=list(train_spec.hooks), 683 saving_listeners=test.mock.ANY) 684 mock_est.export_savedmodel.assert_not_called() 685 686 @test.mock.patch.object(time, 'sleep') 687 @test.mock.patch.object(server_lib, 'Server') 688 def test_train_with_train_hooks(self, mock_server, unused_mock_sleep): 689 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 690 mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123} 691 mock_est.config = self._run_config 692 train_spec = training.TrainSpec( 693 input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) 694 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[]) 695 extra_hooks = [_FakeHook()] 696 697 executor = training._TrainingExecutor( 698 mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks) 699 executor.run_master() 700 701 mock_est.train.assert_called_with( 702 input_fn=train_spec.input_fn, 703 max_steps=train_spec.max_steps, 704 hooks=list(train_spec.hooks) + extra_hooks, 705 saving_listeners=test.mock.ANY) 706 707 @test.mock.patch.object(time, 'sleep') 708 @test.mock.patch.object(server_lib, 'Server') 709 def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep): 710 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 711 mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123} 712 mock_est.config = self._run_config 713 mock_train_spec = test.mock.Mock( 714 spec=training.TrainSpec, max_steps=123, hooks=[]) 715 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[]) 716 717 executor = training._TrainingExecutor(mock_est, mock_train_spec, 718 mock_eval_spec) 719 tf_config = {'TF_CONFIG': json.dumps(_TF_CONFIG_FOR_GOOGLE)} 720 with test.mock.patch.dict('os.environ', tf_config): 721 executor.run_master() 722 mock_server.assert_not_called() 723 724 def test_fail_with_empty_cluster_spec(self): 725 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 726 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 727 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 728 729 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 730 mock_est.config.cluster_spec = None 731 mock_est.config.master = 'grpc://...' 732 mock_est.config.task_type = 'master' 733 mock_est.config.task_id = 2 734 735 with self.assertRaisesRegexp(RuntimeError, 736 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 737 training._TrainingExecutor( 738 mock_est, mock_train_spec, mock_eval_spec).run_master() 739 740 def test_fail_with_empty_master(self): 741 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 742 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 743 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 744 745 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 746 mock_est.config.cluster_spec = server_lib.ClusterSpec( 747 {'master': ['dummy'], 'worker': ['dummy1']}) 748 mock_est.config.master = '' 749 mock_est.config.task_type = 'master' 750 mock_est.config.task_id = 0 751 752 with self.assertRaisesRegexp(RuntimeError, 753 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 754 training._TrainingExecutor( 755 mock_est, mock_train_spec, mock_eval_spec).run_master() 756 757 @test.mock.patch.object(time, 'sleep') 758 @test.mock.patch.object(server_lib, 'Server') 759 def test_single_master_node_with_empty_tf_master( 760 self, mock_server, unused_mock_sleep): 761 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 762 mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123} 763 764 mock_train_spec = test.mock.Mock( 765 spec=training.TrainSpec, max_steps=123, hooks=[]) 766 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[]) 767 768 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 769 mock_est.config.cluster_spec = server_lib.ClusterSpec( 770 {'master': ['dummy']}) 771 mock_est.config.master = '' 772 mock_est.config.task_type = 'master' 773 mock_est.config.task_id = 0 774 775 executor = training._TrainingExecutor( 776 mock_est, mock_train_spec, mock_eval_spec) 777 executor.run_master() 778 779 mock_server.assert_not_called() 780 self.assertTrue(mock_est.train.called) 781 782 def test_fail_with_empty_task_type(self): 783 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 784 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 785 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 786 787 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 788 mock_est.config.cluster_spec = server_lib.ClusterSpec({'master': ['dummy']}) 789 mock_est.config.master = 'grpc://...' 790 mock_est.config.task_type = '' 791 mock_est.config.task_id = 2 792 793 with self.assertRaisesRegexp(RuntimeError, 794 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 795 training._TrainingExecutor( 796 mock_est, mock_train_spec, mock_eval_spec).run_master() 797 798 def test_fail_with_none_task_id(self): 799 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 800 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 801 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 802 803 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 804 mock_est.config.cluster_spec = server_lib.ClusterSpec({'master': ['dummy']}) 805 mock_est.config.master = 'grpc://...' 806 mock_est.config.task_type = 'master' 807 mock_est.config.task_id = None 808 809 with self.assertRaisesRegexp(RuntimeError, 810 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 811 training._TrainingExecutor( 812 mock_est, mock_train_spec, mock_eval_spec).run_master() 813 814 @test.mock.patch.object(server_lib, 'Server') 815 def test_run_master_triggers_evaluate_and_export(self, _): 816 817 def estimator_train(saving_listeners, *args, **kwargs): 818 # There shalt be a saving_listener. Estimator is going to call 819 # `after_save`. 820 del args, kwargs 821 saving_listeners[0].begin() 822 saving_listeners[0].after_save(session=None, global_step_value=None) 823 824 mock_est = test.mock.Mock( 825 spec=estimator_lib.Estimator, model_dir='path/', train=estimator_train) 826 mock_est.latest_checkpoint.return_value = 'checkpoint_path/' 827 mock_est.config = self._run_config 828 829 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 830 exporter.name = 'see_whether_export_is_called' 831 832 train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) 833 eval_spec = training.EvalSpec( 834 input_fn=lambda: 1, steps=2, exporters=exporter) 835 eval_result = {_GLOBAL_STEP_KEY: train_spec.max_steps} 836 mock_est.evaluate.return_value = eval_result 837 838 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 839 executor.run_master() 840 841 mock_est.evaluate.assert_called_with( 842 name=eval_spec.name, 843 input_fn=eval_spec.input_fn, 844 steps=eval_spec.steps, 845 checkpoint_path='checkpoint_path/', 846 hooks=eval_spec.hooks) 847 self.assertEqual(1, exporter.export.call_count) 848 exporter.export.assert_called_with( 849 estimator=mock_est, 850 export_path=os.path.join('path/', 'export', exporter.name), 851 checkpoint_path='checkpoint_path/', 852 eval_result=eval_result, 853 is_the_final_export=True) 854 855 @test.mock.patch.object(basic_session_run_hooks, 'SecondOrStepTimer') 856 @test.mock.patch.object(server_lib, 'Server') 857 def test_run_master_throttle_eval(self, _, mock_timer_class): 858 mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') 859 860 mock_timer = test.mock.Mock() 861 mock_timer_class.return_value = mock_timer 862 863 def estimator_train(saving_listeners, *args, **kwargs): 864 del args, kwargs 865 saving_listeners[0].begin() 866 867 # Call three times. 868 mock_timer.should_trigger_for_step.return_value = True 869 saving_listeners[0].after_save(session=None, global_step_value=None) 870 871 mock_timer.should_trigger_for_step.return_value = False 872 saving_listeners[0].after_save(session=None, global_step_value=None) 873 874 mock_timer.should_trigger_for_step.return_value = True 875 saving_listeners[0].after_save(session=None, global_step_value=None) 876 877 mock_est.train = estimator_train 878 mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2'] 879 mock_est.config = self._run_config 880 881 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 882 exporter.name = 'see_whether_export_is_called' 883 884 train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) 885 eval_spec = training.EvalSpec( 886 input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10) 887 888 mock_est.evaluate.side_effect = [ 889 {_GLOBAL_STEP_KEY: train_spec.max_steps //2}, 890 {_GLOBAL_STEP_KEY: train_spec.max_steps} 891 ] 892 893 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 894 executor.run_master() 895 896 self.assertEqual(2, mock_est.evaluate.call_count) 897 self.assertEqual(2, exporter.export.call_count) 898 899 is_final_export_list = [call[1]['is_the_final_export'] 900 for call in exporter.export.call_args_list] 901 self.assertEqual([False, True], is_final_export_list) 902 903 @test.mock.patch.object(basic_session_run_hooks, 'SecondOrStepTimer') 904 @test.mock.patch.object(server_lib, 'Server') 905 def test_run_master_throttle_eval_which_skips_final_ckpt( 906 self, _, mock_timer_class): 907 mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') 908 909 mock_timer = test.mock.Mock() 910 mock_timer_class.return_value = mock_timer 911 912 def estimator_train(saving_listeners, *args, **kwargs): 913 del args, kwargs 914 saving_listeners[0].begin() 915 916 # Call two times. 917 mock_timer.should_trigger_for_step.return_value = True 918 saving_listeners[0].after_save(session=None, global_step_value=None) 919 920 # The final ckpt is skipped by the timer. It will be picked up the final 921 # export check in the code. 922 mock_timer.should_trigger_for_step.return_value = False 923 saving_listeners[0].after_save(session=None, global_step_value=None) 924 925 mock_est.train = estimator_train 926 mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2'] 927 mock_est.config = self._run_config 928 929 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 930 exporter.name = 'see_whether_export_is_called' 931 932 train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) 933 eval_spec = training.EvalSpec( 934 input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10) 935 936 mock_est.evaluate.side_effect = [ 937 {_GLOBAL_STEP_KEY: train_spec.max_steps //2}, 938 {_GLOBAL_STEP_KEY: train_spec.max_steps} 939 ] 940 941 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 942 executor.run_master() 943 944 self.assertEqual(2, mock_est.evaluate.call_count) 945 self.assertEqual(2, exporter.export.call_count) 946 947 is_final_export_list = [call[1]['is_the_final_export'] 948 for call in exporter.export.call_args_list] 949 self.assertEqual([False, True], is_final_export_list) 950 951 952 class TrainingExecutorRunEvaluatorTest(test.TestCase): 953 """Tests run_evaluator of _TrainingExecutor.""" 954 955 def _set_up_mock_est_to_train_and_evaluate_once(self, mock_est, 956 mock_train_spec): 957 """Sets global step in eval result to end the while True eval loop.""" 958 training_max_step = 200 959 mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: training_max_step} 960 mock_train_spec.max_steps = training_max_step 961 962 def test_evaluate_with_evaluate_spec(self): 963 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 964 mock_est.latest_checkpoint.return_value = 'latest_it_is' 965 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 966 self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) 967 968 eval_spec = training.EvalSpec( 969 input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='cont_eval', 970 start_delay_secs=0, throttle_secs=0) 971 972 executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) 973 executor.run_evaluator() 974 975 mock_est.evaluate.assert_called_with( 976 name='cont_eval', 977 input_fn=eval_spec.input_fn, 978 steps=eval_spec.steps, 979 checkpoint_path='latest_it_is', 980 hooks=eval_spec.hooks) 981 self.assertFalse(mock_est.train.called) 982 983 def test_evaluate_with_train_hooks(self): 984 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 985 mock_est.latest_checkpoint.return_value = 'latest_it_is' 986 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 987 self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) 988 989 eval_spec = training.EvalSpec( 990 input_fn=lambda: 1, 991 steps=2, 992 hooks=[_FakeHook()], 993 name='cont_eval', 994 start_delay_secs=0, 995 throttle_secs=0) 996 997 # The train_hooks will not be called during eval. 998 mock_hook = test.mock.Mock(spec=session_run_hook.SessionRunHook) 999 executor = training._TrainingExecutor( 1000 mock_est, mock_train_spec, eval_spec, train_hooks=[mock_hook]) 1001 executor.run_evaluator() 1002 1003 mock_hook.begin.assert_not_called() 1004 1005 def test_evaluate_multiple_times(self): 1006 training_max_step = 200 1007 1008 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1009 mock_est.model_dir = compat.as_bytes(test.get_temp_dir()) 1010 mock_est.evaluate.side_effect = [ 1011 {_GLOBAL_STEP_KEY: training_max_step // 2}, 1012 {_GLOBAL_STEP_KEY: training_max_step} 1013 ] 1014 mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2'] 1015 1016 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1017 mock_train_spec.max_steps = training_max_step 1018 1019 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 1020 exporter.name = 'see_how_many_times_export_is_called' 1021 1022 mock_est.times_export_was_called = 0 1023 mock_est.times_final_export_was_called = 0 1024 def export(estimator, export_path, checkpoint_path, eval_result, 1025 is_the_final_export): 1026 del export_path, checkpoint_path, eval_result 1027 estimator.times_export_was_called += 1 1028 # final_export is happened at the end. 1029 self.assertEqual(0, estimator.times_final_export_was_called) 1030 if is_the_final_export: 1031 estimator.times_final_export_was_called += 1 1032 1033 exporter.export = export 1034 1035 eval_spec = training.EvalSpec( 1036 input_fn=lambda: 1, 1037 start_delay_secs=0, 1038 throttle_secs=0, 1039 exporters=exporter) 1040 1041 executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) 1042 executor.run_evaluator() 1043 1044 self.assertEqual(2, mock_est.evaluate.call_count) 1045 self.assertEqual(2, mock_est.times_export_was_called) 1046 self.assertEqual(1, mock_est.times_final_export_was_called) 1047 1048 def test_evaluate_listener_before_eval(self): 1049 training_max_step = 200 1050 1051 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1052 mock_est.model_dir = compat.as_bytes(test.get_temp_dir()) 1053 # Without early stopping, this eval will be run twice. 1054 mock_est.evaluate.side_effect = [{ 1055 _GLOBAL_STEP_KEY: training_max_step // 2 1056 }, { 1057 _GLOBAL_STEP_KEY: training_max_step 1058 }] 1059 mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2'] 1060 1061 mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[]) 1062 mock_train_spec.max_steps = training_max_step 1063 1064 class _Listener(training._ContinuousEvalListener): 1065 1066 def __init__(self): 1067 self.call_count = 0 1068 1069 def before_eval(self): 1070 self.call_count += 1 1071 return self.call_count == 1 1072 1073 listener = _Listener() 1074 1075 eval_spec = training.EvalSpec( 1076 input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0) 1077 1078 training._TrainingExecutor( 1079 mock_est, mock_train_spec, eval_spec, 1080 continuous_eval_listener=listener).run_evaluator() 1081 1082 # Before_eval returns False during the second time, so, evaluate will be 1083 # called once. 1084 self.assertEqual(1, mock_est.evaluate.call_count) 1085 self.assertEqual(2, listener.call_count) 1086 1087 def test_evaluate_listener_after_eval(self): 1088 training_max_step = 200 1089 1090 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1091 mock_est.model_dir = compat.as_bytes(test.get_temp_dir()) 1092 # Without early stopping, this eval will be run twice. 1093 expected_eval_metrics = [{ 1094 _GLOBAL_STEP_KEY: training_max_step // 2 1095 }, { 1096 _GLOBAL_STEP_KEY: training_max_step 1097 }] 1098 mock_est.evaluate.side_effect = expected_eval_metrics 1099 mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2'] 1100 1101 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1102 mock_train_spec.max_steps = training_max_step 1103 1104 class _Listener(training._ContinuousEvalListener): 1105 1106 def __init__(self): 1107 self.call_count = 0 1108 1109 def after_eval(self, eval_result): 1110 self.call_count += 1 1111 self.eval_result = eval_result 1112 return False 1113 1114 listener = _Listener() 1115 1116 eval_spec = training.EvalSpec( 1117 input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0) 1118 1119 training._TrainingExecutor( 1120 mock_est, mock_train_spec, eval_spec, 1121 continuous_eval_listener=listener).run_evaluator() 1122 1123 # after_eval returns False during the first time, so, evaluate will be 1124 # called once. 1125 self.assertEqual(1, mock_est.evaluate.call_count) 1126 self.assertEqual(1, listener.call_count) 1127 self.assertAllEqual(expected_eval_metrics[0], listener.eval_result.metrics) 1128 self.assertEqual('path_1', listener.eval_result.checkpoint_path) 1129 1130 def test_final_export_is_true_in_the_end(self): 1131 training_max_step = 200 1132 1133 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1134 mock_est.model_dir = compat.as_bytes(test.get_temp_dir()) 1135 mock_est.evaluate.side_effect = [ 1136 {_GLOBAL_STEP_KEY: training_max_step // 2}, 1137 {_GLOBAL_STEP_KEY: training_max_step} 1138 ] 1139 mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2'] 1140 1141 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1142 mock_train_spec.max_steps = training_max_step 1143 1144 mock_est.times_export_fn_was_called = 0 1145 mock_est.times_the_final_export_was_true = 0 1146 def export(estimator, export_path, checkpoint_path, eval_result, 1147 is_the_final_export): 1148 del export_path, checkpoint_path, eval_result 1149 estimator.times_export_fn_was_called += 1 1150 if is_the_final_export: 1151 estimator.times_the_final_export_was_true += 1 1152 1153 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 1154 exporter.name = 'see_how_many_times_export_is_called' 1155 exporter.export = export 1156 1157 eval_spec = training.EvalSpec( 1158 input_fn=lambda: 1, 1159 start_delay_secs=0, 1160 throttle_secs=0, 1161 exporters=exporter) 1162 1163 executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) 1164 executor.run_evaluator() 1165 1166 self.assertEqual(2, mock_est.evaluate.call_count) 1167 self.assertEqual(2, mock_est.times_export_fn_was_called) 1168 self.assertEqual(1, mock_est.times_the_final_export_was_true) 1169 1170 def test_skip_evaluation_due_to_ckpt(self): 1171 training_max_step = 200 1172 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1173 mock_est.evaluate.side_effect = [ 1174 {_GLOBAL_STEP_KEY: training_max_step // 2}, 1175 {_GLOBAL_STEP_KEY: training_max_step} 1176 ] 1177 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1178 mock_train_spec.max_steps = training_max_step 1179 1180 self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) 1181 1182 # First two items are invalid, next two items are same. 1183 mock_est.latest_checkpoint.side_effect = [ 1184 None, '', 'same', 'same', 'path_2' 1185 ] 1186 1187 eval_spec = training.EvalSpec( 1188 input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0) 1189 1190 executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) 1191 with test.mock.patch.object(logging, 'warning') as mock_log: 1192 executor.run_evaluator() 1193 1194 # Three checkpoint paths are invalid. 1195 self.assertEqual(5, mock_est.latest_checkpoint.call_count) 1196 self.assertEqual(2, mock_est.evaluate.call_count) 1197 1198 # Two warning logs are expected (last warning time is reset after a 1199 # successuful evaluation) 1200 self.assertEqual(2, mock_log.call_count) 1201 1202 def test_continuous_eval_listener_eval_result(self): 1203 training_max_step = 200 1204 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1205 expected_eval_metrics = [{ 1206 _GLOBAL_STEP_KEY: training_max_step // 2 1207 }, { 1208 _GLOBAL_STEP_KEY: training_max_step 1209 }] 1210 mock_est.evaluate.side_effect = expected_eval_metrics 1211 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1212 mock_train_spec.max_steps = training_max_step 1213 1214 class _Listener(training._ContinuousEvalListener): 1215 1216 def __init__(self): 1217 self.eval_results = [] 1218 1219 def after_eval(self, eval_result): 1220 self.eval_results.append(eval_result) 1221 return True 1222 1223 continuous_eval_listener = _Listener() 1224 1225 self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) 1226 1227 # First two items are invalid, next two items are same. 1228 mock_est.latest_checkpoint.side_effect = [ 1229 None, '', 'same', 'same', 'path_2' 1230 ] 1231 expected_eval_results = [ 1232 training._EvalResult(training._EvalStatus.MISSING_CHECKPOINT), 1233 training._EvalResult(training._EvalStatus.MISSING_CHECKPOINT), 1234 training._EvalResult( 1235 training._EvalStatus.EVALUATED, 1236 metrics=expected_eval_metrics[0], 1237 checkpoint_path='same'), 1238 training._EvalResult(training._EvalStatus.NO_NEW_CHECKPOINT), 1239 training._EvalResult( 1240 training._EvalStatus.EVALUATED, 1241 metrics=expected_eval_metrics[1], 1242 checkpoint_path='path_2'), 1243 ] 1244 1245 eval_spec = training.EvalSpec( 1246 input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0) 1247 1248 executor = training._TrainingExecutor( 1249 mock_est, 1250 mock_train_spec, 1251 eval_spec, 1252 continuous_eval_listener=continuous_eval_listener) 1253 executor.run_evaluator() 1254 1255 # Three checkpoint paths are invalid. 1256 self.assertEqual(5, mock_est.latest_checkpoint.call_count) 1257 self.assertEqual(2, mock_est.evaluate.call_count) 1258 1259 self.assertEqual(5, len(continuous_eval_listener.eval_results)) 1260 for i, result in enumerate(continuous_eval_listener.eval_results): 1261 self.assertEqual(expected_eval_results[i].status, result.status) 1262 self.assertAllEqual(expected_eval_results[i].metrics, result.metrics) 1263 self.assertEqual(expected_eval_results[i].checkpoint_path, 1264 result.checkpoint_path) 1265 1266 def test_sleep_start_delay_secs(self): 1267 training_max_step = 200 1268 start_delay_secs = 123 1269 1270 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1271 mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: training_max_step} 1272 mock_est.model_dir = compat.as_bytes(test.get_temp_dir()) 1273 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1274 mock_train_spec.max_steps = training_max_step 1275 1276 eval_spec = training.EvalSpec( 1277 input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='cont_eval', 1278 start_delay_secs=start_delay_secs, throttle_secs=0) 1279 1280 executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) 1281 with test.mock.patch.object(time, 'sleep') as mock_sleep: 1282 executor.run_evaluator() 1283 mock_sleep.assert_called_with(start_delay_secs) 1284 self.assertTrue(mock_est.evaluate.called) 1285 1286 @test.mock.patch.object(time, 'time') 1287 @test.mock.patch.object(time, 'sleep') 1288 def test_throttle_secs(self, mock_sleep, mock_time): 1289 throttle_secs = 123 1290 operation_secs = 12 1291 1292 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1293 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1294 self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) 1295 1296 eval_spec = training.EvalSpec( 1297 input_fn=lambda: 1, start_delay_secs=0, throttle_secs=throttle_secs) 1298 1299 mock_time.side_effect = [921, 921 + operation_secs] 1300 1301 executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) 1302 # Disable logging as it calls time.time also. 1303 with test.mock.patch.object(logging, 'info'): 1304 executor.run_evaluator() 1305 mock_sleep.assert_called_with(throttle_secs - operation_secs) 1306 self.assertTrue(mock_est.evaluate.called) 1307 1308 def test_that_export_is_called(self): 1309 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1310 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1311 self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) 1312 1313 def export(estimator, *args, **kwargs): 1314 del args, kwargs 1315 estimator.export_was_called = True 1316 1317 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 1318 exporter.name = 'see_whether_export_is_called' 1319 exporter.export = export 1320 1321 eval_spec = training.EvalSpec( 1322 input_fn=lambda: 1, 1323 steps=2, 1324 start_delay_secs=0, 1325 throttle_secs=0, 1326 exporters=exporter) 1327 1328 executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) 1329 executor.run_evaluator() 1330 1331 # Verify that export was called on the right estimator. 1332 self.assertTrue(mock_est.export_was_called) 1333 1334 def test_errors_out_if_evaluate_returns_empty_dict(self): 1335 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1336 train_spec = training.TrainSpec(input_fn=lambda: 1) 1337 eval_spec = training.EvalSpec(input_fn=(lambda: 1), 1338 start_delay_secs=0, throttle_secs=0) 1339 mock_est.evaluate.return_value = {} 1340 1341 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1342 with self.assertRaisesRegexp(ValueError, _INVALID_EMPTY_EVAL_RESULT_ERR): 1343 executor.run_evaluator() 1344 1345 def test_errors_out_if_evaluate_returns_non_dict(self): 1346 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1347 train_spec = training.TrainSpec(input_fn=lambda: 1) 1348 eval_spec = training.EvalSpec(input_fn=(lambda: 1), 1349 start_delay_secs=0, throttle_secs=0) 1350 mock_est.evaluate.return_value = 123 1351 1352 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1353 with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR): 1354 executor.run_evaluator() 1355 1356 def test_errors_out_if_evaluate_returns_dict_without_global_step(self): 1357 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1358 train_spec = training.TrainSpec(input_fn=lambda: 1) 1359 eval_spec = training.EvalSpec(input_fn=(lambda: 1), 1360 start_delay_secs=0, throttle_secs=0) 1361 mock_est.evaluate.return_value = {'loss': 123} 1362 1363 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1364 with self.assertRaisesRegexp(ValueError, 1365 _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR): 1366 executor.run_evaluator() 1367 1368 1369 class TrainingExecutorRunPsTest(test.TestCase): 1370 """Tests run_ps of _TrainingExecutor.""" 1371 1372 @test.mock.patch.object(server_lib, 'Server') 1373 def test_std_server(self, mock_server): 1374 mock_server_instance = test.mock.Mock() 1375 mock_server.return_value = mock_server_instance 1376 1377 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1378 mock_est.config = _create_run_config_with_cluster_spec(_TF_CONFIG_FOR_PS) 1379 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1380 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1381 1382 executor = training._TrainingExecutor(mock_est, mock_train_spec, 1383 mock_eval_spec) 1384 executor.run_ps() 1385 1386 mock_server.assert_called_with( 1387 mock_est.config.cluster_spec, 1388 job_name=mock_est.config.task_type, 1389 task_index=mock_est.config.task_id, 1390 config=test.mock.ANY, 1391 start=False) 1392 1393 self.assertTrue(mock_server_instance.start.called) 1394 self.assertTrue(mock_server_instance.join.called) 1395 1396 def test_fail_with_empty_cluster_spec(self): 1397 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1398 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1399 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1400 1401 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 1402 mock_est.config.cluster_spec = None 1403 mock_est.config.master = 'grpc://...' 1404 mock_est.config.task_type = 'ps' 1405 mock_est.config.task_id = 2 1406 1407 with self.assertRaisesRegexp(RuntimeError, 1408 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 1409 training._TrainingExecutor(mock_est, mock_train_spec, 1410 mock_eval_spec).run_ps() 1411 1412 def test_fail_with_empty_master(self): 1413 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1414 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1415 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1416 1417 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 1418 mock_est.config.cluster_spec = server_lib.ClusterSpec({'ps': ['dummy']}) 1419 mock_est.config.master = '' 1420 mock_est.config.task_type = 'ps' 1421 mock_est.config.task_id = 2 1422 1423 with self.assertRaisesRegexp(RuntimeError, 1424 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 1425 training._TrainingExecutor(mock_est, mock_train_spec, 1426 mock_eval_spec).run_ps() 1427 1428 def test_fail_with_empty_task_type(self): 1429 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1430 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1431 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1432 1433 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 1434 mock_est.config.cluster_spec = server_lib.ClusterSpec({'ps': ['dummy']}) 1435 mock_est.config.master = 'grpc://...' 1436 mock_est.config.task_type = '' 1437 mock_est.config.task_id = 2 1438 1439 with self.assertRaisesRegexp(RuntimeError, 1440 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 1441 training._TrainingExecutor(mock_est, mock_train_spec, 1442 mock_eval_spec).run_ps() 1443 1444 def test_fail_with_none_task_id(self): 1445 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1446 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1447 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1448 1449 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 1450 mock_est.config.cluster_spec = server_lib.ClusterSpec({'ps': ['dummy']}) 1451 mock_est.config.master = 'grpc://...' 1452 mock_est.config.task_type = 'ps' 1453 mock_est.config.task_id = None 1454 1455 with self.assertRaisesRegexp(RuntimeError, 1456 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 1457 training._TrainingExecutor(mock_est, mock_train_spec, 1458 mock_eval_spec).run_ps() 1459 1460 1461 class StopAtSecsHookTest(test.TestCase): 1462 """Tests StopAtSecsHook.""" 1463 1464 @test.mock.patch.object(time, 'time') 1465 def test_stops_after_time(self, mock_time): 1466 mock_time.return_value = 1484695987.209386 1467 hook = training._StopAtSecsHook(1000) 1468 with ops.Graph().as_default(): 1469 no_op = control_flow_ops.no_op() 1470 # some time passed before training starts 1471 mock_time.return_value += 250 1472 with monitored_session.MonitoredSession(hooks=[hook]) as sess: 1473 self.assertFalse(sess.should_stop()) 1474 sess.run(no_op) 1475 self.assertFalse(sess.should_stop()) 1476 mock_time.return_value += 500 1477 sess.run(no_op) 1478 self.assertFalse(sess.should_stop()) 1479 mock_time.return_value += 400 1480 sess.run(no_op) 1481 self.assertFalse(sess.should_stop()) 1482 mock_time.return_value += 200 1483 sess.run(no_op) 1484 self.assertTrue(sess.should_stop()) 1485 1486 1487 class TrainingExecutorRunLocalTest(test.TestCase): 1488 """Tests run_local of _TrainingExecutor.""" 1489 1490 def unique_checkpoint_every_time_fn(self): 1491 return 'checkpoint_path_%s/' % random.random() 1492 1493 def test_send_stop_at_secs_to_train(self): 1494 mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') 1495 mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn 1496 train_spec = training.TrainSpec( 1497 input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) 1498 eval_spec = training.EvalSpec( 1499 input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100) 1500 mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} 1501 1502 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1503 executor.run_local() 1504 1505 stop_hook = mock_est.train.call_args[1]['hooks'][-1] 1506 self.assertIsInstance(stop_hook, training._StopAtSecsHook) 1507 self.assertEqual(eval_spec.throttle_secs, stop_hook._stop_after_secs) 1508 1509 def test_runs_in_a_loop_until_max_steps(self): 1510 mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') 1511 mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn 1512 1513 mock_est.times_export_was_called = 0 1514 mock_est.times_final_export_was_called = 0 1515 def export(estimator, export_path, checkpoint_path, eval_result, 1516 is_the_final_export): 1517 del export_path, checkpoint_path, eval_result 1518 estimator.times_export_was_called += 1 1519 # final_export is happened at the end. 1520 self.assertEqual(0, estimator.times_final_export_was_called) 1521 if is_the_final_export: 1522 estimator.times_final_export_was_called += 1 1523 1524 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 1525 exporter.name = 'see_how_many_times_export_is_called' 1526 exporter.export = export 1527 1528 train_spec = training.TrainSpec( 1529 input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) 1530 eval_spec = training.EvalSpec( 1531 input_fn=lambda: 1, 1532 hooks=[_FakeHook()], 1533 throttle_secs=100, 1534 exporters=exporter) 1535 # should be called 3 times. 1536 mock_est.evaluate.side_effect = [{ 1537 _GLOBAL_STEP_KEY: train_spec.max_steps - 100 1538 }, { 1539 _GLOBAL_STEP_KEY: train_spec.max_steps - 50 1540 }, { 1541 _GLOBAL_STEP_KEY: train_spec.max_steps 1542 }] 1543 1544 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1545 executor.run_local() 1546 1547 self.assertEqual(3, mock_est.train.call_count) 1548 self.assertEqual(3, mock_est.evaluate.call_count) 1549 self.assertEqual(3, mock_est.times_export_was_called) 1550 self.assertEqual(1, mock_est.times_final_export_was_called) 1551 1552 def test_handles_no_new_checkpoint_found(self): 1553 mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') 1554 mock_est.latest_checkpoint.return_value = ( 1555 'no_new_checkpoints_after_the_first_train_step') 1556 train_spec = training.TrainSpec( 1557 input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) 1558 eval_spec = training.EvalSpec( 1559 input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100) 1560 # It was going to be called 3 times. 1561 mock_est.evaluate.side_effect = [{ 1562 _GLOBAL_STEP_KEY: train_spec.max_steps - 100 1563 }, { 1564 _GLOBAL_STEP_KEY: train_spec.max_steps - 50 1565 }, { 1566 _GLOBAL_STEP_KEY: train_spec.max_steps 1567 }] 1568 1569 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1570 with self.assertRaisesRegexp(RuntimeError, _STALE_CHECKPOINT_MSG): 1571 executor.run_local() 1572 1573 def test_final_export_is_true_in_the_end(self): 1574 mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') 1575 mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn 1576 1577 mock_est.times_export_fn_was_called = 0 1578 mock_est.times_the_final_export_was_true = 0 1579 def export(estimator, export_path, checkpoint_path, eval_result, 1580 is_the_final_export): 1581 del export_path, checkpoint_path, eval_result 1582 estimator.times_export_fn_was_called += 1 1583 if is_the_final_export: 1584 estimator.times_the_final_export_was_true += 1 1585 1586 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 1587 exporter.name = 'see_how_many_times_export_is_called' 1588 exporter.export = export 1589 1590 train_spec = training.TrainSpec( 1591 input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) 1592 eval_spec = training.EvalSpec( 1593 input_fn=lambda: 1, 1594 hooks=[_FakeHook()], 1595 throttle_secs=100, 1596 exporters=exporter) 1597 # should be called 3 times. 1598 mock_est.evaluate.side_effect = [{ 1599 _GLOBAL_STEP_KEY: train_spec.max_steps - 100 1600 }, { 1601 _GLOBAL_STEP_KEY: train_spec.max_steps - 50 1602 }, { 1603 _GLOBAL_STEP_KEY: train_spec.max_steps 1604 }] 1605 1606 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1607 executor.run_local() 1608 1609 self.assertEqual(3, mock_est.train.call_count) 1610 self.assertEqual(3, mock_est.evaluate.call_count) 1611 self.assertEqual(3, mock_est.times_export_fn_was_called) 1612 self.assertEqual(1, mock_est.times_the_final_export_was_true) 1613 1614 def test_train_and_evaluate_args(self): 1615 mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') 1616 mock_est.latest_checkpoint.return_value = 'checkpoint_path/' 1617 train_spec = training.TrainSpec( 1618 input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) 1619 eval_spec = training.EvalSpec( 1620 input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='local_eval') 1621 mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} 1622 1623 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1624 executor.run_local() 1625 1626 mock_est.evaluate.assert_called_with( 1627 name=eval_spec.name, 1628 input_fn=eval_spec.input_fn, 1629 steps=eval_spec.steps, 1630 checkpoint_path='checkpoint_path/', 1631 hooks=eval_spec.hooks) 1632 1633 train_args = mock_est.train.call_args[1] 1634 self.assertEqual(list(train_spec.hooks), list(train_args['hooks'][:-1])) 1635 self.assertEqual(train_spec.input_fn, train_args['input_fn']) 1636 self.assertEqual(train_spec.max_steps, train_args['max_steps']) 1637 1638 def test_train_hooks(self): 1639 mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') 1640 mock_est.latest_checkpoint.return_value = 'checkpoint_path/' 1641 train_spec = training.TrainSpec( 1642 input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) 1643 eval_spec = training.EvalSpec(input_fn=lambda: 1, steps=2) 1644 mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} 1645 extra_hooks = [_FakeHook()] 1646 1647 executor = training._TrainingExecutor( 1648 mock_est, train_spec, eval_spec, train_hooks=extra_hooks) 1649 executor.run_local() 1650 1651 train_args = mock_est.train.call_args[1] 1652 self.assertEqual( 1653 list(train_spec.hooks) + extra_hooks, [ 1654 h for h in train_args['hooks'] 1655 if not isinstance(h, training._StopAtSecsHook) 1656 ]) 1657 1658 def test_errors_out_if_throttle_secs_is_zero(self): 1659 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1660 train_spec = training.TrainSpec(input_fn=lambda: 1) 1661 eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=0) 1662 1663 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1664 with self.assertRaisesRegexp(ValueError, 'throttle_secs'): 1665 executor.run_local() 1666 1667 def test_that_export_is_called_with_run_local(self): 1668 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1669 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1670 mock_train_spec.max_steps = 200 1671 mock_est.evaluate.return_value = { 1672 _GLOBAL_STEP_KEY: mock_train_spec.max_steps 1673 } 1674 # _validate_hooks would have made sure that train_spec.hooks is [], when 1675 # None were passed. 1676 mock_train_spec.hooks = [] 1677 1678 def export(estimator, *args, **kwargs): 1679 del args, kwargs 1680 estimator.export_was_called = True 1681 1682 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 1683 exporter.name = 'see_whether_export_is_called' 1684 exporter.export = export 1685 1686 eval_spec = training.EvalSpec( 1687 input_fn=lambda: 1, 1688 steps=2, 1689 start_delay_secs=0, 1690 throttle_secs=213, 1691 exporters=exporter) 1692 1693 executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) 1694 executor.run_local() 1695 1696 self.assertTrue(mock_est.export_was_called) 1697 1698 def test_errors_out_if_evaluate_returns_empty_dict(self): 1699 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1700 train_spec = training.TrainSpec(input_fn=lambda: 1) 1701 eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123) 1702 mock_est.evaluate.return_value = {} 1703 1704 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1705 with self.assertRaisesRegexp(ValueError, _INVALID_EMPTY_EVAL_RESULT_ERR): 1706 executor.run_local() 1707 1708 def test_errors_out_if_evaluate_returns_non_dict(self): 1709 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1710 train_spec = training.TrainSpec(input_fn=lambda: 1) 1711 eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123) 1712 mock_est.evaluate.return_value = 123 1713 1714 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1715 with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR): 1716 executor.run_local() 1717 1718 def test_errors_out_if_evaluate_returns_dict_without_global_step(self): 1719 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1720 train_spec = training.TrainSpec(input_fn=lambda: 1) 1721 eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123) 1722 mock_est.evaluate.return_value = {'loss': 123} 1723 1724 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1725 with self.assertRaisesRegexp(ValueError, 1726 _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR): 1727 executor.run_local() 1728 1729 1730 class TrainAndEvaluateRunTest(test.TestCase): 1731 1732 def _test_run_task_and_executor(self, run_config): 1733 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1734 mock_est.config = run_config 1735 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1736 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1737 1738 executor = training._TrainingExecutor(mock_est, mock_train_spec, 1739 mock_eval_spec) 1740 1741 executor.call_task = {} 1742 1743 def task_fn(name): 1744 1745 def _fn(): 1746 executor.call_task[name] = 1 1747 1748 return _fn 1749 1750 executor.run_chief = task_fn('chief') 1751 executor.run_master = task_fn('master') 1752 executor.run_ps = task_fn('ps') 1753 executor.run_evaluator = task_fn('evaluator') 1754 executor.run_worker = task_fn('worker') 1755 executor.run_local = task_fn('local') 1756 return executor 1757 1758 def test_run_chief(self): 1759 executor = self._test_run_task_and_executor( 1760 run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_CHIEF)) 1761 executor.run() 1762 self.assertEqual(1, executor.call_task['chief']) 1763 1764 def test_run_worker(self): 1765 executor = self._test_run_task_and_executor( 1766 run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_WORKER)) 1767 executor.run() 1768 self.assertEqual(1, executor.call_task['worker']) 1769 1770 def test_run_ps(self): 1771 executor = self._test_run_task_and_executor( 1772 run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_PS)) 1773 executor.run() 1774 self.assertEqual(1, executor.call_task['ps']) 1775 1776 def test_run_evaluator(self): 1777 executor = self._test_run_task_and_executor( 1778 run_config=_create_run_config_with_cluster_spec( 1779 _TF_CONFIG_FOR_EVALUATOR)) 1780 executor.run() 1781 self.assertEqual(1, executor.call_task['evaluator']) 1782 1783 def test_run_local(self): 1784 executor = self._test_run_task_and_executor( 1785 run_config=run_config_lib.RunConfig()) 1786 executor.run() 1787 self.assertEqual(1, executor.call_task['local']) 1788 1789 def test_invalid_local_task(self): 1790 tf_config = { 1791 'cluster': { 1792 run_config_lib.TaskType.CHIEF: ['host0:0'], 1793 'local': ['hos1:1'], 1794 }, 1795 'task': { 1796 'type': 'local', # invalid task type. 1797 'index': 0 1798 } 1799 } 1800 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1801 mock_est.config = _create_run_config_with_cluster_spec(tf_config) 1802 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1803 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1804 1805 executor = training._TrainingExecutor(mock_est, mock_train_spec, 1806 mock_eval_spec) 1807 with self.assertRaisesRegexp(ValueError, _INVALID_LOCAL_TASK_WITH_CLUSTER): 1808 executor.run() 1809 1810 def test_unsupported_task_due_to_missing_run_task(self): 1811 unsupported_task = 'alloc' 1812 tf_config = { 1813 'cluster': { 1814 run_config_lib.TaskType.CHIEF: ['host0:0'], 1815 unsupported_task: ['hos1:1'], 1816 }, 1817 'task': { 1818 'type': unsupported_task, 1819 'index': 0 1820 } 1821 } 1822 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1823 mock_est.config = _create_run_config_with_cluster_spec(tf_config) 1824 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1825 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1826 1827 executor = training._TrainingExecutor(mock_est, mock_train_spec, 1828 mock_eval_spec) 1829 with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN): 1830 executor.run() 1831 1832 def test_unsupported_task_due_to_not_callable(self): 1833 unsupported_task = 'alloc' 1834 tf_config = { 1835 'cluster': { 1836 run_config_lib.TaskType.CHIEF: ['host0:0'], 1837 unsupported_task: ['hos1:1'], 1838 }, 1839 'task': { 1840 'type': unsupported_task, 1841 'index': 0 1842 } 1843 } 1844 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1845 mock_est.config = _create_run_config_with_cluster_spec(tf_config) 1846 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1847 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1848 1849 executor = training._TrainingExecutor(mock_est, mock_train_spec, 1850 mock_eval_spec) 1851 executor.run_alloc = 123 # not callable 1852 with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN): 1853 executor.run() 1854 1855 def test_invalid_task_type(self): 1856 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1857 mock_est.config = test.mock.Mock() 1858 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1859 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1860 1861 mock_est.config = test.mock.Mock() 1862 mock_est.config.cluster_spec = server_lib.ClusterSpec({'1': ['dummy']}) 1863 mock_est.config.task_type = '' 1864 1865 executor = training._TrainingExecutor(mock_est, mock_train_spec, 1866 mock_eval_spec) 1867 with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TYPE): 1868 executor.run() 1869 1870 1871 class TrainAndEvaluateIntegrationTest(test.TestCase): 1872 1873 def setUp(self): 1874 self._model_dir = tempfile.mkdtemp() 1875 1876 def tearDown(self): 1877 if self._model_dir: 1878 shutil.rmtree(self._model_dir) 1879 1880 def _as_label(self, data_in_float): 1881 return np.rint(data_in_float).astype(np.int64) 1882 1883 def _get_exporter(self, name, fc): 1884 feature_spec = feature_column.make_parse_example_spec(fc) 1885 serving_input_receiver_fn = ( 1886 export_lib.build_parsing_serving_input_receiver_fn(feature_spec)) 1887 return exporter_lib.LatestExporter( 1888 name, serving_input_receiver_fn=serving_input_receiver_fn) 1889 1890 def _extract_loss_and_global_step(self, event_folder): 1891 """Returns the loss and global step in last event.""" 1892 event_paths = glob.glob(os.path.join(event_folder, 'events*')) 1893 1894 loss = None 1895 global_step_count = None 1896 1897 for e in summary_iterator.summary_iterator(event_paths[-1]): 1898 current_loss = None 1899 for v in e.summary.value: 1900 if v.tag == 'loss': 1901 current_loss = v.simple_value 1902 1903 # If loss is not found, global step is meaningless. 1904 if current_loss is None: 1905 continue 1906 1907 current_global_step = e.step 1908 if global_step_count is None or current_global_step > global_step_count: 1909 global_step_count = current_global_step 1910 loss = current_loss 1911 1912 return (loss, global_step_count) 1913 1914 def test_complete_flow_with_non_distributed_configuration(self): 1915 n_classes = 3 1916 input_dimension = 2 1917 batch_size = 10 1918 1919 eval_name = 'foo' 1920 exporter_name = 'saved_model_exporter' 1921 1922 # max_steps should be larger than save_summary_steps 1923 max_steps = 10 1924 save_summary_steps = 2 1925 1926 data = np.linspace( 1927 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32) 1928 x_data = data.reshape(batch_size, input_dimension) 1929 y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1)) 1930 1931 # learn y = x 1932 train_input_fn = numpy_io.numpy_input_fn( 1933 x={'x': x_data}, 1934 y=y_data, 1935 batch_size=batch_size, 1936 num_epochs=None, 1937 shuffle=True) 1938 1939 eval_input_fn = numpy_io.numpy_input_fn( 1940 x={'x': x_data}, 1941 y=y_data, 1942 batch_size=batch_size, 1943 num_epochs=1, 1944 shuffle=False) 1945 1946 predict_input_fn = numpy_io.numpy_input_fn( 1947 x={'x': x_data}, 1948 batch_size=batch_size, 1949 shuffle=False) 1950 1951 feature_columns = [ 1952 feature_column.numeric_column('x', shape=(input_dimension,))] 1953 1954 est = dnn.DNNClassifier( 1955 hidden_units=(2, 2), 1956 feature_columns=feature_columns, 1957 n_classes=n_classes, 1958 config=run_config_lib.RunConfig(save_summary_steps=save_summary_steps), 1959 model_dir=self._model_dir) 1960 1961 train_spec = training.TrainSpec(input_fn=train_input_fn, 1962 max_steps=max_steps) 1963 1964 eval_spec = training.EvalSpec( 1965 name=eval_name, input_fn=eval_input_fn, steps=None, 1966 exporters=self._get_exporter(exporter_name, feature_columns), 1967 throttle_secs=2) 1968 1969 training.train_and_evaluate(est, train_spec, eval_spec) 1970 1971 # Make sure nothing is stuck in limbo. 1972 writer_cache.FileWriterCache.clear() 1973 1974 # Examine the training events. Use a range to check global step to avoid 1975 # flakyness due to global step race condition. 1976 training_loss, training_global_step = self._extract_loss_and_global_step( 1977 est.model_dir) 1978 self.assertIsNotNone(training_loss) 1979 self.assertTrue( 1980 max_steps - save_summary_steps < training_global_step <= max_steps) 1981 1982 # Examine the eval events. The global step should be accurate. 1983 eval_loss, eval_global_step = self._extract_loss_and_global_step( 1984 event_folder=os.path.join(est.model_dir, 'eval_' + eval_name)) 1985 self.assertIsNotNone(eval_loss) 1986 self.assertEqual(max_steps, eval_global_step) 1987 1988 # Examine the export folder. 1989 export_dir = os.path.join(os.path.join(est.model_dir, 'export'), 1990 exporter_name) 1991 self.assertTrue(gfile.Exists(export_dir)) 1992 1993 # Examine the ckpt for predict. 1994 predicted_proba = np.array([ 1995 x[prediction_keys.PredictionKeys.PROBABILITIES] 1996 for x in est.predict(predict_input_fn) 1997 ]) 1998 self.assertAllEqual((batch_size, n_classes), predicted_proba.shape) 1999 2000 2001 if __name__ == '__main__': 2002 test.main() 2003