1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Contains functions for evaluation and summarization of metrics.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import time 22 import math 23 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.framework import ops 26 from tensorflow.python.ops import array_ops 27 from tensorflow.python.ops import init_ops 28 from tensorflow.python.ops import state_ops 29 from tensorflow.python.ops import variable_scope 30 from tensorflow.python.platform import tf_logging as logging 31 from tensorflow.python.training import basic_session_run_hooks 32 from tensorflow.python.training import monitored_session 33 from tensorflow.python.training import session_run_hook 34 35 36 def _get_or_create_eval_step(): 37 """Gets or creates the eval step `Tensor`. 38 39 Returns: 40 A `Tensor` representing a counter for the evaluation step. 41 42 Raises: 43 ValueError: If multiple `Tensors` have been added to the 44 `tf.GraphKeys.EVAL_STEP` collection. 45 """ 46 graph = ops.get_default_graph() 47 eval_steps = graph.get_collection(ops.GraphKeys.EVAL_STEP) 48 if len(eval_steps) == 1: 49 return eval_steps[0] 50 elif len(eval_steps) > 1: 51 raise ValueError('Multiple tensors added to tf.GraphKeys.EVAL_STEP') 52 else: 53 counter = variable_scope.get_variable( 54 'eval_step', 55 shape=[], 56 dtype=dtypes.int64, 57 initializer=init_ops.zeros_initializer(), 58 trainable=False, 59 collections=[ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.EVAL_STEP]) 60 return counter 61 62 63 def _get_latest_eval_step_value(update_ops): 64 """Gets the eval step `Tensor` value after running `update_ops`. 65 66 Args: 67 update_ops: A list of `Tensors` or a dictionary of names to `Tensors`, 68 which are run before reading the eval step value. 69 70 Returns: 71 A `Tensor` representing the value for the evaluation step. 72 """ 73 if isinstance(update_ops, dict): 74 update_ops = list(update_ops.values()) 75 76 with ops.control_dependencies(update_ops): 77 return array_ops.identity(_get_or_create_eval_step().read_value()) 78 79 80 class _StopAfterNEvalsHook(session_run_hook.SessionRunHook): 81 """Run hook used by the evaluation routines to run the `eval_ops` N times.""" 82 83 def __init__(self, num_evals, log_progress=True): 84 """Constructs the run hook. 85 86 Args: 87 num_evals: The number of evaluations to run for. if set to None, will 88 iterate the dataset until all inputs are exhausted. 89 log_progress: Whether to log evaluation progress, defaults to True. 90 """ 91 # The number of evals to run for. 92 self._num_evals = num_evals 93 self._evals_completed = None 94 self._log_progress = log_progress 95 # Reduce logging frequency if there are 20 or more evaluations. 96 self._log_frequency = (1 if (num_evals is None or num_evals < 20) 97 else math.floor(num_evals / 10.)) 98 99 def _set_evals_completed_tensor(self, updated_eval_step): 100 self._evals_completed = updated_eval_step 101 102 def before_run(self, run_context): 103 return session_run_hook.SessionRunArgs({ 104 'evals_completed': self._evals_completed 105 }) 106 107 def after_run(self, run_context, run_values): 108 evals_completed = run_values.results['evals_completed'] 109 if self._log_progress: 110 if self._num_evals is None: 111 logging.info('Evaluation [%d]', evals_completed) 112 else: 113 if ((evals_completed % self._log_frequency) == 0 or 114 (self._num_evals == evals_completed)): 115 logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals) 116 if self._num_evals is not None and evals_completed >= self._num_evals: 117 run_context.request_stop() 118 119 120 def _evaluate_once(checkpoint_path, 121 master='', 122 scaffold=None, 123 eval_ops=None, 124 feed_dict=None, 125 final_ops=None, 126 final_ops_feed_dict=None, 127 hooks=None, 128 config=None): 129 """Evaluates the model at the given checkpoint path. 130 131 During a single evaluation, the `eval_ops` is run until the session is 132 interrupted or requested to finish. This is typically requested via a 133 `tf.contrib.training.StopAfterNEvalsHook` which results in `eval_ops` running 134 the requested number of times. 135 136 Optionally, a user can pass in `final_ops`, a single `Tensor`, a list of 137 `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is 138 evaluated a single time after `eval_ops` has finished running and the fetched 139 values of `final_ops` are returned. If `final_ops` is left as `None`, then 140 `None` is returned. 141 142 One may also consider using a `tf.contrib.training.SummaryAtEndHook` to record 143 summaries after the `eval_ops` have run. If `eval_ops` is `None`, the 144 summaries run immediately after the model checkpoint has been restored. 145 146 Note that `evaluate_once` creates a local variable used to track the number of 147 evaluations run via `tf.contrib.training.get_or_create_eval_step`. 148 Consequently, if a custom local init op is provided via a `scaffold`, the 149 caller should ensure that the local init op also initializes the eval step. 150 151 Args: 152 checkpoint_path: The path to a checkpoint to use for evaluation. 153 master: The BNS address of the TensorFlow master. 154 scaffold: An tf.train.Scaffold instance for initializing variables and 155 restoring variables. Note that `scaffold.init_fn` is used by the function 156 to restore the checkpoint. If you supply a custom init_fn, then it must 157 also take care of restoring the model from its checkpoint. 158 eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names 159 to `Tensors`, which is run until the session is requested to stop, 160 commonly done by a `tf.contrib.training.StopAfterNEvalsHook`. 161 feed_dict: The feed dictionary to use when executing the `eval_ops`. 162 final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names 163 to `Tensors`. 164 final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`. 165 hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the 166 evaluation loop. 167 config: An instance of `tf.ConfigProto` that will be used to 168 configure the `Session`. If left as `None`, the default will be used. 169 170 Returns: 171 The fetched values of `final_ops` or `None` if `final_ops` is `None`. 172 """ 173 eval_step = _get_or_create_eval_step() 174 175 # Prepare the run hooks. 176 hooks = list(hooks or []) 177 178 if eval_ops is not None: 179 update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True) 180 181 if isinstance(eval_ops, dict): 182 eval_ops['update_eval_step'] = update_eval_step 183 elif isinstance(eval_ops, (tuple, list)): 184 eval_ops = list(eval_ops) + [update_eval_step] 185 else: 186 eval_ops = [eval_ops, update_eval_step] 187 188 eval_step_value = _get_latest_eval_step_value(eval_ops) 189 190 for h in hooks: 191 if isinstance(h, _StopAfterNEvalsHook): 192 h._set_evals_completed_tensor(eval_step_value) # pylint: disable=protected-access 193 194 logging.info('Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S', 195 time.gmtime())) 196 197 # Prepare the session creator. 198 session_creator = monitored_session.ChiefSessionCreator( 199 scaffold=scaffold, 200 checkpoint_filename_with_path=checkpoint_path, 201 master=master, 202 config=config) 203 204 final_ops_hook = basic_session_run_hooks.FinalOpsHook( 205 final_ops, final_ops_feed_dict) 206 hooks.append(final_ops_hook) 207 208 with monitored_session.MonitoredSession( 209 session_creator=session_creator, hooks=hooks) as session: 210 if eval_ops is not None: 211 while not session.should_stop(): 212 session.run(eval_ops, feed_dict) 213 214 logging.info('Finished evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S', 215 time.gmtime())) 216 return final_ops_hook.final_ops_values 217