1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Experiment class collecting information needed for a single training run.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import contextlib 22 import functools 23 import math 24 import os 25 import time 26 27 from tensorflow.contrib.framework import deprecated 28 from tensorflow.contrib.framework import deprecated_args 29 from tensorflow.contrib.framework.python.framework import experimental 30 from tensorflow.contrib.learn.python.learn import evaluable 31 from tensorflow.contrib.learn.python.learn import export_strategy 32 from tensorflow.contrib.learn.python.learn import monitors 33 from tensorflow.contrib.learn.python.learn import trainable 34 from tensorflow.contrib.learn.python.learn.estimators import run_config 35 from tensorflow.contrib.tpu.python.tpu import tpu_estimator 36 from tensorflow.python.estimator import estimator as core_estimator 37 from tensorflow.python.estimator import util as estimator_util 38 from tensorflow.python.framework import ops 39 from tensorflow.python.platform import tf_logging as logging 40 from tensorflow.python.training import basic_session_run_hooks 41 from tensorflow.python.training import saver 42 from tensorflow.python.training import server_lib 43 from tensorflow.python.util import compat 44 45 __all__ = ["Experiment"] 46 47 48 def _get_standardized_predicate_fn(predicate_fn): 49 pred_fn_args = estimator_util.fn_args(predicate_fn) 50 if "checkpoint_path" not in pred_fn_args: 51 # pylint: disable=unused-argument 52 def _pred_fn_wrapper(eval_results, checkpoint_path): 53 return predicate_fn(eval_results) 54 55 return _pred_fn_wrapper 56 else: 57 return predicate_fn 58 59 60 class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener): 61 """Listener that evaluates and exports a model after creating a checkpoint. 62 63 The `EvalAndExportListener` waits for the associated `CheckpointSaverHook` 64 to save a checkpoint. It then uses the provided `eval_fn` and `export_fn` to 65 first evaluate the model using the newly-created checkpoint, and then export 66 the model according to the `export_strategies` provided in the `Experiment`. 67 68 This listener is experimental and may be changed or removed in the future. 69 """ 70 71 def __init__(self, eval_fn, export_fn, model_dir): 72 """Initializes an `EvalAndExportListener`. 73 74 Args: 75 eval_fn: function which evaluates the model with the following signature: 76 `(name, checkpoint_path) -> eval_result` 77 export_fn: function which exports the model according to a set of export 78 strategies. Has the following signature: 79 `(eval_result, checkpoint_path) -> export_results` 80 model_dir: directory which contains estimator parameters and checkpoints. 81 """ 82 self._eval_fn = eval_fn 83 self._export_fn = export_fn 84 self._model_dir = model_dir 85 self._latest_path = None 86 self._eval_result = None 87 self._export_results = None 88 89 def after_save(self, session, global_step_value): 90 """Evaluates and exports the model after a checkpoint is created.""" 91 # Load and cache the path of the most recent checkpoint to avoid duplicate 92 # searches on GCS. 93 logging.info("Checking for checkpoint in %s", self._model_dir) 94 latest_path = saver.latest_checkpoint(self._model_dir) 95 96 if not latest_path: 97 logging.warning("Skipping evaluation and export since model has not been " 98 "saved yet.") 99 elif latest_path == self._latest_path: 100 logging.warning("Skipping evaluation due to same latest checkpoint %s.", 101 latest_path) 102 else: 103 self._latest_path = latest_path 104 self._eval_result = self._eval_fn( 105 name="intermediate_export", checkpoint_path=latest_path) 106 self._export_results = self._export_fn( 107 self._eval_result, checkpoint_path=latest_path) 108 109 @property 110 def eval_result(self): 111 return self._eval_result 112 113 @property 114 def export_results(self): 115 return self._export_results 116 117 118 class Experiment(object): 119 """Experiment is a class containing all information needed to train a model. 120 121 After an experiment is created (by passing an Estimator and inputs for 122 training and evaluation), an Experiment instance knows how to invoke training 123 and eval loops in a sensible fashion for distributed training. 124 """ 125 126 # TODO(ispir): remove delay_workers_by_global_step and make global step based 127 # waiting as only behavior. 128 @deprecated_args( 129 "2016-10-23", 130 "local_eval_frequency is deprecated as local_run will be renamed to " 131 "train_and_evaluate. Use min_eval_frequency and call train_and_evaluate " 132 "instead. Note, however, that the default for min_eval_frequency is 1, " 133 "meaning models will be evaluated every time a new checkpoint is " 134 "available. In contrast, the default for local_eval_frequency is None, " 135 "resulting in evaluation occurring only after training has completed. " 136 "min_eval_frequency is ignored when calling the deprecated local_run.", 137 "local_eval_frequency") 138 def __init__(self, 139 estimator, 140 train_input_fn, 141 eval_input_fn, 142 eval_metrics=None, 143 train_steps=None, 144 eval_steps=100, 145 train_monitors=None, 146 eval_hooks=None, 147 local_eval_frequency=None, 148 eval_delay_secs=120, 149 continuous_eval_throttle_secs=60, 150 min_eval_frequency=None, 151 delay_workers_by_global_step=False, 152 export_strategies=None, 153 train_steps_per_iteration=None, 154 checkpoint_and_export=False, 155 saving_listeners=None): 156 """Constructor for `Experiment`. 157 158 Creates an Experiment instance. None of the functions passed to this 159 constructor are executed at construction time. They are stored and used 160 when a method is executed which requires it. 161 162 Args: 163 estimator: Object implementing Estimator interface, which could be a 164 combination of @{tf.contrib.learn.Trainable} and 165 @{tf.contrib.learn.Evaluable} (deprecated), or 166 @{tf.estimator.Estimator}. 167 train_input_fn: function, returns features and labels for training. 168 eval_input_fn: function, returns features and labels for evaluation. If 169 `eval_steps` is `None`, this should be configured only to produce for a 170 finite number of batches (generally, 1 epoch over the evaluation data). 171 eval_metrics: `dict` of string, metric function. If `None`, default set 172 is used. This should be `None` if the `estimator` is 173 @{tf.estimator.Estimator}. If metrics are provided they will be 174 *appended* to the default set. 175 train_steps: Perform this many steps of training. `None`, the default, 176 means train forever. 177 eval_steps: `evaluate` runs until input is exhausted (or another exception 178 is raised), or for `eval_steps` steps, if specified. 179 train_monitors: A list of monitors to pass to the `Estimator`'s `fit` 180 function. 181 eval_hooks: A list of `SessionRunHook` hooks to pass to the 182 `Estimator`'s `evaluate` function. 183 local_eval_frequency: (applies only to local_run) Frequency of running 184 eval in steps. If `None`, runs evaluation only at the end of training. 185 eval_delay_secs: Start evaluating after waiting for this many seconds. 186 continuous_eval_throttle_secs: Do not re-evaluate unless the last 187 evaluation was started at least this many seconds ago for 188 continuous_eval(). 189 min_eval_frequency: (applies only to train_and_evaluate). the minimum 190 number of steps between evaluations. Of course, evaluation does not 191 occur if no new snapshot is available, hence, this is the minimum. 192 If 0, the evaluation will only happen after training. 193 If None, defaults to 1, unless model_dir is on GCS, in which case the 194 default is 1000. 195 delay_workers_by_global_step: if `True` delays training workers 196 based on global step instead of time. 197 export_strategies: Iterable of `ExportStrategy`s, or a single one, or 198 `None`. 199 train_steps_per_iteration: (applies only to continuous_train_and_eval). 200 Perform this many (integer) number of train steps for each 201 training-evaluation iteration. With a small value, the model will be 202 evaluated more frequently with more checkpoints saved. If `None`, will 203 use a default value (which is smaller than `train_steps` if provided). 204 checkpoint_and_export: (applies only to train_and_evaluate). If `True`, 205 performs intermediate model checkpoints and exports during the training 206 process, rather than only once model training is complete. This 207 parameter is experimental and may be changed or removed in the future. 208 Setting this parameter leads to the following: the value of 209 `min_eval_frequency` will be ignored, and the number of steps between 210 evaluations and exports will instead be determined by the Estimator 211 configuration parameters `save_checkpoints_secs` and 212 `save_checkpoints_steps`. Also, this parameter leads to the creation of 213 a default `CheckpointSaverHook` instead of a `ValidationMonitor`, so the 214 provided `train_monitors` will need to be adjusted accordingly. 215 saving_listeners: list of `CheckpointSaverListener` objects. Used by 216 tf.estimator.Estimator for callbacks that run immediately before or 217 after checkpoint savings. 218 219 Raises: 220 ValueError: if `estimator` does not implement Estimator interface, 221 or if export_strategies has the wrong type. 222 """ 223 if isinstance(estimator, core_estimator.Estimator): 224 self._core_estimator_used = True 225 if eval_metrics is not None: 226 raise ValueError( 227 "`eval_metrics` must be `None` with `tf.estimator.Estimator`. " 228 "Use `eval_metric_ops` in `tf.estimator.EstimatorSpec` instead.") 229 else: 230 self._core_estimator_used = False 231 if not isinstance(estimator, evaluable.Evaluable): 232 raise ValueError( 233 "`estimator` must implement `tf.contrib.learn.Evaluable` " 234 "or `tf.estimator.Estimator`.") 235 if not isinstance(estimator, trainable.Trainable): 236 raise ValueError( 237 "`estimator` must implement `tf.contrib.learn.Trainable`" 238 "or `tf.estimator.`Estimator`.") 239 if saving_listeners is not None: 240 raise ValueError("`saving_listeners` must be `None` with " 241 "`tf.contrib.learn.Estimator`.") 242 243 if isinstance(estimator, tpu_estimator.TPUEstimator): 244 logging.warn( 245 "`Experiment` class cannot work with `tf.contrib.tpu.TPUEstimator`. " 246 "Please call `TPUEstimator` train/evaluate directly. \n" 247 "Details: `Experiment` class is designed for between-graph " 248 "distributed training, while `TPUEstimator` is working in in-graph " 249 "distributed mode. Use with care.") 250 251 super(Experiment, self).__init__() 252 # Immutable fields. 253 self._estimator = estimator 254 self._train_input_fn = train_input_fn 255 self._eval_input_fn = eval_input_fn 256 self._eval_metrics = eval_metrics 257 self._train_steps = train_steps 258 self._eval_steps = eval_steps 259 self._local_eval_frequency = local_eval_frequency 260 self._eval_delay_secs = eval_delay_secs 261 self._continuous_eval_throttle_secs = continuous_eval_throttle_secs 262 self._checkpoint_and_export = checkpoint_and_export 263 self._saving_listeners = saving_listeners 264 # Using 1 on a non-cached file system requires a lot of overhead to 265 # read the checkpoint state file. This is particular bad on GCS, so 266 # we use a different default. This is a temporary band-aid, to be 267 # fixed holistically later (b/36498507). 268 default_min_eval_frequency = 1000 if _is_gcs(estimator.model_dir) else 1 269 self._min_eval_frequency = min_eval_frequency if ( 270 min_eval_frequency is not None) else default_min_eval_frequency 271 self._delay_workers_by_global_step = delay_workers_by_global_step 272 self._train_monitors = train_monitors[:] if train_monitors else [] 273 self._eval_hooks = eval_hooks[:] if eval_hooks else [] 274 self._set_export_strategies(export_strategies) 275 276 self._train_steps_per_iteration = train_steps_per_iteration 277 if (self._train_steps_per_iteration is not None and 278 not isinstance(self._train_steps_per_iteration, int)): 279 raise ValueError("`train_steps_per_iteration` must be an integer.") 280 281 @property 282 def estimator(self): 283 return self._estimator 284 285 @property 286 def eval_metrics(self): 287 return self._eval_metrics 288 289 @property 290 def train_steps(self): 291 return self._train_steps 292 293 @property 294 def eval_steps(self): 295 return self._eval_steps 296 297 def _set_export_strategies(self, values): # pylint: disable=missing-docstring 298 export_strategies = [] 299 if values: 300 if isinstance(values, export_strategy.ExportStrategy): 301 export_strategies.append(values) 302 else: 303 for value in values: 304 if not isinstance(value, export_strategy.ExportStrategy): 305 raise ValueError("`export_strategies` must be an ExportStrategy," 306 " an iterable of ExportStrategy, or `None`," 307 " found %s." % value) 308 export_strategies.append(value) 309 self._export_strategies = tuple(export_strategies) 310 311 def extend_train_hooks(self, additional_hooks): 312 """Extends the hooks for training.""" 313 self._train_monitors.extend(additional_hooks) 314 315 def reset_export_strategies(self, new_export_strategies=None): 316 """Resets the export strategies with the `new_export_strategies`. 317 318 Args: 319 new_export_strategies: A new list of `ExportStrategy`s, or a single one, 320 or None. 321 322 Returns: 323 The old export strategies. 324 """ 325 old_export_strategies = self._export_strategies 326 self._set_export_strategies(new_export_strategies) 327 return old_export_strategies 328 329 def train(self, delay_secs=None): 330 """Fit the estimator using the training data. 331 332 Train the estimator for `self._train_steps` steps, after waiting for 333 `delay_secs` seconds. If `self._train_steps` is `None`, train forever. 334 335 Args: 336 delay_secs: Start training after this many seconds. 337 338 Returns: 339 The trained estimator. 340 """ 341 start = time.time() 342 343 # Start the server, if needed. It's important to start the server before 344 # we (optionally) sleep for the case where no device_filters are set. 345 # Otherwise, the servers will wait to connect to each other before starting 346 # to train. We might as well start as soon as we can. 347 config = self._estimator.config 348 if isinstance(config, run_config.RunConfig): 349 if (config.cluster_spec and config.master and 350 config.environment == run_config.Environment.LOCAL): 351 logging.warn("ClusterSpec and master are provided, but environment is " 352 "set to 'local'. Set environment to 'cloud' if you intend " 353 "to use the distributed runtime.") 354 if (config.environment != run_config.Environment.LOCAL and 355 config.environment != run_config.Environment.GOOGLE and 356 config.cluster_spec and config.master): 357 self._start_server() 358 elif config.cluster_spec and config.master: 359 raise ValueError( 360 "For distributed runtime, Experiment class only works with" 361 "tf.contrib.learn.RunConfig for now, but provided {}".format( 362 type(config))) 363 364 extra_hooks = [] 365 if delay_secs is None: 366 task_id = self._estimator.config.task_id or 0 367 if self._delay_workers_by_global_step: 368 # Wait 5500 global steps for the second worker. Each worker waits more 369 # then previous one but with a diminishing number of steps. 370 extra_hooks.append( 371 basic_session_run_hooks.GlobalStepWaiterHook( 372 int(8000.0 * math.log(task_id + 1)))) 373 delay_secs = 0 374 else: 375 # Wait 5 secs more for each new worker up to 60 secs. 376 delay_secs = min(60, task_id * 5) 377 378 if delay_secs > 0: 379 elapsed_secs = time.time() - start 380 remaining = delay_secs - elapsed_secs 381 logging.info("Waiting %d secs before starting training.", remaining) 382 time.sleep(delay_secs) 383 384 return self._call_train( 385 input_fn=self._train_input_fn, 386 max_steps=self._train_steps, 387 hooks=self._train_monitors + extra_hooks, 388 saving_listeners=self._saving_listeners) 389 390 def evaluate(self, delay_secs=None, name=None): 391 """Evaluate on the evaluation data. 392 393 Runs evaluation on the evaluation data and returns the result. Runs for 394 `self._eval_steps` steps, or if it's `None`, then run until input is 395 exhausted or another exception is raised. Start the evaluation after 396 `delay_secs` seconds, or if it's `None`, defaults to using 397 `self._eval_delay_secs` seconds. 398 399 Args: 400 delay_secs: Start evaluating after this many seconds. If `None`, defaults 401 to using `self._eval_delays_secs`. 402 name: Gives the name to the evauation for the case multiple evaluation is 403 run for the same experiment. 404 405 Returns: 406 The result of the `evaluate` call to the `Estimator`. 407 """ 408 if delay_secs is None: 409 delay_secs = self._eval_delay_secs 410 411 if delay_secs: 412 logging.info("Waiting %d secs before starting eval.", delay_secs) 413 time.sleep(delay_secs) 414 415 return self._call_evaluate( 416 input_fn=self._eval_input_fn, 417 steps=self._eval_steps, 418 metrics=self._eval_metrics, 419 name=(name or "one_pass"), 420 hooks=self._eval_hooks) 421 422 @deprecated( 423 "2016-10-23", 424 "local_run will be renamed to train_and_evaluate and the new default " 425 "behavior will be to run evaluation every time there is a new " 426 "checkpoint.") 427 def local_run(self): 428 with _new_attr_context(self, "_min_eval_frequency"): 429 self._min_eval_frequency = self._local_eval_frequency 430 return self.train_and_evaluate() 431 432 # TODO(xiejw): Allow continuous_eval_predicate_fn to be passed via constructor 433 # once stopping all jobs is implemented. 434 def _continuous_eval(self, 435 input_fn, 436 name, 437 delay_secs, 438 throttle_delay_secs, 439 evaluate_checkpoint_only_once=True, 440 continuous_eval_predicate_fn=None, 441 export=True): 442 """Run continuous eval. 443 444 Runs infinite eval on the evaluation data set. This function starts 445 evaluating after `delay_secs` seconds and then runs no more than one 446 evaluation (with `self._eval_steps` steps each time) per 447 `throttle_delay_secs`. If `train_steps` is not None, will return after 448 global_step reaches `train_steps`. 449 450 Args: 451 input_fn: The input to use for this eval. 452 name: A string appended to the folder name of evaluation results. 453 delay_secs: Start evaluating after this many seconds. If None, defaults to 454 self._eval_delay_secs. 455 throttle_delay_secs: Do not re-evaluate unless the last evaluation was 456 started at least this many seconds ago. If None, defaults to 457 self._continuous_eval_throttle_secs. 458 evaluate_checkpoint_only_once: Whether to skip evaluation of checkpoints 459 that have already been evaluated. Default is `True`. 460 continuous_eval_predicate_fn: A predicate function determining whether to 461 continue eval after each iteration. A `predicate_fn` has one of the 462 following signatures: 463 * (eval_results) -> boolean 464 * (eval_results, checkpoint_path) -> boolean 465 Where `eval_results` is the dictionary of metric evaluations and 466 checkpoint_path is the path to the checkpoint containing the parameters 467 on which that evaluation was based. 468 At the beginning of evaluation, the passed `eval_results` will be None 469 so it's expected that the predicate function handles that gracefully. 470 When `predicate_fn` is not specified, continuous eval will run in an 471 infinite loop (if `train_steps` is None). or exit once global step 472 reaches `train_steps`. 473 474 export: Whether to export from this step. Default is 'True'. 475 476 Raises: 477 ValueError: if `continuous_eval_predicate_fn` is neither None nor 478 callable. 479 """ 480 if continuous_eval_predicate_fn is not None: 481 if not callable(continuous_eval_predicate_fn): 482 raise ValueError( 483 "`continuous_eval_predicate_fn` must be a callable, or None.") 484 predicate_fn = _get_standardized_predicate_fn( 485 continuous_eval_predicate_fn) 486 else: 487 predicate_fn = None 488 489 if delay_secs is None: 490 delay_secs = self._eval_delay_secs 491 if throttle_delay_secs is None: 492 throttle_delay_secs = self._continuous_eval_throttle_secs 493 494 if delay_secs: 495 logging.info("Waiting %f secs before starting eval.", delay_secs) 496 time.sleep(delay_secs) 497 498 previous_path = None 499 eval_result = None 500 last_warning_time = 0 501 while (not predicate_fn or predicate_fn( 502 eval_result, checkpoint_path=previous_path if eval_result else None)): 503 # Exit if we have already reached number of steps to train. 504 if self._has_training_stopped(eval_result): 505 logging.info("Exiting continuous eval, global_step=%s >= " 506 "train_step=%s", eval_result[ops.GraphKeys.GLOBAL_STEP], 507 self._train_steps) 508 return 509 510 start = time.time() 511 512 error_msg = None 513 latest_path = saver.latest_checkpoint(self._estimator.model_dir) 514 if not latest_path: 515 error_msg = ("Estimator is not fitted yet. " 516 "Will start an evaluation when a checkpoint is ready.") 517 elif evaluate_checkpoint_only_once and latest_path == previous_path: 518 error_msg = "No new checkpoint ready for evaluation." 519 520 if error_msg: 521 # Print warning message every 10 mins. 522 eval_result = {} 523 if time.time() - last_warning_time > 600: 524 logging.warning(error_msg) 525 last_warning_time = time.time() 526 else: 527 eval_result = self._call_evaluate( 528 input_fn=input_fn, 529 steps=self._eval_steps, 530 metrics=self._eval_metrics, 531 name=name, 532 checkpoint_path=latest_path, 533 hooks=self._eval_hooks) 534 # Ensure eval result is not None for next round of evaluation. 535 if not eval_result: 536 eval_result = {} 537 538 if export: 539 self._maybe_export(eval_result, checkpoint_path=latest_path) 540 541 # Clear warning timer and update last evaluated checkpoint 542 last_warning_time = 0 543 previous_path = latest_path 544 545 duration = time.time() - start 546 if duration < throttle_delay_secs: 547 difference = throttle_delay_secs - duration 548 logging.info("Waiting %f secs before starting next eval run.", 549 difference) 550 time.sleep(difference) 551 552 def _has_training_stopped(self, eval_result): 553 """Determines whether the training has stopped.""" 554 if not eval_result: 555 return False 556 557 global_step = eval_result.get(ops.GraphKeys.GLOBAL_STEP) 558 return global_step and self._train_steps and (global_step >= 559 self._train_steps) 560 561 def continuous_eval(self, 562 delay_secs=None, 563 throttle_delay_secs=None, 564 evaluate_checkpoint_only_once=True, 565 continuous_eval_predicate_fn=None, 566 name="continuous"): 567 self._continuous_eval( 568 self._eval_input_fn, 569 name=name, 570 delay_secs=delay_secs, 571 throttle_delay_secs=throttle_delay_secs, 572 evaluate_checkpoint_only_once=evaluate_checkpoint_only_once, 573 continuous_eval_predicate_fn=continuous_eval_predicate_fn) 574 575 def continuous_eval_on_train_data(self, 576 delay_secs=None, 577 throttle_delay_secs=None, 578 continuous_eval_predicate_fn=None, 579 name="continuous_on_train_data"): 580 self._continuous_eval( 581 self._train_input_fn, 582 name=name, 583 delay_secs=delay_secs, 584 throttle_delay_secs=throttle_delay_secs, 585 continuous_eval_predicate_fn=continuous_eval_predicate_fn, 586 export=False) 587 588 def train_and_evaluate(self): 589 """Interleaves training and evaluation. 590 591 The frequency of evaluation is controlled by the constructor arg 592 `min_eval_frequency`. When this parameter is 0, evaluation happens 593 only after training has completed. Note that evaluation cannot happen 594 more frequently than checkpoints are taken. If no new snapshots are 595 available when evaluation is supposed to occur, then evaluation doesn't 596 happen for another `min_eval_frequency` steps (assuming a checkpoint is 597 available at that point). Thus, settings `min_eval_frequency` to 1 means 598 that the model will be evaluated everytime there is a new checkpoint. 599 600 This is particular useful for a "Master" task in the cloud, whose 601 responsibility it is to take checkpoints, evaluate those checkpoints, 602 and write out summaries. Participating in training as the supervisor 603 allows such a task to accomplish the first and last items, while 604 performing evaluation allows for the second. 605 606 Returns: 607 The result of the `evaluate` call to the `Estimator` as well as the 608 export results using the specified `ExportStrategy`. 609 """ 610 # The directory to which evaluation summaries are written are determined 611 # by adding a suffix to 'eval'; that suffix is the 'name' parameter to 612 # the various evaluate(...) methods. By setting it to None, we force 613 # the directory name to simply be 'eval'. 614 eval_dir_suffix = None 615 616 # We set every_n_steps to 1, but evaluation only occurs when a new 617 # snapshot is available. If, by the time we finish evaluation 618 # there is a new snapshot, then we just evaluate again. Otherwise, 619 # we keep training until one becomes available. 620 with _new_attr_context(self, "_train_monitors"): 621 self._train_monitors = self._train_monitors or [] 622 config = self._estimator.config 623 intermediate_export = self._checkpoint_and_export and ( 624 config.save_checkpoints_secs or config.save_checkpoints_steps) 625 if intermediate_export: 626 # Create a partially specified evaluate function with the desired 627 # arguments. This will be executed by the _EvalAndExportListener, 628 # which will specify the latest checkpoint path. 629 eval_fn = functools.partial( 630 self._call_evaluate, 631 input_fn=self._eval_input_fn, 632 steps=self._eval_steps, 633 metrics=self._eval_metrics, 634 hooks=self._eval_hooks) 635 636 export_listener = _EvalAndExportListener( 637 eval_fn=eval_fn, 638 export_fn=self._maybe_export, 639 model_dir=self._estimator.model_dir) 640 641 saver_hook = basic_session_run_hooks.CheckpointSaverHook( 642 checkpoint_dir=self._estimator.model_dir, 643 save_secs=config.save_checkpoints_secs, 644 save_steps=config.save_checkpoints_steps, 645 listeners=[export_listener]) 646 self._train_monitors += [saver_hook] 647 else: 648 if self._min_eval_frequency: 649 self._train_monitors += [ 650 monitors.ValidationMonitor( 651 input_fn=self._eval_input_fn, 652 eval_steps=self._eval_steps, 653 metrics=self._eval_metrics, 654 every_n_steps=self._min_eval_frequency, 655 name=eval_dir_suffix, 656 hooks=self._eval_hooks) 657 ] 658 self.train(delay_secs=0) 659 660 # If the checkpoint_and_export flag and appropriate estimator configuration 661 # parameters are set, then model evaluations and exports are done during the 662 # training process. In particular, this will always occur at the end of 663 # training, so we return the most recent results to avoid performing a 664 # duplicate evaluation and model export. 665 if intermediate_export: 666 return export_listener.eval_result, export_listener.export_results 667 else: 668 eval_result = self._call_evaluate( 669 input_fn=self._eval_input_fn, 670 steps=self._eval_steps, 671 metrics=self._eval_metrics, 672 name=eval_dir_suffix, 673 hooks=self._eval_hooks) 674 export_results = self._maybe_export(eval_result) 675 return eval_result, export_results 676 677 @experimental 678 def continuous_train_and_eval(self, continuous_eval_predicate_fn=None): 679 """Interleaves training and evaluation. 680 681 The frequency of evaluation is controlled by the `train_steps_per_iteration` 682 (via constructor). The model will be first trained for 683 `train_steps_per_iteration`, and then be evaluated in turns. 684 685 This method is intended for single machine usage. 686 687 This differs from `train_and_evaluate` as follows: 688 689 1. The procedure will have train and evaluation in turns. The model 690 will be trained for a number of steps (usually smaller than `train_steps` 691 if provided) and then be evaluated. `train_and_evaluate` will train the 692 model for `train_steps` (no small training iterations). 693 694 2. Due to the different approach this schedule takes, it leads to two 695 differences in resource control. First, the resources (e.g., memory) used 696 by training will be released before evaluation (`train_and_evaluate` takes 697 double resources). Second, more checkpoints will be saved as a checkpoint 698 is generated at the end of each training iteration. 699 700 3. As the estimator.train starts from scratch (new graph, new states for 701 input, etc) at each iteration, it is recommended to have the 702 `train_steps_per_iteration` larger. It is also recommended to shuffle your 703 input. 704 705 Args: 706 continuous_eval_predicate_fn: A predicate function determining whether to 707 continue eval after each iteration. A `predicate_fn` has one of the 708 following signatures: 709 * (eval_results) -> boolean 710 * (eval_results, checkpoint_path) -> boolean 711 Where `eval_results` is the dictionary of metric evaluations and 712 checkpoint_path is the path to the checkpoint containing the parameters 713 on which that evaluation was based. 714 At the beginning of evaluation, the passed `eval_results` and 715 `checkpoint_path` will be None so it's expected that the predicate 716 function handles that gracefully. 717 When `predicate_fn` is not specified, continuous eval will run in an 718 infinite loop (if `train_steps` is None). or exit once global step 719 reaches `train_steps`. 720 721 Returns: 722 A tuple of the result of the `evaluate` call to the `Estimator` and the 723 export results using the specified `ExportStrategy`. 724 725 Raises: 726 ValueError: if `continuous_eval_predicate_fn` is neither None nor 727 callable. 728 """ 729 730 if continuous_eval_predicate_fn is not None: 731 if not callable(continuous_eval_predicate_fn): 732 raise ValueError( 733 "`continuous_eval_predicate_fn` must be a callable, or None.") 734 predicate_fn = _get_standardized_predicate_fn( 735 continuous_eval_predicate_fn) 736 else: 737 predicate_fn = None 738 739 export_results = None 740 latest_checkpoint = None 741 eval_result = None 742 743 # Set the default value for train_steps_per_iteration, which will be 744 # overridden by other settings. 745 train_steps_per_iteration = 1000 746 if self._train_steps_per_iteration is not None: 747 train_steps_per_iteration = self._train_steps_per_iteration 748 elif self._train_steps is not None: 749 train_steps_per_iteration = int(self._train_steps / 10) 750 751 while (not predicate_fn or predicate_fn( 752 eval_result, checkpoint_path=latest_checkpoint 753 if eval_result else None)): 754 755 if self._has_training_stopped(eval_result): 756 # Exits once max steps of training is satisfied. 757 logging.info("Stop training model as max steps reached") 758 break 759 760 logging.info("Training model for %s steps", train_steps_per_iteration) 761 self._call_train( 762 input_fn=self._train_input_fn, 763 steps=train_steps_per_iteration, 764 hooks=self._train_monitors, 765 saving_listeners=self._saving_listeners) 766 767 logging.info("Evaluating model now.") 768 latest_checkpoint = saver.latest_checkpoint(self._estimator.model_dir) 769 eval_result = self._call_evaluate( 770 input_fn=self._eval_input_fn, 771 steps=self._eval_steps, 772 metrics=self._eval_metrics, 773 name="one_pass", 774 checkpoint_path=latest_checkpoint, 775 hooks=self._eval_hooks) 776 export_results = self._maybe_export(eval_result) 777 778 return eval_result, export_results 779 780 def _maybe_export(self, eval_result, checkpoint_path=None): 781 """Export the Estimator using export_fn, if defined.""" 782 export_dir_base = os.path.join( 783 compat.as_bytes(self._estimator.model_dir), compat.as_bytes("export")) 784 785 export_results = [] 786 for strategy in self._export_strategies: 787 export_results.append( 788 strategy.export( 789 self._estimator, 790 os.path.join( 791 compat.as_bytes(export_dir_base), 792 compat.as_bytes(strategy.name)), 793 checkpoint_path=checkpoint_path, 794 eval_result=eval_result)) 795 796 return export_results 797 798 def run_std_server(self): 799 """Starts a TensorFlow server and joins the serving thread. 800 801 Typically used for parameter servers. 802 803 Raises: 804 ValueError: if not enough information is available in the estimator's 805 config to create a server. 806 """ 807 self._start_server().join() 808 809 def test(self): 810 """Tests training, evaluating and exporting the estimator for a single step. 811 812 Returns: 813 The result of the `evaluate` call to the `Estimator`. 814 """ 815 self._call_train( 816 input_fn=self._train_input_fn, 817 steps=1, 818 hooks=self._train_monitors, 819 saving_listeners=self._saving_listeners) 820 821 eval_result = self._call_evaluate( 822 input_fn=self._eval_input_fn, 823 steps=1, 824 metrics=self._eval_metrics, 825 name="one_pass") 826 _ = self._maybe_export(eval_result) 827 828 return eval_result 829 830 def _start_server(self): 831 """Creates, starts, and returns a server_lib.Server.""" 832 config = self._estimator.config 833 if (not config.cluster_spec or not config.task_type or not config.master or 834 config.task_id is None): 835 raise ValueError("Could not start server; be sure to specify " 836 "cluster_spec, task_type, master, and task in " 837 "RunConfig or set the TF_CONFIG environment variable.") 838 server = server_lib.Server( 839 config.cluster_spec, 840 job_name=config.task_type, 841 task_index=config.task_id, 842 config=config.tf_config, 843 start=False) 844 server.start() 845 return server 846 847 def _call_train( 848 self, 849 _sentinel=None, # pylint: disable=invalid-name, 850 input_fn=None, 851 steps=None, 852 hooks=None, 853 max_steps=None, 854 saving_listeners=None): 855 if _sentinel is not None: 856 raise ValueError("_call_train should be called with keyword args only") 857 858 # Estimator in core cannot work with monitors. We need to convert them 859 # to hooks. For Estimator in contrib, it is converted internally. So, it is 860 # safe to convert for both cases. 861 hooks = monitors.replace_monitors_with_hooks(hooks, self._estimator) 862 if self._core_estimator_used: 863 return self._estimator.train( 864 input_fn=input_fn, 865 steps=steps, 866 max_steps=max_steps, 867 hooks=hooks, 868 saving_listeners=saving_listeners) 869 else: 870 return self._estimator.fit( 871 input_fn=input_fn, steps=steps, max_steps=max_steps, monitors=hooks) 872 873 def _call_evaluate( 874 self, 875 _sentinel=None, # pylint: disable=invalid-name, 876 input_fn=None, 877 steps=None, 878 metrics=None, 879 name=None, 880 checkpoint_path=None, 881 hooks=None): 882 if _sentinel is not None: 883 raise ValueError("_call_evaluate should be called with keyword args only") 884 885 if self._core_estimator_used: 886 if metrics is not None: 887 raise ValueError( 888 "`eval_metrics` must be `None` with `tf.estimator.Estimator`") 889 return self._estimator.evaluate( 890 input_fn=input_fn, 891 steps=steps, 892 name=name, 893 checkpoint_path=checkpoint_path, 894 hooks=hooks) 895 else: 896 return self._estimator.evaluate( 897 input_fn=input_fn, 898 steps=steps, 899 metrics=metrics, 900 name=name, 901 checkpoint_path=checkpoint_path, 902 hooks=hooks) 903 904 905 @contextlib.contextmanager 906 def _new_attr_context(obj, attr): 907 """Creates a new context in which an object's attribute can be changed. 908 909 This creates a context in which an object's attribute can be changed. 910 Once the context is exited, the attribute reverts to its original value. 911 912 Args: 913 obj: An object whose attribute to restore at the end of the context. 914 attr: An attribute to remember and restore at the end of the context. 915 916 Yields: 917 Context. 918 919 Example: 920 my_obj.x = 1 921 with _new_attr_context(my_obj, "x"): 922 my_obj.x = 2 923 print(my_obj.x) 924 print(my_obj.x) 925 """ 926 saved = getattr(obj, attr) 927 try: 928 yield 929 finally: 930 setattr(obj, attr, saved) 931 932 933 def _is_gcs(model_dir): 934 return model_dir and model_dir.startswith("gs://") 935