Home | History | Annotate | Download | only in learn
      1 #  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 #  Licensed under the Apache License, Version 2.0 (the "License");
      4 #  you may not use this file except in compliance with the License.
      5 #  You may obtain a copy of the License at
      6 #
      7 #   http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 #  Unless required by applicable law or agreed to in writing, software
     10 #  distributed under the License is distributed on an "AS IS" BASIS,
     11 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 #  See the License for the specific language governing permissions and
     13 #  limitations under the License.
     14 """Tests for TaskRunner and Experiment class."""
     15 
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import json
     21 import os
     22 import tempfile
     23 import time
     24 
     25 from tensorflow.contrib.layers.python.layers import feature_column
     26 from tensorflow.contrib.learn.python.learn import estimator as estimator_lib
     27 from tensorflow.contrib.learn.python.learn import evaluable
     28 from tensorflow.contrib.learn.python.learn import experiment
     29 from tensorflow.contrib.learn.python.learn import run_config
     30 from tensorflow.contrib.learn.python.learn import trainable
     31 from tensorflow.contrib.learn.python.learn.estimators import dnn
     32 from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
     33 from tensorflow.contrib.learn.python.learn.estimators import test_data
     34 from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
     35 from tensorflow.core.protobuf import config_pb2
     36 from tensorflow.python.client import session
     37 from tensorflow.python.estimator import estimator as core_estimator
     38 from tensorflow.python.ops import variables
     39 from tensorflow.python.platform import test
     40 from tensorflow.python.platform import tf_logging
     41 from tensorflow.python.training import saver
     42 from tensorflow.python.training import server_lib
     43 from tensorflow.python.training import session_run_hook
     44 from tensorflow.python.util import compat
     45 from tensorflow.python.util import tf_inspect
     46 
     47 
     48 class SheepCounter(object):
     49   """To be patched in for the time module, replacing sleep() and time()."""
     50 
     51   def __init__(self):
     52     self._total_time = 0
     53     self._sleeptimes = []
     54     self._time_calls = 0
     55 
     56   def sleep(self, t):
     57     self._total_time += t
     58     self._sleeptimes += [t]
     59 
     60   def time(self):
     61     self._time_calls += 1
     62     return self._total_time
     63 
     64   @property
     65   def sleep_times(self):
     66     return self._sleeptimes
     67 
     68   @property
     69   def time_calls(self):
     70     return self._time_calls
     71 
     72 
     73 class TestBaseEstimator(object):
     74 
     75   def __init__(self, config, max_evals, eval_dict):
     76     self.eval_count = 0
     77     self.fit_count = 0
     78     self._max_evals = max_evals
     79     self.export_count = 0
     80     self.monitors = []
     81     self.eval_hooks = []
     82     self._config = config or run_config.RunConfig()
     83     self._model_dir = tempfile.mkdtemp()
     84     self._eval_dict = eval_dict
     85 
     86   @property
     87   def model_dir(self):
     88     return self._model_dir
     89 
     90   @property
     91   def config(self):
     92     return self._config
     93 
     94   def evaluate(self, **kwargs):
     95     tf_logging.info('evaluate called with args: %s' % kwargs)
     96     if 'hooks' in kwargs:
     97       self.eval_hooks = kwargs['hooks']
     98     self.eval_count += 1
     99     if self.eval_count > self._max_evals:
    100       tf_logging.info('Ran %d evals. Done.' % self.eval_count)
    101       raise StopIteration()
    102     return self._eval_dict
    103 
    104   def fake_checkpoint(self):
    105     save_path = os.path.join(self.model_dir, 'model.ckpt')
    106     with session.Session() as sess:
    107       var = variables.Variable(1.0, name='var0')
    108       save = saver.Saver({var.op.name: var})
    109       var.initializer.run()
    110       save.save(sess, save_path, global_step=0)
    111 
    112   def train(self, **kwargs):
    113     self.fake_checkpoint()
    114     tf_logging.info('fit called with args: %s' % kwargs)
    115     self.fit_count += 1
    116 
    117     return [(key, kwargs[key]) for key in sorted(kwargs.keys())]
    118 
    119   def export_savedmodel(self, export_dir_base, serving_input_fn, **kwargs):
    120     tf_logging.info('export_savedmodel called with args: %s, %s, %s' %
    121                     (export_dir_base, serving_input_fn, kwargs))
    122     self.export_count += 1
    123     return os.path.join(
    124         compat.as_bytes(export_dir_base), compat.as_bytes('bogus_timestamp'))
    125 
    126 
    127 def _check_method_supports_args(method, kwargs):
    128   """Checks that the given method supports the given args."""
    129   supported_args = tuple(tf_inspect.getargspec(method).args)
    130   for kwarg in kwargs:
    131     if kwarg not in supported_args:
    132       raise ValueError(
    133           'Argument `{}` is not supported in method {}.'.format(kwarg, method))
    134 
    135 
    136 class TestEstimator(
    137     TestBaseEstimator, evaluable.Evaluable, trainable.Trainable):
    138 
    139   def __init__(self, config=None, max_evals=5, eval_dict=None):
    140     super(TestEstimator, self).__init__(config, max_evals, eval_dict)
    141     tf_logging.info('Create Estimator')
    142 
    143   def evaluate(self, **kwargs):
    144     _check_method_supports_args(evaluable.Evaluable.evaluate, kwargs)
    145     return super(TestEstimator, self).evaluate(**kwargs)
    146 
    147   def fit(self, **kwargs):
    148     _check_method_supports_args(trainable.Trainable.fit, kwargs)
    149     if 'monitors' in kwargs:
    150       self.monitors = kwargs['monitors']
    151     return super(TestEstimator, self).train(**kwargs)
    152 
    153   def train(self, **kwargs):
    154     raise ValueError('`train` is not defined in Estimator.')
    155 
    156   def export_savedmodel(
    157       self, export_dir_base, serving_input_fn, **kwargs):
    158     _check_method_supports_args(
    159         estimator_lib.Estimator.export_savedmodel, kwargs)
    160     return super(TestEstimator, self).export_savedmodel(
    161         export_dir_base, serving_input_fn, **kwargs)
    162 
    163 
    164 class TestCoreEstimator(TestBaseEstimator, core_estimator.Estimator):
    165 
    166   def __init__(self, config=None, max_evals=5, eval_dict=None):
    167     super(TestCoreEstimator, self).__init__(config, max_evals, eval_dict)
    168     tf_logging.info('Create Core Estimator')
    169 
    170   def evaluate(self, **kwargs):
    171     _check_method_supports_args(core_estimator.Estimator.evaluate, kwargs)
    172     return super(TestCoreEstimator, self).evaluate(**kwargs)
    173 
    174   def train(self, **kwargs):
    175     _check_method_supports_args(core_estimator.Estimator.train, kwargs)
    176     if 'hooks' in kwargs:
    177       self.monitors = kwargs['hooks']
    178     return super(TestCoreEstimator, self).train(**kwargs)
    179 
    180   def export_savedmodel(
    181       self, export_dir_base, serving_input_receiver_fn, **kwargs):
    182     _check_method_supports_args(
    183         core_estimator.Estimator.export_savedmodel, kwargs)
    184     return super(TestCoreEstimator, self).export_savedmodel(
    185         export_dir_base, serving_input_receiver_fn, **kwargs)
    186 
    187 
    188 class _NoopHook(session_run_hook.SessionRunHook):
    189   pass
    190 
    191 
    192 class ExperimentTest(test.TestCase):
    193 
    194   def _cluster_spec(self):
    195     return {
    196         run_config_lib.TaskType.PS: ['host1:2222', 'host2:2222'],
    197         run_config_lib.TaskType.WORKER:
    198             ['host3:2222', 'host4:2222', 'host5:2222']
    199     }
    200 
    201   def _estimators_for_tests(self, config=None, eval_dict=None):
    202     return [TestEstimator(config=config, eval_dict=eval_dict),
    203             TestCoreEstimator(config=config, eval_dict=eval_dict)]
    204 
    205   def test_eval_metrcis_for_core_estimator(self):
    206     est = TestCoreEstimator()
    207     with self.assertRaisesRegexp(
    208         ValueError, '`eval_metrics` must be `None`'):
    209       experiment.Experiment(
    210           est,
    211           train_input_fn='train_input',
    212           train_steps='train_steps',
    213           eval_input_fn='eval_input',
    214           eval_metrics='eval_metrics')
    215 
    216   def test_default_output_alternative_key_core_estimator(self):
    217     est = TestCoreEstimator()
    218     export_strategy = saved_model_export_utils.make_export_strategy(
    219         est,
    220         default_output_alternative_key='export_key',
    221         exports_to_keep=None)
    222     ex = experiment.Experiment(
    223         est,
    224         train_input_fn='train_input',
    225         eval_input_fn='eval_input',
    226         train_steps=100,
    227         eval_steps=100,
    228         export_strategies=export_strategy)
    229     with self.assertRaisesRegexp(
    230         ValueError, 'default_output_alternative_key is not supported'):
    231       ex.train_and_evaluate()
    232 
    233   def test_train(self):
    234     for est in self._estimators_for_tests():
    235       if isinstance(est, core_estimator.Estimator):
    236         eval_metrics = None
    237         saving_listeners = 'saving_listeners'
    238       else:
    239         eval_metrics = 'eval_metrics'
    240         saving_listeners = None
    241       ex = experiment.Experiment(
    242           est,
    243           train_input_fn='train_input',
    244           train_steps='train_steps',
    245           eval_input_fn='eval_input',
    246           eval_metrics=eval_metrics,
    247           saving_listeners=saving_listeners)
    248       fit_args = ex.train(delay_secs=0)
    249       self.assertEqual(1, est.fit_count)
    250       self.assertIn(('max_steps', 'train_steps'), fit_args)
    251       self.assertEqual(0, est.eval_count)
    252 
    253   def test_train_delay(self):
    254     for est in self._estimators_for_tests():
    255       ex = experiment.Experiment(
    256           est, train_input_fn='train_input', eval_input_fn='eval_input')
    257       for delay in [0, 1, 3]:
    258         sheep = SheepCounter()
    259         with test.mock.patch.object(time, 'time', sheep.time):
    260           with test.mock.patch.object(time, 'sleep', sheep.sleep):
    261             ex.train(delay_secs=delay)
    262             self.assertAlmostEqual(delay, sheep.time(), delta=1e-4)
    263 
    264   def test_train_default_delay(self):
    265     for task_id in [0, 1, 3]:
    266       tf_config = {'task': {'index': task_id}}
    267       with test.mock.patch.dict('os.environ',
    268                                 {'TF_CONFIG': json.dumps(tf_config)}):
    269         config = run_config.RunConfig()
    270       for est in self._estimators_for_tests(config):
    271         ex = experiment.Experiment(
    272             est, train_input_fn='train_input', eval_input_fn='eval_input')
    273 
    274         sheep = SheepCounter()
    275         with test.mock.patch.object(time, 'time', sheep.time):
    276           with test.mock.patch.object(time, 'sleep', sheep.sleep):
    277             ex.train()
    278             self.assertAlmostEqual(task_id * 5, sheep.time(), delta=1e-4)
    279 
    280   @test.mock.patch.object(server_lib, 'Server')
    281   def test_train_starts_server(self, mock_server):
    282     # Arrange.
    283     tf_config = {
    284         'cluster': self._cluster_spec(),
    285         'environment': run_config_lib.Environment.CLOUD,
    286         'task': {
    287             'type': run_config_lib.TaskType.WORKER,
    288             'index': 1
    289         }
    290     }
    291     with test.mock.patch.dict('os.environ',
    292                               {'TF_CONFIG': json.dumps(tf_config)}):
    293       config = run_config_lib.RunConfig(
    294           master='host4:2222', num_cores=15, gpu_memory_fraction=0.314)
    295 
    296     for est in self._estimators_for_tests(config):
    297       ex = experiment.Experiment(
    298           est, train_input_fn='train_input', eval_input_fn='eval_input')
    299 
    300       # Act.
    301       # We want to make sure we discount the time it takes to start the server
    302       # in our accounting of the delay, so we set a small delay here.
    303       sheep = SheepCounter()
    304       with test.mock.patch.object(time, 'time', sheep.time):
    305         with test.mock.patch.object(time, 'sleep', sheep.sleep):
    306           ex.train(delay_secs=1)
    307           # Ensure that the delay takes into account the time to start server.
    308           self.assertAlmostEqual(1, sheep.time(), delta=1e-4)
    309 
    310       # Assert.
    311       expected_config_proto = config_pb2.ConfigProto()
    312       expected_config_proto.inter_op_parallelism_threads = 15
    313       expected_config_proto.intra_op_parallelism_threads = 15
    314       expected_config_proto.gpu_options.per_process_gpu_memory_fraction = 0.314
    315       mock_server.assert_called_with(
    316           config.cluster_spec,
    317           job_name=run_config_lib.TaskType.WORKER,
    318           task_index=1,
    319           config=expected_config_proto,
    320           start=False)
    321       mock_server.assert_has_calls([test.mock.call().start()])
    322 
    323   @test.mock.patch.object(server_lib, 'Server')
    324   def test_train_server_does_not_start_without_cluster_spec(self, mock_server):
    325     config = run_config_lib.RunConfig(master='host4:2222')
    326     for est in self._estimators_for_tests(config):
    327       ex = experiment.Experiment(
    328           est,
    329           train_input_fn='train_input',
    330           eval_input_fn='eval_input')
    331       ex.train()
    332 
    333       # The server should not have started because there was no ClusterSpec.
    334       self.assertFalse(mock_server.called)
    335 
    336   @test.mock.patch.object(server_lib, 'Server')
    337   def test_train_server_does_not_start_with_empty_master(self, mock_server):
    338     tf_config = {'cluster': self._cluster_spec()}
    339     with test.mock.patch.dict('os.environ',
    340                               {'TF_CONFIG': json.dumps(tf_config)}):
    341       config = run_config_lib.RunConfig(master='')
    342     for est in self._estimators_for_tests(config):
    343       ex = experiment.Experiment(
    344           est,
    345           train_input_fn='train_input',
    346           eval_input_fn='eval_input')
    347       ex.train()
    348       # The server should not have started because master was the empty string.
    349       self.assertFalse(mock_server.called)
    350 
    351   def test_train_raises_if_job_name_is_missing(self):
    352     tf_config = {
    353         'cluster': self._cluster_spec(),
    354         'environment': run_config_lib.Environment.CLOUD,
    355         'task': {
    356             'index': 1
    357         }
    358     }
    359     with test.mock.patch.dict(
    360         'os.environ',
    361         {'TF_CONFIG': json.dumps(tf_config)}), self.assertRaises(ValueError):
    362       config = run_config_lib.RunConfig(
    363           master='host3:2222'  # Normally selected by task type.
    364       )
    365       for est in self._estimators_for_tests(config):
    366         ex = experiment.Experiment(
    367             est,
    368             train_input_fn='train_input',
    369             eval_input_fn='eval_input')
    370         ex.train()
    371 
    372   def test_evaluate(self):
    373     for est in self._estimators_for_tests():
    374       eval_metrics = 'eval_metrics' if not isinstance(
    375           est, core_estimator.Estimator) else None
    376       est.fake_checkpoint()
    377       noop_hook = _NoopHook()
    378       ex = experiment.Experiment(
    379           est,
    380           train_input_fn='train_input',
    381           eval_input_fn='eval_input',
    382           eval_metrics=eval_metrics,
    383           eval_hooks=[noop_hook],
    384           eval_steps='steps',
    385           eval_delay_secs=0)
    386       ex.evaluate()
    387       self.assertEqual(0, est.fit_count)
    388       self.assertEqual(1, est.eval_count)
    389       self.assertEqual([noop_hook], est.eval_hooks)
    390 
    391   def test_evaluate_delay(self):
    392     for est in self._estimators_for_tests():
    393       est.fake_checkpoint()
    394       noop_hook = _NoopHook()
    395       ex = experiment.Experiment(
    396           est, train_input_fn='train_input', eval_input_fn='eval_input',
    397           eval_hooks=[noop_hook])
    398 
    399       for delay in [0, 1, 3]:
    400         sheep = SheepCounter()
    401         with test.mock.patch.object(time, 'time', sheep.time):
    402           with test.mock.patch.object(time, 'sleep', sheep.sleep):
    403             ex.evaluate(delay_secs=delay)
    404         self.assertAlmostEqual(delay, sheep.time(), delta=1e-4)
    405         self.assertEqual([noop_hook], est.eval_hooks)
    406 
    407   def test_continuous_eval(self):
    408     for est in self._estimators_for_tests(eval_dict={'global_step': 100}):
    409       eval_metrics = 'eval_metrics' if not isinstance(
    410           est, core_estimator.Estimator) else None
    411       est.fake_checkpoint()
    412       noop_hook = _NoopHook()
    413       ex = experiment.Experiment(
    414           est,
    415           train_input_fn='train_input',
    416           eval_input_fn='eval_input',
    417           eval_metrics=eval_metrics,
    418           eval_hooks=[noop_hook],
    419           eval_delay_secs=0,
    420           continuous_eval_throttle_secs=0)
    421       self.assertRaises(StopIteration, ex.continuous_eval,
    422                         evaluate_checkpoint_only_once=False)
    423       self.assertEqual(0, est.fit_count)
    424       self.assertEqual(6, est.eval_count)
    425       self.assertEqual([noop_hook], est.eval_hooks)
    426 
    427   def test_continuous_eval_ends_after_train_step(self):
    428     for est in self._estimators_for_tests(eval_dict={'global_step': 100}):
    429       eval_metrics = 'eval_metrics' if not isinstance(
    430           est, core_estimator.Estimator) else None
    431       est.fake_checkpoint()
    432       noop_hook = _NoopHook()
    433       ex = experiment.Experiment(
    434           est,
    435           train_input_fn='train_input',
    436           eval_input_fn='eval_input',
    437           eval_metrics=eval_metrics,
    438           eval_hooks=[noop_hook],
    439           eval_delay_secs=0,
    440           continuous_eval_throttle_secs=0,
    441           train_steps=100)
    442       ex.continuous_eval()
    443       self.assertEqual(0, est.fit_count)
    444       self.assertEqual(1, est.eval_count)
    445       self.assertEqual([noop_hook], est.eval_hooks)
    446 
    447   def test_continuous_eval_throttle_delay(self):
    448     for delay in [0, 1, 2]:
    449       for est in self._estimators_for_tests():
    450         eval_metrics = 'eval_metrics' if not isinstance(
    451             est, core_estimator.Estimator) else None
    452         est.fake_checkpoint()
    453         noop_hook = _NoopHook()
    454         ex = experiment.Experiment(
    455             est,
    456             train_input_fn='train_input',
    457             eval_input_fn='eval_input',
    458             eval_metrics=eval_metrics,
    459             eval_hooks=[noop_hook],
    460             continuous_eval_throttle_secs=delay,
    461             eval_delay_secs=0)
    462         sheep = SheepCounter()
    463         with test.mock.patch.object(time, 'time', sheep.time):
    464           with test.mock.patch.object(time, 'sleep', sheep.sleep):
    465             self.assertRaises(
    466                 StopIteration,
    467                 ex.continuous_eval,
    468                 evaluate_checkpoint_only_once=False)
    469             self.assertAlmostEqual(5 * delay, sheep.time(), delta=1e-4)
    470 
    471   def test_continuous_eval_predicate_fn(self):
    472     for est in self._estimators_for_tests():
    473       eval_metrics = 'eval_metrics' if not isinstance(
    474           est, core_estimator.Estimator) else None
    475       est.fake_checkpoint()
    476       noop_hook = _NoopHook()
    477 
    478       def _predicate_fn(unused_eval_result):
    479         return est.eval_count < 3  # pylint: disable=cell-var-from-loop
    480 
    481       ex = experiment.Experiment(
    482           est,
    483           train_input_fn='train_input',
    484           eval_input_fn='eval_input',
    485           eval_metrics=eval_metrics,
    486           eval_hooks=[noop_hook],
    487           eval_delay_secs=0,
    488           continuous_eval_throttle_secs=0)
    489       ex.continuous_eval(evaluate_checkpoint_only_once=False,
    490                          continuous_eval_predicate_fn=_predicate_fn)
    491       self.assertEqual(0, est.fit_count)
    492       self.assertEqual(3, est.eval_count)
    493       self.assertEqual([noop_hook], est.eval_hooks)
    494 
    495   def test_continuous_eval_predicate_fn_with_checkpoint(self):
    496     for est in self._estimators_for_tests():
    497       eval_metrics = 'eval_metrics' if not isinstance(
    498           est, core_estimator.Estimator) else None
    499       est.fake_checkpoint()
    500       noop_hook = _NoopHook()
    501 
    502       def _predicate_fn(eval_result, checkpoint_path):
    503         self.assertEqual(eval_result is None,
    504                          checkpoint_path is None)
    505         return est.eval_count < 3  # pylint: disable=cell-var-from-loop
    506 
    507       ex = experiment.Experiment(
    508           est,
    509           train_input_fn='train_input',
    510           eval_input_fn='eval_input',
    511           eval_metrics=eval_metrics,
    512           eval_hooks=[noop_hook],
    513           eval_delay_secs=0,
    514           continuous_eval_throttle_secs=0)
    515       ex.continuous_eval(
    516           evaluate_checkpoint_only_once=False,
    517           continuous_eval_predicate_fn=_predicate_fn)
    518       self.assertEqual(0, est.fit_count)
    519       self.assertEqual(3, est.eval_count)
    520       self.assertEqual([noop_hook], est.eval_hooks)
    521 
    522   def test_run_local(self):
    523     for est in self._estimators_for_tests():
    524       eval_metrics = 'eval_metrics' if not isinstance(
    525           est, core_estimator.Estimator) else None
    526       noop_hook = _NoopHook()
    527       ex = experiment.Experiment(
    528           est,
    529           train_input_fn='train_input',
    530           eval_input_fn='eval_input',
    531           eval_metrics=eval_metrics,
    532           eval_hooks=[noop_hook],
    533           train_steps=100,
    534           eval_steps=100,
    535           local_eval_frequency=10)
    536       ex.local_run()
    537       self.assertEqual(1, est.fit_count)
    538       self.assertEqual(1, est.eval_count)
    539       self.assertEqual(1, len(est.monitors))
    540       self.assertEqual([noop_hook], est.eval_hooks)
    541       self.assertTrue(isinstance(est.monitors[0],
    542                                  session_run_hook.SessionRunHook))
    543 
    544   def test_train_hooks_extend_does_not_mutate_input_hooks(self):
    545     for est in self._estimators_for_tests():
    546       eval_metrics = 'eval_metrics' if not isinstance(
    547           est, core_estimator.Estimator) else None
    548       noop_hook = _NoopHook()
    549       input_hooks = [noop_hook]
    550 
    551       ex = experiment.Experiment(
    552           est,
    553           train_input_fn='train_input',
    554           eval_input_fn='eval_input',
    555           eval_metrics=eval_metrics,
    556           train_monitors=input_hooks)
    557       self.assertAllEqual([noop_hook], ex._train_monitors)
    558 
    559       another_noop_hook = _NoopHook()
    560       # Assert that the extend API mutates the hooks, but not the input hooks
    561       ex.extend_train_hooks([another_noop_hook])
    562       self.assertAllEqual([noop_hook, another_noop_hook], ex._train_monitors)
    563       self.assertAllEqual([noop_hook], input_hooks)
    564 
    565   def test_invalid_export_strategies(self):
    566     for est in self._estimators_for_tests():
    567       with self.assertRaisesRegexp(ValueError, 'ExportStrategy'):
    568         experiment.Experiment(
    569             est,
    570             train_input_fn='train_input',
    571             eval_input_fn='eval_input',
    572             train_steps=100,
    573             eval_steps=100,
    574             export_strategies='not_an_export_strategy')
    575       with self.assertRaisesRegexp(ValueError, 'ExportStrategy'):
    576         experiment.Experiment(
    577             est,
    578             train_input_fn='train_input',
    579             eval_input_fn='eval_input',
    580             train_steps=100,
    581             eval_steps=100,
    582             export_strategies=['not_an_export_srategy'])
    583 
    584   def test_export_strategies_reset(self):
    585     for est in self._estimators_for_tests():
    586       eval_metrics = 'eval_metrics' if not isinstance(
    587           est, core_estimator.Estimator) else None
    588       export_strategy_1 = saved_model_export_utils.make_export_strategy(
    589           est,
    590           None if isinstance(est, core_estimator.Estimator) else 'export_1',
    591           exports_to_keep=None)
    592 
    593       ex = experiment.Experiment(
    594           est,
    595           train_input_fn='train_input',
    596           eval_input_fn='eval_input',
    597           eval_metrics=eval_metrics,
    598           train_steps=100,
    599           eval_steps=100,
    600           export_strategies=(export_strategy_1,))
    601       ex.train_and_evaluate()
    602       self.assertEqual(1, est.export_count)
    603 
    604       # After reset with empty list (None), the count does not change and the
    605       # user provided export strategy list should remain intact.
    606       old_es = ex.reset_export_strategies()
    607       ex.train_and_evaluate()
    608       self.assertAllEqual([export_strategy_1], old_es)
    609       self.assertEqual(1, est.export_count)
    610 
    611       # After reset with list, the count should increase with the number of
    612       # items.
    613       export_strategy_2 = saved_model_export_utils.make_export_strategy(
    614           est,
    615           None if isinstance(est, core_estimator.Estimator) else 'export_2',
    616           exports_to_keep=None)
    617       export_strategy_3 = saved_model_export_utils.make_export_strategy(
    618           est,
    619           None if isinstance(est, core_estimator.Estimator) else 'export_3',
    620           exports_to_keep=None)
    621 
    622       old_es = ex.reset_export_strategies(
    623           [export_strategy_2, export_strategy_3])
    624       ex.train_and_evaluate()
    625       self.assertAllEqual([], old_es)
    626       self.assertEqual(3, est.export_count)
    627 
    628   def test_train_and_evaluate(self):
    629     for est in self._estimators_for_tests():
    630       eval_metrics = 'eval_metrics' if not isinstance(
    631           est, core_estimator.Estimator) else None
    632       noop_hook = _NoopHook()
    633       export_strategy = saved_model_export_utils.make_export_strategy(
    634           est,
    635           None if isinstance(est, core_estimator.Estimator) else 'export_input',
    636           exports_to_keep=None)
    637       ex = experiment.Experiment(
    638           est,
    639           train_input_fn='train_input',
    640           eval_input_fn='eval_input',
    641           eval_metrics=eval_metrics,
    642           eval_hooks=[noop_hook],
    643           train_steps=100,
    644           eval_steps=100,
    645           export_strategies=export_strategy)
    646       ex.train_and_evaluate()
    647       self.assertEqual(1, est.fit_count)
    648       self.assertEqual(1, est.eval_count)
    649       self.assertEqual(1, est.export_count)
    650       self.assertEqual(1, len(est.monitors))
    651       self.assertEqual([noop_hook], est.eval_hooks)
    652       self.assertTrue(isinstance(est.monitors[0],
    653                                  session_run_hook.SessionRunHook))
    654 
    655   def test_train_and_evaluate_with_no_eval_during_training(self):
    656     for est in self._estimators_for_tests():
    657       eval_metrics = 'eval_metrics' if not isinstance(
    658           est, core_estimator.Estimator) else None
    659       noop_hook = _NoopHook()
    660       ex = experiment.Experiment(
    661           est,
    662           train_input_fn='train_input',
    663           eval_input_fn='eval_input',
    664           eval_metrics=eval_metrics,
    665           eval_hooks=[noop_hook],
    666           train_steps=100,
    667           eval_steps=100,
    668           min_eval_frequency=0)
    669       ex.train_and_evaluate()
    670       self.assertEqual(1, est.fit_count)
    671       self.assertEqual(1, est.eval_count)
    672       self.assertEqual(0, len(est.monitors))
    673 
    674   def test_min_eval_frequency_defaults(self):
    675     def dummy_model_fn(features, labels):  # pylint: disable=unused-argument
    676       pass
    677     estimator = core_estimator.Estimator(dummy_model_fn, '/tmp/dummy')
    678     ex = experiment.Experiment(
    679         estimator, train_input_fn=None, eval_input_fn=None)
    680     self.assertEquals(ex._min_eval_frequency, 1)
    681 
    682   def test_continuous_train_and_eval(self):
    683     for est in self._estimators_for_tests(eval_dict={'global_step': 100}):
    684       if isinstance(est, core_estimator.Estimator):
    685         eval_metrics = None
    686         saving_listeners = 'saving_listeners'
    687       else:
    688         eval_metrics = 'eval_metrics'
    689         saving_listeners = None
    690       noop_hook = _NoopHook()
    691       export_strategy = saved_model_export_utils.make_export_strategy(
    692           est,
    693           None if isinstance(est, core_estimator.Estimator) else 'export_input',
    694           exports_to_keep=None)
    695       ex = experiment.Experiment(
    696           est,
    697           train_input_fn='train_input',
    698           eval_input_fn='eval_input',
    699           eval_metrics=eval_metrics,
    700           eval_hooks=[noop_hook],
    701           train_steps=100,
    702           eval_steps=100,
    703           export_strategies=export_strategy,
    704           saving_listeners=saving_listeners)
    705       ex.continuous_train_and_eval()
    706       self.assertEqual(1, est.fit_count)
    707       self.assertEqual(1, est.eval_count)
    708       self.assertEqual(1, est.export_count)
    709       self.assertEqual([noop_hook], est.eval_hooks)
    710 
    711   def test_continuous_train_and_eval_with_predicate_fn(self):
    712     for est in self._estimators_for_tests(eval_dict={'global_step': 100}):
    713       eval_metrics = 'eval_metrics' if not isinstance(
    714           est, core_estimator.Estimator) else None
    715       export_strategy = saved_model_export_utils.make_export_strategy(
    716           est,
    717           None if isinstance(est, core_estimator.Estimator) else 'export_input',
    718           exports_to_keep=None)
    719       ex = experiment.Experiment(
    720           est,
    721           train_input_fn='train_input',
    722           eval_input_fn='eval_input',
    723           eval_metrics=eval_metrics,
    724           train_steps=100000000000,  # a value will make `ex` never stops.
    725           eval_steps=100,
    726           export_strategies=export_strategy)
    727 
    728       def predicate_fn(eval_result):
    729         del eval_result  # unused. for fn signature.
    730         return False
    731 
    732       ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn)
    733       self.assertEqual(0, est.fit_count)
    734       self.assertEqual(0, est.eval_count)
    735       self.assertEqual(0, est.export_count)
    736 
    737   def test_continuous_train_and_eval_with_adapted_steps_per_iteration(self):
    738     mock_estimator = test.mock.Mock(core_estimator.Estimator)
    739     type(mock_estimator).model_dir = test.mock.PropertyMock(
    740         return_value='test_dir')
    741 
    742     total_steps = 100000000000000
    743     ex = experiment.Experiment(
    744         mock_estimator,
    745         train_input_fn='train_input',
    746         eval_input_fn='eval_input',
    747         train_steps=total_steps)
    748 
    749     def predicate_fn(eval_result):
    750       # Allows the first invoke only.
    751       return eval_result is None
    752 
    753     ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn)
    754     mock_estimator.train.assert_called_once_with(
    755         input_fn='train_input',
    756         steps=int(total_steps / 10),
    757         max_steps=test.mock.ANY,
    758         hooks=test.mock.ANY,
    759         saving_listeners=test.mock.ANY)
    760 
    761   def test_continuous_train_and_eval_with_steps_per_iteration_from_user(self):
    762     mock_estimator = test.mock.Mock(core_estimator.Estimator)
    763     type(mock_estimator).model_dir = test.mock.PropertyMock(
    764         return_value='test_dir')
    765 
    766     total_steps = 100000000000000
    767     ex = experiment.Experiment(
    768         mock_estimator,
    769         train_input_fn='train_input',
    770         eval_input_fn='eval_input',
    771         train_steps_per_iteration=1234,
    772         train_steps=total_steps)
    773 
    774     def predicate_fn(eval_result):
    775       # Allows the first invoke only.
    776       return eval_result is None
    777 
    778     ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn)
    779     mock_estimator.train.assert_called_once_with(
    780         input_fn='train_input',
    781         steps=1234,
    782         max_steps=test.mock.ANY,
    783         hooks=test.mock.ANY,
    784         saving_listeners=test.mock.ANY)
    785 
    786   def test_continuous_train_and_eval_with_default_steps_per_iteration(self):
    787     mock_estimator = test.mock.Mock(core_estimator.Estimator)
    788     type(mock_estimator).model_dir = test.mock.PropertyMock(
    789         return_value='test_dir')
    790 
    791     ex = experiment.Experiment(
    792         mock_estimator,
    793         train_input_fn='train_input',
    794         eval_input_fn='eval_input',
    795         train_steps_per_iteration=None,
    796         train_steps=None)
    797 
    798     def predicate_fn(eval_result):
    799       # Allows the first invoke only.
    800       return eval_result is None
    801 
    802     ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn)
    803     mock_estimator.train.assert_called_once_with(
    804         input_fn='train_input',
    805         steps=1000,
    806         max_steps=test.mock.ANY,
    807         hooks=test.mock.ANY,
    808         saving_listeners=test.mock.ANY)
    809 
    810   def test_continuous_train_and_eval_with_invalid_predicate_fn(self):
    811     for est in self._estimators_for_tests():
    812       ex = experiment.Experiment(
    813           est,
    814           train_input_fn='train_input',
    815           eval_input_fn='eval_input')
    816       with self.assertRaisesRegexp(
    817           ValueError, '`continuous_eval_predicate_fn` must be a callable'):
    818         ex.continuous_train_and_eval(continuous_eval_predicate_fn='fn')
    819 
    820   def test_continuous_train_and_eval_with_invalid_train_steps_iterations(self):
    821     for est in self._estimators_for_tests():
    822       with self.assertRaisesRegexp(
    823           ValueError, '`train_steps_per_iteration` must be an integer.'):
    824         experiment.Experiment(
    825             est,
    826             train_input_fn='train_input',
    827             eval_input_fn='eval_input',
    828             train_steps_per_iteration='123')
    829 
    830   @test.mock.patch.object(server_lib, 'Server')
    831   def test_run_std_server(self, mock_server):
    832     # Arrange.
    833     tf_config = {
    834         'cluster': self._cluster_spec(),
    835         'task': {
    836             'type': run_config_lib.TaskType.PS,
    837             'index': 1
    838         }
    839     }
    840     with test.mock.patch.dict('os.environ',
    841                               {'TF_CONFIG': json.dumps(tf_config)}):
    842       config = run_config_lib.RunConfig(
    843           master='host2:2222',
    844           num_cores=15,
    845           gpu_memory_fraction=0.314,)
    846     for est in self._estimators_for_tests(config):
    847       ex = experiment.Experiment(
    848           est, train_input_fn='train_input', eval_input_fn='eval_input')
    849 
    850       # Act.
    851       ex.run_std_server()
    852 
    853       # Assert.
    854       mock_server.assert_has_calls(
    855           [test.mock.call().start(), test.mock.call().join()])
    856 
    857   @test.mock.patch.object(server_lib, 'Server')
    858   def test_run_std_server_raises_without_cluster_spec(self, mock_server):
    859     config = run_config_lib.RunConfig(master='host4:2222')
    860     for est in self._estimators_for_tests(config):
    861       with self.assertRaises(ValueError):
    862         ex = experiment.Experiment(
    863             est,
    864             train_input_fn='train_input',
    865             eval_input_fn='eval_input')
    866         ex.run_std_server()
    867 
    868   def test_test(self):
    869     for est in self._estimators_for_tests():
    870       exp_strategy = saved_model_export_utils.make_export_strategy(
    871           est,
    872           None if isinstance(est, core_estimator.Estimator) else 'export_input',
    873           exports_to_keep=None)
    874       if isinstance(est, core_estimator.Estimator):
    875         eval_metrics = None
    876         saving_listeners = 'saving_listeners'
    877       else:
    878         eval_metrics = 'eval_metrics'
    879         saving_listeners = None
    880       ex = experiment.Experiment(
    881           est,
    882           train_input_fn='train_input',
    883           eval_input_fn='eval_input',
    884           export_strategies=(exp_strategy,),
    885           eval_metrics=eval_metrics,
    886           saving_listeners=saving_listeners)
    887       ex.test()
    888       self.assertEqual(1, est.fit_count)
    889       self.assertEqual(1, est.eval_count)
    890       self.assertEqual(1, est.export_count)
    891 
    892   def test_continuous_eval_evaluates_checkpoint_once(self):
    893     for est in self._estimators_for_tests(eval_dict={'global_step': 100}):
    894       eval_metrics = 'eval_metrics' if not isinstance(
    895           est, core_estimator.Estimator) else None
    896       est.fake_checkpoint()
    897 
    898       result = {
    899           'called': 0,
    900           'called_with_eval_result': 0,
    901       }
    902       # pylint: disable=cell-var-from-loop
    903       def _predicate_fn(eval_result):
    904         result['called'] += 1
    905         if eval_result:
    906           # If eval_result is not empty nor None, the checkpoint has been
    907           # evaluated.
    908           result['called_with_eval_result'] += 1
    909         # With 300 times of evaluation, this should prove something.
    910         return result['called'] < 300
    911       # pylint: enable=cell-var-from-loop
    912 
    913       ex = experiment.Experiment(
    914           est,
    915           train_input_fn='train_input',
    916           eval_input_fn='eval_input',
    917           eval_metrics=eval_metrics,
    918           eval_delay_secs=0,
    919           continuous_eval_throttle_secs=0)
    920       ex.continuous_eval(evaluate_checkpoint_only_once=True,
    921                          continuous_eval_predicate_fn=_predicate_fn)
    922 
    923       self.assertEqual(0, est.fit_count)
    924       self.assertEqual(1, est.eval_count)
    925       self.assertEqual(300, result['called'])
    926       self.assertEqual(1, result['called_with_eval_result'])
    927 
    928   def test_checkpoint_and_export(self):
    929     model_dir = tempfile.mkdtemp()
    930     config = run_config_lib.RunConfig(save_checkpoints_steps=3)
    931     est = dnn.DNNClassifier(
    932         n_classes=3,
    933         feature_columns=[
    934             feature_column.real_valued_column('feature', dimension=4)
    935         ],
    936         hidden_units=[3, 3],
    937         model_dir=model_dir,
    938         config=config)
    939 
    940     exp_strategy = saved_model_export_utils.make_export_strategy(
    941         est, 'export_input', exports_to_keep=None)
    942 
    943     ex = experiment.Experiment(
    944         est,
    945         train_input_fn=test_data.iris_input_multiclass_fn,
    946         eval_input_fn=test_data.iris_input_multiclass_fn,
    947         export_strategies=(exp_strategy,),
    948         train_steps=8,
    949         checkpoint_and_export=True,
    950         eval_delay_secs=0)
    951 
    952     with test.mock.patch.object(ex, '_maybe_export'):
    953       with test.mock.patch.object(ex, '_call_evaluate'):
    954         ex.train_and_evaluate()
    955         # Eval and export are called after steps 1, 4, 7, and 8 (after training
    956         # is completed).
    957         self.assertEqual(ex._maybe_export.call_count, 4)
    958         self.assertEqual(ex._call_evaluate.call_count, 4)
    959 
    960 
    961 if __name__ == '__main__':
    962   test.main()
    963