Home | History | Annotate | Download | only in training
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 # http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Contains functions for evaluation and summarization of metrics."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import time
     22 import math
     23 
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import init_ops
     28 from tensorflow.python.ops import state_ops
     29 from tensorflow.python.ops import variable_scope
     30 from tensorflow.python.platform import tf_logging as logging
     31 from tensorflow.python.training import basic_session_run_hooks
     32 from tensorflow.python.training import monitored_session
     33 from tensorflow.python.training import session_run_hook
     34 
     35 
     36 def _get_or_create_eval_step():
     37   """Gets or creates the eval step `Tensor`.
     38 
     39   Returns:
     40     A `Tensor` representing a counter for the evaluation step.
     41 
     42   Raises:
     43     ValueError: If multiple `Tensors` have been added to the
     44       `tf.GraphKeys.EVAL_STEP` collection.
     45   """
     46   graph = ops.get_default_graph()
     47   eval_steps = graph.get_collection(ops.GraphKeys.EVAL_STEP)
     48   if len(eval_steps) == 1:
     49     return eval_steps[0]
     50   elif len(eval_steps) > 1:
     51     raise ValueError('Multiple tensors added to tf.GraphKeys.EVAL_STEP')
     52   else:
     53     counter = variable_scope.get_variable(
     54         'eval_step',
     55         shape=[],
     56         dtype=dtypes.int64,
     57         initializer=init_ops.zeros_initializer(),
     58         trainable=False,
     59         collections=[ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.EVAL_STEP])
     60     return counter
     61 
     62 
     63 def _get_latest_eval_step_value(update_ops):
     64   """Gets the eval step `Tensor` value after running `update_ops`.
     65 
     66   Args:
     67     update_ops: A list of `Tensors` or a dictionary of names to `Tensors`,
     68         which are run before reading the eval step value.
     69 
     70   Returns:
     71     A `Tensor` representing the value for the evaluation step.
     72   """
     73   if isinstance(update_ops, dict):
     74     update_ops = list(update_ops.values())
     75 
     76   with ops.control_dependencies(update_ops):
     77     return array_ops.identity(_get_or_create_eval_step().read_value())
     78 
     79 
     80 class _StopAfterNEvalsHook(session_run_hook.SessionRunHook):
     81   """Run hook used by the evaluation routines to run the `eval_ops` N times."""
     82 
     83   def __init__(self, num_evals, log_progress=True):
     84     """Constructs the run hook.
     85 
     86     Args:
     87       num_evals: The number of evaluations to run for. if set to None, will
     88         iterate the dataset until all inputs are exhausted.
     89       log_progress: Whether to log evaluation progress, defaults to True.
     90     """
     91     # The number of evals to run for.
     92     self._num_evals = num_evals
     93     self._evals_completed = None
     94     self._log_progress = log_progress
     95     # Reduce logging frequency if there are 20 or more evaluations.
     96     self._log_frequency = (1 if (num_evals is None or num_evals < 20)
     97                            else math.floor(num_evals / 10.))
     98 
     99   def _set_evals_completed_tensor(self, updated_eval_step):
    100     self._evals_completed = updated_eval_step
    101 
    102   def before_run(self, run_context):
    103     return session_run_hook.SessionRunArgs({
    104         'evals_completed': self._evals_completed
    105     })
    106 
    107   def after_run(self, run_context, run_values):
    108     evals_completed = run_values.results['evals_completed']
    109     if self._log_progress:
    110       if self._num_evals is None:
    111         logging.info('Evaluation [%d]', evals_completed)
    112       else:
    113         if ((evals_completed % self._log_frequency) == 0 or
    114             (self._num_evals == evals_completed)):
    115           logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals)
    116     if self._num_evals is not None and evals_completed >= self._num_evals:
    117       run_context.request_stop()
    118 
    119 
    120 def _evaluate_once(checkpoint_path,
    121                    master='',
    122                    scaffold=None,
    123                    eval_ops=None,
    124                    feed_dict=None,
    125                    final_ops=None,
    126                    final_ops_feed_dict=None,
    127                    hooks=None,
    128                    config=None):
    129   """Evaluates the model at the given checkpoint path.
    130 
    131   During a single evaluation, the `eval_ops` is run until the session is
    132   interrupted or requested to finish. This is typically requested via a
    133   `tf.contrib.training.StopAfterNEvalsHook` which results in `eval_ops` running
    134   the requested number of times.
    135 
    136   Optionally, a user can pass in `final_ops`, a single `Tensor`, a list of
    137   `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is
    138   evaluated a single time after `eval_ops` has finished running and the fetched
    139   values of `final_ops` are returned. If `final_ops` is left as `None`, then
    140   `None` is returned.
    141 
    142   One may also consider using a `tf.contrib.training.SummaryAtEndHook` to record
    143   summaries after the `eval_ops` have run. If `eval_ops` is `None`, the
    144   summaries run immediately after the model checkpoint has been restored.
    145 
    146   Note that `evaluate_once` creates a local variable used to track the number of
    147   evaluations run via `tf.contrib.training.get_or_create_eval_step`.
    148   Consequently, if a custom local init op is provided via a `scaffold`, the
    149   caller should ensure that the local init op also initializes the eval step.
    150 
    151   Args:
    152     checkpoint_path: The path to a checkpoint to use for evaluation.
    153     master: The BNS address of the TensorFlow master.
    154     scaffold: An tf.train.Scaffold instance for initializing variables and
    155       restoring variables. Note that `scaffold.init_fn` is used by the function
    156       to restore the checkpoint. If you supply a custom init_fn, then it must
    157       also take care of restoring the model from its checkpoint.
    158     eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
    159       to `Tensors`, which is run until the session is requested to stop,
    160       commonly done by a `tf.contrib.training.StopAfterNEvalsHook`.
    161     feed_dict: The feed dictionary to use when executing the `eval_ops`.
    162     final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
    163       to `Tensors`.
    164     final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`.
    165     hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
    166       evaluation loop.
    167     config: An instance of `tf.ConfigProto` that will be used to
    168       configure the `Session`. If left as `None`, the default will be used.
    169 
    170   Returns:
    171     The fetched values of `final_ops` or `None` if `final_ops` is `None`.
    172   """
    173   eval_step = _get_or_create_eval_step()
    174 
    175   # Prepare the run hooks.
    176   hooks = list(hooks or [])
    177 
    178   if eval_ops is not None:
    179     update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True)
    180 
    181     if isinstance(eval_ops, dict):
    182       eval_ops['update_eval_step'] = update_eval_step
    183     elif isinstance(eval_ops, (tuple, list)):
    184       eval_ops = list(eval_ops) + [update_eval_step]
    185     else:
    186       eval_ops = [eval_ops, update_eval_step]
    187 
    188     eval_step_value = _get_latest_eval_step_value(eval_ops)
    189 
    190     for h in hooks:
    191       if isinstance(h, _StopAfterNEvalsHook):
    192         h._set_evals_completed_tensor(eval_step_value)  # pylint: disable=protected-access
    193 
    194   logging.info('Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
    195                                                          time.gmtime()))
    196 
    197   # Prepare the session creator.
    198   session_creator = monitored_session.ChiefSessionCreator(
    199       scaffold=scaffold,
    200       checkpoint_filename_with_path=checkpoint_path,
    201       master=master,
    202       config=config)
    203 
    204   final_ops_hook = basic_session_run_hooks.FinalOpsHook(
    205       final_ops, final_ops_feed_dict)
    206   hooks.append(final_ops_hook)
    207 
    208   with monitored_session.MonitoredSession(
    209       session_creator=session_creator, hooks=hooks) as session:
    210     if eval_ops is not None:
    211       while not session.should_stop():
    212         session.run(eval_ops, feed_dict)
    213 
    214   logging.info('Finished evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
    215                                                          time.gmtime()))
    216   return final_ops_hook.final_ops_values
    217