Home | History | Annotate | Download | only in estimator
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Tests for training.py."""
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import glob
     23 import json
     24 import os
     25 import random
     26 import shutil
     27 import tempfile
     28 import time
     29 
     30 import numpy as np
     31 
     32 from tensorflow.python.estimator import estimator as estimator_lib
     33 from tensorflow.python.estimator import exporter as exporter_lib
     34 from tensorflow.python.estimator import run_config as run_config_lib
     35 from tensorflow.python.estimator import training
     36 from tensorflow.python.estimator.canned import dnn
     37 from tensorflow.python.estimator.canned import prediction_keys
     38 from tensorflow.python.estimator.export import export as export_lib
     39 from tensorflow.python.estimator.inputs import numpy_io
     40 from tensorflow.python.feature_column import feature_column
     41 from tensorflow.python.framework import ops
     42 from tensorflow.python.ops import control_flow_ops
     43 from tensorflow.python.platform import gfile
     44 from tensorflow.python.platform import test
     45 from tensorflow.python.platform import tf_logging as logging
     46 from tensorflow.python.summary import summary_iterator
     47 from tensorflow.python.summary.writer import writer_cache
     48 from tensorflow.python.training import basic_session_run_hooks
     49 from tensorflow.python.training import monitored_session
     50 from tensorflow.python.training import server_lib
     51 from tensorflow.python.training import session_run_hook
     52 from tensorflow.python.util import compat
     53 
     54 _DEFAULT_EVAL_STEPS = 100
     55 _DEFAULT_EVAL_DELAY_SECS = 120
     56 _DEFAULT_EVAL_THROTTLE_SECS = 600
     57 _DELAY_SECS_PER_WORKER = 5
     58 _GLOBAL_STEP_KEY = ops.GraphKeys.GLOBAL_STEP
     59 _INVALID_INPUT_FN_MSG = '`input_fn` must be callable'
     60 _INVALID_HOOK_MSG = 'All hooks must be `SessionRunHook` instances'
     61 _INVALID_MAX_STEPS_MSG = 'Must specify max_steps > 0'
     62 _INVALID_STEPS_MSG = 'Must specify steps > 0'
     63 _INVALID_NAME_MSG = '`name` must be string'
     64 _INVALID_EVAL_DELAY_SECS_MSG = 'Must specify start_delay_secs >= 0'
     65 _INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0'
     66 _INVALID_ESTIMATOR_MSG = '`estimator` must have type `tf.estimator.Estimator`'
     67 _STALE_CHECKPOINT_MSG = 'There was no new checkpoint after the training.'
     68 _INVALID_EXPORTER_MSG = '`exporters` must be an Exporter'
     69 _INVALID_EXPORTER_NAME_TYPE_MSG = 'An Exporter must have a string name'
     70 _DUPLICATE_EXPORTER_NAMES_MSG = '`exporters` must have unique names.'
     71 _NONE_EXPORTER_NAME_MSG = (
     72     'An Exporter cannot have a name that is `None` or empty.')
     73 _INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`'
     74 _INVALID_EVAL_SPEC_MSG = '`eval_spec` must have type `tf.estimator.EvalSpec`'
     75 _INVALID_EVAL_LISTENER_MSG = 'must have type `_ContinuousEvalListener`'
     76 _INVALID_CONFIG_FOR_STD_SERVER_MSG = 'Could not start server; .*TF_CONFIG'
     77 _INVALID_LOCAL_TASK_WITH_CLUSTER = '`task.type` in TF_CONFIG cannot be `local`'
     78 _INVALID_TASK_TYPE = '`estimator.config` must have task_type set.'
     79 # The message should NOT have 'local' word as part of it. As (?!word) is looking
     80 # ahead, so, the $ (ending) check is required; otherwise, it will match
     81 # partially and return successuful.
     82 _INVALID_TASK_TO_RUN = (
     83     'Task type .* is not supported. Supported task types are ((?!local).)*$')
     84 _INVALID_EMPTY_EVAL_RESULT_ERR = (
     85     'Internal error: `Estimator.evaluate` should never return empty metrics')
     86 _INVALID_EVAL_RESULT_TYPE_ERR = '`Estimator.evaluate` should return dict.'
     87 _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR = (
     88     'Internal error: `Estimator.evaluate` result should have `global_step`')
     89 _INVALID_EVAL_TASK_ID_ERR = (
     90     'there can only be one `evaluator` task .*with task id 0')
     91 
     92 _TF_CONFIG_FOR_CHIEF = {
     93     'cluster': {
     94         run_config_lib.TaskType.CHIEF: ['host0:0'],
     95         run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
     96         run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']
     97     },
     98     'task': {
     99         'type': run_config_lib.TaskType.CHIEF,
    100         'index': 0
    101     }
    102 }
    103 
    104 _TF_CONFIG_FOR_MASTER = {
    105     'cluster': {
    106         run_config_lib.TaskType.MASTER: ['host0:0'],
    107         run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
    108         run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']
    109     },
    110     'task': {
    111         'type': run_config_lib.TaskType.MASTER,
    112         'index': 0
    113     }
    114 }
    115 
    116 _TF_CONFIG_FOR_WORKER = {
    117     'cluster': {
    118         run_config_lib.TaskType.CHIEF: ['host0:0'],
    119         run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
    120         run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']
    121     },
    122     'task': {
    123         'type': run_config_lib.TaskType.WORKER,
    124         'index': 1
    125     }
    126 }
    127 
    128 _TF_CONFIG_FOR_PS = {
    129     'cluster': {
    130         run_config_lib.TaskType.CHIEF: ['host0:0'],
    131         run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
    132         run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']
    133     },
    134     'task': {
    135         'type': run_config_lib.TaskType.PS,
    136         'index': 1
    137     }
    138 }
    139 
    140 _TF_CONFIG_FOR_EVALUATOR = {
    141     'cluster': {
    142         run_config_lib.TaskType.CHIEF: ['host0:0'],
    143         run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
    144         run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']
    145     },
    146     'task': {
    147         'type': run_config_lib.TaskType.EVALUATOR,
    148         'index': 0
    149     }
    150 }
    151 
    152 _TF_CONFIG_FOR_GOOGLE = {'environment': 'google'}
    153 
    154 
    155 class _FakeHook(session_run_hook.SessionRunHook):
    156   """Fake implementation of `SessionRunHook`."""
    157 
    158 
    159 class _InvalidHook(object):
    160   """Invalid hook (not a subclass of `SessionRunHook`)."""
    161 
    162 
    163 def _create_exporter(name):
    164   class FakeExporter(exporter_lib.Exporter):
    165 
    166     def __init__(self, name):
    167       self._name = name
    168 
    169     @property
    170     def name(self):
    171       return self._name
    172 
    173     def export(self, *args, **kwargs):
    174       del args, kwargs
    175 
    176   return FakeExporter(name=name)
    177 
    178 
    179 def _create_run_config_with_cluster_spec(tf_config):
    180   with test.mock.patch.dict('os.environ', {'TF_CONFIG': json.dumps(tf_config)}):
    181     return run_config_lib.RunConfig()
    182 
    183 
    184 class TrainSpecTest(test.TestCase):
    185   """Tests TrainSpec."""
    186 
    187   def testRequiredArgumentsSet(self):
    188     """Tests that no errors are raised when all required arguments are set."""
    189     spec = training.TrainSpec(input_fn=lambda: 1)
    190     self.assertEqual(1, spec.input_fn())
    191     self.assertIsNone(spec.max_steps)
    192     self.assertEqual(0, len(spec.hooks))
    193 
    194   def testAllArgumentsSet(self):
    195     """Tests that no errors are raised when all arguments are set."""
    196     hooks = [_FakeHook()]
    197     spec = training.TrainSpec(input_fn=lambda: 1, max_steps=2, hooks=hooks)
    198     self.assertEqual(1, spec.input_fn())
    199     self.assertEqual(2, spec.max_steps)
    200     self.assertEqual(tuple(hooks), spec.hooks)
    201 
    202   def testInvalidInputFn(self):
    203     with self.assertRaisesRegexp(TypeError, _INVALID_INPUT_FN_MSG):
    204       training.TrainSpec(input_fn='invalid')
    205 
    206   def testInvalidMaxStep(self):
    207     with self.assertRaisesRegexp(ValueError, _INVALID_MAX_STEPS_MSG):
    208       training.TrainSpec(input_fn=lambda: 1, max_steps=0)
    209 
    210   def testInvalidHook(self):
    211     with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG):
    212       training.TrainSpec(input_fn=lambda: 1, hooks=[_InvalidHook()])
    213 
    214 
    215 class EvalSpecTest(test.TestCase):
    216   """Tests EvalSpec."""
    217 
    218   def testRequiredArgumentsSet(self):
    219     """Tests that no errors are raised when all required arguments are set."""
    220     spec = training.EvalSpec(input_fn=lambda: 1)
    221     self.assertEqual(1, spec.input_fn())
    222     self.assertEqual(_DEFAULT_EVAL_STEPS, spec.steps)
    223     self.assertIsNone(spec.name)
    224     self.assertEqual(0, len(spec.hooks))
    225     self.assertEqual(0, len(spec.exporters))
    226     self.assertEqual(_DEFAULT_EVAL_DELAY_SECS, spec.start_delay_secs)
    227     self.assertEqual(_DEFAULT_EVAL_THROTTLE_SECS, spec.throttle_secs)
    228 
    229   def testAllArgumentsSet(self):
    230     """Tests that no errors are raised when all arguments are set."""
    231     hooks = [_FakeHook()]
    232     exporter = _create_exporter('a')
    233 
    234     spec = training.EvalSpec(
    235         input_fn=lambda: 1,
    236         steps=2,
    237         name='name',
    238         hooks=hooks,
    239         exporters=exporter,
    240         start_delay_secs=3,
    241         throttle_secs=4)
    242     self.assertEqual(1, spec.input_fn())
    243     self.assertEqual(2, spec.steps)
    244     self.assertEqual('name', spec.name)
    245     self.assertEqual(tuple(hooks), spec.hooks)
    246     self.assertEqual((exporter,), spec.exporters)
    247     self.assertEqual(3, spec.start_delay_secs)
    248     self.assertEqual(4, spec.throttle_secs)
    249 
    250   def testListOfExporters(self):
    251     """Tests that no errors are raised with multiple exporters."""
    252     exporters = [_create_exporter('a'), _create_exporter('b')]
    253 
    254     spec = training.EvalSpec(input_fn=lambda: 1, exporters=exporters)
    255     self.assertEqual(1, spec.input_fn())
    256     self.assertEqual(tuple(exporters), spec.exporters)
    257 
    258   def testInvalidInputFn(self):
    259     with self.assertRaisesRegexp(TypeError, _INVALID_INPUT_FN_MSG):
    260       training.EvalSpec(input_fn='invalid')
    261 
    262   def testInvalidMaxStep(self):
    263     with self.assertRaisesRegexp(ValueError, _INVALID_STEPS_MSG):
    264       training.EvalSpec(input_fn=lambda: 1, steps=0)
    265 
    266   def testInvalidName(self):
    267     with self.assertRaisesRegexp(TypeError, _INVALID_NAME_MSG):
    268       training.EvalSpec(input_fn=lambda: 1, name=123)
    269 
    270   def testInvalidHook(self):
    271     with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG):
    272       training.EvalSpec(input_fn=lambda: 1, hooks=[_InvalidHook()])
    273 
    274   def testInvalidDelaySecs(self):
    275     with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_DELAY_SECS_MSG):
    276       training.EvalSpec(input_fn=lambda: 1, start_delay_secs=-1)
    277 
    278   def testInvalidThrottleSecs(self):
    279     with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_THROTTLE_SECS_MSG):
    280       training.EvalSpec(input_fn=lambda: 1, throttle_secs=-1)
    281 
    282   def testInvalidTypeOfListOfExporters(self):
    283     with self.assertRaisesRegexp(TypeError, _INVALID_EXPORTER_MSG):
    284       training.EvalSpec(
    285           input_fn=lambda: 1, exporters=[_create_exporter('a'),
    286                                          _FakeHook()])
    287 
    288   def testInvalidTypeOfIndividualExporter(self):
    289     with self.assertRaisesRegexp(TypeError, _INVALID_EXPORTER_MSG):
    290       training.EvalSpec(input_fn=lambda: 1, exporters=_FakeHook())
    291 
    292   def testInvalidTypeOfExporterName(self):
    293     with self.assertRaisesRegexp(ValueError, _INVALID_EXPORTER_NAME_TYPE_MSG):
    294       training.EvalSpec(input_fn=lambda: 1,
    295                         exporters=_create_exporter(name=123))
    296 
    297   def testMultipleExportersWithTheSameName(self):
    298     with self.assertRaisesRegexp(ValueError, _DUPLICATE_EXPORTER_NAMES_MSG):
    299       training.EvalSpec(
    300           input_fn=lambda: 1,
    301           exporters=[_create_exporter('a'), _create_exporter('a')])
    302 
    303   def testMultipleExportersAndOneWithoutAName(self):
    304     with self.assertRaisesRegexp(ValueError, _NONE_EXPORTER_NAME_MSG):
    305       training.EvalSpec(
    306           input_fn=lambda: 1,
    307           exporters=[_create_exporter('a'),
    308                      _create_exporter(None)])
    309 
    310   def testSingleExporterWithoutAName(self):
    311     with self.assertRaisesRegexp(ValueError, _NONE_EXPORTER_NAME_MSG):
    312       training.EvalSpec(input_fn=lambda: 1, exporters=_create_exporter(None))
    313 
    314 
    315 class TrainAndEvaluateTest(test.TestCase):
    316 
    317   def test_run_task(self):
    318     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    319     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
    320     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
    321 
    322     with test.mock.patch.object(training, '_TrainingExecutor') as mock_executor:
    323       mock_executor_instance = test.mock.Mock()
    324       mock_executor.return_value = mock_executor_instance
    325       training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec)
    326       mock_executor.assert_called_with(estimator=mock_est,
    327                                        train_spec=mock_train_spec,
    328                                        eval_spec=mock_eval_spec)
    329       self.assertTrue(mock_executor_instance.run.called)
    330 
    331   def test_error_out_if_evaluator_task_id_is_non_zero(self):
    332     tf_config = {
    333         'cluster': {
    334             run_config_lib.TaskType.CHIEF: ['host0:0'],
    335         },
    336         'task': {
    337             'type': run_config_lib.TaskType.EVALUATOR,
    338             'index': 1
    339         }
    340     }
    341 
    342     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    343     mock_est.config = _create_run_config_with_cluster_spec(tf_config)
    344     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
    345     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
    346 
    347     with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_TASK_ID_ERR):
    348       training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec)
    349 
    350   def test_invalid_estimator(self):
    351     invalid_estimator = object()
    352     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
    353     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
    354 
    355     with self.assertRaisesRegexp(TypeError, _INVALID_ESTIMATOR_MSG):
    356       training.train_and_evaluate(invalid_estimator, mock_train_spec,
    357                                   mock_eval_spec)
    358 
    359 
    360 class TrainingExecutorConstructorTest(test.TestCase):
    361   """Tests constructor of _TrainingExecutor."""
    362 
    363   def testRequiredArgumentsSet(self):
    364     estimator = estimator_lib.Estimator(model_fn=lambda features: features)
    365     train_spec = training.TrainSpec(input_fn=lambda: 1)
    366     eval_spec = training.EvalSpec(input_fn=lambda: 1)
    367 
    368     executor = training._TrainingExecutor(estimator, train_spec, eval_spec)
    369     self.assertEqual(estimator, executor.estimator)
    370 
    371   def test_invalid_estimator(self):
    372     invalid_estimator = object()
    373     train_spec = training.TrainSpec(input_fn=lambda: 1)
    374     eval_spec = training.EvalSpec(input_fn=lambda: 1)
    375 
    376     with self.assertRaisesRegexp(TypeError, _INVALID_ESTIMATOR_MSG):
    377       training._TrainingExecutor(invalid_estimator, train_spec, eval_spec)
    378 
    379   def test_invalid_train_spec(self):
    380     estimator = estimator_lib.Estimator(model_fn=lambda features: features)
    381     invalid_train_spec = object()
    382     eval_spec = training.EvalSpec(input_fn=lambda: 1)
    383 
    384     with self.assertRaisesRegexp(TypeError, _INVALID_TRAIN_SPEC_MSG):
    385       training._TrainingExecutor(estimator, invalid_train_spec, eval_spec)
    386 
    387   def test_invalid_eval_spec(self):
    388     estimator = estimator_lib.Estimator(model_fn=lambda features: features)
    389     train_spec = training.TrainSpec(input_fn=lambda: 1)
    390     invalid_eval_spec = object()
    391 
    392     with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
    393       training._TrainingExecutor(estimator, train_spec, invalid_eval_spec)
    394 
    395   def test_invalid_train_hooks(self):
    396     estimator = estimator_lib.Estimator(model_fn=lambda features: features)
    397     train_spec = training.TrainSpec(input_fn=lambda: 1)
    398     eval_spec = training.EvalSpec(input_fn=lambda: 1)
    399     invalid_train_hooks = [object()]
    400 
    401     with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG):
    402       training._TrainingExecutor(
    403           estimator, train_spec, eval_spec, train_hooks=invalid_train_hooks)
    404 
    405   def test_invalid_continuous_eval_listener(self):
    406     estimator = estimator_lib.Estimator(model_fn=lambda features: features)
    407     train_spec = training.TrainSpec(input_fn=lambda: 1)
    408     eval_spec = training.EvalSpec(input_fn=lambda: 1)
    409     invalid_continuous_eval_listener = object()
    410 
    411     with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_LISTENER_MSG):
    412       training._TrainingExecutor(
    413           estimator,
    414           train_spec,
    415           eval_spec,
    416           continuous_eval_listener=invalid_continuous_eval_listener)
    417 
    418 
    419 class _TrainingExecutorTrainingTest(object):
    420   """Tests training of _TrainingExecutor."""
    421 
    422   def __init__(self, run_config):
    423     self._run_config = run_config
    424 
    425   def _run_task(self, executor):
    426     # We should not call executor.run as the test here is intended to test
    427     # run_foo explicitly (foo is the task type).
    428     return getattr(executor, 'run_' + self._run_config.task_type)()
    429 
    430   @test.mock.patch.object(time, 'sleep')
    431   @test.mock.patch.object(server_lib, 'Server')
    432   def test_train_with_train_spec(self, mock_server, unused_mock_sleep):
    433     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    434     mock_est.config = self._run_config
    435     train_spec = training.TrainSpec(
    436         input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
    437     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
    438     mock_server_instance = mock_server.return_value
    439 
    440     executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec)
    441     self._run_task(executor)
    442 
    443     mock_server.assert_called_with(
    444         mock_est.config.cluster_spec,
    445         job_name=mock_est.config.task_type,
    446         task_index=mock_est.config.task_id,
    447         config=test.mock.ANY,
    448         start=False)
    449 
    450     self.assertTrue(mock_server_instance.start.called)
    451 
    452     mock_est.train.assert_called_with(
    453         input_fn=train_spec.input_fn,
    454         max_steps=train_spec.max_steps,
    455         hooks=list(train_spec.hooks),
    456         saving_listeners=test.mock.ANY)
    457     mock_est.evaluate.assert_not_called()
    458     mock_est.export_savedmodel.assert_not_called()
    459 
    460   @test.mock.patch.object(time, 'sleep')
    461   @test.mock.patch.object(server_lib, 'Server')
    462   def test_train_with_train_hooks(self, unused_mock_server, unused_mock_sleep):
    463     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    464     mock_est.config = self._run_config
    465     train_spec = training.TrainSpec(
    466         input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
    467     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
    468     extra_hooks = [_FakeHook()]
    469 
    470     executor = training._TrainingExecutor(
    471         mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks)
    472     self._run_task(executor)
    473 
    474     mock_est.train.assert_called_with(
    475         input_fn=train_spec.input_fn,
    476         max_steps=train_spec.max_steps,
    477         hooks=list(train_spec.hooks) + extra_hooks,
    478         saving_listeners=test.mock.ANY)
    479 
    480   @test.mock.patch.object(time, 'sleep')
    481   @test.mock.patch.object(server_lib, 'Server')
    482   def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep):
    483     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    484     mock_est.config = self._run_config
    485     mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[])
    486     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
    487 
    488     executor = training._TrainingExecutor(mock_est, mock_train_spec,
    489                                           mock_eval_spec)
    490     tf_config = {'TF_CONFIG': json.dumps(_TF_CONFIG_FOR_GOOGLE)}
    491     with test.mock.patch.dict('os.environ', tf_config):
    492       self._run_task(executor)
    493       mock_server.assert_not_called()
    494 
    495   def test_fail_with_empty_cluster_spec(self):
    496     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    497     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
    498     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
    499 
    500     mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
    501     mock_est.config.cluster_spec = None
    502     mock_est.config.master = 'grpc://...'
    503     mock_est.config.task_type = 'worker'
    504     mock_est.config.task_id = 2
    505 
    506     with self.assertRaisesRegexp(RuntimeError,
    507                                  _INVALID_CONFIG_FOR_STD_SERVER_MSG):
    508       self._run_task(training._TrainingExecutor(mock_est, mock_train_spec,
    509                                                 mock_eval_spec))
    510 
    511   def test_fail_with_empty_master(self):
    512     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    513     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
    514     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
    515 
    516     mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
    517     mock_est.config.cluster_spec = server_lib.ClusterSpec(
    518         {'worker': ['dummy', 'dummy1']})
    519     mock_est.config.master = ''
    520     mock_est.config.task_type = 'worker'
    521     mock_est.config.task_id = 2
    522 
    523     with self.assertRaisesRegexp(RuntimeError,
    524                                  _INVALID_CONFIG_FOR_STD_SERVER_MSG):
    525       self._run_task(training._TrainingExecutor(mock_est, mock_train_spec,
    526                                                 mock_eval_spec))
    527 
    528   @test.mock.patch.object(time, 'sleep')
    529   @test.mock.patch.object(server_lib, 'Server')
    530   def test_single_worker_node_with_empty_tf_master(
    531       self, mock_server, unused_mock_sleep):
    532     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    533     mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[])
    534     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
    535 
    536     mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
    537     # Single node cluster.
    538     mock_est.config.cluster_spec = server_lib.ClusterSpec({'worker': ['dummy']})
    539     mock_est.config.master = ''
    540     mock_est.config.task_type = 'worker'
    541     mock_est.config.task_id = 2
    542 
    543     self._run_task(training._TrainingExecutor(mock_est, mock_train_spec,
    544                                               mock_eval_spec))
    545     self.assertTrue(mock_est.train.called)
    546     mock_server.assert_not_called()
    547 
    548   def test_fail_with_empty_task_type(self):
    549     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    550     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
    551     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
    552 
    553     mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
    554     mock_est.config.cluster_spec = server_lib.ClusterSpec({'worker': ['dummy']})
    555     mock_est.config.master = 'grpc://...'
    556     mock_est.config.task_type = ''
    557     mock_est.config.task_id = 2
    558 
    559     with self.assertRaisesRegexp(RuntimeError,
    560                                  _INVALID_CONFIG_FOR_STD_SERVER_MSG):
    561       self._run_task(training._TrainingExecutor(mock_est, mock_train_spec,
    562                                                 mock_eval_spec))
    563 
    564   def test_fail_with_none_task_id(self):
    565     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    566     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
    567     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
    568 
    569     mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
    570     mock_est.config.cluster_spec = server_lib.ClusterSpec({'worker': ['dummy']})
    571     mock_est.config.master = 'grpc://...'
    572     mock_est.config.task_type = 'worker'
    573     mock_est.config.task_id = None
    574 
    575     with self.assertRaisesRegexp(RuntimeError,
    576                                  _INVALID_CONFIG_FOR_STD_SERVER_MSG):
    577       self._run_task(training._TrainingExecutor(mock_est, mock_train_spec,
    578                                                 mock_eval_spec))
    579 
    580 
    581 class TrainingExecutorRunWorkerTest(_TrainingExecutorTrainingTest,
    582                                     test.TestCase):
    583   """Tests run_worker of _TrainingExecutor."""
    584 
    585   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    586     test.TestCase.__init__(self, methodName)
    587     _TrainingExecutorTrainingTest.__init__(
    588         self,
    589         run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_WORKER))
    590 
    591   @test.mock.patch.object(server_lib, 'Server')
    592   def test_delay_for_worker(self, _):
    593     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    594     mock_est.config = self._run_config
    595     mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[])
    596     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
    597 
    598     executor = training._TrainingExecutor(mock_est, mock_train_spec,
    599                                           mock_eval_spec)
    600 
    601     expected_secs = (self._run_config.task_id + 1) * _DELAY_SECS_PER_WORKER
    602     with test.mock.patch.object(time, 'sleep') as mock_sleep:
    603       mock_sleep.side_effect = lambda s: self.assertEqual(expected_secs, s)
    604       self._run_task(executor)
    605       self.assertTrue(mock_sleep.called)
    606 
    607 
    608 class TrainingExecutorRunChiefTest(_TrainingExecutorTrainingTest,
    609                                    test.TestCase):
    610   """Tests run_chief of _TrainingExecutor."""
    611 
    612   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    613     test.TestCase.__init__(self, methodName)
    614     _TrainingExecutorTrainingTest.__init__(
    615         self,
    616         run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_CHIEF))
    617 
    618   @test.mock.patch.object(server_lib, 'Server')
    619   def test_no_delay_for_chief(self, _):
    620     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    621     mock_est.config = self._run_config
    622     mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[])
    623     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
    624 
    625     executor = training._TrainingExecutor(mock_est, mock_train_spec,
    626                                           mock_eval_spec)
    627 
    628     with test.mock.patch.object(time, 'sleep') as mock_sleep:
    629       self._run_task(executor)
    630       mock_sleep.assert_not_called()
    631 
    632 
    633 class TrainingExecutorRunMasterTest(test.TestCase):
    634   """Tests run_chief of _TrainingExecutor."""
    635 
    636   def setUp(self):
    637     self._run_config = _create_run_config_with_cluster_spec(
    638         _TF_CONFIG_FOR_MASTER)
    639 
    640   @test.mock.patch.object(server_lib, 'Server')
    641   def test_no_delay_for_master(self, _):
    642     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    643     mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
    644     mock_est.config = self._run_config
    645     mock_train_spec = test.mock.Mock(
    646         spec=training.TrainSpec, max_steps=123, hooks=[])
    647     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[])
    648 
    649     executor = training._TrainingExecutor(mock_est, mock_train_spec,
    650                                           mock_eval_spec)
    651 
    652     with test.mock.patch.object(time, 'sleep') as mock_sleep:
    653       executor.run_master()
    654       mock_sleep.assert_not_called()
    655 
    656   @test.mock.patch.object(time, 'sleep')
    657   @test.mock.patch.object(server_lib, 'Server')
    658   def test_train_with_train_spec(self, mock_server, unused_mock_sleep):
    659     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    660     mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
    661     mock_est.config = self._run_config
    662     train_spec = training.TrainSpec(
    663         input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
    664     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[])
    665     mock_server_instance = mock_server.return_value
    666 
    667     executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec)
    668     executor.run_master()
    669 
    670     mock_server.assert_called_with(
    671         mock_est.config.cluster_spec,
    672         job_name=mock_est.config.task_type,
    673         task_index=mock_est.config.task_id,
    674         config=test.mock.ANY,
    675         start=False)
    676 
    677     self.assertTrue(mock_server_instance.start.called)
    678 
    679     mock_est.train.assert_called_with(
    680         input_fn=train_spec.input_fn,
    681         max_steps=train_spec.max_steps,
    682         hooks=list(train_spec.hooks),
    683         saving_listeners=test.mock.ANY)
    684     mock_est.export_savedmodel.assert_not_called()
    685 
    686   @test.mock.patch.object(time, 'sleep')
    687   @test.mock.patch.object(server_lib, 'Server')
    688   def test_train_with_train_hooks(self, mock_server, unused_mock_sleep):
    689     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    690     mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
    691     mock_est.config = self._run_config
    692     train_spec = training.TrainSpec(
    693         input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
    694     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[])
    695     extra_hooks = [_FakeHook()]
    696 
    697     executor = training._TrainingExecutor(
    698         mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks)
    699     executor.run_master()
    700 
    701     mock_est.train.assert_called_with(
    702         input_fn=train_spec.input_fn,
    703         max_steps=train_spec.max_steps,
    704         hooks=list(train_spec.hooks) + extra_hooks,
    705         saving_listeners=test.mock.ANY)
    706 
    707   @test.mock.patch.object(time, 'sleep')
    708   @test.mock.patch.object(server_lib, 'Server')
    709   def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep):
    710     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    711     mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
    712     mock_est.config = self._run_config
    713     mock_train_spec = test.mock.Mock(
    714         spec=training.TrainSpec, max_steps=123, hooks=[])
    715     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[])
    716 
    717     executor = training._TrainingExecutor(mock_est, mock_train_spec,
    718                                           mock_eval_spec)
    719     tf_config = {'TF_CONFIG': json.dumps(_TF_CONFIG_FOR_GOOGLE)}
    720     with test.mock.patch.dict('os.environ', tf_config):
    721       executor.run_master()
    722       mock_server.assert_not_called()
    723 
    724   def test_fail_with_empty_cluster_spec(self):
    725     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    726     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
    727     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
    728 
    729     mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
    730     mock_est.config.cluster_spec = None
    731     mock_est.config.master = 'grpc://...'
    732     mock_est.config.task_type = 'master'
    733     mock_est.config.task_id = 2
    734 
    735     with self.assertRaisesRegexp(RuntimeError,
    736                                  _INVALID_CONFIG_FOR_STD_SERVER_MSG):
    737       training._TrainingExecutor(
    738           mock_est, mock_train_spec, mock_eval_spec).run_master()
    739 
    740   def test_fail_with_empty_master(self):
    741     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    742     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
    743     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
    744 
    745     mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
    746     mock_est.config.cluster_spec = server_lib.ClusterSpec(
    747         {'master': ['dummy'], 'worker': ['dummy1']})
    748     mock_est.config.master = ''
    749     mock_est.config.task_type = 'master'
    750     mock_est.config.task_id = 0
    751 
    752     with self.assertRaisesRegexp(RuntimeError,
    753                                  _INVALID_CONFIG_FOR_STD_SERVER_MSG):
    754       training._TrainingExecutor(
    755           mock_est, mock_train_spec, mock_eval_spec).run_master()
    756 
    757   @test.mock.patch.object(time, 'sleep')
    758   @test.mock.patch.object(server_lib, 'Server')
    759   def test_single_master_node_with_empty_tf_master(
    760       self, mock_server, unused_mock_sleep):
    761     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    762     mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
    763 
    764     mock_train_spec = test.mock.Mock(
    765         spec=training.TrainSpec, max_steps=123, hooks=[])
    766     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[])
    767 
    768     mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
    769     mock_est.config.cluster_spec = server_lib.ClusterSpec(
    770         {'master': ['dummy']})
    771     mock_est.config.master = ''
    772     mock_est.config.task_type = 'master'
    773     mock_est.config.task_id = 0
    774 
    775     executor = training._TrainingExecutor(
    776         mock_est, mock_train_spec, mock_eval_spec)
    777     executor.run_master()
    778 
    779     mock_server.assert_not_called()
    780     self.assertTrue(mock_est.train.called)
    781 
    782   def test_fail_with_empty_task_type(self):
    783     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    784     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
    785     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
    786 
    787     mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
    788     mock_est.config.cluster_spec = server_lib.ClusterSpec({'master': ['dummy']})
    789     mock_est.config.master = 'grpc://...'
    790     mock_est.config.task_type = ''
    791     mock_est.config.task_id = 2
    792 
    793     with self.assertRaisesRegexp(RuntimeError,
    794                                  _INVALID_CONFIG_FOR_STD_SERVER_MSG):
    795       training._TrainingExecutor(
    796           mock_est, mock_train_spec, mock_eval_spec).run_master()
    797 
    798   def test_fail_with_none_task_id(self):
    799     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    800     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
    801     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
    802 
    803     mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
    804     mock_est.config.cluster_spec = server_lib.ClusterSpec({'master': ['dummy']})
    805     mock_est.config.master = 'grpc://...'
    806     mock_est.config.task_type = 'master'
    807     mock_est.config.task_id = None
    808 
    809     with self.assertRaisesRegexp(RuntimeError,
    810                                  _INVALID_CONFIG_FOR_STD_SERVER_MSG):
    811       training._TrainingExecutor(
    812           mock_est, mock_train_spec, mock_eval_spec).run_master()
    813 
    814   @test.mock.patch.object(server_lib, 'Server')
    815   def test_run_master_triggers_evaluate_and_export(self, _):
    816 
    817     def estimator_train(saving_listeners, *args, **kwargs):
    818       #  There shalt be a saving_listener.  Estimator is going to call
    819       # `after_save`.
    820       del args, kwargs
    821       saving_listeners[0].begin()
    822       saving_listeners[0].after_save(session=None, global_step_value=None)
    823 
    824     mock_est = test.mock.Mock(
    825         spec=estimator_lib.Estimator, model_dir='path/', train=estimator_train)
    826     mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
    827     mock_est.config = self._run_config
    828 
    829     exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
    830     exporter.name = 'see_whether_export_is_called'
    831 
    832     train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
    833     eval_spec = training.EvalSpec(
    834         input_fn=lambda: 1, steps=2, exporters=exporter)
    835     eval_result = {_GLOBAL_STEP_KEY: train_spec.max_steps}
    836     mock_est.evaluate.return_value = eval_result
    837 
    838     executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
    839     executor.run_master()
    840 
    841     mock_est.evaluate.assert_called_with(
    842         name=eval_spec.name,
    843         input_fn=eval_spec.input_fn,
    844         steps=eval_spec.steps,
    845         checkpoint_path='checkpoint_path/',
    846         hooks=eval_spec.hooks)
    847     self.assertEqual(1, exporter.export.call_count)
    848     exporter.export.assert_called_with(
    849         estimator=mock_est,
    850         export_path=os.path.join('path/', 'export', exporter.name),
    851         checkpoint_path='checkpoint_path/',
    852         eval_result=eval_result,
    853         is_the_final_export=True)
    854 
    855   @test.mock.patch.object(basic_session_run_hooks, 'SecondOrStepTimer')
    856   @test.mock.patch.object(server_lib, 'Server')
    857   def test_run_master_throttle_eval(self, _, mock_timer_class):
    858     mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
    859 
    860     mock_timer = test.mock.Mock()
    861     mock_timer_class.return_value = mock_timer
    862 
    863     def estimator_train(saving_listeners, *args, **kwargs):
    864       del args, kwargs
    865       saving_listeners[0].begin()
    866 
    867       # Call three times.
    868       mock_timer.should_trigger_for_step.return_value = True
    869       saving_listeners[0].after_save(session=None, global_step_value=None)
    870 
    871       mock_timer.should_trigger_for_step.return_value = False
    872       saving_listeners[0].after_save(session=None, global_step_value=None)
    873 
    874       mock_timer.should_trigger_for_step.return_value = True
    875       saving_listeners[0].after_save(session=None, global_step_value=None)
    876 
    877     mock_est.train = estimator_train
    878     mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2']
    879     mock_est.config = self._run_config
    880 
    881     exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
    882     exporter.name = 'see_whether_export_is_called'
    883 
    884     train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
    885     eval_spec = training.EvalSpec(
    886         input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10)
    887 
    888     mock_est.evaluate.side_effect = [
    889         {_GLOBAL_STEP_KEY: train_spec.max_steps //2},
    890         {_GLOBAL_STEP_KEY: train_spec.max_steps}
    891     ]
    892 
    893     executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
    894     executor.run_master()
    895 
    896     self.assertEqual(2, mock_est.evaluate.call_count)
    897     self.assertEqual(2, exporter.export.call_count)
    898 
    899     is_final_export_list = [call[1]['is_the_final_export']
    900                             for call in exporter.export.call_args_list]
    901     self.assertEqual([False, True], is_final_export_list)
    902 
    903   @test.mock.patch.object(basic_session_run_hooks, 'SecondOrStepTimer')
    904   @test.mock.patch.object(server_lib, 'Server')
    905   def test_run_master_throttle_eval_which_skips_final_ckpt(
    906       self, _, mock_timer_class):
    907     mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
    908 
    909     mock_timer = test.mock.Mock()
    910     mock_timer_class.return_value = mock_timer
    911 
    912     def estimator_train(saving_listeners, *args, **kwargs):
    913       del args, kwargs
    914       saving_listeners[0].begin()
    915 
    916       # Call two times.
    917       mock_timer.should_trigger_for_step.return_value = True
    918       saving_listeners[0].after_save(session=None, global_step_value=None)
    919 
    920       # The final ckpt is skipped by the timer. It will be picked up the final
    921       # export check in the code.
    922       mock_timer.should_trigger_for_step.return_value = False
    923       saving_listeners[0].after_save(session=None, global_step_value=None)
    924 
    925     mock_est.train = estimator_train
    926     mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2']
    927     mock_est.config = self._run_config
    928 
    929     exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
    930     exporter.name = 'see_whether_export_is_called'
    931 
    932     train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
    933     eval_spec = training.EvalSpec(
    934         input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10)
    935 
    936     mock_est.evaluate.side_effect = [
    937         {_GLOBAL_STEP_KEY: train_spec.max_steps //2},
    938         {_GLOBAL_STEP_KEY: train_spec.max_steps}
    939     ]
    940 
    941     executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
    942     executor.run_master()
    943 
    944     self.assertEqual(2, mock_est.evaluate.call_count)
    945     self.assertEqual(2, exporter.export.call_count)
    946 
    947     is_final_export_list = [call[1]['is_the_final_export']
    948                             for call in exporter.export.call_args_list]
    949     self.assertEqual([False, True], is_final_export_list)
    950 
    951 
    952 class TrainingExecutorRunEvaluatorTest(test.TestCase):
    953   """Tests run_evaluator of _TrainingExecutor."""
    954 
    955   def _set_up_mock_est_to_train_and_evaluate_once(self, mock_est,
    956                                                   mock_train_spec):
    957     """Sets global step in eval result to end the while True eval loop."""
    958     training_max_step = 200
    959     mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: training_max_step}
    960     mock_train_spec.max_steps = training_max_step
    961 
    962   def test_evaluate_with_evaluate_spec(self):
    963     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    964     mock_est.latest_checkpoint.return_value = 'latest_it_is'
    965     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
    966     self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
    967 
    968     eval_spec = training.EvalSpec(
    969         input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='cont_eval',
    970         start_delay_secs=0, throttle_secs=0)
    971 
    972     executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
    973     executor.run_evaluator()
    974 
    975     mock_est.evaluate.assert_called_with(
    976         name='cont_eval',
    977         input_fn=eval_spec.input_fn,
    978         steps=eval_spec.steps,
    979         checkpoint_path='latest_it_is',
    980         hooks=eval_spec.hooks)
    981     self.assertFalse(mock_est.train.called)
    982 
    983   def test_evaluate_with_train_hooks(self):
    984     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    985     mock_est.latest_checkpoint.return_value = 'latest_it_is'
    986     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
    987     self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
    988 
    989     eval_spec = training.EvalSpec(
    990         input_fn=lambda: 1,
    991         steps=2,
    992         hooks=[_FakeHook()],
    993         name='cont_eval',
    994         start_delay_secs=0,
    995         throttle_secs=0)
    996 
    997     # The train_hooks will not be called during eval.
    998     mock_hook = test.mock.Mock(spec=session_run_hook.SessionRunHook)
    999     executor = training._TrainingExecutor(
   1000         mock_est, mock_train_spec, eval_spec, train_hooks=[mock_hook])
   1001     executor.run_evaluator()
   1002 
   1003     mock_hook.begin.assert_not_called()
   1004 
   1005   def test_evaluate_multiple_times(self):
   1006     training_max_step = 200
   1007 
   1008     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1009     mock_est.model_dir = compat.as_bytes(test.get_temp_dir())
   1010     mock_est.evaluate.side_effect = [
   1011         {_GLOBAL_STEP_KEY: training_max_step // 2},
   1012         {_GLOBAL_STEP_KEY: training_max_step}
   1013     ]
   1014     mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']
   1015 
   1016     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1017     mock_train_spec.max_steps = training_max_step
   1018 
   1019     exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
   1020     exporter.name = 'see_how_many_times_export_is_called'
   1021 
   1022     mock_est.times_export_was_called = 0
   1023     mock_est.times_final_export_was_called = 0
   1024     def export(estimator, export_path, checkpoint_path, eval_result,
   1025                is_the_final_export):
   1026       del export_path, checkpoint_path, eval_result
   1027       estimator.times_export_was_called += 1
   1028       # final_export is happened at the end.
   1029       self.assertEqual(0, estimator.times_final_export_was_called)
   1030       if is_the_final_export:
   1031         estimator.times_final_export_was_called += 1
   1032 
   1033     exporter.export = export
   1034 
   1035     eval_spec = training.EvalSpec(
   1036         input_fn=lambda: 1,
   1037         start_delay_secs=0,
   1038         throttle_secs=0,
   1039         exporters=exporter)
   1040 
   1041     executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
   1042     executor.run_evaluator()
   1043 
   1044     self.assertEqual(2, mock_est.evaluate.call_count)
   1045     self.assertEqual(2, mock_est.times_export_was_called)
   1046     self.assertEqual(1, mock_est.times_final_export_was_called)
   1047 
   1048   def test_evaluate_listener_before_eval(self):
   1049     training_max_step = 200
   1050 
   1051     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1052     mock_est.model_dir = compat.as_bytes(test.get_temp_dir())
   1053     # Without early stopping, this eval will be run twice.
   1054     mock_est.evaluate.side_effect = [{
   1055         _GLOBAL_STEP_KEY: training_max_step // 2
   1056     }, {
   1057         _GLOBAL_STEP_KEY: training_max_step
   1058     }]
   1059     mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']
   1060 
   1061     mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[])
   1062     mock_train_spec.max_steps = training_max_step
   1063 
   1064     class _Listener(training._ContinuousEvalListener):
   1065 
   1066       def __init__(self):
   1067         self.call_count = 0
   1068 
   1069       def before_eval(self):
   1070         self.call_count += 1
   1071         return  self.call_count == 1
   1072 
   1073     listener = _Listener()
   1074 
   1075     eval_spec = training.EvalSpec(
   1076         input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
   1077 
   1078     training._TrainingExecutor(
   1079         mock_est, mock_train_spec, eval_spec,
   1080         continuous_eval_listener=listener).run_evaluator()
   1081 
   1082     # Before_eval returns False during the second time, so, evaluate will be
   1083     # called once.
   1084     self.assertEqual(1, mock_est.evaluate.call_count)
   1085     self.assertEqual(2, listener.call_count)
   1086 
   1087   def test_evaluate_listener_after_eval(self):
   1088     training_max_step = 200
   1089 
   1090     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1091     mock_est.model_dir = compat.as_bytes(test.get_temp_dir())
   1092     # Without early stopping, this eval will be run twice.
   1093     expected_eval_metrics = [{
   1094         _GLOBAL_STEP_KEY: training_max_step // 2
   1095     }, {
   1096         _GLOBAL_STEP_KEY: training_max_step
   1097     }]
   1098     mock_est.evaluate.side_effect = expected_eval_metrics
   1099     mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']
   1100 
   1101     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1102     mock_train_spec.max_steps = training_max_step
   1103 
   1104     class _Listener(training._ContinuousEvalListener):
   1105 
   1106       def __init__(self):
   1107         self.call_count = 0
   1108 
   1109       def after_eval(self, eval_result):
   1110         self.call_count += 1
   1111         self.eval_result = eval_result
   1112         return False
   1113 
   1114     listener = _Listener()
   1115 
   1116     eval_spec = training.EvalSpec(
   1117         input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
   1118 
   1119     training._TrainingExecutor(
   1120         mock_est, mock_train_spec, eval_spec,
   1121         continuous_eval_listener=listener).run_evaluator()
   1122 
   1123     # after_eval returns False during the first time, so, evaluate will be
   1124     # called once.
   1125     self.assertEqual(1, mock_est.evaluate.call_count)
   1126     self.assertEqual(1, listener.call_count)
   1127     self.assertAllEqual(expected_eval_metrics[0], listener.eval_result.metrics)
   1128     self.assertEqual('path_1', listener.eval_result.checkpoint_path)
   1129 
   1130   def test_final_export_is_true_in_the_end(self):
   1131     training_max_step = 200
   1132 
   1133     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1134     mock_est.model_dir = compat.as_bytes(test.get_temp_dir())
   1135     mock_est.evaluate.side_effect = [
   1136         {_GLOBAL_STEP_KEY: training_max_step // 2},
   1137         {_GLOBAL_STEP_KEY: training_max_step}
   1138     ]
   1139     mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']
   1140 
   1141     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1142     mock_train_spec.max_steps = training_max_step
   1143 
   1144     mock_est.times_export_fn_was_called = 0
   1145     mock_est.times_the_final_export_was_true = 0
   1146     def export(estimator, export_path, checkpoint_path, eval_result,
   1147                is_the_final_export):
   1148       del export_path, checkpoint_path, eval_result
   1149       estimator.times_export_fn_was_called += 1
   1150       if is_the_final_export:
   1151         estimator.times_the_final_export_was_true += 1
   1152 
   1153     exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
   1154     exporter.name = 'see_how_many_times_export_is_called'
   1155     exporter.export = export
   1156 
   1157     eval_spec = training.EvalSpec(
   1158         input_fn=lambda: 1,
   1159         start_delay_secs=0,
   1160         throttle_secs=0,
   1161         exporters=exporter)
   1162 
   1163     executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
   1164     executor.run_evaluator()
   1165 
   1166     self.assertEqual(2, mock_est.evaluate.call_count)
   1167     self.assertEqual(2, mock_est.times_export_fn_was_called)
   1168     self.assertEqual(1, mock_est.times_the_final_export_was_true)
   1169 
   1170   def test_skip_evaluation_due_to_ckpt(self):
   1171     training_max_step = 200
   1172     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1173     mock_est.evaluate.side_effect = [
   1174         {_GLOBAL_STEP_KEY: training_max_step // 2},
   1175         {_GLOBAL_STEP_KEY: training_max_step}
   1176     ]
   1177     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1178     mock_train_spec.max_steps = training_max_step
   1179 
   1180     self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
   1181 
   1182     # First two items are invalid, next two items are same.
   1183     mock_est.latest_checkpoint.side_effect = [
   1184         None, '', 'same', 'same', 'path_2'
   1185     ]
   1186 
   1187     eval_spec = training.EvalSpec(
   1188         input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
   1189 
   1190     executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
   1191     with test.mock.patch.object(logging, 'warning') as mock_log:
   1192       executor.run_evaluator()
   1193 
   1194     # Three checkpoint paths are invalid.
   1195     self.assertEqual(5, mock_est.latest_checkpoint.call_count)
   1196     self.assertEqual(2, mock_est.evaluate.call_count)
   1197 
   1198     # Two warning logs are expected (last warning time is reset after a
   1199     # successuful evaluation)
   1200     self.assertEqual(2, mock_log.call_count)
   1201 
   1202   def test_continuous_eval_listener_eval_result(self):
   1203     training_max_step = 200
   1204     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1205     expected_eval_metrics = [{
   1206         _GLOBAL_STEP_KEY: training_max_step // 2
   1207     }, {
   1208         _GLOBAL_STEP_KEY: training_max_step
   1209     }]
   1210     mock_est.evaluate.side_effect = expected_eval_metrics
   1211     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1212     mock_train_spec.max_steps = training_max_step
   1213 
   1214     class _Listener(training._ContinuousEvalListener):
   1215 
   1216       def __init__(self):
   1217         self.eval_results = []
   1218 
   1219       def after_eval(self, eval_result):
   1220         self.eval_results.append(eval_result)
   1221         return True
   1222 
   1223     continuous_eval_listener = _Listener()
   1224 
   1225     self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
   1226 
   1227     # First two items are invalid, next two items are same.
   1228     mock_est.latest_checkpoint.side_effect = [
   1229         None, '', 'same', 'same', 'path_2'
   1230     ]
   1231     expected_eval_results = [
   1232         training._EvalResult(training._EvalStatus.MISSING_CHECKPOINT),
   1233         training._EvalResult(training._EvalStatus.MISSING_CHECKPOINT),
   1234         training._EvalResult(
   1235             training._EvalStatus.EVALUATED,
   1236             metrics=expected_eval_metrics[0],
   1237             checkpoint_path='same'),
   1238         training._EvalResult(training._EvalStatus.NO_NEW_CHECKPOINT),
   1239         training._EvalResult(
   1240             training._EvalStatus.EVALUATED,
   1241             metrics=expected_eval_metrics[1],
   1242             checkpoint_path='path_2'),
   1243     ]
   1244 
   1245     eval_spec = training.EvalSpec(
   1246         input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
   1247 
   1248     executor = training._TrainingExecutor(
   1249         mock_est,
   1250         mock_train_spec,
   1251         eval_spec,
   1252         continuous_eval_listener=continuous_eval_listener)
   1253     executor.run_evaluator()
   1254 
   1255     # Three checkpoint paths are invalid.
   1256     self.assertEqual(5, mock_est.latest_checkpoint.call_count)
   1257     self.assertEqual(2, mock_est.evaluate.call_count)
   1258 
   1259     self.assertEqual(5, len(continuous_eval_listener.eval_results))
   1260     for i, result in enumerate(continuous_eval_listener.eval_results):
   1261       self.assertEqual(expected_eval_results[i].status, result.status)
   1262       self.assertAllEqual(expected_eval_results[i].metrics, result.metrics)
   1263       self.assertEqual(expected_eval_results[i].checkpoint_path,
   1264                        result.checkpoint_path)
   1265 
   1266   def test_sleep_start_delay_secs(self):
   1267     training_max_step = 200
   1268     start_delay_secs = 123
   1269 
   1270     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1271     mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: training_max_step}
   1272     mock_est.model_dir = compat.as_bytes(test.get_temp_dir())
   1273     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1274     mock_train_spec.max_steps = training_max_step
   1275 
   1276     eval_spec = training.EvalSpec(
   1277         input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='cont_eval',
   1278         start_delay_secs=start_delay_secs, throttle_secs=0)
   1279 
   1280     executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
   1281     with test.mock.patch.object(time, 'sleep') as mock_sleep:
   1282       executor.run_evaluator()
   1283       mock_sleep.assert_called_with(start_delay_secs)
   1284       self.assertTrue(mock_est.evaluate.called)
   1285 
   1286   @test.mock.patch.object(time, 'time')
   1287   @test.mock.patch.object(time, 'sleep')
   1288   def test_throttle_secs(self, mock_sleep, mock_time):
   1289     throttle_secs = 123
   1290     operation_secs = 12
   1291 
   1292     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1293     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1294     self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
   1295 
   1296     eval_spec = training.EvalSpec(
   1297         input_fn=lambda: 1, start_delay_secs=0, throttle_secs=throttle_secs)
   1298 
   1299     mock_time.side_effect = [921, 921 + operation_secs]
   1300 
   1301     executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
   1302     # Disable logging as it calls time.time also.
   1303     with test.mock.patch.object(logging, 'info'):
   1304       executor.run_evaluator()
   1305     mock_sleep.assert_called_with(throttle_secs - operation_secs)
   1306     self.assertTrue(mock_est.evaluate.called)
   1307 
   1308   def test_that_export_is_called(self):
   1309     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1310     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1311     self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
   1312 
   1313     def export(estimator, *args, **kwargs):
   1314       del args, kwargs
   1315       estimator.export_was_called = True
   1316 
   1317     exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
   1318     exporter.name = 'see_whether_export_is_called'
   1319     exporter.export = export
   1320 
   1321     eval_spec = training.EvalSpec(
   1322         input_fn=lambda: 1,
   1323         steps=2,
   1324         start_delay_secs=0,
   1325         throttle_secs=0,
   1326         exporters=exporter)
   1327 
   1328     executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
   1329     executor.run_evaluator()
   1330 
   1331     # Verify that export was called on the right estimator.
   1332     self.assertTrue(mock_est.export_was_called)
   1333 
   1334   def test_errors_out_if_evaluate_returns_empty_dict(self):
   1335     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1336     train_spec = training.TrainSpec(input_fn=lambda: 1)
   1337     eval_spec = training.EvalSpec(input_fn=(lambda: 1),
   1338                                   start_delay_secs=0, throttle_secs=0)
   1339     mock_est.evaluate.return_value = {}
   1340 
   1341     executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
   1342     with self.assertRaisesRegexp(ValueError, _INVALID_EMPTY_EVAL_RESULT_ERR):
   1343       executor.run_evaluator()
   1344 
   1345   def test_errors_out_if_evaluate_returns_non_dict(self):
   1346     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1347     train_spec = training.TrainSpec(input_fn=lambda: 1)
   1348     eval_spec = training.EvalSpec(input_fn=(lambda: 1),
   1349                                   start_delay_secs=0, throttle_secs=0)
   1350     mock_est.evaluate.return_value = 123
   1351 
   1352     executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
   1353     with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR):
   1354       executor.run_evaluator()
   1355 
   1356   def test_errors_out_if_evaluate_returns_dict_without_global_step(self):
   1357     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1358     train_spec = training.TrainSpec(input_fn=lambda: 1)
   1359     eval_spec = training.EvalSpec(input_fn=(lambda: 1),
   1360                                   start_delay_secs=0, throttle_secs=0)
   1361     mock_est.evaluate.return_value = {'loss': 123}
   1362 
   1363     executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
   1364     with self.assertRaisesRegexp(ValueError,
   1365                                  _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR):
   1366       executor.run_evaluator()
   1367 
   1368 
   1369 class TrainingExecutorRunPsTest(test.TestCase):
   1370   """Tests run_ps of _TrainingExecutor."""
   1371 
   1372   @test.mock.patch.object(server_lib, 'Server')
   1373   def test_std_server(self, mock_server):
   1374     mock_server_instance = test.mock.Mock()
   1375     mock_server.return_value = mock_server_instance
   1376 
   1377     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1378     mock_est.config = _create_run_config_with_cluster_spec(_TF_CONFIG_FOR_PS)
   1379     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1380     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
   1381 
   1382     executor = training._TrainingExecutor(mock_est, mock_train_spec,
   1383                                           mock_eval_spec)
   1384     executor.run_ps()
   1385 
   1386     mock_server.assert_called_with(
   1387         mock_est.config.cluster_spec,
   1388         job_name=mock_est.config.task_type,
   1389         task_index=mock_est.config.task_id,
   1390         config=test.mock.ANY,
   1391         start=False)
   1392 
   1393     self.assertTrue(mock_server_instance.start.called)
   1394     self.assertTrue(mock_server_instance.join.called)
   1395 
   1396   def test_fail_with_empty_cluster_spec(self):
   1397     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1398     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1399     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
   1400 
   1401     mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
   1402     mock_est.config.cluster_spec = None
   1403     mock_est.config.master = 'grpc://...'
   1404     mock_est.config.task_type = 'ps'
   1405     mock_est.config.task_id = 2
   1406 
   1407     with self.assertRaisesRegexp(RuntimeError,
   1408                                  _INVALID_CONFIG_FOR_STD_SERVER_MSG):
   1409       training._TrainingExecutor(mock_est, mock_train_spec,
   1410                                  mock_eval_spec).run_ps()
   1411 
   1412   def test_fail_with_empty_master(self):
   1413     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1414     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1415     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
   1416 
   1417     mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
   1418     mock_est.config.cluster_spec = server_lib.ClusterSpec({'ps': ['dummy']})
   1419     mock_est.config.master = ''
   1420     mock_est.config.task_type = 'ps'
   1421     mock_est.config.task_id = 2
   1422 
   1423     with self.assertRaisesRegexp(RuntimeError,
   1424                                  _INVALID_CONFIG_FOR_STD_SERVER_MSG):
   1425       training._TrainingExecutor(mock_est, mock_train_spec,
   1426                                  mock_eval_spec).run_ps()
   1427 
   1428   def test_fail_with_empty_task_type(self):
   1429     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1430     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1431     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
   1432 
   1433     mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
   1434     mock_est.config.cluster_spec = server_lib.ClusterSpec({'ps': ['dummy']})
   1435     mock_est.config.master = 'grpc://...'
   1436     mock_est.config.task_type = ''
   1437     mock_est.config.task_id = 2
   1438 
   1439     with self.assertRaisesRegexp(RuntimeError,
   1440                                  _INVALID_CONFIG_FOR_STD_SERVER_MSG):
   1441       training._TrainingExecutor(mock_est, mock_train_spec,
   1442                                  mock_eval_spec).run_ps()
   1443 
   1444   def test_fail_with_none_task_id(self):
   1445     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1446     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1447     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
   1448 
   1449     mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
   1450     mock_est.config.cluster_spec = server_lib.ClusterSpec({'ps': ['dummy']})
   1451     mock_est.config.master = 'grpc://...'
   1452     mock_est.config.task_type = 'ps'
   1453     mock_est.config.task_id = None
   1454 
   1455     with self.assertRaisesRegexp(RuntimeError,
   1456                                  _INVALID_CONFIG_FOR_STD_SERVER_MSG):
   1457       training._TrainingExecutor(mock_est, mock_train_spec,
   1458                                  mock_eval_spec).run_ps()
   1459 
   1460 
   1461 class StopAtSecsHookTest(test.TestCase):
   1462   """Tests StopAtSecsHook."""
   1463 
   1464   @test.mock.patch.object(time, 'time')
   1465   def test_stops_after_time(self, mock_time):
   1466     mock_time.return_value = 1484695987.209386
   1467     hook = training._StopAtSecsHook(1000)
   1468     with ops.Graph().as_default():
   1469       no_op = control_flow_ops.no_op()
   1470       # some time passed before training starts
   1471       mock_time.return_value += 250
   1472       with monitored_session.MonitoredSession(hooks=[hook]) as sess:
   1473         self.assertFalse(sess.should_stop())
   1474         sess.run(no_op)
   1475         self.assertFalse(sess.should_stop())
   1476         mock_time.return_value += 500
   1477         sess.run(no_op)
   1478         self.assertFalse(sess.should_stop())
   1479         mock_time.return_value += 400
   1480         sess.run(no_op)
   1481         self.assertFalse(sess.should_stop())
   1482         mock_time.return_value += 200
   1483         sess.run(no_op)
   1484         self.assertTrue(sess.should_stop())
   1485 
   1486 
   1487 class TrainingExecutorRunLocalTest(test.TestCase):
   1488   """Tests run_local of _TrainingExecutor."""
   1489 
   1490   def unique_checkpoint_every_time_fn(self):
   1491     return 'checkpoint_path_%s/' % random.random()
   1492 
   1493   def test_send_stop_at_secs_to_train(self):
   1494     mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
   1495     mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn
   1496     train_spec = training.TrainSpec(
   1497         input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
   1498     eval_spec = training.EvalSpec(
   1499         input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100)
   1500     mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}
   1501 
   1502     executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
   1503     executor.run_local()
   1504 
   1505     stop_hook = mock_est.train.call_args[1]['hooks'][-1]
   1506     self.assertIsInstance(stop_hook, training._StopAtSecsHook)
   1507     self.assertEqual(eval_spec.throttle_secs, stop_hook._stop_after_secs)
   1508 
   1509   def test_runs_in_a_loop_until_max_steps(self):
   1510     mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
   1511     mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn
   1512 
   1513     mock_est.times_export_was_called = 0
   1514     mock_est.times_final_export_was_called = 0
   1515     def export(estimator, export_path, checkpoint_path, eval_result,
   1516                is_the_final_export):
   1517       del export_path, checkpoint_path, eval_result
   1518       estimator.times_export_was_called += 1
   1519       # final_export is happened at the end.
   1520       self.assertEqual(0, estimator.times_final_export_was_called)
   1521       if is_the_final_export:
   1522         estimator.times_final_export_was_called += 1
   1523 
   1524     exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
   1525     exporter.name = 'see_how_many_times_export_is_called'
   1526     exporter.export = export
   1527 
   1528     train_spec = training.TrainSpec(
   1529         input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
   1530     eval_spec = training.EvalSpec(
   1531         input_fn=lambda: 1,
   1532         hooks=[_FakeHook()],
   1533         throttle_secs=100,
   1534         exporters=exporter)
   1535     # should be called 3 times.
   1536     mock_est.evaluate.side_effect = [{
   1537         _GLOBAL_STEP_KEY: train_spec.max_steps - 100
   1538     }, {
   1539         _GLOBAL_STEP_KEY: train_spec.max_steps - 50
   1540     }, {
   1541         _GLOBAL_STEP_KEY: train_spec.max_steps
   1542     }]
   1543 
   1544     executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
   1545     executor.run_local()
   1546 
   1547     self.assertEqual(3, mock_est.train.call_count)
   1548     self.assertEqual(3, mock_est.evaluate.call_count)
   1549     self.assertEqual(3, mock_est.times_export_was_called)
   1550     self.assertEqual(1, mock_est.times_final_export_was_called)
   1551 
   1552   def test_handles_no_new_checkpoint_found(self):
   1553     mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
   1554     mock_est.latest_checkpoint.return_value = (
   1555         'no_new_checkpoints_after_the_first_train_step')
   1556     train_spec = training.TrainSpec(
   1557         input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
   1558     eval_spec = training.EvalSpec(
   1559         input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100)
   1560     # It was going to be called 3 times.
   1561     mock_est.evaluate.side_effect = [{
   1562         _GLOBAL_STEP_KEY: train_spec.max_steps - 100
   1563     }, {
   1564         _GLOBAL_STEP_KEY: train_spec.max_steps - 50
   1565     }, {
   1566         _GLOBAL_STEP_KEY: train_spec.max_steps
   1567     }]
   1568 
   1569     executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
   1570     with self.assertRaisesRegexp(RuntimeError, _STALE_CHECKPOINT_MSG):
   1571       executor.run_local()
   1572 
   1573   def test_final_export_is_true_in_the_end(self):
   1574     mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
   1575     mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn
   1576 
   1577     mock_est.times_export_fn_was_called = 0
   1578     mock_est.times_the_final_export_was_true = 0
   1579     def export(estimator, export_path, checkpoint_path, eval_result,
   1580                is_the_final_export):
   1581       del export_path, checkpoint_path, eval_result
   1582       estimator.times_export_fn_was_called += 1
   1583       if is_the_final_export:
   1584         estimator.times_the_final_export_was_true += 1
   1585 
   1586     exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
   1587     exporter.name = 'see_how_many_times_export_is_called'
   1588     exporter.export = export
   1589 
   1590     train_spec = training.TrainSpec(
   1591         input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
   1592     eval_spec = training.EvalSpec(
   1593         input_fn=lambda: 1,
   1594         hooks=[_FakeHook()],
   1595         throttle_secs=100,
   1596         exporters=exporter)
   1597     # should be called 3 times.
   1598     mock_est.evaluate.side_effect = [{
   1599         _GLOBAL_STEP_KEY: train_spec.max_steps - 100
   1600     }, {
   1601         _GLOBAL_STEP_KEY: train_spec.max_steps - 50
   1602     }, {
   1603         _GLOBAL_STEP_KEY: train_spec.max_steps
   1604     }]
   1605 
   1606     executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
   1607     executor.run_local()
   1608 
   1609     self.assertEqual(3, mock_est.train.call_count)
   1610     self.assertEqual(3, mock_est.evaluate.call_count)
   1611     self.assertEqual(3, mock_est.times_export_fn_was_called)
   1612     self.assertEqual(1, mock_est.times_the_final_export_was_true)
   1613 
   1614   def test_train_and_evaluate_args(self):
   1615     mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
   1616     mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
   1617     train_spec = training.TrainSpec(
   1618         input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
   1619     eval_spec = training.EvalSpec(
   1620         input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='local_eval')
   1621     mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}
   1622 
   1623     executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
   1624     executor.run_local()
   1625 
   1626     mock_est.evaluate.assert_called_with(
   1627         name=eval_spec.name,
   1628         input_fn=eval_spec.input_fn,
   1629         steps=eval_spec.steps,
   1630         checkpoint_path='checkpoint_path/',
   1631         hooks=eval_spec.hooks)
   1632 
   1633     train_args = mock_est.train.call_args[1]
   1634     self.assertEqual(list(train_spec.hooks), list(train_args['hooks'][:-1]))
   1635     self.assertEqual(train_spec.input_fn, train_args['input_fn'])
   1636     self.assertEqual(train_spec.max_steps, train_args['max_steps'])
   1637 
   1638   def test_train_hooks(self):
   1639     mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
   1640     mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
   1641     train_spec = training.TrainSpec(
   1642         input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
   1643     eval_spec = training.EvalSpec(input_fn=lambda: 1, steps=2)
   1644     mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}
   1645     extra_hooks = [_FakeHook()]
   1646 
   1647     executor = training._TrainingExecutor(
   1648         mock_est, train_spec, eval_spec, train_hooks=extra_hooks)
   1649     executor.run_local()
   1650 
   1651     train_args = mock_est.train.call_args[1]
   1652     self.assertEqual(
   1653         list(train_spec.hooks) + extra_hooks, [
   1654             h for h in train_args['hooks']
   1655             if not isinstance(h, training._StopAtSecsHook)
   1656         ])
   1657 
   1658   def test_errors_out_if_throttle_secs_is_zero(self):
   1659     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1660     train_spec = training.TrainSpec(input_fn=lambda: 1)
   1661     eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=0)
   1662 
   1663     executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
   1664     with self.assertRaisesRegexp(ValueError, 'throttle_secs'):
   1665       executor.run_local()
   1666 
   1667   def test_that_export_is_called_with_run_local(self):
   1668     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1669     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1670     mock_train_spec.max_steps = 200
   1671     mock_est.evaluate.return_value = {
   1672         _GLOBAL_STEP_KEY: mock_train_spec.max_steps
   1673     }
   1674     # _validate_hooks would have made sure that train_spec.hooks is [], when
   1675     # None were passed.
   1676     mock_train_spec.hooks = []
   1677 
   1678     def export(estimator, *args, **kwargs):
   1679       del args, kwargs
   1680       estimator.export_was_called = True
   1681 
   1682     exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
   1683     exporter.name = 'see_whether_export_is_called'
   1684     exporter.export = export
   1685 
   1686     eval_spec = training.EvalSpec(
   1687         input_fn=lambda: 1,
   1688         steps=2,
   1689         start_delay_secs=0,
   1690         throttle_secs=213,
   1691         exporters=exporter)
   1692 
   1693     executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
   1694     executor.run_local()
   1695 
   1696     self.assertTrue(mock_est.export_was_called)
   1697 
   1698   def test_errors_out_if_evaluate_returns_empty_dict(self):
   1699     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1700     train_spec = training.TrainSpec(input_fn=lambda: 1)
   1701     eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123)
   1702     mock_est.evaluate.return_value = {}
   1703 
   1704     executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
   1705     with self.assertRaisesRegexp(ValueError, _INVALID_EMPTY_EVAL_RESULT_ERR):
   1706       executor.run_local()
   1707 
   1708   def test_errors_out_if_evaluate_returns_non_dict(self):
   1709     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1710     train_spec = training.TrainSpec(input_fn=lambda: 1)
   1711     eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123)
   1712     mock_est.evaluate.return_value = 123
   1713 
   1714     executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
   1715     with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR):
   1716       executor.run_local()
   1717 
   1718   def test_errors_out_if_evaluate_returns_dict_without_global_step(self):
   1719     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1720     train_spec = training.TrainSpec(input_fn=lambda: 1)
   1721     eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123)
   1722     mock_est.evaluate.return_value = {'loss': 123}
   1723 
   1724     executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
   1725     with self.assertRaisesRegexp(ValueError,
   1726                                  _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR):
   1727       executor.run_local()
   1728 
   1729 
   1730 class TrainAndEvaluateRunTest(test.TestCase):
   1731 
   1732   def _test_run_task_and_executor(self, run_config):
   1733     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1734     mock_est.config = run_config
   1735     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1736     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
   1737 
   1738     executor = training._TrainingExecutor(mock_est, mock_train_spec,
   1739                                           mock_eval_spec)
   1740 
   1741     executor.call_task = {}
   1742 
   1743     def task_fn(name):
   1744 
   1745       def _fn():
   1746         executor.call_task[name] = 1
   1747 
   1748       return _fn
   1749 
   1750     executor.run_chief = task_fn('chief')
   1751     executor.run_master = task_fn('master')
   1752     executor.run_ps = task_fn('ps')
   1753     executor.run_evaluator = task_fn('evaluator')
   1754     executor.run_worker = task_fn('worker')
   1755     executor.run_local = task_fn('local')
   1756     return executor
   1757 
   1758   def test_run_chief(self):
   1759     executor = self._test_run_task_and_executor(
   1760         run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_CHIEF))
   1761     executor.run()
   1762     self.assertEqual(1, executor.call_task['chief'])
   1763 
   1764   def test_run_worker(self):
   1765     executor = self._test_run_task_and_executor(
   1766         run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_WORKER))
   1767     executor.run()
   1768     self.assertEqual(1, executor.call_task['worker'])
   1769 
   1770   def test_run_ps(self):
   1771     executor = self._test_run_task_and_executor(
   1772         run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_PS))
   1773     executor.run()
   1774     self.assertEqual(1, executor.call_task['ps'])
   1775 
   1776   def test_run_evaluator(self):
   1777     executor = self._test_run_task_and_executor(
   1778         run_config=_create_run_config_with_cluster_spec(
   1779             _TF_CONFIG_FOR_EVALUATOR))
   1780     executor.run()
   1781     self.assertEqual(1, executor.call_task['evaluator'])
   1782 
   1783   def test_run_local(self):
   1784     executor = self._test_run_task_and_executor(
   1785         run_config=run_config_lib.RunConfig())
   1786     executor.run()
   1787     self.assertEqual(1, executor.call_task['local'])
   1788 
   1789   def test_invalid_local_task(self):
   1790     tf_config = {
   1791         'cluster': {
   1792             run_config_lib.TaskType.CHIEF: ['host0:0'],
   1793             'local': ['hos1:1'],
   1794         },
   1795         'task': {
   1796             'type': 'local',  # invalid task type.
   1797             'index': 0
   1798         }
   1799     }
   1800     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1801     mock_est.config = _create_run_config_with_cluster_spec(tf_config)
   1802     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1803     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
   1804 
   1805     executor = training._TrainingExecutor(mock_est, mock_train_spec,
   1806                                           mock_eval_spec)
   1807     with self.assertRaisesRegexp(ValueError, _INVALID_LOCAL_TASK_WITH_CLUSTER):
   1808       executor.run()
   1809 
   1810   def test_unsupported_task_due_to_missing_run_task(self):
   1811     unsupported_task = 'alloc'
   1812     tf_config = {
   1813         'cluster': {
   1814             run_config_lib.TaskType.CHIEF: ['host0:0'],
   1815             unsupported_task: ['hos1:1'],
   1816         },
   1817         'task': {
   1818             'type': unsupported_task,
   1819             'index': 0
   1820         }
   1821     }
   1822     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1823     mock_est.config = _create_run_config_with_cluster_spec(tf_config)
   1824     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1825     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
   1826 
   1827     executor = training._TrainingExecutor(mock_est, mock_train_spec,
   1828                                           mock_eval_spec)
   1829     with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN):
   1830       executor.run()
   1831 
   1832   def test_unsupported_task_due_to_not_callable(self):
   1833     unsupported_task = 'alloc'
   1834     tf_config = {
   1835         'cluster': {
   1836             run_config_lib.TaskType.CHIEF: ['host0:0'],
   1837             unsupported_task: ['hos1:1'],
   1838         },
   1839         'task': {
   1840             'type': unsupported_task,
   1841             'index': 0
   1842         }
   1843     }
   1844     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1845     mock_est.config = _create_run_config_with_cluster_spec(tf_config)
   1846     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1847     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
   1848 
   1849     executor = training._TrainingExecutor(mock_est, mock_train_spec,
   1850                                           mock_eval_spec)
   1851     executor.run_alloc = 123  # not callable
   1852     with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN):
   1853       executor.run()
   1854 
   1855   def test_invalid_task_type(self):
   1856     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
   1857     mock_est.config = test.mock.Mock()
   1858     mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
   1859     mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
   1860 
   1861     mock_est.config = test.mock.Mock()
   1862     mock_est.config.cluster_spec = server_lib.ClusterSpec({'1': ['dummy']})
   1863     mock_est.config.task_type = ''
   1864 
   1865     executor = training._TrainingExecutor(mock_est, mock_train_spec,
   1866                                           mock_eval_spec)
   1867     with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TYPE):
   1868       executor.run()
   1869 
   1870 
   1871 class TrainAndEvaluateIntegrationTest(test.TestCase):
   1872 
   1873   def setUp(self):
   1874     self._model_dir = tempfile.mkdtemp()
   1875 
   1876   def tearDown(self):
   1877     if self._model_dir:
   1878       shutil.rmtree(self._model_dir)
   1879 
   1880   def _as_label(self, data_in_float):
   1881     return np.rint(data_in_float).astype(np.int64)
   1882 
   1883   def _get_exporter(self, name, fc):
   1884     feature_spec = feature_column.make_parse_example_spec(fc)
   1885     serving_input_receiver_fn = (
   1886         export_lib.build_parsing_serving_input_receiver_fn(feature_spec))
   1887     return exporter_lib.LatestExporter(
   1888         name, serving_input_receiver_fn=serving_input_receiver_fn)
   1889 
   1890   def _extract_loss_and_global_step(self, event_folder):
   1891     """Returns the loss and global step in last event."""
   1892     event_paths = glob.glob(os.path.join(event_folder, 'events*'))
   1893 
   1894     loss = None
   1895     global_step_count = None
   1896 
   1897     for e in summary_iterator.summary_iterator(event_paths[-1]):
   1898       current_loss = None
   1899       for v in e.summary.value:
   1900         if v.tag == 'loss':
   1901           current_loss = v.simple_value
   1902 
   1903       # If loss is not found, global step is meaningless.
   1904       if current_loss is None:
   1905         continue
   1906 
   1907       current_global_step = e.step
   1908       if global_step_count is None or current_global_step > global_step_count:
   1909         global_step_count = current_global_step
   1910         loss = current_loss
   1911 
   1912     return (loss, global_step_count)
   1913 
   1914   def test_complete_flow_with_non_distributed_configuration(self):
   1915     n_classes = 3
   1916     input_dimension = 2
   1917     batch_size = 10
   1918 
   1919     eval_name = 'foo'
   1920     exporter_name = 'saved_model_exporter'
   1921 
   1922     # max_steps should be larger than save_summary_steps
   1923     max_steps = 10
   1924     save_summary_steps = 2
   1925 
   1926     data = np.linspace(
   1927         0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
   1928     x_data = data.reshape(batch_size, input_dimension)
   1929     y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
   1930 
   1931     # learn y = x
   1932     train_input_fn = numpy_io.numpy_input_fn(
   1933         x={'x': x_data},
   1934         y=y_data,
   1935         batch_size=batch_size,
   1936         num_epochs=None,
   1937         shuffle=True)
   1938 
   1939     eval_input_fn = numpy_io.numpy_input_fn(
   1940         x={'x': x_data},
   1941         y=y_data,
   1942         batch_size=batch_size,
   1943         num_epochs=1,
   1944         shuffle=False)
   1945 
   1946     predict_input_fn = numpy_io.numpy_input_fn(
   1947         x={'x': x_data},
   1948         batch_size=batch_size,
   1949         shuffle=False)
   1950 
   1951     feature_columns = [
   1952         feature_column.numeric_column('x', shape=(input_dimension,))]
   1953 
   1954     est = dnn.DNNClassifier(
   1955         hidden_units=(2, 2),
   1956         feature_columns=feature_columns,
   1957         n_classes=n_classes,
   1958         config=run_config_lib.RunConfig(save_summary_steps=save_summary_steps),
   1959         model_dir=self._model_dir)
   1960 
   1961     train_spec = training.TrainSpec(input_fn=train_input_fn,
   1962                                     max_steps=max_steps)
   1963 
   1964     eval_spec = training.EvalSpec(
   1965         name=eval_name, input_fn=eval_input_fn, steps=None,
   1966         exporters=self._get_exporter(exporter_name, feature_columns),
   1967         throttle_secs=2)
   1968 
   1969     training.train_and_evaluate(est, train_spec, eval_spec)
   1970 
   1971     # Make sure nothing is stuck in limbo.
   1972     writer_cache.FileWriterCache.clear()
   1973 
   1974     # Examine the training events. Use a range to check global step to avoid
   1975     # flakyness due to global step race condition.
   1976     training_loss, training_global_step = self._extract_loss_and_global_step(
   1977         est.model_dir)
   1978     self.assertIsNotNone(training_loss)
   1979     self.assertTrue(
   1980         max_steps - save_summary_steps < training_global_step <= max_steps)
   1981 
   1982     # Examine the eval events. The global step should be accurate.
   1983     eval_loss, eval_global_step = self._extract_loss_and_global_step(
   1984         event_folder=os.path.join(est.model_dir, 'eval_' + eval_name))
   1985     self.assertIsNotNone(eval_loss)
   1986     self.assertEqual(max_steps, eval_global_step)
   1987 
   1988     # Examine the export folder.
   1989     export_dir = os.path.join(os.path.join(est.model_dir, 'export'),
   1990                               exporter_name)
   1991     self.assertTrue(gfile.Exists(export_dir))
   1992 
   1993     # Examine the ckpt for predict.
   1994     predicted_proba = np.array([
   1995         x[prediction_keys.PredictionKeys.PROBABILITIES]
   1996         for x in est.predict(predict_input_fn)
   1997     ])
   1998     self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
   1999 
   2000 
   2001 if __name__ == '__main__':
   2002   test.main()
   2003