Home | History | Annotate | Download | only in estimator
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Base Estimator class."""
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import copy
     23 import os
     24 import tempfile
     25 
     26 import numpy as np
     27 import six
     28 
     29 from google.protobuf import message
     30 from tensorflow.core.framework import summary_pb2
     31 from tensorflow.core.protobuf import config_pb2
     32 from tensorflow.python.client import session as tf_session
     33 from tensorflow.python.data.ops import dataset_ops
     34 from tensorflow.python.eager import context
     35 from tensorflow.python.estimator import model_fn as model_fn_lib
     36 from tensorflow.python.estimator import run_config
     37 from tensorflow.python.estimator import util
     38 from tensorflow.python.estimator import warm_starting_util
     39 from tensorflow.python.estimator.export.export import build_all_signature_defs
     40 from tensorflow.python.estimator.export.export import get_temp_export_dir
     41 from tensorflow.python.estimator.export.export import get_timestamped_export_dir
     42 from tensorflow.python.framework import ops
     43 from tensorflow.python.framework import random_seed
     44 from tensorflow.python.ops import control_flow_ops
     45 from tensorflow.python.ops import metrics as metrics_lib
     46 from tensorflow.python.platform import gfile
     47 from tensorflow.python.platform import tf_logging as logging
     48 from tensorflow.python.saved_model import builder as saved_model_builder
     49 from tensorflow.python.saved_model import tag_constants
     50 from tensorflow.python.summary import summary
     51 from tensorflow.python.summary.writer import writer_cache
     52 from tensorflow.python.training import evaluation
     53 from tensorflow.python.training import monitored_session
     54 from tensorflow.python.training import saver
     55 from tensorflow.python.training import training
     56 from tensorflow.python.training import training_util
     57 from tensorflow.python.util import compat
     58 from tensorflow.python.util import compat_internal
     59 from tensorflow.python.util import nest
     60 from tensorflow.python.util.tf_export import tf_export
     61 
     62 
     63 _VALID_MODEL_FN_ARGS = set(
     64     ['features', 'labels', 'mode', 'params', 'self', 'config'])
     65 
     66 
     67 @tf_export('estimator.Estimator')
     68 class Estimator(object):
     69   """Estimator class to train and evaluate TensorFlow models.
     70 
     71   The `Estimator` object wraps a model which is specified by a `model_fn`,
     72   which, given inputs and a number of other parameters, returns the ops
     73   necessary to perform training, evaluation, or predictions.
     74 
     75   All outputs (checkpoints, event files, etc.) are written to `model_dir`, or a
     76   subdirectory thereof. If `model_dir` is not set, a temporary directory is
     77   used.
     78 
     79   The `config` argument can be passed `RunConfig` object containing information
     80   about the execution environment. It is passed on to the `model_fn`, if the
     81   `model_fn` has a parameter named "config" (and input functions in the same
     82   manner). If the `config` parameter is not passed, it is instantiated by the
     83   `Estimator`. Not passing config means that defaults useful for local execution
     84   are used. `Estimator` makes config available to the model (for instance, to
     85   allow specialization based on the number of workers available), and also uses
     86   some of its fields to control internals, especially regarding checkpointing.
     87 
     88   The `params` argument contains hyperparameters. It is passed to the
     89   `model_fn`, if the `model_fn` has a parameter named "params", and to the input
     90   functions in the same manner. `Estimator` only passes params along, it does
     91   not inspect it. The structure of `params` is therefore entirely up to the
     92   developer.
     93 
     94   None of `Estimator`'s methods can be overridden in subclasses (its
     95   constructor enforces this). Subclasses should use `model_fn` to configure
     96   the base class, and may add methods implementing specialized functionality.
     97 
     98   @compatibility(eager)
     99   Estimators are not compatible with eager execution.
    100   @end_compatibility
    101   """
    102 
    103   def __init__(self, model_fn, model_dir=None, config=None, params=None,
    104                warm_start_from=None):
    105     """Constructs an `Estimator` instance.
    106 
    107     See @{$estimators} for more information. To warm-start an `Estimator`:
    108 
    109     ```python
    110     estimator = tf.estimator.DNNClassifier(
    111         feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
    112         hidden_units=[1024, 512, 256],
    113         warm_start_from="/path/to/checkpoint/dir")
    114     ```
    115 
    116     For more details on warm-start configuration, see
    117     @{tf.estimator.WarmStartSettings$WarmStartSettings}.
    118 
    119     Args:
    120       model_fn: Model function. Follows the signature:
    121 
    122         * Args:
    123 
    124           * `features`: This is the first item returned from the `input_fn`
    125                  passed to `train`, `evaluate`, and `predict`. This should be a
    126                  single `Tensor` or `dict` of same.
    127           * `labels`: This is the second item returned from the `input_fn`
    128                  passed to `train`, `evaluate`, and `predict`. This should be a
    129                  single `Tensor` or `dict` of same (for multi-head models). If
    130                  mode is `ModeKeys.PREDICT`, `labels=None` will be passed. If
    131                  the `model_fn`'s signature does not accept `mode`, the
    132                  `model_fn` must still be able to handle `labels=None`.
    133           * `mode`: Optional. Specifies if this training, evaluation or
    134                  prediction. See `ModeKeys`.
    135           * `params`: Optional `dict` of hyperparameters.  Will receive what
    136                  is passed to Estimator in `params` parameter. This allows
    137                  to configure Estimators from hyper parameter tuning.
    138           * `config`: Optional configuration object. Will receive what is passed
    139                  to Estimator in `config` parameter, or the default `config`.
    140                  Allows updating things in your model_fn based on configuration
    141                  such as `num_ps_replicas`, or `model_dir`.
    142 
    143         * Returns:
    144           `EstimatorSpec`
    145 
    146       model_dir: Directory to save model parameters, graph and etc. This can
    147         also be used to load checkpoints from the directory into a estimator to
    148         continue training a previously saved model. If `PathLike` object, the
    149         path will be resolved. If `None`, the model_dir in `config` will be used
    150         if set. If both are set, they must be same. If both are `None`, a
    151         temporary directory will be used.
    152       config: Configuration object.
    153       params: `dict` of hyper parameters that will be passed into `model_fn`.
    154               Keys are names of parameters, values are basic python types.
    155       warm_start_from: Optional string filepath to a checkpoint to warm-start
    156                        from, or a `tf.estimator.WarmStartSettings` object to
    157                        fully configure warm-starting.  If the string filepath is
    158                        provided instead of a `WarmStartSettings`, then all
    159                        variables are warm-started, and it is assumed that
    160                        vocabularies and Tensor names are unchanged.
    161 
    162     Raises:
    163       RuntimeError: If eager execution is enabled.
    164       ValueError: parameters of `model_fn` don't match `params`.
    165       ValueError: if this is called via a subclass and if that class overrides
    166         a member of `Estimator`.
    167     """
    168     if context.in_eager_mode():
    169       raise RuntimeError(
    170           'Estimators are not supported when eager execution is enabled.')
    171 
    172     Estimator._assert_members_are_not_overridden(self)
    173 
    174     if config is None:
    175       self._config = run_config.RunConfig()
    176       logging.info('Using default config.')
    177     else:
    178       if not isinstance(config, run_config.RunConfig):
    179         raise ValueError(
    180             'config must be an instance of RunConfig, but provided %s.' %
    181             config)
    182       self._config = config
    183 
    184     # Model directory.
    185     model_dir = compat_internal.path_to_str(model_dir)
    186     if (model_dir is not None) and (self._config.model_dir is not None):
    187       if model_dir != self._config.model_dir:
    188         # TODO(alanyee): remove this suppression after it is no longer needed
    189         # pylint: disable=g-doc-exception
    190         raise ValueError(
    191             "model_dir are set both in constructor and RunConfig, but with "
    192             "different values. In constructor: '{}', in RunConfig: "
    193             "'{}' ".format(model_dir, self._config.model_dir))
    194         # pylint: enable=g-doc-exception
    195 
    196     self._model_dir = model_dir or self._config.model_dir
    197     if self._model_dir is None:
    198       self._model_dir = tempfile.mkdtemp()
    199       logging.warning('Using temporary folder as model directory: %s',
    200                       self._model_dir)
    201     if self._config.model_dir is None:
    202       self._config = self._config.replace(model_dir=self._model_dir)
    203     logging.info('Using config: %s', str(vars(self._config)))
    204 
    205     if self._config.session_config is None:
    206       self._session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    207     else:
    208       self._session_config = self._config.session_config
    209 
    210     self._device_fn = _get_replica_device_setter(self._config)
    211 
    212     if model_fn is None:
    213       raise ValueError('model_fn must be provided to Estimator.')
    214     _verify_model_fn_args(model_fn, params)
    215     self._model_fn = model_fn
    216     self._params = copy.deepcopy(params or {})
    217 
    218     # pylint: disable=protected-access
    219     self._warm_start_settings = (
    220         warm_starting_util._get_default_warm_start_settings(warm_start_from))
    221     # pylint: enable=protected-access
    222 
    223   @property
    224   def model_dir(self):
    225     return self._model_dir
    226 
    227   @property
    228   def config(self):
    229     return copy.deepcopy(self._config)
    230 
    231   @property
    232   def params(self):
    233     return copy.deepcopy(self._params)
    234 
    235   @property
    236   def model_fn(self):
    237     """Returns the model_fn which is bound to self.params.
    238 
    239     Returns:
    240       The model_fn with following signature:
    241         `def model_fn(features, labels, mode, config)`
    242     """
    243 
    244     def public_model_fn(features, labels, mode, config):
    245       return self._call_model_fn(features, labels, mode, config)
    246 
    247     return public_model_fn
    248 
    249   # TODO(ispir): support a list of names
    250   def get_variable_value(self, name):
    251     """Returns value of the variable given by name.
    252 
    253     Args:
    254       name: string or a list of string, name of the tensor.
    255 
    256     Returns:
    257       Numpy array - value of the tensor.
    258 
    259     Raises:
    260       ValueError: If the Estimator has not produced a checkpoint yet.
    261     """
    262     _check_checkpoint_available(self.model_dir)
    263     return training.load_variable(self.model_dir, name)
    264 
    265   def get_variable_names(self):
    266     """Returns list of all variable names in this model.
    267 
    268     Returns:
    269       List of names.
    270 
    271     Raises:
    272       ValueError: If the Estimator has not produced a checkpoint yet.
    273     """
    274     _check_checkpoint_available(self.model_dir)
    275     return [name for name, _ in training.list_variables(self.model_dir)]
    276 
    277   def latest_checkpoint(self):
    278     """Finds the filename of latest saved checkpoint file in `model_dir`.
    279 
    280     Returns:
    281       The full path to the latest checkpoint or `None` if no checkpoint was
    282       found.
    283     """
    284     return saver.latest_checkpoint(self.model_dir)
    285 
    286   def train(self,
    287             input_fn,
    288             hooks=None,
    289             steps=None,
    290             max_steps=None,
    291             saving_listeners=None):
    292     """Trains a model given training data input_fn.
    293 
    294     Args:
    295       input_fn: A function that provides input data for training as minibatches.
    296         See @{$get_started/premade_estimators#create_input_functions} for more
    297         information. The function should construct and return one of
    298         the following:
    299 
    300           * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
    301             tuple (features, labels) with same constraints as below.
    302           * A tuple (features, labels): Where features is a `Tensor` or a
    303             dictionary of string feature name to `Tensor` and labels is a
    304             `Tensor` or a dictionary of string label name to `Tensor`. Both
    305             features and labels are consumed by `model_fn`. They should satisfy
    306             the expectation of `model_fn` from inputs.
    307 
    308       hooks: List of `SessionRunHook` subclass instances. Used for callbacks
    309         inside the training loop.
    310       steps: Number of steps for which to train model. If `None`, train forever
    311         or train until input_fn generates the `OutOfRange` error or
    312         `StopIteration` exception. 'steps' works incrementally. If you call two
    313         times train(steps=10) then training occurs in total 20 steps. If
    314         `OutOfRange` or `StopIteration` occurs in the middle, training stops
    315         before 20 steps. If you don't want to have incremental behavior please
    316         set `max_steps` instead. If set, `max_steps` must be `None`.
    317       max_steps: Number of total steps for which to train model. If `None`,
    318         train forever or train until input_fn generates the `OutOfRange` error
    319         or `StopIteration` exception. If set, `steps` must be `None`. If
    320         `OutOfRange` or `StopIteration` occurs in the middle, training stops
    321         before `max_steps` steps.
    322         Two calls to `train(steps=100)` means 200 training
    323         iterations. On the other hand, two calls to `train(max_steps=100)` means
    324         that the second call will not do any iteration since first call did
    325         all 100 steps.
    326       saving_listeners: list of `CheckpointSaverListener` objects. Used for
    327         callbacks that run immediately before or after checkpoint savings.
    328 
    329     Returns:
    330       `self`, for chaining.
    331 
    332     Raises:
    333       ValueError: If both `steps` and `max_steps` are not `None`.
    334       ValueError: If either `steps` or `max_steps` is <= 0.
    335     """
    336     if (steps is not None) and (max_steps is not None):
    337       raise ValueError('Can not provide both steps and max_steps.')
    338     if steps is not None and steps <= 0:
    339       raise ValueError('Must specify steps > 0, given: {}'.format(steps))
    340     if max_steps is not None and max_steps <= 0:
    341       raise ValueError(
    342           'Must specify max_steps > 0, given: {}'.format(max_steps))
    343 
    344     if max_steps is not None:
    345       start_step = _load_global_step_from_checkpoint_dir(self._model_dir)
    346       if max_steps <= start_step:
    347         logging.info('Skipping training since max_steps has already saved.')
    348         return self
    349 
    350     hooks = _check_hooks_type(hooks)
    351     hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps))
    352 
    353     saving_listeners = _check_listeners_type(saving_listeners)
    354     loss = self._train_model(input_fn, hooks, saving_listeners)
    355     logging.info('Loss for final step: %s.', loss)
    356     return self
    357 
    358   def _convert_train_steps_to_hooks(self, steps, max_steps):
    359     if steps is not None or max_steps is not None:
    360       return [training.StopAtStepHook(steps, max_steps)]
    361     else:
    362       return []
    363 
    364   def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None,
    365                name=None):
    366     """Evaluates the model given evaluation data input_fn.
    367 
    368     For each step, calls `input_fn`, which returns one batch of data.
    369     Evaluates until:
    370     - `steps` batches are processed, or
    371     - `input_fn` raises an end-of-input exception (`OutOfRangeError` or
    372     `StopIteration`).
    373 
    374     Args:
    375       input_fn: A function that constructs the input data for evaluation.
    376         See @{$get_started/premade_estimators#create_input_functions} for more
    377         information. The function should construct and return one of
    378         the following:
    379 
    380           * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
    381             tuple (features, labels) with same constraints as below.
    382           * A tuple (features, labels): Where features is a `Tensor` or a
    383             dictionary of string feature name to `Tensor` and labels is a
    384             `Tensor` or a dictionary of string label name to `Tensor`. Both
    385             features and labels are consumed by `model_fn`. They should satisfy
    386             the expectation of `model_fn` from inputs.
    387 
    388       steps: Number of steps for which to evaluate model. If `None`, evaluates
    389         until `input_fn` raises an end-of-input exception.
    390       hooks: List of `SessionRunHook` subclass instances. Used for callbacks
    391         inside the evaluation call.
    392       checkpoint_path: Path of a specific checkpoint to evaluate. If `None`, the
    393         latest checkpoint in `model_dir` is used.
    394       name: Name of the evaluation if user needs to run multiple evaluations on
    395         different data sets, such as on training data vs test data. Metrics for
    396         different evaluations are saved in separate folders, and appear
    397         separately in tensorboard.
    398 
    399     Returns:
    400       A dict containing the evaluation metrics specified in `model_fn` keyed by
    401       name, as well as an entry `global_step` which contains the value of the
    402       global step for which this evaluation was performed.
    403 
    404     Raises:
    405       ValueError: If `steps <= 0`.
    406       ValueError: If no model has been trained, namely `model_dir`, or the
    407         given `checkpoint_path` is empty.
    408     """
    409     hooks = _check_hooks_type(hooks)
    410     hooks.extend(self._convert_eval_steps_to_hooks(steps))
    411 
    412     return self._evaluate_model(
    413         input_fn=input_fn,
    414         hooks=hooks,
    415         checkpoint_path=checkpoint_path,
    416         name=name)
    417 
    418   def _convert_eval_steps_to_hooks(self, steps):
    419     if steps is None:
    420       return []
    421 
    422     if steps <= 0:
    423       raise ValueError('Must specify steps > 0, given: {}'.format(steps))
    424     return [evaluation._StopAfterNEvalsHook(num_evals=steps)]  # pylint: disable=protected-access
    425 
    426   def predict(self,
    427               input_fn,
    428               predict_keys=None,
    429               hooks=None,
    430               checkpoint_path=None,
    431               yield_single_examples=True):
    432     """Yields predictions for given features.
    433 
    434     Args:
    435       input_fn: A function that constructs the features. Prediction continues
    436         until `input_fn` raises an end-of-input exception (`OutOfRangeError` or
    437         `StopIteration`).
    438         See @{$get_started/premade_estimators#create_input_functions} for more
    439         information. The function should construct and return one of
    440         the following:
    441 
    442           * A 'tf.data.Dataset' object: Outputs of `Dataset` object must have
    443             same constraints as below.
    444           * features: A `Tensor` or a dictionary of string feature name to
    445             `Tensor`. features are consumed by `model_fn`. They should satisfy
    446             the expectation of `model_fn` from inputs.
    447           * A tuple, in which case the first item is extracted as features.
    448 
    449       predict_keys: list of `str`, name of the keys to predict. It is used if
    450         the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used
    451         then rest of the predictions will be filtered from the dictionary. If
    452         `None`, returns all.
    453       hooks: List of `SessionRunHook` subclass instances. Used for callbacks
    454         inside the prediction call.
    455       checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
    456         latest checkpoint in `model_dir` is used.
    457       yield_single_examples: If False, yield the whole batch as returned by the
    458         model_fn instead of decomposing the batch into individual elements. This
    459         is useful if model_fn return some tensor with first dimension not
    460         equal to the batch size
    461 
    462     Yields:
    463       Evaluated values of `predictions` tensors.
    464 
    465     Raises:
    466       ValueError: Could not find a trained model in model_dir.
    467       ValueError: if batch length of predictions are not same and
    468         yield_single_examples is True.
    469       ValueError: If there is a conflict between `predict_keys` and
    470         `predictions`. For example if `predict_keys` is not `None` but
    471         `EstimatorSpec.predictions` is not a `dict`.
    472     """
    473     hooks = _check_hooks_type(hooks)
    474     # Check that model has been trained.
    475     if not checkpoint_path:
    476       checkpoint_path = saver.latest_checkpoint(self._model_dir)
    477     if not checkpoint_path:
    478       raise ValueError('Could not find trained model in model_dir: {}.'.format(
    479           self._model_dir))
    480 
    481     with ops.Graph().as_default() as g:
    482       random_seed.set_random_seed(self._config.tf_random_seed)
    483       self._create_and_assert_global_step(g)
    484       features, input_hooks = self._get_features_from_input_fn(
    485           input_fn, model_fn_lib.ModeKeys.PREDICT)
    486       estimator_spec = self._call_model_fn(
    487           features, None, model_fn_lib.ModeKeys.PREDICT, self.config)
    488       predictions = self._extract_keys(estimator_spec.predictions, predict_keys)
    489       all_hooks = list(input_hooks)
    490       all_hooks.extend(hooks)
    491       all_hooks.extend(list(estimator_spec.prediction_hooks or []))
    492       with training.MonitoredSession(
    493           session_creator=training.ChiefSessionCreator(
    494               checkpoint_filename_with_path=checkpoint_path,
    495               master=self._config.master,
    496               scaffold=estimator_spec.scaffold,
    497               config=self._session_config),
    498           hooks=all_hooks) as mon_sess:
    499         while not mon_sess.should_stop():
    500           preds_evaluated = mon_sess.run(predictions)
    501           if not yield_single_examples:
    502             yield preds_evaluated
    503           elif not isinstance(predictions, dict):
    504             for pred in preds_evaluated:
    505               yield pred
    506           else:
    507             for i in range(self._extract_batch_length(preds_evaluated)):
    508               yield {
    509                   key: value[i]
    510                   for key, value in six.iteritems(preds_evaluated)
    511               }
    512 
    513   def _assert_members_are_not_overridden(self):
    514     """Asserts members of `Estimator` are not overridden."""
    515     allowed_overrides = set([
    516         '_call_input_fn', '_create_global_step',
    517         '_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks',
    518         '_tf_api_names'
    519     ])
    520     estimator_members = set([m for m in Estimator.__dict__.keys()
    521                              if not m.startswith('__')])
    522     subclass_members = set(self.__class__.__dict__.keys())
    523     common_members = estimator_members & subclass_members - allowed_overrides
    524     overridden_members = [
    525         m for m in common_members
    526         if Estimator.__dict__[m] != self.__class__.__dict__[m]]
    527     if overridden_members:
    528       raise ValueError(
    529           'Subclasses of Estimator cannot override members of Estimator. '
    530           '{} does override {}'.format(self.__class__, overridden_members))
    531 
    532   def export_savedmodel(
    533       self, export_dir_base, serving_input_receiver_fn,
    534       assets_extra=None,
    535       as_text=False,
    536       checkpoint_path=None,
    537       strip_default_attrs=False):
    538     # pylint: disable=line-too-long
    539     """Exports inference graph as a SavedModel into given dir.
    540 
    541     For a detailed guide, see
    542     @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}.
    543 
    544     This method builds a new graph by first calling the
    545     serving_input_receiver_fn to obtain feature `Tensor`s, and then calling
    546     this `Estimator`'s model_fn to generate the model graph based on those
    547     features. It restores the given checkpoint (or, lacking that, the most
    548     recent checkpoint) into this graph in a fresh session.  Finally it creates
    549     a timestamped export directory below the given export_dir_base, and writes
    550     a `SavedModel` into it containing a single `MetaGraphDef` saved from this
    551     session.
    552 
    553     The exported `MetaGraphDef` will provide one `SignatureDef` for each
    554     element of the export_outputs dict returned from the model_fn, named using
    555     the same keys.  One of these keys is always
    556     signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which
    557     signature will be served when a serving request does not specify one.
    558     For each signature, the outputs are provided by the corresponding
    559     `ExportOutput`s, and the inputs are always the input receivers provided by
    560     the serving_input_receiver_fn.
    561 
    562     Extra assets may be written into the SavedModel via the assets_extra
    563     argument.  This should be a dict, where each key gives a destination path
    564     (including the filename) relative to the assets.extra directory.  The
    565     corresponding value gives the full path of the source file to be copied.
    566     For example, the simple case of copying a single file without renaming it
    567     is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
    568 
    569     Args:
    570       export_dir_base: A string containing a directory in which to create
    571         timestamped subdirectories containing exported SavedModels.
    572       serving_input_receiver_fn: A function that takes no argument and
    573         returns a `ServingInputReceiver`.
    574       assets_extra: A dict specifying how to populate the assets.extra directory
    575         within the exported SavedModel, or `None` if no extra assets are needed.
    576       as_text: whether to write the SavedModel proto in text format.
    577       checkpoint_path: The checkpoint path to export.  If `None` (the default),
    578         the most recent checkpoint found within the model directory is chosen.
    579       strip_default_attrs: Boolean. If `True`, default-valued attributes will be
    580         removed from the NodeDefs. For a detailed guide, see
    581         [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
    582 
    583     Returns:
    584       The string path to the exported directory.
    585 
    586     Raises:
    587       ValueError: if no serving_input_receiver_fn is provided, no export_outputs
    588           are provided, or no checkpoint can be found.
    589     """
    590     # pylint: enable=line-too-long
    591     if serving_input_receiver_fn is None:
    592       raise ValueError('serving_input_receiver_fn must be defined.')
    593 
    594     with ops.Graph().as_default() as g:
    595       self._create_and_assert_global_step(g)
    596       random_seed.set_random_seed(self._config.tf_random_seed)
    597       serving_input_receiver = serving_input_receiver_fn()
    598 
    599       # Call the model_fn and collect the export_outputs.
    600       estimator_spec = self._call_model_fn(
    601           features=serving_input_receiver.features,
    602           labels=None,
    603           mode=model_fn_lib.ModeKeys.PREDICT,
    604           config=self.config)
    605 
    606       # Build the SignatureDefs from receivers and all outputs
    607       signature_def_map = build_all_signature_defs(
    608           serving_input_receiver.receiver_tensors,
    609           estimator_spec.export_outputs,
    610           serving_input_receiver.receiver_tensors_alternatives)
    611 
    612       if not checkpoint_path:
    613         # Locate the latest checkpoint
    614         checkpoint_path = saver.latest_checkpoint(self._model_dir)
    615       if not checkpoint_path:
    616         raise ValueError("Couldn't find trained model at %s." % self._model_dir)
    617 
    618       export_dir = get_timestamped_export_dir(export_dir_base)
    619       temp_export_dir = get_temp_export_dir(export_dir)
    620 
    621       # TODO(soergel): Consider whether MonitoredSession makes sense here
    622       with tf_session.Session(config=self._session_config) as session:
    623 
    624         saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
    625             sharded=True)
    626         saver_for_restore.restore(session, checkpoint_path)
    627 
    628         # pylint: disable=protected-access
    629         local_init_op = (
    630             estimator_spec.scaffold.local_init_op or
    631             monitored_session.Scaffold._default_local_init_op())
    632         # pylint: enable=protected-access
    633 
    634         # Perform the export
    635         builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
    636         builder.add_meta_graph_and_variables(
    637             session, [tag_constants.SERVING],
    638             signature_def_map=signature_def_map,
    639             assets_collection=ops.get_collection(
    640                 ops.GraphKeys.ASSET_FILEPATHS),
    641             legacy_init_op=local_init_op,
    642             strip_default_attrs=strip_default_attrs)
    643         builder.save(as_text)
    644 
    645       # Add the extra assets
    646       if assets_extra:
    647         assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir),
    648                                          compat.as_bytes('assets.extra'))
    649         for dest_relative, source in assets_extra.items():
    650           dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
    651                                        compat.as_bytes(dest_relative))
    652           dest_path = os.path.dirname(dest_absolute)
    653           gfile.MakeDirs(dest_path)
    654           gfile.Copy(source, dest_absolute)
    655 
    656       gfile.Rename(temp_export_dir, export_dir)
    657       return export_dir
    658 
    659   def _get_features_from_input_fn(self, input_fn, mode):
    660     """Extracts the `features` from return values of `input_fn`."""
    661     result = self._call_input_fn(input_fn, mode)
    662     input_hooks = []
    663     if isinstance(result, dataset_ops.Dataset):
    664       iterator = result.make_initializable_iterator()
    665       input_hooks.append(_DatasetInitializerHook(iterator))
    666       result = iterator.get_next()
    667     if isinstance(result, (list, tuple)):
    668       # Unconditionally drop the label (the second element of result).
    669       result = result[0]
    670 
    671     if not _has_dataset_or_queue_runner(result):
    672       logging.warning('Input graph does not use tf.data.Dataset or contain a '
    673                       'QueueRunner. That means predict yields forever. '
    674                       'This is probably a mistake.')
    675     return result, input_hooks
    676 
    677   def _get_features_and_labels_from_input_fn(self, input_fn, mode):
    678     """Extracts the `features` and labels from return values of `input_fn`."""
    679     result = self._call_input_fn(input_fn, mode)
    680     input_hooks = []
    681     if isinstance(result, dataset_ops.Dataset):
    682       iterator = result.make_initializable_iterator()
    683       input_hooks.append(_DatasetInitializerHook(iterator))
    684       result = iterator.get_next()
    685     if isinstance(result, (list, tuple)):
    686       if len(result) != 2:
    687         raise ValueError(
    688             'input_fn should return (features, labels) as a len 2 tuple.')
    689       return result[0], result[1], input_hooks
    690     return result, None, input_hooks
    691 
    692   def _extract_batch_length(self, preds_evaluated):
    693     """Extracts batch length of predictions."""
    694     batch_length = None
    695     for key, value in six.iteritems(preds_evaluated):
    696       batch_length = batch_length or value.shape[0]
    697       if value.shape[0] != batch_length:
    698         raise ValueError('Batch length of predictions should be same. %s has '
    699                          'different batch length then others.' % key)
    700     return batch_length
    701 
    702   def _extract_keys(self, predictions, predict_keys):
    703     """Extracts `predict_keys` from `predictions`."""
    704     if not predict_keys:
    705       return predictions
    706     if not isinstance(predictions, dict):
    707       raise ValueError(
    708           'predict_keys argument is not valid in case of non-dict predictions.')
    709     existing_keys = predictions.keys()
    710     predictions = {
    711         key: value
    712         for key, value in six.iteritems(predictions) if key in predict_keys
    713     }
    714     if not predictions:
    715       raise ValueError('Expected to run at least one output from %s, '
    716                        'provided %s.' % (existing_keys, predict_keys))
    717     return predictions
    718 
    719   def _create_global_step(self, graph):
    720     """Creates the global step tensor in graph.
    721 
    722     The global step tensor must be an integer type with name 'global_step' and
    723     be added to the collection ${tf.GraphKeys.GLOBAL_STEP}.
    724 
    725     Args:
    726       graph: The graph in which to create the global step tensor.
    727 
    728     Returns:
    729       The global step `Tensor`.
    730     """
    731     return training.create_global_step(graph)
    732 
    733   def _create_and_assert_global_step(self, graph):
    734     """Creates and asserts properties of the global step.
    735 
    736     Args:
    737       graph: The graph in which to create the global step tensor.
    738 
    739     Returns:
    740       The global step `Tensor`.
    741     """
    742     step = self._create_global_step(graph)
    743     assert step == training.get_global_step()
    744     assert step.dtype.is_integer
    745     return step
    746 
    747   def _call_input_fn(self, input_fn, mode):
    748     """Calls the input function.
    749 
    750     Args:
    751       input_fn: The input function.
    752       mode: ModeKeys
    753 
    754     Returns:
    755       Either features or (features, labels) where features and labels are:
    756         features - `Tensor` or dictionary of string feature name to `Tensor`.
    757         labels - `Tensor` or dictionary of `Tensor` with labels.
    758 
    759     Raises:
    760       ValueError: if input_fn takes invalid arguments.
    761     """
    762     input_fn_args = util.fn_args(input_fn)
    763     kwargs = {}
    764     if 'mode' in input_fn_args:
    765       kwargs['mode'] = mode
    766     if 'params' in input_fn_args:
    767       kwargs['params'] = self.params
    768     if 'config' in input_fn_args:
    769       kwargs['config'] = self.config
    770     with ops.device('/cpu:0'):
    771       return input_fn(**kwargs)
    772 
    773   def _call_model_fn(self, features, labels, mode, config):
    774     """Calls model function.
    775 
    776     Args:
    777       features: features dict.
    778       labels: labels dict.
    779       mode: ModeKeys
    780       config: RunConfig
    781 
    782     Returns:
    783       An `EstimatorSpec` object.
    784 
    785     Raises:
    786       ValueError: if model_fn returns invalid objects.
    787     """
    788     model_fn_args = util.fn_args(self._model_fn)
    789     kwargs = {}
    790     if 'labels' in model_fn_args:
    791       kwargs['labels'] = labels
    792     else:
    793       if labels is not None:
    794         raise ValueError(
    795             'model_fn does not take labels, but input_fn returns labels.')
    796     if 'mode' in model_fn_args:
    797       kwargs['mode'] = mode
    798     if 'params' in model_fn_args:
    799       kwargs['params'] = self.params
    800     if 'config' in model_fn_args:
    801       kwargs['config'] = config
    802 
    803     logging.info('Calling model_fn.')
    804     model_fn_results = self._model_fn(features=features, **kwargs)
    805     logging.info('Done calling model_fn.')
    806 
    807     if not isinstance(model_fn_results, model_fn_lib.EstimatorSpec):
    808       raise ValueError('model_fn should return an EstimatorSpec.')
    809 
    810     return model_fn_results
    811 
    812   def _train_model(self, input_fn, hooks, saving_listeners):
    813     worker_hooks = []
    814     with ops.Graph().as_default() as g, g.device(self._device_fn):
    815       random_seed.set_random_seed(self._config.tf_random_seed)
    816       global_step_tensor = self._create_and_assert_global_step(g)
    817       training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    818       features, labels, input_hooks = (
    819           self._get_features_and_labels_from_input_fn(
    820               input_fn, model_fn_lib.ModeKeys.TRAIN))
    821       worker_hooks.extend(input_hooks)
    822       estimator_spec = self._call_model_fn(
    823           features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
    824 
    825       if self._warm_start_settings:
    826         logging.info('Warm-starting with WarmStartSettings: %s' %
    827                      (self._warm_start_settings,))
    828         # pylint: disable=protected-access
    829         warm_starting_util._warm_start(self._warm_start_settings)
    830         # pylint: enable=protected-access
    831       # Check if the user created a loss summary, and add one if they didn't.
    832       # We assume here that the summary is called 'loss'. If it is not, we will
    833       # make another one with the name 'loss' to ensure it shows up in the right
    834       # graph in TensorBoard.
    835       if not any([x.op.name == 'loss'
    836                   for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
    837         summary.scalar('loss', estimator_spec.loss)
    838       ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
    839       worker_hooks.extend(hooks)
    840       worker_hooks.extend([
    841           training.NanTensorHook(estimator_spec.loss),
    842           training.LoggingTensorHook(
    843               {
    844                   'loss': estimator_spec.loss,
    845                   'step': global_step_tensor
    846               },
    847               every_n_iter=100)
    848       ])
    849       worker_hooks.extend(estimator_spec.training_hooks)
    850 
    851       if not (estimator_spec.scaffold.saver or
    852               ops.get_collection(ops.GraphKeys.SAVERS)):
    853         ops.add_to_collection(
    854             ops.GraphKeys.SAVERS,
    855             training.Saver(
    856                 sharded=True,
    857                 max_to_keep=self._config.keep_checkpoint_max,
    858                 keep_checkpoint_every_n_hours=(
    859                     self._config.keep_checkpoint_every_n_hours),
    860                 defer_build=True,
    861                 save_relative_paths=True))
    862 
    863       chief_hooks = []
    864       all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
    865       saver_hooks = [
    866           h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
    867       if (self._config.save_checkpoints_secs or
    868           self._config.save_checkpoints_steps):
    869         if not saver_hooks:
    870           chief_hooks = [
    871               training.CheckpointSaverHook(
    872                   self._model_dir,
    873                   save_secs=self._config.save_checkpoints_secs,
    874                   save_steps=self._config.save_checkpoints_steps,
    875                   scaffold=estimator_spec.scaffold)
    876           ]
    877           saver_hooks = [chief_hooks[0]]
    878       if saving_listeners:
    879         if not saver_hooks:
    880           raise ValueError(
    881               'There should be a CheckpointSaverHook to use saving_listeners. '
    882               'Please set one of the RunConfig.save_checkpoints_steps or '
    883               'RunConfig.save_checkpoints_secs.')
    884         else:
    885           # It is expected to have one CheckpointSaverHook. If multiple, we pick
    886           # up the first one to add listener.
    887           saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access
    888       with training.MonitoredTrainingSession(
    889           master=self._config.master,
    890           is_chief=self._config.is_chief,
    891           checkpoint_dir=self._model_dir,
    892           scaffold=estimator_spec.scaffold,
    893           hooks=worker_hooks,
    894           chief_only_hooks=(
    895               tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
    896           save_checkpoint_secs=0,  # Saving is handled by a hook.
    897           save_summaries_steps=self._config.save_summary_steps,
    898           config=self._session_config,
    899           log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
    900         loss = None
    901         while not mon_sess.should_stop():
    902           _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
    903       return loss
    904 
    905   def _evaluate_model(self,
    906                       input_fn,
    907                       hooks=None,
    908                       checkpoint_path=None,
    909                       name=''):
    910     """Evaluates the model using the training.evaluation library."""
    911     # Check that model has been trained (if nothing has been set explicitly).
    912     if not checkpoint_path:
    913       latest_path = saver.latest_checkpoint(self._model_dir)
    914       if not latest_path:
    915         raise ValueError('Could not find trained model in model_dir: {}.'.
    916                          format(self._model_dir))
    917       checkpoint_path = latest_path
    918 
    919     # Setup output directory.
    920     eval_dir = os.path.join(self._model_dir, 'eval' if not name else
    921                             'eval_' + name)
    922 
    923     with ops.Graph().as_default() as g:
    924       random_seed.set_random_seed(self._config.tf_random_seed)
    925       global_step_tensor = self._create_and_assert_global_step(g)
    926       features, labels, input_hooks = (
    927           self._get_features_and_labels_from_input_fn(
    928               input_fn, model_fn_lib.ModeKeys.EVAL))
    929       estimator_spec = self._call_model_fn(
    930           features, labels, model_fn_lib.ModeKeys.EVAL, self.config)
    931 
    932       if model_fn_lib.LOSS_METRIC_KEY in estimator_spec.eval_metric_ops:
    933         raise ValueError(
    934             'Metric with name "%s" is not allowed, because Estimator ' % (
    935                 model_fn_lib.LOSS_METRIC_KEY) +
    936             'already defines a default metric with the same name.')
    937       estimator_spec.eval_metric_ops[
    938           model_fn_lib.LOSS_METRIC_KEY] = metrics_lib.mean(estimator_spec.loss)
    939 
    940       update_op, eval_dict = _extract_metric_update_ops(
    941           estimator_spec.eval_metric_ops)
    942 
    943       if ops.GraphKeys.GLOBAL_STEP in eval_dict:
    944         raise ValueError(
    945             'Metric with name `global_step` is not allowed, because Estimator '
    946             'already defines a default metric with the same name.')
    947       eval_dict[ops.GraphKeys.GLOBAL_STEP] = global_step_tensor
    948 
    949       all_hooks = list(input_hooks)
    950       all_hooks.extend(hooks)
    951       all_hooks.extend(list(estimator_spec.evaluation_hooks or []))
    952 
    953       eval_results = evaluation._evaluate_once(  # pylint: disable=protected-access
    954           checkpoint_path=checkpoint_path,
    955           master=self._config.evaluation_master,
    956           scaffold=estimator_spec.scaffold,
    957           eval_ops=update_op,
    958           final_ops=eval_dict,
    959           hooks=all_hooks,
    960           config=self._session_config)
    961 
    962       _write_dict_to_summary(
    963           output_dir=eval_dir,
    964           dictionary=eval_results,
    965           current_global_step=eval_results[ops.GraphKeys.GLOBAL_STEP])
    966 
    967     return eval_results
    968 
    969 
    970 def _check_checkpoint_available(model_dir):
    971   latest_path = saver.latest_checkpoint(model_dir)
    972   if not latest_path:
    973     raise ValueError(
    974         'Could not find trained model in model_dir: {}.'.format(model_dir))
    975 
    976 
    977 def _check_hooks_type(hooks):
    978   """Returns hooks if all are SessionRunHook, raises TypeError otherwise."""
    979   hooks = list(hooks or [])
    980   for h in hooks:
    981     if not isinstance(h, training.SessionRunHook):
    982       raise TypeError('Hooks must be a SessionRunHook, given: {}'.format(h))
    983   return hooks
    984 
    985 
    986 def _check_listeners_type(saving_listeners):
    987   """Check listeners type."""
    988   listeners = list(saving_listeners or [])
    989   for l in listeners:
    990     if not isinstance(l, training.CheckpointSaverListener):
    991       raise TypeError(
    992           'saving_listeners must be a list of CheckpointSaverListener, '
    993           'given: {}'.format(l))
    994   return listeners
    995 
    996 
    997 def _get_replica_device_setter(config):
    998   """Creates a replica device setter if required as a default device_fn.
    999 
   1000   `Estimator` uses ReplicaDeviceSetter as a default device placer. It sets the
   1001   distributed related arguments such as number of ps_replicas based on given
   1002   config.
   1003 
   1004   Args:
   1005     config: A `RunConfig` instance.
   1006 
   1007   Returns:
   1008     A replica device setter, or None.
   1009   """
   1010   ps_ops = [
   1011       'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
   1012       'MutableHashTableV2', 'MutableHashTableOfTensors',
   1013       'MutableHashTableOfTensorsV2', 'MutableDenseHashTable',
   1014       'MutableDenseHashTableV2', 'VarHandleOp'
   1015   ]
   1016 
   1017   if config.task_type:
   1018     worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id)
   1019   else:
   1020     worker_device = '/job:worker'
   1021 
   1022   if config.num_ps_replicas > 0:
   1023     return training.replica_device_setter(
   1024         ps_tasks=config.num_ps_replicas,
   1025         worker_device=worker_device,
   1026         merge_devices=True,
   1027         ps_ops=ps_ops,
   1028         cluster=config.cluster_spec)
   1029   else:
   1030     return None
   1031 
   1032 
   1033 def _verify_model_fn_args(model_fn, params):
   1034   """Verifies model fn arguments."""
   1035   args = set(util.fn_args(model_fn))
   1036   if 'features' not in args:
   1037     raise ValueError('model_fn (%s) must include features argument.' % model_fn)
   1038   if params is not None and 'params' not in args:
   1039     raise ValueError('model_fn (%s) does not include params argument, '
   1040                      'but params (%s) is passed to Estimator.' % (model_fn,
   1041                                                                   params))
   1042   if params is None and 'params' in args:
   1043     logging.warning('Estimator\'s model_fn (%s) includes params '
   1044                     'argument, but params are not passed to Estimator.',
   1045                     model_fn)
   1046   non_valid_args = list(args - _VALID_MODEL_FN_ARGS)
   1047   if non_valid_args:
   1048     raise ValueError('model_fn (%s) has following not expected args: %s' %
   1049                      (model_fn, non_valid_args))
   1050 
   1051 
   1052 def _load_global_step_from_checkpoint_dir(checkpoint_dir):
   1053   try:
   1054     checkpoint_reader = training.NewCheckpointReader(
   1055         training.latest_checkpoint(checkpoint_dir))
   1056     return checkpoint_reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)
   1057   except:  # pylint: disable=bare-except
   1058     return 0
   1059 
   1060 
   1061 def _extract_metric_update_ops(eval_dict):
   1062   """Separate update operations from metric value operations."""
   1063   update_ops = []
   1064   value_ops = {}
   1065   # Sort metrics lexicographically so graph is identical every time.
   1066   for name, metric_ops in sorted(six.iteritems(eval_dict)):
   1067     value_ops[name] = metric_ops[0]
   1068     update_ops.append(metric_ops[1])
   1069 
   1070   if update_ops:
   1071     update_op = control_flow_ops.group(*update_ops)
   1072   else:
   1073     update_op = None
   1074 
   1075   return update_op, value_ops
   1076 
   1077 
   1078 def _dict_to_str(dictionary):
   1079   """Get a `str` representation of a `dict`.
   1080 
   1081   Args:
   1082     dictionary: The `dict` to be represented as `str`.
   1083 
   1084   Returns:
   1085     A `str` representing the `dictionary`.
   1086   """
   1087   return ', '.join('%s = %s' % (k, v)
   1088                    for k, v in sorted(six.iteritems(dictionary)))
   1089 
   1090 
   1091 def _write_dict_to_summary(output_dir,
   1092                            dictionary,
   1093                            current_global_step):
   1094   """Writes a `dict` into summary file in given output directory.
   1095 
   1096   Args:
   1097     output_dir: `str`, directory to write the summary file in.
   1098     dictionary: the `dict` to be written to summary file.
   1099     current_global_step: `int`, the current global step.
   1100   """
   1101   logging.info('Saving dict for global step %d: %s', current_global_step,
   1102                _dict_to_str(dictionary))
   1103   summary_writer = writer_cache.FileWriterCache.get(output_dir)
   1104   summary_proto = summary_pb2.Summary()
   1105   for key in dictionary:
   1106     if dictionary[key] is None:
   1107       continue
   1108     if key == 'global_step':
   1109       continue
   1110     if (isinstance(dictionary[key], np.float32) or
   1111         isinstance(dictionary[key], float)):
   1112       summary_proto.value.add(tag=key, simple_value=float(dictionary[key]))
   1113     elif (isinstance(dictionary[key], np.int64) or
   1114           isinstance(dictionary[key], np.int32) or
   1115           isinstance(dictionary[key], int)):
   1116       summary_proto.value.add(tag=key, simple_value=int(dictionary[key]))
   1117     elif isinstance(dictionary[key], six.binary_type):
   1118       try:
   1119         summ = summary_pb2.Summary.FromString(dictionary[key])
   1120         for i, _ in enumerate(summ.value):
   1121           summ.value[i].tag = key
   1122         summary_proto.value.extend(summ.value)
   1123       except message.DecodeError:
   1124         logging.warn('Skipping summary for %s, cannot parse string to Summary.',
   1125                      key)
   1126         continue
   1127     else:
   1128       logging.warn(
   1129           'Skipping summary for %s, must be a float, np.float32, np.int64, '
   1130           'np.int32 or int or a serialized string of Summary.', key)
   1131   summary_writer.add_summary(summary_proto, current_global_step)
   1132   summary_writer.flush()
   1133 
   1134 
   1135 def _has_dataset_or_queue_runner(maybe_tensor):
   1136   """Returns True if TF dataset or QueueRunner has been used."""
   1137   # Check TF dataset first. Here, we use a simple algorithm to check the top
   1138   # level Tensors only, which should be sufficient for most users.
   1139   tensors = [x for x in nest.flatten(maybe_tensor) if isinstance(x, ops.Tensor)]
   1140   if any([t.op.type == 'IteratorGetNext' for t in tensors]):
   1141     return True
   1142 
   1143   # Now, check queue.
   1144   return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS)
   1145 
   1146 
   1147 class _DatasetInitializerHook(training.SessionRunHook):
   1148 
   1149   def __init__(self, iterator):
   1150     self._iterator = iterator
   1151 
   1152   def begin(self):
   1153     self._initializer = self._iterator.initializer
   1154 
   1155   def after_create_session(self, session, coord):
   1156     del coord
   1157     session.run(self._initializer)
   1158