Home | History | Annotate | Download | only in learn
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Experiment class collecting information needed for a single training run."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import contextlib
     22 import functools
     23 import math
     24 import os
     25 import time
     26 
     27 from tensorflow.contrib.framework import deprecated
     28 from tensorflow.contrib.framework import deprecated_args
     29 from tensorflow.contrib.framework.python.framework import experimental
     30 from tensorflow.contrib.learn.python.learn import evaluable
     31 from tensorflow.contrib.learn.python.learn import export_strategy
     32 from tensorflow.contrib.learn.python.learn import monitors
     33 from tensorflow.contrib.learn.python.learn import trainable
     34 from tensorflow.contrib.learn.python.learn.estimators import run_config
     35 from tensorflow.contrib.tpu.python.tpu import tpu_estimator
     36 from tensorflow.python.estimator import estimator as core_estimator
     37 from tensorflow.python.estimator import util as estimator_util
     38 from tensorflow.python.framework import ops
     39 from tensorflow.python.platform import tf_logging as logging
     40 from tensorflow.python.training import basic_session_run_hooks
     41 from tensorflow.python.training import saver
     42 from tensorflow.python.training import server_lib
     43 from tensorflow.python.util import compat
     44 
     45 __all__ = ["Experiment"]
     46 
     47 
     48 def _get_standardized_predicate_fn(predicate_fn):
     49   pred_fn_args = estimator_util.fn_args(predicate_fn)
     50   if "checkpoint_path" not in pred_fn_args:
     51     # pylint: disable=unused-argument
     52     def _pred_fn_wrapper(eval_results, checkpoint_path):
     53       return predicate_fn(eval_results)
     54 
     55     return _pred_fn_wrapper
     56   else:
     57     return predicate_fn
     58 
     59 
     60 class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener):
     61   """Listener that evaluates and exports a model after creating a checkpoint.
     62 
     63   The `EvalAndExportListener` waits for the associated `CheckpointSaverHook`
     64   to save a checkpoint. It then uses the provided `eval_fn` and `export_fn` to
     65   first evaluate the model using the newly-created checkpoint, and then export
     66   the model according to the `export_strategies` provided in the `Experiment`.
     67 
     68   This listener is experimental and may be changed or removed in the future.
     69   """
     70 
     71   def __init__(self, eval_fn, export_fn, model_dir):
     72     """Initializes an `EvalAndExportListener`.
     73 
     74     Args:
     75       eval_fn: function which evaluates the model with the following signature:
     76         `(name, checkpoint_path) -> eval_result`
     77       export_fn: function which exports the model according to a set of export
     78         strategies. Has the following signature:
     79         `(eval_result, checkpoint_path) -> export_results`
     80       model_dir: directory which contains estimator parameters and checkpoints.
     81     """
     82     self._eval_fn = eval_fn
     83     self._export_fn = export_fn
     84     self._model_dir = model_dir
     85     self._latest_path = None
     86     self._eval_result = None
     87     self._export_results = None
     88 
     89   def after_save(self, session, global_step_value):
     90     """Evaluates and exports the model after a checkpoint is created."""
     91     # Load and cache the path of the most recent checkpoint to avoid duplicate
     92     # searches on GCS.
     93     logging.info("Checking for checkpoint in %s", self._model_dir)
     94     latest_path = saver.latest_checkpoint(self._model_dir)
     95 
     96     if not latest_path:
     97       logging.warning("Skipping evaluation and export since model has not been "
     98                       "saved yet.")
     99     elif latest_path == self._latest_path:
    100       logging.warning("Skipping evaluation due to same latest checkpoint %s.",
    101                       latest_path)
    102     else:
    103       self._latest_path = latest_path
    104       self._eval_result = self._eval_fn(
    105           name="intermediate_export", checkpoint_path=latest_path)
    106       self._export_results = self._export_fn(
    107           self._eval_result, checkpoint_path=latest_path)
    108 
    109   @property
    110   def eval_result(self):
    111     return self._eval_result
    112 
    113   @property
    114   def export_results(self):
    115     return self._export_results
    116 
    117 
    118 class Experiment(object):
    119   """Experiment is a class containing all information needed to train a model.
    120 
    121   After an experiment is created (by passing an Estimator and inputs for
    122   training and evaluation), an Experiment instance knows how to invoke training
    123   and eval loops in a sensible fashion for distributed training.
    124   """
    125 
    126   # TODO(ispir): remove delay_workers_by_global_step and make global step based
    127   # waiting as only behavior.
    128   @deprecated_args(
    129       "2016-10-23",
    130       "local_eval_frequency is deprecated as local_run will be renamed to "
    131       "train_and_evaluate. Use min_eval_frequency and call train_and_evaluate "
    132       "instead. Note, however, that the default for min_eval_frequency is 1, "
    133       "meaning models will be evaluated every time a new checkpoint is "
    134       "available. In contrast, the default for local_eval_frequency is None, "
    135       "resulting in evaluation occurring only after training has completed. "
    136       "min_eval_frequency is ignored when calling the deprecated local_run.",
    137       "local_eval_frequency")
    138   def __init__(self,
    139                estimator,
    140                train_input_fn,
    141                eval_input_fn,
    142                eval_metrics=None,
    143                train_steps=None,
    144                eval_steps=100,
    145                train_monitors=None,
    146                eval_hooks=None,
    147                local_eval_frequency=None,
    148                eval_delay_secs=120,
    149                continuous_eval_throttle_secs=60,
    150                min_eval_frequency=None,
    151                delay_workers_by_global_step=False,
    152                export_strategies=None,
    153                train_steps_per_iteration=None,
    154                checkpoint_and_export=False,
    155                saving_listeners=None):
    156     """Constructor for `Experiment`.
    157 
    158     Creates an Experiment instance. None of the functions passed to this
    159     constructor are executed at construction time. They are stored and used
    160     when a method is executed which requires it.
    161 
    162     Args:
    163       estimator: Object implementing Estimator interface, which could be a
    164         combination of @{tf.contrib.learn.Trainable} and
    165         @{tf.contrib.learn.Evaluable} (deprecated), or
    166         @{tf.estimator.Estimator}.
    167       train_input_fn: function, returns features and labels for training.
    168       eval_input_fn: function, returns features and labels for evaluation. If
    169         `eval_steps` is `None`, this should be configured only to produce for a
    170         finite number of batches (generally, 1 epoch over the evaluation data).
    171       eval_metrics: `dict` of string, metric function. If `None`, default set
    172         is used. This should be `None` if the `estimator` is
    173         @{tf.estimator.Estimator}. If metrics are provided they will be
    174         *appended* to the default set.
    175       train_steps: Perform this many steps of training. `None`, the default,
    176         means train forever.
    177       eval_steps: `evaluate` runs until input is exhausted (or another exception
    178         is raised), or for `eval_steps` steps, if specified.
    179       train_monitors: A list of monitors to pass to the `Estimator`'s `fit`
    180         function.
    181       eval_hooks: A list of `SessionRunHook` hooks to pass to the
    182         `Estimator`'s `evaluate` function.
    183       local_eval_frequency: (applies only to local_run) Frequency of running
    184         eval in steps. If `None`, runs evaluation only at the end of training.
    185       eval_delay_secs: Start evaluating after waiting for this many seconds.
    186       continuous_eval_throttle_secs: Do not re-evaluate unless the last
    187         evaluation was started at least this many seconds ago for
    188         continuous_eval().
    189       min_eval_frequency: (applies only to train_and_evaluate). the minimum
    190         number of steps between evaluations. Of course, evaluation does not
    191         occur if no new snapshot is available, hence, this is the minimum.
    192         If 0, the evaluation will only happen after training.
    193         If None, defaults to 1, unless model_dir is on GCS, in which case the
    194         default is 1000.
    195       delay_workers_by_global_step: if `True` delays training workers
    196         based on global step instead of time.
    197       export_strategies: Iterable of `ExportStrategy`s, or a single one, or
    198         `None`.
    199       train_steps_per_iteration: (applies only to continuous_train_and_eval).
    200         Perform this many (integer) number of train steps for each
    201         training-evaluation iteration. With a small value, the model will be
    202         evaluated more frequently with more checkpoints saved. If `None`, will
    203         use a default value (which is smaller than `train_steps` if provided).
    204       checkpoint_and_export: (applies only to train_and_evaluate). If `True`,
    205         performs intermediate model checkpoints and exports during the training
    206         process, rather than only once model training is complete. This
    207         parameter is experimental and may be changed or removed in the future.
    208         Setting this parameter leads to the following: the value of
    209         `min_eval_frequency` will be ignored, and the number of steps between
    210         evaluations and exports will instead be determined by the Estimator
    211         configuration parameters `save_checkpoints_secs` and
    212         `save_checkpoints_steps`. Also, this parameter leads to the creation of
    213         a default `CheckpointSaverHook` instead of a `ValidationMonitor`, so the
    214         provided `train_monitors` will need to be adjusted accordingly.
    215       saving_listeners: list of `CheckpointSaverListener` objects. Used by
    216         tf.estimator.Estimator for callbacks that run immediately before or
    217         after checkpoint savings.
    218 
    219     Raises:
    220       ValueError: if `estimator` does not implement Estimator interface,
    221         or if export_strategies has the wrong type.
    222     """
    223     if isinstance(estimator, core_estimator.Estimator):
    224       self._core_estimator_used = True
    225       if eval_metrics is not None:
    226         raise ValueError(
    227             "`eval_metrics` must be `None` with `tf.estimator.Estimator`. "
    228             "Use `eval_metric_ops` in `tf.estimator.EstimatorSpec` instead.")
    229     else:
    230       self._core_estimator_used = False
    231       if not isinstance(estimator, evaluable.Evaluable):
    232         raise ValueError(
    233             "`estimator` must implement `tf.contrib.learn.Evaluable` "
    234             "or `tf.estimator.Estimator`.")
    235       if not isinstance(estimator, trainable.Trainable):
    236         raise ValueError(
    237             "`estimator` must implement `tf.contrib.learn.Trainable`"
    238             "or `tf.estimator.`Estimator`.")
    239       if saving_listeners is not None:
    240         raise ValueError("`saving_listeners` must be `None` with "
    241                          "`tf.contrib.learn.Estimator`.")
    242 
    243     if isinstance(estimator, tpu_estimator.TPUEstimator):
    244       logging.warn(
    245           "`Experiment` class cannot work with `tf.contrib.tpu.TPUEstimator`. "
    246           "Please call `TPUEstimator` train/evaluate directly. \n"
    247           "Details: `Experiment` class is designed for between-graph "
    248           "distributed training, while `TPUEstimator` is working in in-graph "
    249           "distributed mode. Use with care.")
    250 
    251     super(Experiment, self).__init__()
    252     # Immutable fields.
    253     self._estimator = estimator
    254     self._train_input_fn = train_input_fn
    255     self._eval_input_fn = eval_input_fn
    256     self._eval_metrics = eval_metrics
    257     self._train_steps = train_steps
    258     self._eval_steps = eval_steps
    259     self._local_eval_frequency = local_eval_frequency
    260     self._eval_delay_secs = eval_delay_secs
    261     self._continuous_eval_throttle_secs = continuous_eval_throttle_secs
    262     self._checkpoint_and_export = checkpoint_and_export
    263     self._saving_listeners = saving_listeners
    264     # Using 1 on a non-cached file system requires a lot of overhead to
    265     # read the checkpoint state file. This is particular bad on GCS, so
    266     # we use a different default. This is a temporary band-aid, to be
    267     # fixed holistically later (b/36498507).
    268     default_min_eval_frequency = 1000 if _is_gcs(estimator.model_dir) else 1
    269     self._min_eval_frequency = min_eval_frequency if (
    270         min_eval_frequency is not None) else default_min_eval_frequency
    271     self._delay_workers_by_global_step = delay_workers_by_global_step
    272     self._train_monitors = train_monitors[:] if train_monitors else []
    273     self._eval_hooks = eval_hooks[:] if eval_hooks else []
    274     self._set_export_strategies(export_strategies)
    275 
    276     self._train_steps_per_iteration = train_steps_per_iteration
    277     if (self._train_steps_per_iteration is not None and
    278         not isinstance(self._train_steps_per_iteration, int)):
    279       raise ValueError("`train_steps_per_iteration` must be an integer.")
    280 
    281   @property
    282   def estimator(self):
    283     return self._estimator
    284 
    285   @property
    286   def eval_metrics(self):
    287     return self._eval_metrics
    288 
    289   @property
    290   def train_steps(self):
    291     return self._train_steps
    292 
    293   @property
    294   def eval_steps(self):
    295     return self._eval_steps
    296 
    297   def _set_export_strategies(self, values):  # pylint: disable=missing-docstring
    298     export_strategies = []
    299     if values:
    300       if isinstance(values, export_strategy.ExportStrategy):
    301         export_strategies.append(values)
    302       else:
    303         for value in values:
    304           if not isinstance(value, export_strategy.ExportStrategy):
    305             raise ValueError("`export_strategies` must be an ExportStrategy,"
    306                              " an iterable of ExportStrategy, or `None`,"
    307                              " found %s." % value)
    308           export_strategies.append(value)
    309     self._export_strategies = tuple(export_strategies)
    310 
    311   def extend_train_hooks(self, additional_hooks):
    312     """Extends the hooks for training."""
    313     self._train_monitors.extend(additional_hooks)
    314 
    315   def reset_export_strategies(self, new_export_strategies=None):
    316     """Resets the export strategies with the `new_export_strategies`.
    317 
    318     Args:
    319       new_export_strategies: A new list of `ExportStrategy`s, or a single one,
    320         or None.
    321 
    322     Returns:
    323       The old export strategies.
    324     """
    325     old_export_strategies = self._export_strategies
    326     self._set_export_strategies(new_export_strategies)
    327     return old_export_strategies
    328 
    329   def train(self, delay_secs=None):
    330     """Fit the estimator using the training data.
    331 
    332     Train the estimator for `self._train_steps` steps, after waiting for
    333     `delay_secs` seconds. If `self._train_steps` is `None`, train forever.
    334 
    335     Args:
    336       delay_secs: Start training after this many seconds.
    337 
    338     Returns:
    339       The trained estimator.
    340     """
    341     start = time.time()
    342 
    343     # Start the server, if needed. It's important to start the server before
    344     # we (optionally) sleep for the case where no device_filters are set.
    345     # Otherwise, the servers will wait to connect to each other before starting
    346     # to train. We might as well start as soon as we can.
    347     config = self._estimator.config
    348     if isinstance(config, run_config.RunConfig):
    349       if (config.cluster_spec and config.master and
    350           config.environment == run_config.Environment.LOCAL):
    351         logging.warn("ClusterSpec and master are provided, but environment is "
    352                      "set to 'local'. Set environment to 'cloud' if you intend "
    353                      "to use the distributed runtime.")
    354       if (config.environment != run_config.Environment.LOCAL and
    355           config.environment != run_config.Environment.GOOGLE and
    356           config.cluster_spec and config.master):
    357         self._start_server()
    358     elif config.cluster_spec and config.master:
    359       raise ValueError(
    360           "For distributed runtime, Experiment class only works with"
    361           "tf.contrib.learn.RunConfig for now, but provided {}".format(
    362               type(config)))
    363 
    364     extra_hooks = []
    365     if delay_secs is None:
    366       task_id = self._estimator.config.task_id or 0
    367       if self._delay_workers_by_global_step:
    368         # Wait 5500 global steps for the second worker. Each worker waits more
    369         # then previous one but with a diminishing number of steps.
    370         extra_hooks.append(
    371             basic_session_run_hooks.GlobalStepWaiterHook(
    372                 int(8000.0 * math.log(task_id + 1))))
    373         delay_secs = 0
    374       else:
    375         # Wait 5 secs more for each new worker up to 60 secs.
    376         delay_secs = min(60, task_id * 5)
    377 
    378     if delay_secs > 0:
    379       elapsed_secs = time.time() - start
    380       remaining = delay_secs - elapsed_secs
    381       logging.info("Waiting %d secs before starting training.", remaining)
    382       time.sleep(delay_secs)
    383 
    384     return self._call_train(
    385         input_fn=self._train_input_fn,
    386         max_steps=self._train_steps,
    387         hooks=self._train_monitors + extra_hooks,
    388         saving_listeners=self._saving_listeners)
    389 
    390   def evaluate(self, delay_secs=None, name=None):
    391     """Evaluate on the evaluation data.
    392 
    393     Runs evaluation on the evaluation data and returns the result. Runs for
    394     `self._eval_steps` steps, or if it's `None`, then run until input is
    395     exhausted or another exception is raised. Start the evaluation after
    396     `delay_secs` seconds, or if it's `None`, defaults to using
    397     `self._eval_delay_secs` seconds.
    398 
    399     Args:
    400       delay_secs: Start evaluating after this many seconds. If `None`, defaults
    401         to using `self._eval_delays_secs`.
    402       name: Gives the name to the evauation for the case multiple evaluation is
    403         run for the same experiment.
    404 
    405     Returns:
    406       The result of the `evaluate` call to the `Estimator`.
    407     """
    408     if delay_secs is None:
    409       delay_secs = self._eval_delay_secs
    410 
    411     if delay_secs:
    412       logging.info("Waiting %d secs before starting eval.", delay_secs)
    413       time.sleep(delay_secs)
    414 
    415     return self._call_evaluate(
    416         input_fn=self._eval_input_fn,
    417         steps=self._eval_steps,
    418         metrics=self._eval_metrics,
    419         name=(name or "one_pass"),
    420         hooks=self._eval_hooks)
    421 
    422   @deprecated(
    423       "2016-10-23",
    424       "local_run will be renamed to train_and_evaluate and the new default "
    425       "behavior will be to run evaluation every time there is a new "
    426       "checkpoint.")
    427   def local_run(self):
    428     with _new_attr_context(self, "_min_eval_frequency"):
    429       self._min_eval_frequency = self._local_eval_frequency
    430       return self.train_and_evaluate()
    431 
    432   # TODO(xiejw): Allow continuous_eval_predicate_fn to be passed via constructor
    433   # once stopping all jobs is implemented.
    434   def _continuous_eval(self,
    435                        input_fn,
    436                        name,
    437                        delay_secs,
    438                        throttle_delay_secs,
    439                        evaluate_checkpoint_only_once=True,
    440                        continuous_eval_predicate_fn=None,
    441                        export=True):
    442     """Run continuous eval.
    443 
    444     Runs infinite eval on the evaluation data set. This function starts
    445     evaluating after `delay_secs` seconds and then runs no more than one
    446     evaluation (with `self._eval_steps` steps each time) per
    447     `throttle_delay_secs`. If `train_steps` is not None, will return after
    448     global_step reaches `train_steps`.
    449 
    450     Args:
    451       input_fn: The input to use for this eval.
    452       name: A string appended to the folder name of evaluation results.
    453       delay_secs: Start evaluating after this many seconds. If None, defaults to
    454         self._eval_delay_secs.
    455       throttle_delay_secs: Do not re-evaluate unless the last evaluation was
    456         started at least this many seconds ago. If None, defaults to
    457         self._continuous_eval_throttle_secs.
    458       evaluate_checkpoint_only_once: Whether to skip evaluation of checkpoints
    459         that have already been evaluated. Default is `True`.
    460       continuous_eval_predicate_fn: A predicate function determining whether to
    461         continue eval after each iteration. A `predicate_fn` has one of the
    462         following signatures:
    463           * (eval_results) -> boolean
    464           * (eval_results, checkpoint_path) -> boolean
    465         Where `eval_results` is the dictionary of metric evaluations and
    466         checkpoint_path is the path to the checkpoint containing the parameters
    467         on which that evaluation was based.
    468         At the beginning of evaluation, the passed `eval_results` will be None
    469         so it's expected that the predicate function handles that gracefully.
    470         When `predicate_fn` is not specified, continuous eval will run in an
    471         infinite loop (if `train_steps` is None). or exit once global step
    472         reaches `train_steps`.
    473 
    474       export: Whether to export from this step. Default is 'True'.
    475 
    476     Raises:
    477       ValueError: if `continuous_eval_predicate_fn` is neither None nor
    478         callable.
    479     """
    480     if continuous_eval_predicate_fn is not None:
    481       if not callable(continuous_eval_predicate_fn):
    482         raise ValueError(
    483             "`continuous_eval_predicate_fn` must be a callable, or None.")
    484       predicate_fn = _get_standardized_predicate_fn(
    485           continuous_eval_predicate_fn)
    486     else:
    487       predicate_fn = None
    488 
    489     if delay_secs is None:
    490       delay_secs = self._eval_delay_secs
    491     if throttle_delay_secs is None:
    492       throttle_delay_secs = self._continuous_eval_throttle_secs
    493 
    494     if delay_secs:
    495       logging.info("Waiting %f secs before starting eval.", delay_secs)
    496       time.sleep(delay_secs)
    497 
    498     previous_path = None
    499     eval_result = None
    500     last_warning_time = 0
    501     while (not predicate_fn or predicate_fn(
    502         eval_result, checkpoint_path=previous_path if eval_result else None)):
    503       # Exit if we have already reached number of steps to train.
    504       if self._has_training_stopped(eval_result):
    505         logging.info("Exiting continuous eval, global_step=%s >= "
    506                      "train_step=%s", eval_result[ops.GraphKeys.GLOBAL_STEP],
    507                      self._train_steps)
    508         return
    509 
    510       start = time.time()
    511 
    512       error_msg = None
    513       latest_path = saver.latest_checkpoint(self._estimator.model_dir)
    514       if not latest_path:
    515         error_msg = ("Estimator is not fitted yet. "
    516                      "Will start an evaluation when a checkpoint is ready.")
    517       elif evaluate_checkpoint_only_once and latest_path == previous_path:
    518         error_msg = "No new checkpoint ready for evaluation."
    519 
    520       if error_msg:
    521         # Print warning message every 10 mins.
    522         eval_result = {}
    523         if time.time() - last_warning_time > 600:
    524           logging.warning(error_msg)
    525           last_warning_time = time.time()
    526       else:
    527         eval_result = self._call_evaluate(
    528             input_fn=input_fn,
    529             steps=self._eval_steps,
    530             metrics=self._eval_metrics,
    531             name=name,
    532             checkpoint_path=latest_path,
    533             hooks=self._eval_hooks)
    534         # Ensure eval result is not None for next round of evaluation.
    535         if not eval_result:
    536           eval_result = {}
    537 
    538         if export:
    539           self._maybe_export(eval_result, checkpoint_path=latest_path)
    540 
    541         # Clear warning timer and update last evaluated checkpoint
    542         last_warning_time = 0
    543         previous_path = latest_path
    544 
    545       duration = time.time() - start
    546       if duration < throttle_delay_secs:
    547         difference = throttle_delay_secs - duration
    548         logging.info("Waiting %f secs before starting next eval run.",
    549                      difference)
    550         time.sleep(difference)
    551 
    552   def _has_training_stopped(self, eval_result):
    553     """Determines whether the training has stopped."""
    554     if not eval_result:
    555       return False
    556 
    557     global_step = eval_result.get(ops.GraphKeys.GLOBAL_STEP)
    558     return global_step and self._train_steps and (global_step >=
    559                                                   self._train_steps)
    560 
    561   def continuous_eval(self,
    562                       delay_secs=None,
    563                       throttle_delay_secs=None,
    564                       evaluate_checkpoint_only_once=True,
    565                       continuous_eval_predicate_fn=None,
    566                       name="continuous"):
    567     self._continuous_eval(
    568         self._eval_input_fn,
    569         name=name,
    570         delay_secs=delay_secs,
    571         throttle_delay_secs=throttle_delay_secs,
    572         evaluate_checkpoint_only_once=evaluate_checkpoint_only_once,
    573         continuous_eval_predicate_fn=continuous_eval_predicate_fn)
    574 
    575   def continuous_eval_on_train_data(self,
    576                                     delay_secs=None,
    577                                     throttle_delay_secs=None,
    578                                     continuous_eval_predicate_fn=None,
    579                                     name="continuous_on_train_data"):
    580     self._continuous_eval(
    581         self._train_input_fn,
    582         name=name,
    583         delay_secs=delay_secs,
    584         throttle_delay_secs=throttle_delay_secs,
    585         continuous_eval_predicate_fn=continuous_eval_predicate_fn,
    586         export=False)
    587 
    588   def train_and_evaluate(self):
    589     """Interleaves training and evaluation.
    590 
    591     The frequency of evaluation is controlled by the constructor arg
    592     `min_eval_frequency`. When this parameter is 0, evaluation happens
    593     only after training has completed. Note that evaluation cannot happen
    594     more frequently than checkpoints are taken. If no new snapshots are
    595     available when evaluation is supposed to occur, then evaluation doesn't
    596     happen for another `min_eval_frequency` steps (assuming a checkpoint is
    597     available at that point). Thus, settings `min_eval_frequency` to 1 means
    598     that the model will be evaluated everytime there is a new checkpoint.
    599 
    600     This is particular useful for a "Master" task in the cloud, whose
    601     responsibility it is to take checkpoints, evaluate those checkpoints,
    602     and write out summaries. Participating in training as the supervisor
    603     allows such a task to accomplish the first and last items, while
    604     performing evaluation allows for the second.
    605 
    606     Returns:
    607       The result of the `evaluate` call to the `Estimator` as well as the
    608       export results using the specified `ExportStrategy`.
    609     """
    610     # The directory to which evaluation summaries are written are determined
    611     # by adding a suffix to 'eval'; that suffix is the 'name' parameter to
    612     # the various evaluate(...) methods. By setting it to None, we force
    613     # the directory name to simply be 'eval'.
    614     eval_dir_suffix = None
    615 
    616     # We set every_n_steps to 1, but evaluation only occurs when a new
    617     # snapshot is available. If, by the time we finish evaluation
    618     # there is a new snapshot, then we just evaluate again. Otherwise,
    619     # we keep training until one becomes available.
    620     with _new_attr_context(self, "_train_monitors"):
    621       self._train_monitors = self._train_monitors or []
    622       config = self._estimator.config
    623       intermediate_export = self._checkpoint_and_export and (
    624           config.save_checkpoints_secs or config.save_checkpoints_steps)
    625       if intermediate_export:
    626         # Create a partially specified evaluate function with the desired
    627         # arguments. This will be executed by the _EvalAndExportListener,
    628         # which will specify the latest checkpoint path.
    629         eval_fn = functools.partial(
    630             self._call_evaluate,
    631             input_fn=self._eval_input_fn,
    632             steps=self._eval_steps,
    633             metrics=self._eval_metrics,
    634             hooks=self._eval_hooks)
    635 
    636         export_listener = _EvalAndExportListener(
    637             eval_fn=eval_fn,
    638             export_fn=self._maybe_export,
    639             model_dir=self._estimator.model_dir)
    640 
    641         saver_hook = basic_session_run_hooks.CheckpointSaverHook(
    642             checkpoint_dir=self._estimator.model_dir,
    643             save_secs=config.save_checkpoints_secs,
    644             save_steps=config.save_checkpoints_steps,
    645             listeners=[export_listener])
    646         self._train_monitors += [saver_hook]
    647       else:
    648         if self._min_eval_frequency:
    649           self._train_monitors += [
    650               monitors.ValidationMonitor(
    651                   input_fn=self._eval_input_fn,
    652                   eval_steps=self._eval_steps,
    653                   metrics=self._eval_metrics,
    654                   every_n_steps=self._min_eval_frequency,
    655                   name=eval_dir_suffix,
    656                   hooks=self._eval_hooks)
    657           ]
    658       self.train(delay_secs=0)
    659 
    660     # If the checkpoint_and_export flag and appropriate estimator configuration
    661     # parameters are set, then model evaluations and exports are done during the
    662     # training process. In particular, this will always occur at the end of
    663     # training, so we return the most recent results to avoid performing a
    664     # duplicate evaluation and model export.
    665     if intermediate_export:
    666       return export_listener.eval_result, export_listener.export_results
    667     else:
    668       eval_result = self._call_evaluate(
    669           input_fn=self._eval_input_fn,
    670           steps=self._eval_steps,
    671           metrics=self._eval_metrics,
    672           name=eval_dir_suffix,
    673           hooks=self._eval_hooks)
    674       export_results = self._maybe_export(eval_result)
    675       return eval_result, export_results
    676 
    677   @experimental
    678   def continuous_train_and_eval(self, continuous_eval_predicate_fn=None):
    679     """Interleaves training and evaluation.
    680 
    681     The frequency of evaluation is controlled by the `train_steps_per_iteration`
    682     (via constructor). The model will be first trained for
    683     `train_steps_per_iteration`, and then be evaluated in turns.
    684 
    685     This method is intended for single machine usage.
    686 
    687     This differs from `train_and_evaluate` as follows:
    688 
    689       1. The procedure will have train and evaluation in turns. The model
    690       will be trained for a number of steps (usually smaller than `train_steps`
    691       if provided) and then be evaluated.  `train_and_evaluate` will train the
    692       model for `train_steps` (no small training iterations).
    693 
    694       2. Due to the different approach this schedule takes, it leads to two
    695       differences in resource control. First, the resources (e.g., memory) used
    696       by training will be released before evaluation (`train_and_evaluate` takes
    697       double resources). Second, more checkpoints will be saved as a checkpoint
    698       is generated at the end of each training iteration.
    699 
    700       3. As the estimator.train starts from scratch (new graph, new states for
    701       input, etc) at each iteration, it is recommended to have the
    702       `train_steps_per_iteration` larger. It is also recommended to shuffle your
    703       input.
    704 
    705     Args:
    706       continuous_eval_predicate_fn: A predicate function determining whether to
    707         continue eval after each iteration. A `predicate_fn` has one of the
    708         following signatures:
    709           * (eval_results) -> boolean
    710           * (eval_results, checkpoint_path) -> boolean
    711         Where `eval_results` is the dictionary of metric evaluations and
    712         checkpoint_path is the path to the checkpoint containing the parameters
    713         on which that evaluation was based.
    714         At the beginning of evaluation, the passed `eval_results` and
    715         `checkpoint_path` will be None so it's expected that the predicate
    716         function handles that gracefully.
    717         When `predicate_fn` is not specified, continuous eval will run in an
    718         infinite loop (if `train_steps` is None). or exit once global step
    719         reaches `train_steps`.
    720 
    721     Returns:
    722       A tuple of the result of the `evaluate` call to the `Estimator` and the
    723       export results using the specified `ExportStrategy`.
    724 
    725     Raises:
    726       ValueError: if `continuous_eval_predicate_fn` is neither None nor
    727         callable.
    728     """
    729 
    730     if continuous_eval_predicate_fn is not None:
    731       if not callable(continuous_eval_predicate_fn):
    732         raise ValueError(
    733             "`continuous_eval_predicate_fn` must be a callable, or None.")
    734       predicate_fn = _get_standardized_predicate_fn(
    735           continuous_eval_predicate_fn)
    736     else:
    737       predicate_fn = None
    738 
    739     export_results = None
    740     latest_checkpoint = None
    741     eval_result = None
    742 
    743     # Set the default value for train_steps_per_iteration, which will be
    744     # overridden by other settings.
    745     train_steps_per_iteration = 1000
    746     if self._train_steps_per_iteration is not None:
    747       train_steps_per_iteration = self._train_steps_per_iteration
    748     elif self._train_steps is not None:
    749       train_steps_per_iteration = int(self._train_steps / 10)
    750 
    751     while (not predicate_fn or predicate_fn(
    752         eval_result, checkpoint_path=latest_checkpoint
    753         if eval_result else None)):
    754 
    755       if self._has_training_stopped(eval_result):
    756         # Exits once max steps of training is satisfied.
    757         logging.info("Stop training model as max steps reached")
    758         break
    759 
    760       logging.info("Training model for %s steps", train_steps_per_iteration)
    761       self._call_train(
    762           input_fn=self._train_input_fn,
    763           steps=train_steps_per_iteration,
    764           hooks=self._train_monitors,
    765           saving_listeners=self._saving_listeners)
    766 
    767       logging.info("Evaluating model now.")
    768       latest_checkpoint = saver.latest_checkpoint(self._estimator.model_dir)
    769       eval_result = self._call_evaluate(
    770           input_fn=self._eval_input_fn,
    771           steps=self._eval_steps,
    772           metrics=self._eval_metrics,
    773           name="one_pass",
    774           checkpoint_path=latest_checkpoint,
    775           hooks=self._eval_hooks)
    776       export_results = self._maybe_export(eval_result)
    777 
    778     return eval_result, export_results
    779 
    780   def _maybe_export(self, eval_result, checkpoint_path=None):
    781     """Export the Estimator using export_fn, if defined."""
    782     export_dir_base = os.path.join(
    783         compat.as_bytes(self._estimator.model_dir), compat.as_bytes("export"))
    784 
    785     export_results = []
    786     for strategy in self._export_strategies:
    787       export_results.append(
    788           strategy.export(
    789               self._estimator,
    790               os.path.join(
    791                   compat.as_bytes(export_dir_base),
    792                   compat.as_bytes(strategy.name)),
    793               checkpoint_path=checkpoint_path,
    794               eval_result=eval_result))
    795 
    796     return export_results
    797 
    798   def run_std_server(self):
    799     """Starts a TensorFlow server and joins the serving thread.
    800 
    801     Typically used for parameter servers.
    802 
    803     Raises:
    804       ValueError: if not enough information is available in the estimator's
    805         config to create a server.
    806     """
    807     self._start_server().join()
    808 
    809   def test(self):
    810     """Tests training, evaluating and exporting the estimator for a single step.
    811 
    812     Returns:
    813       The result of the `evaluate` call to the `Estimator`.
    814     """
    815     self._call_train(
    816         input_fn=self._train_input_fn,
    817         steps=1,
    818         hooks=self._train_monitors,
    819         saving_listeners=self._saving_listeners)
    820 
    821     eval_result = self._call_evaluate(
    822         input_fn=self._eval_input_fn,
    823         steps=1,
    824         metrics=self._eval_metrics,
    825         name="one_pass")
    826     _ = self._maybe_export(eval_result)
    827 
    828     return eval_result
    829 
    830   def _start_server(self):
    831     """Creates, starts, and returns a server_lib.Server."""
    832     config = self._estimator.config
    833     if (not config.cluster_spec or not config.task_type or not config.master or
    834         config.task_id is None):
    835       raise ValueError("Could not start server; be sure to specify "
    836                        "cluster_spec, task_type, master, and task in "
    837                        "RunConfig or set the TF_CONFIG environment variable.")
    838     server = server_lib.Server(
    839         config.cluster_spec,
    840         job_name=config.task_type,
    841         task_index=config.task_id,
    842         config=config.tf_config,
    843         start=False)
    844     server.start()
    845     return server
    846 
    847   def _call_train(
    848       self,
    849       _sentinel=None,  # pylint: disable=invalid-name,
    850       input_fn=None,
    851       steps=None,
    852       hooks=None,
    853       max_steps=None,
    854       saving_listeners=None):
    855     if _sentinel is not None:
    856       raise ValueError("_call_train should be called with keyword args only")
    857 
    858     # Estimator in core cannot work with monitors. We need to convert them
    859     # to hooks. For Estimator in contrib, it is converted internally. So, it is
    860     # safe to convert for both cases.
    861     hooks = monitors.replace_monitors_with_hooks(hooks, self._estimator)
    862     if self._core_estimator_used:
    863       return self._estimator.train(
    864           input_fn=input_fn,
    865           steps=steps,
    866           max_steps=max_steps,
    867           hooks=hooks,
    868           saving_listeners=saving_listeners)
    869     else:
    870       return self._estimator.fit(
    871           input_fn=input_fn, steps=steps, max_steps=max_steps, monitors=hooks)
    872 
    873   def _call_evaluate(
    874       self,
    875       _sentinel=None,  # pylint: disable=invalid-name,
    876       input_fn=None,
    877       steps=None,
    878       metrics=None,
    879       name=None,
    880       checkpoint_path=None,
    881       hooks=None):
    882     if _sentinel is not None:
    883       raise ValueError("_call_evaluate should be called with keyword args only")
    884 
    885     if self._core_estimator_used:
    886       if metrics is not None:
    887         raise ValueError(
    888             "`eval_metrics` must be `None` with `tf.estimator.Estimator`")
    889       return self._estimator.evaluate(
    890           input_fn=input_fn,
    891           steps=steps,
    892           name=name,
    893           checkpoint_path=checkpoint_path,
    894           hooks=hooks)
    895     else:
    896       return self._estimator.evaluate(
    897           input_fn=input_fn,
    898           steps=steps,
    899           metrics=metrics,
    900           name=name,
    901           checkpoint_path=checkpoint_path,
    902           hooks=hooks)
    903 
    904 
    905 @contextlib.contextmanager
    906 def _new_attr_context(obj, attr):
    907   """Creates a new context in which an object's attribute can be changed.
    908 
    909   This creates a context in which an object's attribute can be changed.
    910   Once the context is exited, the attribute reverts to its original value.
    911 
    912   Args:
    913     obj: An object whose attribute to restore at the end of the context.
    914     attr: An attribute to remember and restore at the end of the context.
    915 
    916   Yields:
    917     Context.
    918 
    919   Example:
    920     my_obj.x = 1
    921     with _new_attr_context(my_obj, "x"):
    922       my_obj.x = 2
    923       print(my_obj.x)
    924     print(my_obj.x)
    925   """
    926   saved = getattr(obj, attr)
    927   try:
    928     yield
    929   finally:
    930     setattr(obj, attr, saved)
    931 
    932 
    933 def _is_gcs(model_dir):
    934   return model_dir and model_dir.startswith("gs://")
    935