1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # =================================================================== 15 """TPUEstimator class.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import collections 22 import copy 23 import os 24 import signal 25 import sys 26 import threading 27 import time 28 29 import numpy as np 30 import six 31 from six.moves import queue as Queue # pylint: disable=redefined-builtin 32 from six.moves import xrange # pylint: disable=redefined-builtin 33 34 from tensorflow.core.framework import variable_pb2 35 from tensorflow.core.framework.summary_pb2 import Summary 36 from tensorflow.core.protobuf import config_pb2 37 from tensorflow.core.protobuf.tpu import compilation_result_pb2 as tpu_compilation_result 38 from tensorflow.python.client import session as tf_session 39 from tensorflow.python.data.ops import dataset_ops 40 from tensorflow.python.data.util import nest as data_nest 41 from tensorflow.python.estimator import estimator as estimator_lib 42 from tensorflow.python.estimator import model_fn as model_fn_lib 43 from tensorflow.python.estimator.export import export_output as export_output_lib 44 from tensorflow.python.framework import constant_op 45 from tensorflow.python.framework import dtypes 46 from tensorflow.python.framework import errors 47 from tensorflow.python.framework import function 48 from tensorflow.python.framework import ops 49 from tensorflow.python.ops import array_ops 50 from tensorflow.python.ops import check_ops 51 from tensorflow.python.ops import control_flow_ops 52 from tensorflow.python.ops import init_ops 53 from tensorflow.python.ops import math_ops 54 from tensorflow.python.ops import resource_variable_ops 55 from tensorflow.python.ops import state_ops 56 from tensorflow.python.ops import summary_ops_v2 as contrib_summary 57 from tensorflow.python.ops import variable_scope 58 from tensorflow.python.ops import variables 59 from tensorflow.python.platform import tf_logging as logging 60 from tensorflow.python.saved_model import tag_constants 61 from tensorflow.python.summary import summary 62 from tensorflow.python.tpu import _tpu_estimator_embedding 63 from tensorflow.python.tpu import error_handling 64 from tensorflow.python.tpu import functional as tpu_functional 65 from tensorflow.python.tpu import session_support 66 from tensorflow.python.tpu import tensor_tracer 67 from tensorflow.python.tpu import tpu 68 from tensorflow.python.tpu import tpu_config 69 from tensorflow.python.tpu import tpu_context 70 from tensorflow.python.tpu import tpu_embedding_gradient 71 from tensorflow.python.tpu import tpu_feed 72 from tensorflow.python.tpu import tpu_function 73 from tensorflow.python.tpu import training_loop 74 from tensorflow.python.tpu import util as util_lib 75 from tensorflow.python.tpu._tpu_estimator_embedding import AdagradParameters # pylint: disable=unused-import 76 from tensorflow.python.tpu._tpu_estimator_embedding import AdamParameters # pylint: disable=unused-import 77 from tensorflow.python.tpu._tpu_estimator_embedding import StochasticGradientDescentParameters # pylint: disable=unused-import 78 from tensorflow.python.tpu._tpu_estimator_embedding import EmbeddingConfigSpec # pylint: disable=unused-import 79 from tensorflow.python.tpu.ops import tpu_ops 80 from tensorflow.python.training import basic_session_run_hooks 81 from tensorflow.python.training import evaluation 82 from tensorflow.python.training import session_run_hook 83 from tensorflow.python.training import training 84 from tensorflow.python.training import training_util 85 from tensorflow.python.util import function_utils 86 from tensorflow.python.util import nest 87 from tensorflow.python.util import tf_inspect 88 89 _INITIAL_LOSS = 1e7 90 _ZERO_LOSS = 0. 91 _TPU_ESTIMATOR = 'tpu_estimator' 92 _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' 93 _BATCH_SIZE_KEY = 'batch_size' 94 _CTX_KEY = 'context' 95 _USE_TPU_KEY = 'use_tpu' 96 _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' 97 _ONE_GIGABYTE = 1024 * 1024 * 1024 98 _TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' 99 _TPU_TRAIN_OP = '_tpu_train_op' 100 _REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' 101 _KEY_WHEN_PREDICTIONS_IS_A_TENSOR = '_key_when_predictions_is_a_tensor' 102 103 # Ideally _USE_TPU_KEY should be reserved as well. However there are already 104 # models that make use of this key, thus it can not be reserved now to prevent 105 # breakage. In the long run, we would like to mitigate this by migrating models 106 # off of using _USE_TPU_KEY. 107 _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY] 108 109 # TODO(b/65703635): Flip the value and remove all dead code. Currently, this is 110 # only used for per-core based deployments. For per-host based pipelines, if a 111 # user returns a Dataset instance it will be automatically wrapped in a 112 # tf.while_loop (This can be disabled by returning features and labels 113 # explicitly). 114 _WRAP_INPUT_FN_INTO_WHILE_LOOP = False 115 116 ops.register_proto_function( 117 '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR), 118 proto_type=variable_pb2.VariableDef, 119 to_proto=resource_variable_ops._to_proto_fn, # pylint: disable=protected-access 120 from_proto=resource_variable_ops._from_proto_fn) # pylint: disable=protected-access 121 122 123 def _is_iterable(obj): 124 """A Python 2 and 3 compatible util to check whether `obj` is iterable.""" 125 try: 126 iter(obj) 127 return True 128 except TypeError: 129 return False 130 131 132 class CatchInvalidHostcallFunctions(control_flow_ops.XLAControlFlowContext): 133 134 def AddOp(self, op): 135 if op.type in [ 136 'AudioSummary', 'AudioSummaryV2', 'HistogramSummary', 'ImageSummary', 137 'MergeSummary', 'ScalarSummary', 'TensorSummary', 'TensorSummaryV2' 138 ]: 139 raise ValueError('Use tf.contrib.summary inside of host_calls.') 140 141 142 def _create_global_step(graph): 143 graph = graph or ops.get_default_graph() 144 if training.get_global_step(graph) is not None: 145 raise ValueError('"global_step" already exists.') 146 # Create in proper graph and base name_scope. 147 with graph.as_default() as g, g.name_scope(None): 148 return variable_scope.get_variable( 149 ops.GraphKeys.GLOBAL_STEP, 150 shape=[], 151 dtype=dtypes.int64, 152 initializer=init_ops.zeros_initializer(), 153 trainable=False, 154 use_resource=True, 155 collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]) 156 157 158 def _create_or_get_iterations_per_loop(): 159 """Creates or gets the iterations_per_loop variable. 160 161 In TPUEstimator, the user provided computation, the model_fn, is wrapped 162 inside a tf.while_loop for peak performance. The iterations of the loop are 163 specified by this variable, which adjusts its value on the CPU after each TPU 164 program execution and before the next TPU execution. 165 166 The purpose of using a variable, rather then a constant, is to allow 167 TPUEstimator adapt the TPU training iterations according to the final steps 168 specified by users. For example, if the user sets the iterations_per_loop as 4 169 in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop 170 variable will have the following value before each TPU training. 171 172 - 1-th TPU execution: iterations_per_loop = 4 173 - 2-th TPU execution: iterations_per_loop = 4 174 - 3-th TPU execution: iterations_per_loop = 2 175 176 As model_fn increases the global step once per train_op invocation, the global 177 step is 10 after all TPU executions, matching the steps=10 inputs passed in by 178 users. 179 180 Returns: 181 A TF non-trainable resource variable. 182 183 Raises: 184 RuntimeError: If multi iterations_per_loop variables were found. 185 """ 186 graph = ops.get_default_graph() 187 collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR) 188 iter_vars = graph.get_collection(collection_name) 189 if len(iter_vars) == 1: 190 return iter_vars[0] 191 elif len(iter_vars) > 1: 192 raise RuntimeError('Multiple iterations_per_loop_var in collection.') 193 194 with ops.colocate_with(training_util.get_global_step()): 195 with variable_scope.variable_scope( 196 _TPU_ESTIMATOR, reuse=variable_scope.AUTO_REUSE): 197 return variable_scope.get_variable( 198 _ITERATIONS_PER_LOOP_VAR, 199 initializer=init_ops.zeros_initializer(), 200 shape=[], 201 dtype=dtypes.int32, 202 trainable=False, 203 collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES], 204 use_resource=True) 205 206 207 def _sync_variables_ops(ctx): 208 """Create varriables synchronization ops. 209 210 Gets the variables back from TPU nodes. This means the variables updated 211 by TPU will now be *synced* to host memory. 212 In BROADCAST mode, we skip this sync since the variables are ususally too 213 big to transmit via RPC. 214 215 Args: 216 ctx: A `_InternalTPUContext` instance with mode. 217 218 Returns: 219 A list of sync ops. 220 """ 221 222 if not ctx.is_input_broadcast_with_iterators(): 223 return [ 224 array_ops.check_numerics(v.read_value(), 225 'Gradient for %s is NaN' % v.name).op 226 for v in variables.trainable_variables() 227 ] 228 else: 229 return [control_flow_ops.no_op()] 230 231 232 def _increase_eval_step_op(iterations_per_loop): 233 """Returns an op to increase the eval step for TPU evaluation. 234 235 Args: 236 iterations_per_loop: Tensor. The number of eval steps running in TPU system 237 before returning to CPU host for each `Session.run`. 238 239 Returns: 240 An operation 241 """ 242 eval_step = evaluation._get_or_create_eval_step() # pylint: disable=protected-access 243 # Estimator evaluate increases 1 by default. So, we increase the difference. 244 return state_ops.assign_add( 245 eval_step, 246 math_ops.cast(iterations_per_loop - 1, dtype=eval_step.dtype), 247 use_locking=True) 248 249 250 def _extract_key_names(tensor_or_dict): 251 if isinstance(tensor_or_dict, dict): 252 return sorted(tensor_or_dict.keys()) 253 return [] 254 255 256 class _SIGNAL(object): 257 """Signal used to control the thread of infeed/outfeed. 258 259 All preserved signals must be negative numbers. Positive numbers are used to 260 indicate the number of iterations for next training/evaluation loop. 261 """ 262 NEXT_BATCH = -1 263 STOP = -2 264 265 266 class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access 267 """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`. 268 269 See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and 270 `export_outputs`. 271 272 For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where 273 `metric_fn` runs on CPU to generate metrics and `tensors` represents the 274 `Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`. 275 To be precise, TPU evaluation expects a slightly different signature from the 276 `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a 277 dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`. 278 The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The 279 `tensors` usually specify the model logits, which are transferred back from 280 TPU system to CPU host. All tensors must have be batch-major, i.e., the batch 281 size is the first dimension. Once all tensors are available at CPU host from 282 all shards, they are concatenated (on CPU) and passed as positional arguments 283 to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is 284 a dict. `metric_fn` takes the `tensors` and returns a dict from metric string 285 name to the result of calling a metric function, namely a `(metric_tensor, 286 update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the 287 `eval_metrics`. 288 289 `scaffold_fn` is a function running on CPU to generate the `Scaffold`. This 290 function should not capture any Tensors in `model_fn`. 291 292 `host_call` is a tuple of a `function` and a list or dictionary of `tensors` 293 to pass to that function and returns a list of Tensors. `host_call` currently 294 works for train() and evaluate(). The Tensors returned by the function is 295 executed on the CPU on every step, so there is communication overhead when 296 sending tensors from TPU to CPU. To reduce the overhead, try reducing the 297 size of the tensors. The `tensors` are concatenated along their major (batch) 298 dimension, and so must be >= rank 1. The `host_call` is useful for writing 299 summaries with `tf.contrib.summary.create_file_writer`. 300 """ 301 302 def __new__(cls, 303 mode, 304 predictions=None, 305 loss=None, 306 train_op=None, 307 eval_metrics=None, 308 export_outputs=None, 309 scaffold_fn=None, 310 host_call=None, 311 training_hooks=None, 312 evaluation_hooks=None, 313 prediction_hooks=None): 314 """Creates a validated `TPUEstimatorSpec` instance.""" 315 host_calls = {} 316 if eval_metrics is not None: 317 host_calls['eval_metrics'] = eval_metrics 318 if host_call is not None: 319 host_calls['host_call'] = host_call 320 _OutfeedHostCall.validate(host_calls) 321 322 training_hooks = tuple(training_hooks or []) 323 evaluation_hooks = tuple(evaluation_hooks or []) 324 prediction_hooks = tuple(prediction_hooks or []) 325 326 for hook in training_hooks + evaluation_hooks + prediction_hooks: 327 if not isinstance(hook, session_run_hook.SessionRunHook): 328 raise TypeError('All hooks must be SessionRunHook instances, given: {}' 329 .format(hook)) 330 331 return super(TPUEstimatorSpec, cls).__new__( 332 cls, 333 mode=mode, 334 predictions=predictions, 335 loss=loss, 336 train_op=train_op, 337 eval_metrics=eval_metrics, 338 export_outputs=export_outputs, 339 scaffold_fn=scaffold_fn, 340 host_call=host_call, 341 training_hooks=training_hooks, 342 evaluation_hooks=evaluation_hooks, 343 prediction_hooks=prediction_hooks) 344 345 def as_estimator_spec(self): 346 """Creates an equivalent `EstimatorSpec` used by CPU train/eval.""" 347 host_calls = {} 348 if self.eval_metrics is not None: 349 host_calls['eval_metrics'] = self.eval_metrics 350 if self.host_call is not None: 351 host_calls['host_call'] = self.host_call 352 host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls) 353 eval_metric_ops = None 354 if self.eval_metrics is not None: 355 eval_metric_ops = host_call_ret['eval_metrics'] 356 hooks = None 357 if self.host_call is not None: 358 hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] 359 loss = self.loss 360 if tensor_tracer.TensorTracer.is_enabled() \ 361 and self.train_op is not None: 362 tt = tensor_tracer.TensorTracer() 363 loss = tt.trace_cpu(ops.get_default_graph(), loss, self.train_op) 364 365 hooks = tuple(hooks or []) 366 scaffold = self.scaffold_fn() if self.scaffold_fn else None 367 return model_fn_lib.EstimatorSpec( 368 mode=self.mode, 369 predictions=self.predictions, 370 loss=loss, 371 train_op=self.train_op, 372 eval_metric_ops=eval_metric_ops, 373 export_outputs=self.export_outputs, 374 scaffold=scaffold, 375 training_hooks=self.training_hooks + hooks, 376 evaluation_hooks=self.evaluation_hooks + hooks, 377 prediction_hooks=self.prediction_hooks + hooks) 378 379 380 class _OpQueueContext(object): 381 """Manages work queue and thread for a infeed/outfeed thread.""" 382 383 def __init__(self, name, target, args): 384 self._name = name 385 self._queue = Queue.Queue() 386 args = (self,) + args 387 self._thread = threading.Thread(name=name, target=target, args=args) 388 self._thread.daemon = True 389 self._thread.start() 390 391 def stop(self): 392 self._queue.put(_SIGNAL.STOP) 393 394 def send_next_batch_signal(self, iterations): 395 self._queue.put(iterations) 396 397 def read_iteration_counts(self): 398 while True: 399 iterations = self._queue.get(block=True) 400 logging.debug('%s read iterations %s', self._name, iterations) 401 if iterations == _SIGNAL.STOP: 402 logging.info('%s received shutdown signal, stopping.', self._name) 403 return 404 yield iterations 405 406 def join(self): 407 logging.info('Shutting down %s thread.', self._name) 408 self.stop() 409 self._thread.join() 410 411 412 class _OpSignalOnceQueueContext(_OpQueueContext): 413 """Manages work queue and thread for a infeed/outfeed thread. 414 415 This subclass only signals once. 416 """ 417 418 def __init__(self, name, target, args): 419 super(_OpSignalOnceQueueContext, self).__init__(name, target, args) 420 self._has_signaled = False 421 422 def send_next_batch_signal(self, iterations): 423 if not self._has_signaled: 424 self._queue.put(iterations) 425 self._has_signaled = True 426 427 428 class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): 429 """A Session hook setting up the TPU initialization, infeed, and outfeed. 430 431 This hook does two major things: 432 1. initialize and shutdown TPU system. 433 2. launch and join the threads for infeed enqueue and (optional) outfeed 434 dequeue. 435 """ 436 437 def __init__(self, 438 ctx, 439 enqueue_ops, 440 dequeue_ops, 441 tpu_compile_op, 442 run_infeed_loop_on_coordinator=True, 443 rendezvous=None, 444 master=None, 445 session_config=None, 446 tpu_init_ops=None): 447 self._master_job = ctx.master_job 448 self._enqueue_ops = enqueue_ops 449 self._dequeue_ops = dequeue_ops 450 self._rendezvous = rendezvous 451 self._master = master 452 self._session_config = session_config 453 self._init_ops = list(tpu_init_ops or []) 454 if ctx.embedding_config is None: 455 self._embedding_layer_config = None 456 else: 457 self._embedding_layer_config = ( 458 ctx.embedding_config.tpu_embedding.config_proto) 459 self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator 460 self._initial_infeed_sleep_secs = ( 461 ctx.config.tpu_config.initial_infeed_sleep_secs) 462 463 self._feed_error = None 464 self._finished = False 465 # When using model parallelism, the TPU is pre-initialized at startup to 466 # fetch mesh information. We skip re-initializing it here to avoid 467 # suspected issues due to the mesh layout changing on the second 468 # initialization. 469 self._should_initialize_tpu = not ctx.model_parallelism_enabled 470 self._tpu_compile_op = tpu_compile_op 471 472 def begin(self): 473 logging.info('TPU job name %s', self._master_job) 474 self._iterations_per_loop_var = _create_or_get_iterations_per_loop() 475 if self._should_initialize_tpu: 476 self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] 477 else: 478 self._finalize_ops = [] 479 480 summary_writer_init_ops = contrib_summary.summary_writer_initializer_op() 481 self._init_ops.extend(summary_writer_init_ops) 482 # Get all the writer resources from the initializer, so we know what to 483 # flush. 484 for op in summary_writer_init_ops: 485 self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) 486 487 def _run_infeed(self, queue_ctx, session): 488 logging.info('Starting infeed thread controller.') 489 if self._initial_infeed_sleep_secs: 490 logging.info('Infeed thread sleeping for %d seconds.', 491 self._initial_infeed_sleep_secs) 492 time.sleep(self._initial_infeed_sleep_secs) 493 logging.info('Infeed thread starting after sleep') 494 495 with self._rendezvous.catch_errors(source='infeed', session=session): 496 if self._run_infeed_loop_on_coordinator: 497 for count, steps in enumerate(queue_ctx.read_iteration_counts()): 498 for i in xrange(steps): 499 logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) 500 session.run(self._enqueue_ops) 501 else: 502 for _ in queue_ctx.read_iteration_counts(): 503 session.run(self._enqueue_ops) 504 logging.info('Infeed thread finished, shutting down.') 505 506 def _run_outfeed(self, queue_ctx, session): 507 logging.info('Starting outfeed thread controller.') 508 with self._rendezvous.catch_errors(source='outfeed', session=session): 509 for count, steps in enumerate(queue_ctx.read_iteration_counts()): 510 for i in xrange(steps): 511 logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i) 512 session.run(self._dequeue_ops) 513 logging.info('Outfeed thread finished, shutting down.') 514 515 def _create_infeed_controller(self, name, target, args): 516 return _OpQueueContext(name=name, target=target, args=args) 517 518 def _assertCompilationSucceeded(self, result, coord): 519 proto = tpu_compilation_result.CompilationResultProto() 520 proto.ParseFromString(result) 521 if proto.status_error_message: 522 logging.error('Compilation failed: {}'.format(proto.status_error_message)) 523 coord.request_stop() 524 else: 525 logging.info('Compilation succeeded') 526 527 def after_create_session(self, session, coord): 528 if self._should_initialize_tpu: 529 logging.info('Init TPU system') 530 start = time.time() 531 with ops.Graph().as_default(): 532 with tf_session.Session( 533 self._master, config=self._session_config) as sess: 534 sess.run( 535 tpu.initialize_system( 536 job=self._master_job, 537 embedding_config=self._embedding_layer_config)) 538 logging.info('Initialized TPU in %d seconds', time.time() - start) 539 540 session.run(self._init_ops, 541 options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) 542 543 if os.environ.get('TPU_SPLIT_COMPILE_AND_EXECUTE', '') == '1': 544 logging.info('Compiling user program: this may take a while...') 545 self._assertCompilationSucceeded(session.run(self._tpu_compile_op), coord) 546 547 self._infeed_controller = self._create_infeed_controller( 548 name='InfeedController', target=self._run_infeed, args=(session,)) 549 550 self._outfeed_controller = _OpQueueContext( 551 name='OutfeedController', target=self._run_outfeed, args=(session,)) 552 553 # Enable the worker watchdog to terminate workers on coordinator exit. 554 watchdog_timeout = int(os.environ.get('TF_TPU_WATCHDOG_TIMEOUT', '0')) 555 if watchdog_timeout > 0: 556 session_support.start_worker_watchdog(session, 557 shutdown_timeout=watchdog_timeout) 558 559 def before_run(self, run_context): 560 self._feed_error = None 561 562 iterations = run_context.session.run(self._iterations_per_loop_var) 563 564 logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations) 565 self._infeed_controller.send_next_batch_signal(iterations) 566 567 logging.info('Dequeue next (%d) batch(es) of data from outfeed.', 568 iterations) 569 self._outfeed_controller.send_next_batch_signal(iterations) 570 571 def end(self, session): 572 self._finished = True 573 logging.info('Stop infeed thread controller') 574 self._infeed_controller.join() 575 self._rendezvous.record_done('infeed') 576 577 logging.info('Stop output thread controller') 578 self._outfeed_controller.join() 579 self._rendezvous.record_done('outfeed') 580 581 logging.info('Shutdown TPU system.') 582 session.run(self._finalize_ops) 583 584 585 class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook): 586 587 def __init__(self, ctx, enqueue_ops, dequeue_ops, tpu_compile_op, 588 rendezvous=None, master=None, session_config=None): 589 super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__( 590 ctx, 591 enqueue_ops, 592 dequeue_ops, 593 tpu_compile_op=tpu_compile_op, 594 run_infeed_loop_on_coordinator=False, 595 rendezvous=rendezvous, 596 master=master, 597 session_config=session_config) 598 599 def _create_infeed_controller(self, name, target, args): 600 return _OpSignalOnceQueueContext(name=name, target=target, args=args) 601 602 603 class _TPUStopAtStepHook(session_run_hook.SessionRunHook): 604 """Hook that requests stop at a specified step. 605 606 This hook is similar to the `session_run_hook._StopAfterNEvalsHook` with 607 following differences for TPU training: 608 609 1. This hook sets the variable for iterations_per_loop, which is used by 610 `TPUInfeedOutfeedSessionHook` to control the iterations for infeed/outfeed. 611 As the hook execution order is not guaranteed, the variable update is 612 handled in `after_create_session` and `after_run` as 613 `TPUInfeedOutfeedSessionHook` reads the variable value in `before_run`. 614 615 2. For each training loop (session.run), the global step could be increased 616 multiple times on TPU. The global step tensor value will be explicitly read 617 again in `after_run` to ensure the latest value is retrieved to avoid race 618 condition. 619 """ 620 621 def __init__(self, iterations, num_steps=None, last_step=None): 622 """Initializes a `StopAtStepHook`. 623 624 Args: 625 iterations: The number of iterations to run optimizer per training loop. 626 num_steps: Number of steps to execute. 627 last_step: Step after which to stop. 628 629 Raises: 630 ValueError: If one of the arguments is invalid. 631 """ 632 if num_steps is None and last_step is None: 633 raise ValueError('One of num_steps or last_step must be specified.') 634 if num_steps is not None and last_step is not None: 635 raise ValueError('Only one of num_steps or last_step can be specified.') 636 self._num_steps = num_steps 637 self._last_step = last_step 638 self._iterations = iterations 639 640 def _next_iterations(self, global_step, last_step): 641 gap = last_step - global_step 642 return min(gap, self._iterations) 643 644 def begin(self): 645 self._global_step_tensor = training_util.get_global_step() 646 if self._global_step_tensor is None: 647 raise RuntimeError('Global step should be created.') 648 649 self._iterations_per_loop_var = _create_or_get_iterations_per_loop() 650 651 def after_create_session(self, session, coord): 652 global_step = session.run(self._global_step_tensor) 653 if self._last_step is None: 654 self._last_step = global_step + self._num_steps 655 656 iterations = self._next_iterations(global_step, self._last_step) 657 658 self._iterations_per_loop_var.load(iterations, session=session) 659 660 def after_run(self, run_context, run_values): 661 # Global step cannot be retrieved via SessionRunArgs and before_run due to 662 # race condition. 663 global_step = run_context.session.run(self._global_step_tensor) 664 if global_step >= self._last_step: 665 run_context.request_stop() 666 else: 667 iterations = self._next_iterations(global_step, self._last_step) 668 self._iterations_per_loop_var.load( 669 iterations, session=run_context.session) 670 671 672 class _SetEvalIterationsHook(session_run_hook.SessionRunHook): 673 """Hook that requests stop at a specified step.""" 674 675 def __init__(self, num_steps): 676 """Initializes a `_SetEvalIterationsHook`. 677 678 Args: 679 num_steps: Number of steps to execute. 680 """ 681 self._num_steps = num_steps 682 683 def begin(self): 684 self._iterations_per_loop_var = _create_or_get_iterations_per_loop() 685 686 def after_create_session(self, session, coord): 687 self._iterations_per_loop_var.load(self._num_steps, session=session) 688 689 690 class _StoppingPredictHook(session_run_hook.SessionRunHook): 691 """Hook that requests stop according to the stopping signal in prediction.""" 692 693 def __init__(self, scalar_stopping_signal): 694 self._scalar_stopping_signal = scalar_stopping_signal 695 696 def begin(self): 697 self._iterations_per_loop_var = _create_or_get_iterations_per_loop() 698 699 def after_create_session(self, session, coord): 700 # This is not necessary as we do not run infeed enqueue and outfeed dequeue 701 # in side threads for prediction model. But it makes the 702 # TPUInfeedOutfeedSessionHook prints nice message. 703 self._iterations_per_loop_var.load(1, session=session) 704 705 def before_run(self, run_context): 706 return session_run_hook.SessionRunArgs(self._scalar_stopping_signal) 707 708 def after_run(self, run_context, run_values): 709 _ = run_context 710 scalar_stopping_signal = run_values.results 711 if _StopSignals.should_stop(scalar_stopping_signal): 712 # NOTE(xiejw): In prediction, stopping signals are inserted for each 713 # batch. And we append one more batch to signal the system it should stop. 714 # The data flow might look like 715 # 716 # batch 0: images, labels, stop = 0 (user provided) 717 # batch 1: images, labels, stop = 0 (user provided) 718 # ... 719 # batch 99: images, labels, stop = 0 (user provided) 720 # batch 100: images, labels, stop = 1 (TPUEstimator appended) 721 # 722 # where the final batch (id = 100) is appended by TPUEstimator, so we 723 # should drop it before returning the predictions to user. 724 # To achieve that, we throw the OutOfRangeError in after_run. Once 725 # Monitored Session sees this error in SessionRunHook.after_run, the 726 # "current" prediction, i.e., batch with id=100, will be discarded 727 # immediately 728 raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.') 729 730 731 def generate_per_core_enqueue_ops_fn_for_host( 732 ctx, input_fn, inputs_structure_recorder, host_device, host_id): 733 """Generates infeed enqueue ops for per-core input_fn on a single host.""" 734 captured_infeed_queue = _CapturedObject() 735 tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) 736 737 def enqueue_ops_fn(): 738 """A fn returns enqueue_ops.""" 739 num_cores_per_host = ctx.num_of_cores_per_host 740 per_host_sharded_inputs = [] 741 for core_ordinal in range(num_cores_per_host): 742 with ops.name_scope('ordinal_%d' % (core_ordinal)): 743 user_context = tpu_context.TPUContext( 744 internal_ctx=ctx, 745 input_device=host_device, 746 invocation_index=host_id * ctx.num_of_cores_per_host + core_ordinal) 747 inputs = _Inputs.from_input_fn(input_fn(user_context)) 748 if inputs.is_dataset: 749 raise TypeError( 750 '`input_fn` returning `Dataset` is not yet supported in ' 751 'per-Core input pipeline deployment yet. Please set ' 752 'TPUConfig.per_host_input_for_training to True or return ' 753 '`features` and `labels` from `input_fn`') 754 features, labels = inputs.features_and_labels() 755 756 inputs_structure_recorder.validate_and_record_structure( 757 features, labels) 758 flattened_inputs = ( 759 inputs_structure_recorder.flatten_features_and_labels( 760 features, labels)) 761 per_host_sharded_inputs.append(flattened_inputs) 762 763 infeed_queue = tpu_feed.InfeedQueue( 764 number_of_tuple_elements=len(per_host_sharded_inputs[0])) 765 captured_infeed_queue.capture(infeed_queue) 766 767 per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( 768 per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) 769 return per_host_enqueue_ops 770 771 return enqueue_ops_fn, captured_infeed_queue 772 773 774 def generate_per_host_enqueue_ops_fn_for_host( 775 ctx, input_fn, inputs_structure_recorder, batch_axis, device, host_id): 776 """Generates infeed enqueue ops for per-host input_fn on a single host.""" 777 captured_infeed_queue = _CapturedObject() 778 779 dataset_initializer = None 780 781 with ops.device(device): 782 user_context = tpu_context.TPUContext( 783 internal_ctx=ctx, input_device=device, invocation_index=host_id) 784 inputs = _Inputs.from_input_fn(input_fn(user_context)) 785 786 is_dataset = inputs.is_dataset 787 if ctx.mode == model_fn_lib.ModeKeys.PREDICT: 788 if not is_dataset: 789 raise TypeError( 790 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' 791 '`features` and `labels`.') 792 if batch_axis is not None: 793 raise TypeError('For mode PREDICT, batch_axis is not supported yet.') 794 inputs = _InputsWithStoppingSignals( 795 dataset=inputs.dataset, 796 batch_size=ctx.batch_size_for_input_fn, 797 add_padding=True) 798 799 if is_dataset: 800 dataset_initializer = inputs.dataset_initializer() 801 802 tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) 803 804 def enqueue_ops_fn(): 805 """A Fn returning the TPU infeed enqueue ops. 806 807 By providing as a Fn, it can be invoked inside the tf.while_loop such that 808 the input pipeline for multiple iterations can be executed by one 809 Session.run call. 810 811 Returns: 812 list of dict of ops. 813 """ 814 with ops.device(device): 815 num_of_replicas_per_host = ctx.num_of_replicas_per_host 816 # Convert user input to features and labels. If the user returns a 817 # dataset, it is initialized and the features and labels extracted via 818 # `dataset.iterator.get_next()` 819 features, labels = inputs.features_and_labels() 820 signals = inputs.signals() 821 822 inputs_structure_recorder.validate_and_record_structure(features, labels) 823 unsharded_tensor_list = ( 824 inputs_structure_recorder.flatten_features_and_labels( 825 features, labels, signals)) 826 827 infeed_queue = tpu_feed.InfeedQueue( 828 tuple_types=[t.dtype for t in unsharded_tensor_list], 829 tuple_shapes=[t.shape for t in unsharded_tensor_list], 830 shard_dimensions=batch_axis) 831 captured_infeed_queue.capture(infeed_queue) 832 infeed_queue.set_number_of_shards(num_of_replicas_per_host) 833 per_host_enqueue_ops = ( 834 infeed_queue.split_inputs_and_generate_enqueue_ops( 835 unsharded_tensor_list, 836 placement_function=lambda x: device, 837 tpu_ordinal_function=tpu_ordinal_function_impl)) 838 if signals is None: 839 return per_host_enqueue_ops 840 else: 841 return { 842 'ops': per_host_enqueue_ops, 843 'signals': signals, 844 } 845 846 return enqueue_ops_fn, captured_infeed_queue, dataset_initializer 847 848 849 def generate_per_host_v2_enqueue_ops_fn_for_host( 850 ctx, input_fn, inputs_structure_recorder, device, host_id): 851 """Generates infeed enqueue ops for per-host input_fn on a single host.""" 852 captured_infeed_queue = _CapturedObject() 853 dataset_initializer = None 854 855 with ops.device(device): 856 user_context = tpu_context.TPUContext( 857 internal_ctx=ctx, input_device=device, invocation_index=host_id) 858 inputs = _Inputs.from_input_fn(input_fn(user_context)) 859 860 is_dataset = inputs.is_dataset 861 if not is_dataset: 862 raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 ' 863 'input pipeline configuration.') 864 865 if ctx.mode == model_fn_lib.ModeKeys.PREDICT: 866 inputs = _InputsWithStoppingSignals( 867 dataset=inputs.dataset, 868 batch_size=ctx.batch_size_for_input_fn, 869 add_padding=True, 870 num_invocations_per_step=ctx.num_of_replicas_per_host) 871 872 dataset_initializer = inputs.dataset_initializer() 873 tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) 874 875 def enqueue_ops_fn(): 876 """Generates the per_host enqueue ops.""" 877 control_deps = [] 878 per_host_sharded_inputs = [] 879 sparse_features_list = [] 880 num_replicas_per_host = ctx.num_of_replicas_per_host 881 cached_signals = None 882 with ops.device(device): 883 if not inputs.is_dataset: 884 raise TypeError('`input_fn` must return a `Dataset` for this mode.') 885 for _ in range(num_replicas_per_host): 886 # Use control dependencies to ensure a deterministic ordering. 887 with ops.control_dependencies(control_deps): 888 features, labels = inputs.features_and_labels() # Calls get_next() 889 signals = inputs.signals() 890 891 # All the replicas share the replica 0's stopping singal. 892 # This avoids inconsistent state among different model replcias. 893 if cached_signals: 894 signals['stopping'] = cached_signals['stopping'] 895 else: 896 cached_signals = signals 897 898 features, labels, sparse_features = ( 899 _tpu_estimator_embedding.split_inputs(ctx, features, labels)) 900 sparse_features_list.append(sparse_features) 901 902 inputs_structure_recorder.validate_and_record_structure( 903 features, labels) 904 flattened_inputs = ( 905 inputs_structure_recorder.flatten_features_and_labels( 906 features, labels, signals)) 907 control_deps.extend(flattened_inputs) 908 per_host_sharded_inputs.append(flattened_inputs) 909 910 if inputs_structure_recorder.flattened_input_dims: 911 input_partition_dims = inputs_structure_recorder.flattened_input_dims 912 if signals: 913 input_partition_dims += [None] * len(signals) 914 # pylint: disable=protected-access 915 infeed_queue = tpu_feed._PartitionedInfeedQueue( 916 number_of_tuple_elements=len(per_host_sharded_inputs[0]), 917 host_id=host_id, 918 input_partition_dims=input_partition_dims, 919 device_assignment=ctx.device_assignment) 920 per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( 921 per_host_sharded_inputs) 922 else: 923 infeed_queue = tpu_feed.InfeedQueue( 924 number_of_tuple_elements=len(per_host_sharded_inputs[0])) 925 per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( 926 per_host_sharded_inputs, 927 tpu_ordinal_function=tpu_ordinal_function_impl) 928 captured_infeed_queue.capture(infeed_queue) 929 930 if ctx.embedding_config: 931 per_host_enqueue_ops.extend( 932 ctx.embedding_config.tpu_embedding.generate_enqueue_ops( 933 sparse_features_list)) 934 935 if signals is None: 936 return per_host_enqueue_ops 937 else: 938 return { 939 'ops': per_host_enqueue_ops, 940 'signals': signals, 941 } 942 943 return enqueue_ops_fn, captured_infeed_queue, dataset_initializer 944 945 946 def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder, 947 num_hosts): 948 """Generates infeed enqueue ops for one input_fn on all the hosts.""" 949 captured_infeed_queue = _CapturedObject() 950 dataset_initializer = None 951 device_0 = ctx.tpu_host_placement_function(host_id=0) 952 with ops.device(device_0): 953 user_context = tpu_context.TPUContext( 954 internal_ctx=ctx, input_device=device_0, invocation_index=0) 955 inputs = _Inputs.from_input_fn(input_fn(user_context)) 956 957 is_dataset = inputs.is_dataset 958 if ctx.mode == model_fn_lib.ModeKeys.PREDICT: 959 if not is_dataset: 960 raise TypeError( 961 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' 962 '`features` and `labels`.') 963 964 inputs = _InputsWithStoppingSignals( 965 dataset=inputs.dataset, 966 batch_size=ctx.batch_size_for_input_fn, 967 add_padding=True) 968 969 if is_dataset: 970 dataset_initializer = inputs.dataset_initializer() 971 num_replicas_per_host = ctx.num_of_replicas_per_host 972 973 def tpu_ordinal_function_impl(replica_id): 974 if ctx.device_assignment: 975 return ctx.device_assignment.tpu_ordinal(replica=replica_id) 976 else: 977 return replica_id % num_replicas_per_host 978 979 def device_function_impl(replica_id): 980 return ctx.tpu_host_placement_function(replica_id=replica_id) 981 982 def enqueue_ops_fn(): 983 """Generates enqueue ops for all the hosts.""" 984 broadcasted_inputs = [] 985 flattened_inputs = None # Cache result from input_fn. 986 signals = None 987 num_replicas = ctx.num_replicas 988 core_id = 0 989 for host_id in xrange(num_hosts): 990 with ops.device(ctx.tpu_host_placement_function(host_id=host_id)): 991 for _ in xrange(ctx.num_of_replicas_per_host): 992 # Note: input_fn is only called once at host 0 for the first replica. 993 # The features and labels returned from that invocation are 994 # broadcasted to other replicas(including the replicas on other 995 # hosts). 996 if flattened_inputs is None: 997 features, labels = inputs.features_and_labels() # Calls get_next() 998 signals = inputs.signals() 999 1000 inputs_structure_recorder.validate_and_record_structure( 1001 features, labels) 1002 flattened_inputs = ( 1003 inputs_structure_recorder.flatten_features_and_labels( 1004 features, labels, signals)) 1005 if (ctx.config.tpu_config.eval_training_input_configuration is 1006 tpu_config.InputPipelineConfig.SLICED): 1007 input_slices = [ 1008 array_ops.split(x, num_replicas) for x in flattened_inputs 1009 ] 1010 if (ctx.config.tpu_config.eval_training_input_configuration is 1011 tpu_config.InputPipelineConfig.SLICED): 1012 # for each core, slice out the flattened_inputs for each core. 1013 broadcasted_inputs.append([x[core_id] for x in input_slices]) 1014 core_id += 1 1015 else: 1016 broadcasted_inputs.append(flattened_inputs) 1017 1018 infeed_queue = tpu_feed.InfeedQueue( 1019 number_of_tuple_elements=len(broadcasted_inputs[0])) 1020 captured_infeed_queue.capture(infeed_queue) 1021 enqueue_ops = infeed_queue.generate_enqueue_ops( 1022 broadcasted_inputs, 1023 tpu_ordinal_function=tpu_ordinal_function_impl, 1024 placement_function=device_function_impl) 1025 1026 if signals is None: 1027 return enqueue_ops 1028 else: 1029 return { 1030 'ops': enqueue_ops, 1031 'signals': signals, 1032 } 1033 1034 return enqueue_ops_fn, captured_infeed_queue, dataset_initializer 1035 1036 1037 class _InputPipeline(object): 1038 """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue. 1039 1040 `_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from 1041 call site. To be precise, based on the configuration in 1042 `_InternalTPUContext`, it invokes `input_fn` for all cores (usually 1043 multi-host TPU training) or for one host (usually for single-host TPU 1044 evaluation), and sends all `features` and `labels` returned by `input_fn` to 1045 TPU infeed. For per-core invocation, `features` and `labels` are piped to 1046 infeed directly, one tuple for each core. For per-host invocation, `features` 1047 and `labels` are split at host (with respect to `batch_axis`) and piped to all 1048 cores accordingly. 1049 1050 In addition, flatten/unflatten are handled by `_InputPipeline` also. Model 1051 inputs returned by the `input_fn` can have one of the following forms: 1052 1. features 1053 2. (features, labels) 1054 3. ((arbitrarily nested structure of features), labels) 1055 1056 Internally, form 1 is reformed to `(features, None)` as features and labels 1057 are passed separately to underlying methods. For TPU training, TPUEstimator 1058 may expect multiple `features` and `labels` tuples one for each core. 1059 1060 TPUEstimator allows various different structures for inputs (namely `features` 1061 and `labels`). Both `features` and `labels` can be any nested sturcture 1062 supported by TF nest (namely, dict, tuples, namedtuples or any nested 1063 structure of such of Tensors). `labels` could be `None` as well. 1064 1065 These are flattened before they are passed to the infeed/outfeed library 1066 as that expectes flattend lists. 1067 """ 1068 1069 class InputsStructureRecorder(object): 1070 """The recorder to record inputs structure.""" 1071 1072 def __init__(self, input_partition_dims=None): 1073 # Holds the structure of inputs 1074 self._feature_structure = {} 1075 self._flattened_input_dims = None 1076 1077 if input_partition_dims: 1078 # This should have been validated in TPUConfig. 1079 assert len(input_partition_dims) <= 2, 'must have 1 or 2 elements.' 1080 if len(input_partition_dims) == 2: 1081 self._feature_dims, self._label_dims = input_partition_dims 1082 else: 1083 self._feature_dims = input_partition_dims[0] 1084 self._label_dims = None 1085 1086 assert self._feature_dims is not None, ('input_partition_dims[0] must ' 1087 'not be None') 1088 else: 1089 self._feature_dims = None 1090 self._label_dims = None 1091 1092 # Internal state. 1093 self._initialized = False 1094 1095 @property 1096 def flattened_input_dims(self): 1097 assert self._initialized, 'InputsStructureRecorder is not initialized.' 1098 return self._flattened_input_dims 1099 1100 def has_labels(self): 1101 return 'labels' in self._feature_structure 1102 1103 def _flatten_input_dims(self, feature_dims, feature_dims_names, label_dims, 1104 label_dims_names, label_names, has_labels): 1105 """Flatten input dims with the same order as flattened input tensors.""" 1106 flattened_input_dims = [] 1107 if feature_dims_names: 1108 # We need a fixed ordering for matching the tensors in features. 1109 flattened_input_dims.extend( 1110 [feature_dims[name] for name in feature_dims_names]) 1111 else: 1112 flattened_input_dims.append(feature_dims) 1113 1114 if label_dims_names: 1115 # We need a fixed ordering for matching the tensors in labels. 1116 flattened_input_dims.extend( 1117 [label_dims[name] for name in label_dims_names]) 1118 else: 1119 if label_names: 1120 num_tensors_in_label = len(label_names) 1121 else: 1122 num_tensors_in_label = int(has_labels) 1123 # Setting `None` in input_partition_dims[1] will apply `None` to 1124 # all the tensors in labels, regardless of internal structure. 1125 flattened_input_dims.extend([label_dims] * num_tensors_in_label) 1126 1127 return flattened_input_dims 1128 1129 def validate_and_record_structure(self, features, labels): 1130 """Validates and records the structure of `features` and `labels`.""" 1131 # Extract structure. 1132 has_labels = labels is not None 1133 feature_names = _extract_key_names(features) 1134 label_names = _extract_key_names(labels) 1135 1136 if not self._initialized: 1137 # Record structure. 1138 self._initialized = True 1139 if self._feature_dims is not None: 1140 feature_dims_names = _extract_key_names(self._feature_dims) 1141 if feature_dims_names != feature_names: 1142 raise ValueError( 1143 'TPUConfig.input_partition_dims[0] mismatched feature' 1144 ' keys. Expected {}, got {}'.format(feature_names, 1145 feature_dims_names)) 1146 1147 label_dims_names = _extract_key_names(self._label_dims) 1148 if self._label_dims is not None and label_dims_names != label_names: 1149 raise ValueError( 1150 'TPUConfig.input_partition_dims[1] mismatched label' 1151 ' keys. Expected {}, got {}'.format(label_names, 1152 label_dims_names)) 1153 1154 self._flattened_input_dims = self._flatten_input_dims( 1155 self._feature_dims, feature_dims_names, self._label_dims, 1156 label_dims_names, label_names, has_labels) 1157 1158 def flatten_features_and_labels(self, features, labels, signals=None): 1159 """Flattens the `features` and `labels` to a single tensor list.""" 1160 self._feature_structure['features'] = features 1161 if labels is not None: 1162 self._feature_structure['labels'] = labels 1163 if signals is not None: 1164 self._feature_structure['signals'] = signals 1165 return data_nest.flatten(self._feature_structure) 1166 1167 def unflatten_features_and_labels(self, flattened_inputs): 1168 """Restores the flattened inputs to original features and labels form. 1169 1170 Args: 1171 flattened_inputs: Flattened inputs for each shard. 1172 1173 Returns: 1174 A tuple of (`features`, `labels`), where `labels` could be None. 1175 Each one, if present, should have identical structure (single tensor vs 1176 dict) as the one returned by input_fn. 1177 1178 Raises: 1179 ValueError: If the number of expected tensors from `flattened_inputs` 1180 mismatches the recorded structure. 1181 """ 1182 1183 unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure, 1184 flattened_inputs) 1185 return _Inputs( 1186 unflattened_inputs['features'], 1187 unflattened_inputs.get('labels'), 1188 signals=unflattened_inputs.get('signals')) 1189 1190 def __init__(self, input_fn, batch_axis, ctx): 1191 """Constructor. 1192 1193 Args: 1194 input_fn: input fn for train or eval. 1195 batch_axis: A python tuple of int values describing how each tensor 1196 produced by the Estimator `input_fn` should be split across the TPU 1197 compute shards. 1198 ctx: A `_InternalTPUContext` instance with mode. 1199 1200 Raises: 1201 ValueError: If both `sharded_features` and `num_cores` are `None`. 1202 """ 1203 self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder( 1204 ctx.input_partition_dims) 1205 1206 self._sharded_per_core = ctx.is_input_sharded_per_core() 1207 self._input_fn = input_fn 1208 self._infeed_queue = None 1209 self._ctx = ctx 1210 self._batch_axis = batch_axis 1211 1212 def generate_infeed_enqueue_ops_and_dequeue_fn(self): 1213 """Generates infeed enqueue ops and dequeue_fn.""" 1214 # While tf.while_loop is called, the body function, which invokes 1215 # `enqueue_fn` passed in, is called to construct the graph. So, input_fn 1216 # structure is recorded. 1217 enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = ( 1218 self._invoke_input_fn_and_record_structure()) 1219 1220 self._validate_input_pipeline() 1221 1222 def dequeue_fn(): 1223 """dequeue_fn is used by TPU to retrieve the tensors.""" 1224 # In the model-parallel case, both the host-side and device-side 1225 # computations must agree on the core on which infeed takes place. We 1226 # choose to perform infeed on logical core 0 of each replica. 1227 values = self._infeed_queue.generate_dequeue_op(tpu_device=0) 1228 # The unflatten process uses the structure information recorded above. 1229 return self._inputs_structure_recorder.unflatten_features_and_labels( 1230 values) 1231 1232 return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator) 1233 1234 def _invoke_input_fn_and_record_structure(self): 1235 """Deploys the input pipeline and record input structure.""" 1236 enqueue_ops = [] 1237 infeed_queues = [] 1238 all_dataset_initializers = [] 1239 num_hosts = self._ctx.num_hosts 1240 tpu_host_placement_fn = self._ctx.tpu_host_placement_function 1241 1242 run_infeed_loop_on_coordinator = True 1243 1244 if self._sharded_per_core: 1245 # Per-Core input pipeline deployment. 1246 # Invoke input pipeline for each core and placed on the corresponding 1247 # host. 1248 for host_id in range(num_hosts): 1249 host_device = tpu_host_placement_fn(host_id=host_id) 1250 with ops.device(host_device): 1251 with ops.name_scope('input_pipeline_task%d' % (host_id)): 1252 enqueue_ops_fn, captured_infeed_queue = ( 1253 generate_per_core_enqueue_ops_fn_for_host( 1254 self._ctx, self._input_fn, self._inputs_structure_recorder, 1255 host_device, host_id)) 1256 1257 if _WRAP_INPUT_FN_INTO_WHILE_LOOP: 1258 run_infeed_loop_on_coordinator = False 1259 enqueue_ops.append( 1260 _wrap_computation_in_while_loop( 1261 device=host_device, op_fn=enqueue_ops_fn)) 1262 else: 1263 enqueue_ops.append(enqueue_ops_fn()) 1264 # Infeed_queue_getter must be called after enqueue_ops_fn is called. 1265 infeed_queues.append(captured_infeed_queue.get()) 1266 1267 elif self._ctx.is_input_broadcast_with_iterators(): 1268 # Only calls input_fn in host 0. 1269 host_device = tpu_host_placement_fn(host_id=0) 1270 enqueue_ops_fn, captured_infeed_queue, dataset_initializer = ( 1271 generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn, 1272 self._inputs_structure_recorder, 1273 num_hosts)) 1274 if dataset_initializer: 1275 all_dataset_initializers.append(dataset_initializer) 1276 run_infeed_loop_on_coordinator = False 1277 wrap_fn = ( 1278 _wrap_computation_in_while_loop 1279 if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else 1280 _wrap_computation_in_while_loop_with_stopping_signals) 1281 enqueue_ops.append(wrap_fn(device=host_device, op_fn=enqueue_ops_fn)) 1282 else: 1283 enqueue_ops.append(enqueue_ops_fn()) 1284 infeed_queues.append(captured_infeed_queue.get()) 1285 else: 1286 for host_id in range(num_hosts): 1287 host_device = tpu_host_placement_fn(host_id=host_id) 1288 with ops.device(host_device): 1289 with ops.name_scope('input_pipeline_task%d' % (host_id)): 1290 if self._ctx.is_input_per_host_with_iterators(): 1291 enqueue_ops_fn, captured_infeed_queue, dataset_initializer = ( 1292 generate_per_host_v2_enqueue_ops_fn_for_host( 1293 self._ctx, self._input_fn, 1294 self._inputs_structure_recorder, host_device, host_id)) 1295 else: 1296 enqueue_ops_fn, captured_infeed_queue, dataset_initializer = ( 1297 generate_per_host_enqueue_ops_fn_for_host( 1298 self._ctx, self._input_fn, 1299 self._inputs_structure_recorder, self._batch_axis, 1300 host_device, host_id)) 1301 1302 # NOTE(xiejw): We dispatch here based on the return type of the 1303 # users `input_fn`. 1304 # 1305 # 1. If input_fn returns a Dataset instance, we initialize the 1306 # iterator outside of tf.while_loop, and call the iterator.get_next 1307 # inside tf.while_loop. This should be always safe. 1308 # 1309 # 2. If input_fn returns (features, labels), it is too late to wrap 1310 # them inside tf.while_loop, as resource initialization cannot be 1311 # handled in TF control flow properly. In this case, we will use 1312 # python loop to enqueue the data into TPU system. This may be 1313 # slow compared to the previous case. 1314 if dataset_initializer: 1315 all_dataset_initializers.append(dataset_initializer) 1316 run_infeed_loop_on_coordinator = False 1317 wrap_fn = ( 1318 _wrap_computation_in_while_loop 1319 if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else 1320 _wrap_computation_in_while_loop_with_stopping_signals) 1321 enqueue_ops.append( 1322 wrap_fn(device=host_device, op_fn=enqueue_ops_fn)) 1323 else: 1324 enqueue_ops.append(enqueue_ops_fn()) 1325 infeed_queues.append(captured_infeed_queue.get()) 1326 # infeed_queue is used to generate dequeue ops. The only thing it uses for 1327 # dequeue is dtypes and types. So, any one can be used. Here, grab the 1328 # first one. 1329 self._infeed_queue = infeed_queues[0] 1330 return enqueue_ops, [ 1331 util_lib.MultiHostDatasetInitializerHook(all_dataset_initializers) 1332 ], run_infeed_loop_on_coordinator 1333 1334 def _validate_input_pipeline(self): 1335 """Validates the input pipeline. 1336 1337 Perform some sanity checks to log user friendly information. We should 1338 error out to give users better error message. But, if 1339 _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break 1340 user code, so, log a warning. 1341 1342 Raises: 1343 RuntimeError: If the validation failed. 1344 """ 1345 if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS): 1346 err_msg = ('Input pipeline contains one or more QueueRunners. ' 1347 'It could be slow and not scalable. Please consider ' 1348 'converting your input pipeline to use `tf.data` instead (see ' 1349 'https://www.tensorflow.org/guide/datasets for ' 1350 'instructions.') 1351 if _WRAP_INPUT_FN_INTO_WHILE_LOOP: 1352 raise RuntimeError(err_msg) 1353 else: 1354 logging.warn(err_msg) 1355 1356 1357 def call_computation(computation, 1358 experimental_exported_model_uses_all_cores=True): 1359 """Call computation. 1360 1361 computation uses a single-core for TPU inference. If 1362 `experimental_exported_model_uses_all_cores` is `True`, this function will 1363 round-robin 1364 computation among all TPU cores visible to the host; otherwise, it will use 1365 a single core. 1366 1367 Args: 1368 computation: A Python function that takes no inputs and builds computation 1369 graph. If `computation` returns m outputs, this function will return a 1370 list of m Tensors. 1371 experimental_exported_model_uses_all_cores: Whether to round-robin among all 1372 cores visible to the host, or to use a single core. 1373 1374 Returns: 1375 A list of output tensors. 1376 """ 1377 if experimental_exported_model_uses_all_cores: 1378 # Using `TPUPartitionedCall` makes it possible to target a different 1379 # TPU core with every `Session.run()` call. Note that the entire inference 1380 # graph executes on a single core, and that invocations of this graph 1381 # will round-robin among the cores attached to a host. 1382 @function.Defun(capture_resource_var_by_value=False) 1383 def tpu_subgraph(): 1384 return computation() 1385 1386 return tpu_functional.TPUPartitionedCall( 1387 args=tpu_subgraph.captured_inputs, 1388 device_ordinal=tpu_ops.tpu_ordinal_selector(), 1389 Tout=[o.type for o in tpu_subgraph.definition.signature.output_arg], 1390 f=tpu_subgraph) 1391 else: 1392 return computation() 1393 1394 1395 class _ModelFnWrapper(object): 1396 """A `model_fn` wrapper. 1397 1398 This makes calling model_fn on CPU and TPU easier and more consistent and 1399 performs necessary check and mutation required by TPU training and evaluation. 1400 1401 In addition, this wrapper manages converting the `model_fn` to a single TPU 1402 train and eval step. 1403 """ 1404 1405 def __init__(self, model_fn, config, params, ctx): 1406 self._model_fn = model_fn 1407 self._config = config 1408 self._params = params 1409 self._ctx = ctx 1410 1411 def call_without_tpu(self, features, labels, is_export_mode): 1412 return self._call_model_fn(features, labels, is_export_mode=is_export_mode) 1413 1414 def _add_embedding_features(self, features, hook_dummy_table_variables): 1415 """Add embedding features, optionally add hook to intercept gradient.""" 1416 if self._ctx.embedding_config: 1417 tpu_embedding_ = self._ctx.embedding_config.tpu_embedding 1418 embedding_activations = tpu_embedding_.get_activations() 1419 if hook_dummy_table_variables: 1420 new_embedding_activations = ( 1421 tpu_embedding_gradient.hook_dummy_table_variables_to_activations( 1422 tpu_embedding_, embedding_activations, 1423 self._ctx.embedding_config.dummy_table_variables)) 1424 features.update(new_embedding_activations) 1425 else: 1426 features.update(embedding_activations) 1427 1428 def convert_to_single_tpu_train_step(self, dequeue_fn): 1429 """Converts user provided model_fn` as a single train step on TPU. 1430 1431 The user provided `model_fn` takes input tuple 1432 (features, labels) and produces the EstimatorSpec with train_op and loss for 1433 train `mode`. This usually represents a single train computation on CPU. 1434 1435 For TPU training, a train (computation) step is first wrapped in a 1436 tf.while_loop control flow to repeat for many times and then replicated to 1437 all TPU shards. Besides the input should be taken from TPU infeed rather 1438 than input pipeline (input_fn) directly. To fit TPU loop and replicate 1439 pattern, the original train computation should be reformed, which is the 1440 returned `train_step`. 1441 1442 Args: 1443 dequeue_fn: The function to retrieve inputs, features and labels, from TPU 1444 infeed dequeue channel. 1445 1446 Returns: 1447 A tuple of train_fn, host_calls, and captured scaffold_fn. The train_fn 1448 representing the train step for TPU. 1449 """ 1450 1451 host_call = _OutfeedHostCall(self._ctx) 1452 captured_scaffold_fn = _CapturedObject() 1453 captured_training_hooks = _CapturedObject() 1454 1455 def train_step(loss): 1456 """Training step function for use inside a while loop.""" 1457 del loss # unused; required in function signature. 1458 inputs = dequeue_fn() 1459 features, labels = inputs.features_and_labels() 1460 self._add_embedding_features(features, True) 1461 1462 estimator_spec = self._verify_estimator_spec( 1463 self._call_model_fn(features, labels)) 1464 loss, train_op = estimator_spec.loss, estimator_spec.train_op 1465 1466 if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access 1467 captured_scaffold_fn.capture(estimator_spec.scaffold_fn) 1468 else: 1469 captured_scaffold_fn.capture(None) 1470 1471 captured_training_hooks.capture(estimator_spec.training_hooks) 1472 1473 if self._ctx.embedding_config is None: 1474 apply_sparse_grads = [] 1475 else: 1476 tpu_embedding_ = self._ctx.embedding_config.tpu_embedding 1477 gradients = ( 1478 tpu_embedding_gradient.get_gradients_through_dummy_table_variables( 1479 tpu_embedding_) 1480 ) 1481 apply_sparse_grads = [ 1482 tpu_embedding_.generate_send_gradients_op(gradients) 1483 ] 1484 1485 # We must run train_op to update the variables prior to running the 1486 # outfeed. 1487 with ops.control_dependencies([train_op] + apply_sparse_grads): 1488 host_call_outfeed_ops = [] 1489 if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec) # pylint: disable=protected-access 1490 and estimator_spec.host_call is not None): 1491 host_call.record({'host_call': estimator_spec.host_call}) 1492 host_call_outfeed_ops = host_call.create_enqueue_op() 1493 with ops.control_dependencies(host_call_outfeed_ops): 1494 return array_ops.identity(loss) 1495 1496 return (train_step, host_call, captured_scaffold_fn, 1497 captured_training_hooks) 1498 1499 def convert_to_single_tpu_eval_step(self, dequeue_fn): 1500 """Converts user provided model_fn` as a single eval step on TPU. 1501 1502 Similar to training, the user provided `model_fn` takes input tuple 1503 (features, labels) and produces the TPUEstimatorSpec with eval_metrics for 1504 eval `mode`. This usually represents a single evaluation computation on CPU. 1505 1506 For TPU evaluation, a eval (computation) step is first wrapped in a 1507 tf.while_loop control flow to repeat for many times and then replicated to 1508 all TPU shards. Besides the input and output are slightly different. Input, 1509 features and labels, should be taken from TPU infeed rather than input 1510 pipeline (input_fn) directly. Output is managed in two stages. First, the 1511 model outputs as the result of evaluation computation, usually model logits, 1512 should be transferred from TPU system to CPU. Then, all model outputs are 1513 concatenated first on CPU and sent to the metric_fn for metrics computation. 1514 To fit TPU evaluation pattern, the original eval computation should be 1515 reformed, which is the returned `eval_step`. 1516 1517 Args: 1518 dequeue_fn: The function to retrieve inputs, features and labels, from TPU 1519 infeed dequeue channel. 1520 1521 Returns: 1522 A tuple of eval_fn, host_calls, and captured scaffold_fn. The eval_fn 1523 representing the eval step for TPU. 1524 """ 1525 host_calls = _OutfeedHostCall(self._ctx) 1526 captured_scaffold_fn = _CapturedObject() 1527 captured_eval_hooks = _CapturedObject() 1528 1529 def eval_step(total_loss): 1530 """Evaluation step function for use inside a while loop.""" 1531 inputs = dequeue_fn() 1532 features, labels = inputs.features_and_labels() 1533 self._add_embedding_features(features, False) 1534 1535 tpu_estimator_spec = self._call_model_fn(features, labels) 1536 if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access 1537 raise RuntimeError( 1538 'estimator_spec used by TPU evaluation must have type' 1539 '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) 1540 1541 loss = tpu_estimator_spec.loss 1542 captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) 1543 captured_eval_hooks.capture(tpu_estimator_spec.evaluation_hooks) 1544 1545 to_record = {} 1546 if tpu_estimator_spec.eval_metrics: 1547 to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics 1548 if tpu_estimator_spec.host_call is not None: 1549 # We assume that evaluate won't update global step, so we don't wrap 1550 # this host_call. 1551 to_record['host_call'] = tpu_estimator_spec.host_call 1552 host_calls.record(to_record) 1553 1554 with ops.control_dependencies(host_calls.create_enqueue_op()): 1555 return math_ops.add(total_loss, loss) 1556 1557 return eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks 1558 1559 def convert_to_single_tpu_predict_step(self, dequeue_fn): 1560 """Converts user provided model_fn` as a single predict step on TPU. 1561 1562 Args: 1563 dequeue_fn: The function to retrieve inputs, features and labels, from TPU 1564 infeed dequeue channel. 1565 1566 Returns: 1567 A tuple of predict_fn, host_calls, and captured scaffold_fn. The 1568 predict_fn representing the predict step for TPU. 1569 """ 1570 host_calls = _OutfeedHostCall(self._ctx) 1571 captured_scaffold_fn = _CapturedObject() 1572 captured_predict_hooks = _CapturedObject() 1573 1574 def predict_step(unused_scalar_stopping_signal): 1575 """Evaluation step function for use inside a while loop.""" 1576 inputs = dequeue_fn() 1577 features, labels = inputs.features_and_labels() 1578 stopping_signals = inputs.signals() 1579 1580 assert stopping_signals is not None, ( 1581 'Internal Error: `signals` is missing.') 1582 1583 tpu_estimator_spec = self._call_model_fn( 1584 features, labels, is_export_mode=False) 1585 if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access 1586 raise RuntimeError( 1587 'estimator_spec used by TPU prediction must have type' 1588 '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) 1589 1590 self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions) 1591 1592 captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) 1593 captured_predict_hooks.capture(tpu_estimator_spec.prediction_hooks) 1594 to_record = {} 1595 identity_fn = lambda **kwargs: kwargs 1596 to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions] 1597 to_record['signals'] = [identity_fn, stopping_signals] 1598 if tpu_estimator_spec.host_call is not None: 1599 to_record['host_call'] = tpu_estimator_spec.host_call 1600 host_calls.record(to_record) 1601 1602 with ops.control_dependencies(host_calls.create_enqueue_op()): 1603 return _StopSignals.as_scalar_stopping_signal(stopping_signals) 1604 1605 return (predict_step, host_calls, captured_scaffold_fn, 1606 captured_predict_hooks) 1607 1608 def _verify_tpu_spec_predictions(self, predictions): 1609 """Validates TPUEstimatorSpec.predictions dict.""" 1610 # TODO(xiejw): Adds validation for prediction dictionrary. 1611 # TODO(xiejw): Adds support for single tensor as predictions. 1612 if not isinstance(predictions, dict): 1613 raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.') 1614 1615 for (key, tensor) in predictions.items(): 1616 if tensor.shape.dims[0].value is None: 1617 raise ValueError( 1618 'The tensor with key ({}) in TPUEstimatorSpec.predictions has ' 1619 'dynamic shape (should be static). Tensor: {}'.format(key, tensor)) 1620 return predictions 1621 1622 def _validate_model_features_and_labels(self, features, labels, 1623 is_export_mode): 1624 """Validates that the features and labels for the model function are valid. 1625 1626 A valid features/labels object is the one with: 1627 - Type: A tensor or any nested structure of tensors supported by TF nest, 1628 namely nested dictionary, tuple, namedtuple, or sequence of tensors. 1629 - Static shape if is_export_mode is False. 1630 1631 Args: 1632 features: the features that would be input to the model function. 1633 labels: the labels that would be input to the model function. 1634 is_export_mode: boolean value specifying if in export mode. 1635 1636 Raises: 1637 TypeError: If features/labels are not of the correct type. 1638 ValueError: If features/labels have dynamic shape. 1639 """ 1640 1641 def validate(obj, obj_name): 1642 """Helper validate function.""" 1643 if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode): 1644 return 1645 if isinstance(obj, ops.Tensor): 1646 if not obj.get_shape().is_fully_defined(): 1647 raise ValueError( 1648 'The {} to the model returned by input_fn must have static shape.' 1649 ' Tensor: {}'.format(obj_name, obj)) 1650 else: 1651 for tensor in data_nest.flatten(obj): 1652 if not tensor.get_shape().is_fully_defined(): 1653 raise ValueError( 1654 ('The {} to the model returned by input_fn must have static ' 1655 'shape. Tensor: {}').format(obj_name, tensor)) 1656 1657 validate(features, 'features') 1658 if labels is not None: 1659 validate(labels, 'labels') 1660 1661 def _call_model_fn(self, features, labels, is_export_mode=False): 1662 """Calls the model_fn with required parameters.""" 1663 self._validate_model_features_and_labels(features, labels, is_export_mode) 1664 model_fn_args = function_utils.fn_args(self._model_fn) 1665 kwargs = {} 1666 1667 # Makes deep copy with `config` and params` in case user mutates them. 1668 config = copy.deepcopy(self._config) 1669 params = copy.deepcopy(self._params) 1670 1671 if 'labels' in model_fn_args: 1672 kwargs['labels'] = labels 1673 elif labels is not None: 1674 raise ValueError( 1675 'model_fn does not take labels, but input_fn returns labels.') 1676 if 'mode' in model_fn_args: 1677 kwargs['mode'] = self._ctx.mode 1678 if 'config' in model_fn_args: 1679 kwargs['config'] = config 1680 if 'params' in model_fn_args: 1681 kwargs['params'] = params 1682 1683 if 'params' not in model_fn_args: 1684 raise ValueError('model_fn ({}) does not include params argument, ' 1685 'required by TPUEstimator to pass batch size as ' 1686 'params[\'batch_size\']'.format(self._model_fn)) 1687 1688 if is_export_mode: 1689 batch_size_for_model_fn = None 1690 else: 1691 batch_size_for_model_fn = self._ctx.batch_size_for_model_fn 1692 1693 if batch_size_for_model_fn is not None: 1694 _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn) 1695 1696 running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode) 1697 # In export mode, params['use_tpu'] has already been set based on mode 1698 # (i.e. True for _REWRITE_FOR_INFERENCE_MODE, False otherwise). 1699 if not is_export_mode: 1700 _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu) 1701 1702 if not running_on_cpu: 1703 user_context = tpu_context.TPUContext( 1704 internal_ctx=self._ctx, call_from_input_fn=False) 1705 _add_item_to_params(params, _CTX_KEY, user_context) 1706 1707 estimator_spec = self._model_fn(features=features, **kwargs) 1708 if (running_on_cpu and 1709 isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access 1710 # The estimator_spec will be passed to `Estimator` directly, which expects 1711 # type `EstimatorSpec`. 1712 return estimator_spec.as_estimator_spec() 1713 else: 1714 return estimator_spec 1715 1716 def _verify_estimator_spec(self, estimator_spec): 1717 """Validates the estimator_spec.""" 1718 if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access 1719 return estimator_spec 1720 1721 err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.' 1722 if estimator_spec.training_chief_hooks: 1723 raise ValueError( 1724 err_msg.format('training_chief_hooks') + 'If you want' + 1725 ' to pass training hooks, please pass via training_hooks.') 1726 1727 if estimator_spec.scaffold: 1728 logging.warning('EstimatorSpec.Scaffold is ignored by TPU train/eval. ' 1729 'Please use TPUEstimatorSpec.') 1730 return estimator_spec 1731 1732 1733 class _OutfeedHostCall(object): 1734 """Support for `eval_metrics` and `host_call` in TPUEstimatorSpec.""" 1735 1736 def __init__(self, ctx): 1737 self._ctx = ctx 1738 self._names = [] 1739 # All of these are dictionaries of lists keyed on the name. 1740 self._host_fns = {} 1741 self._tensor_keys = collections.defaultdict(list) 1742 self._tensors = collections.defaultdict(list) 1743 self._tensor_dtypes = collections.defaultdict(list) 1744 self._tensor_shapes = collections.defaultdict(list) 1745 1746 @staticmethod 1747 def validate(host_calls): 1748 """Validates the `eval_metrics` and `host_call` in `TPUEstimatorSpec`.""" 1749 1750 for name, host_call in host_calls.items(): 1751 if not isinstance(host_call, (tuple, list)): 1752 raise ValueError('{} should be tuple or list'.format(name)) 1753 if len(host_call) != 2: 1754 raise ValueError('{} should have two elements.'.format(name)) 1755 if not callable(host_call[0]): 1756 raise TypeError('{}[0] should be callable.'.format(name)) 1757 if not isinstance(host_call[1], (tuple, list, dict)): 1758 raise ValueError('{}[1] should be tuple or list, or dict.'.format(name)) 1759 1760 if isinstance(host_call[1], (tuple, list)): 1761 fullargspec = tf_inspect.getfullargspec(host_call[0]) 1762 fn_args = function_utils.fn_args(host_call[0]) 1763 # wrapped_hostcall_with_global_step uses varargs, so we allow that. 1764 if fullargspec.varargs is None and len(host_call[1]) != len(fn_args): 1765 raise RuntimeError( 1766 'In TPUEstimatorSpec.{}, length of tensors {} does not match ' 1767 'method args of the function, which takes {}.'.format( 1768 name, len(host_call[1]), len(fn_args))) 1769 1770 @staticmethod 1771 def create_cpu_hostcall(host_calls): 1772 """Runs on the host_call on CPU instead of TPU when use_tpu=False.""" 1773 1774 _OutfeedHostCall.validate(host_calls) 1775 ret = {} 1776 for name, host_call in host_calls.items(): 1777 host_fn, tensors = host_call 1778 if isinstance(tensors, (tuple, list)): 1779 ret[name] = host_fn(*tensors) 1780 else: 1781 # Must be dict. 1782 try: 1783 ret[name] = host_fn(**tensors) 1784 except TypeError as e: 1785 logging.warning( 1786 'Exception while calling %s: %s. It is likely the tensors ' 1787 '(%s[1]) do not match the ' 1788 'function\'s arguments', name, e, name) 1789 raise 1790 return ret 1791 1792 def record(self, host_calls): 1793 """Records the host_call structure.""" 1794 1795 for name, host_call in host_calls.items(): 1796 host_fn, tensor_list_or_dict = host_call 1797 self._names.append(name) 1798 self._host_fns[name] = host_fn 1799 1800 if isinstance(tensor_list_or_dict, dict): 1801 for (key, tensor) in six.iteritems(tensor_list_or_dict): 1802 self._tensor_keys[name].append(key) 1803 self._tensors[name].append(tensor) 1804 self._tensor_dtypes[name].append(tensor.dtype) 1805 self._tensor_shapes[name].append(tensor.shape) 1806 else: 1807 # List or tuple. 1808 self._tensor_keys[name] = None 1809 for tensor in tensor_list_or_dict: 1810 self._tensors[name].append(tensor) 1811 self._tensor_dtypes[name].append(tensor.dtype) 1812 self._tensor_shapes[name].append(tensor.shape) 1813 1814 def create_enqueue_op(self): 1815 """Create the op to enqueue the recorded host_calls. 1816 1817 Returns: 1818 A list of enqueue ops, which is empty if there are no host calls. 1819 """ 1820 if not self._names: 1821 return [] 1822 1823 tensors = [] 1824 # TODO(jhseu): Consider deduping tensors. 1825 for name in self._names: 1826 tensors.extend(self._tensors[name]) 1827 1828 with ops.device(tpu.core(0)): 1829 return [tpu_ops.outfeed_enqueue_tuple(tensors)] 1830 1831 def create_tpu_hostcall(self): 1832 """Sends the tensors through outfeed and runs the host_fn on CPU. 1833 1834 The tensors are concatenated along dimension 0 to form a global tensor 1835 across all shards. The concatenated function is passed to the host_fn and 1836 executed on the first host. 1837 1838 Returns: 1839 A dictionary mapping name to the return type of the host_call by that 1840 name. 1841 1842 Raises: 1843 RuntimeError: If outfeed tensor is scalar. 1844 """ 1845 if not self._names: 1846 return {} 1847 1848 ret = {} 1849 # For each i, dequeue_ops[i] is a list containing the tensors from all 1850 # shards. This list is concatenated later. 1851 dequeue_ops = [] 1852 tensor_dtypes = [] 1853 tensor_shapes = [] 1854 for name in self._names: 1855 for _ in self._tensors[name]: 1856 dequeue_ops.append([]) 1857 for dtype in self._tensor_dtypes[name]: 1858 tensor_dtypes.append(dtype) 1859 for shape in self._tensor_shapes[name]: 1860 tensor_shapes.append(shape) 1861 1862 # Outfeed ops execute on each replica's first logical core. Note: we must 1863 # constraint it such that we have at most one outfeed dequeue and enqueue 1864 # per replica. 1865 for i in xrange(self._ctx.num_replicas): 1866 host_device, ordinal_id = self._ctx.device_for_replica(i) 1867 with ops.device(host_device): 1868 outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( 1869 dtypes=tensor_dtypes, 1870 shapes=tensor_shapes, 1871 device_ordinal=ordinal_id) 1872 for j, item in enumerate(outfeed_tensors): 1873 dequeue_ops[j].append(item) 1874 1875 # Deconstruct dequeue ops. 1876 flat_dequeue_ops = [] 1877 for l in dequeue_ops: 1878 flat_dequeue_ops.extend(l) 1879 1880 dequeue_ops_by_name = {} 1881 pos = 0 1882 for name in self._names: 1883 dequeue_ops_by_name[name] = dequeue_ops[pos:pos + 1884 len(self._tensors[name])] 1885 pos += len(self._tensors[name]) 1886 1887 def _call_host_fn(fn, *args, **kw): 1888 context = CatchInvalidHostcallFunctions() 1889 context.Enter() 1890 result = fn(*args, **kw) 1891 context.Exit() 1892 context.ExitResult(result) 1893 return result 1894 1895 # It is assumed evaluation always happens on single host TPU system. So, 1896 # place all ops on tpu host if possible. 1897 # 1898 # TODO(jhseu): Evaluate whether this is right for summaries. 1899 with ops.device(self._ctx.tpu_host_placement_function(replica_id=0)): 1900 for name in self._names: 1901 dequeue_ops = dequeue_ops_by_name[name] 1902 for i, item in enumerate(dequeue_ops): 1903 if dequeue_ops[i][0].shape.ndims == 0: 1904 raise RuntimeError( 1905 'All tensors outfed from TPU should preserve batch size ' 1906 'dimension, but got scalar {}'.format(dequeue_ops[i][0])) 1907 # TODO(xiejw): Make the specification of the outfeed combinaton 1908 # function more explicit and well-documented. We may want to give the 1909 # user the option of concatenating along any axis. 1910 if (self._ctx.config.tpu_config.per_host_input_for_training is 1911 tpu_config.InputPipelineConfig.BROADCAST): 1912 # If the infeed is in BROADCAST mode (each core recieving the same 1913 # input), then we assume that the cores also produce identical 1914 # copies of the same output, and we simply take the output from 1915 # the first core. This mode is used by Mesh-TensorFlow. 1916 with ops.control_dependencies(dequeue_ops[i]): 1917 dequeue_ops[i] = array_ops.identity(dequeue_ops[i][0]) 1918 else: 1919 # Assume that the input has been batch-split and that axis 0 of the 1920 # output tensors represents the batch size. Concatenate along 1921 # the axis 0 to re-combine the batch. 1922 dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0) 1923 1924 if self._tensor_keys[name] is not None: 1925 # The user-provided eval_metrics[1] is a dict. 1926 dequeue_ops = dict(zip(self._tensor_keys[name], dequeue_ops)) 1927 try: 1928 ret[name] = _call_host_fn(self._host_fns[name], **dequeue_ops) 1929 except TypeError as e: 1930 logging.warning( 1931 'Exception while calling %s: %s. It is likely the tensors ' 1932 '(%s[1]) do not match the ' 1933 'function\'s arguments', name, e, name) 1934 raise 1935 else: 1936 ret[name] = _call_host_fn(self._host_fns[name], *dequeue_ops) 1937 1938 # force all dequeue operations to be run if not consumed by the host calls 1939 ret['__force_dequeue'] = control_flow_ops.group(*flat_dequeue_ops) 1940 return ret 1941 1942 1943 class _OutfeedHostCallHook(session_run_hook.SessionRunHook): 1944 """Hook to run host calls when use_tpu=False.""" 1945 1946 def __init__(self, tensors): 1947 self._tensors = tensors 1948 1949 def begin(self): 1950 # We duplicate this code from the TPUInfeedOutfeedSessionHook rather than 1951 # create a separate hook to guarantee execution order, because summaries 1952 # need to be initialized before the outfeed thread starts. 1953 # TODO(jhseu): Make a wrapper hook instead? 1954 self._init_ops = contrib_summary.summary_writer_initializer_op() 1955 # Get all the writer resources from the initializer, so we know what to 1956 # flush. 1957 self._finalize_ops = [] 1958 for op in self._init_ops: 1959 self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) 1960 1961 def after_create_session(self, session, coord): 1962 session.run(self._init_ops) 1963 1964 def before_run(self, run_context): 1965 return basic_session_run_hooks.SessionRunArgs(self._tensors) 1966 1967 def end(self, session): 1968 session.run(self._finalize_ops) 1969 1970 1971 class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook): 1972 """Calculate and report global_step/sec and examples/sec during runtime.""" 1973 1974 def __init__(self, 1975 batch_size, 1976 every_n_steps=100, 1977 every_n_secs=None, 1978 output_dir=None, 1979 summary_writer=None): 1980 self._batch_size = batch_size 1981 super(ExamplesPerSecondHook, self).__init__( 1982 every_n_steps=every_n_steps, 1983 every_n_secs=every_n_secs, 1984 output_dir=output_dir, 1985 summary_writer=summary_writer) 1986 1987 def _log_and_record(self, elapsed_steps, elapsed_time, global_step): 1988 global_step_per_sec = elapsed_steps / elapsed_time 1989 examples_per_sec = self._batch_size * global_step_per_sec 1990 if self._summary_writer is not None: 1991 global_step_summary = Summary(value=[ 1992 Summary.Value(tag='global_step/sec', simple_value=global_step_per_sec) 1993 ]) 1994 example_summary = Summary(value=[ 1995 Summary.Value(tag='examples/sec', simple_value=examples_per_sec) 1996 ]) 1997 self._summary_writer.add_summary(global_step_summary, global_step) 1998 self._summary_writer.add_summary(example_summary, global_step) 1999 logging.info('global_step/sec: %g', global_step_per_sec) 2000 logging.info('examples/sec: %g', examples_per_sec) 2001 2002 2003 class InstallSignalHandlerHook(session_run_hook.SessionRunHook): 2004 """Change SIGINT (CTRL^C) handler to force quit the process. 2005 2006 The default behavior often results in hanging processes. 2007 The original handler is restored after training/evaluation. 2008 """ 2009 2010 def __init__(self): 2011 self._signal_fn = signal.getsignal(signal.SIGINT) 2012 2013 def before_run(self, run_context): 2014 signal.signal(signal.SIGINT, signal.SIG_DFL) 2015 2016 def end(self, session): 2017 signal.signal(signal.SIGINT, self._signal_fn) 2018 2019 2020 class TPUEstimator(estimator_lib.Estimator): 2021 """Estimator with TPU support. 2022 2023 TPUEstimator also supports training on CPU and GPU. You don't need to define 2024 a separate `tf.estimator.Estimator`. 2025 2026 TPUEstimator handles many of the details of running on TPU devices, such as 2027 replicating inputs and models for each core, and returning to host 2028 periodically to run hooks. 2029 2030 TPUEstimator transforms a global batch size in params to a per-shard batch 2031 size when calling the `input_fn` and `model_fn`. Users should specify 2032 global batch size in constructor, and then get the batch size for each shard 2033 in `input_fn` and `model_fn` by `params['batch_size']`. 2034 2035 - For training, `model_fn` gets per-core batch size; `input_fn` may get 2036 per-core or per-host batch size depending on `per_host_input_for_training` 2037 in `TPUConfig` (See docstring for TPUConfig for details). 2038 2039 - For evaluation and prediction, `model_fn` gets per-core batch size and 2040 `input_fn` get per-host batch size. 2041 2042 Evaluation 2043 ========== 2044 2045 `model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics` 2046 for TPU evaluation. If eval_on_tpu is False, the evaluation will execute on 2047 CPU or GPU; in this case the following discussion on TPU evaluation does not 2048 apply. 2049 2050 `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where 2051 `tensors` could be a list of any nested structure of `Tensor`s (See 2052 `TPUEstimatorSpec` for details). `metric_fn` takes the `tensors` and returns 2053 a dict from metric string name to the result of calling a metric function, 2054 namely a `(metric_tensor, update_op)` tuple. 2055 2056 One can set `use_tpu` to `False` for testing. All training, evaluation, and 2057 predict will be executed on CPU. `input_fn` and `model_fn` will receive 2058 `train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`. 2059 2060 Current limitations: 2061 -------------------- 2062 2063 1. TPU evaluation only works on a single host (one TPU worker) except 2064 BROADCAST mode. 2065 2066 2. `input_fn` for evaluation should **NOT** raise an end-of-input exception 2067 (`OutOfRangeError` or `StopIteration`). And all evaluation steps and all 2068 batches should have the same size. 2069 2070 Example (MNIST): 2071 ---------------- 2072 2073 ``` 2074 # The metric Fn which runs on CPU. 2075 def metric_fn(labels, logits): 2076 predictions = tf.argmax(logits, 1) 2077 return { 2078 'accuracy': tf.metrics.precision( 2079 labels=labels, predictions=predictions), 2080 } 2081 2082 # Your model Fn which runs on TPU (eval_metrics is list in this example) 2083 def model_fn(features, labels, mode, config, params): 2084 ... 2085 logits = ... 2086 2087 if mode = tf.estimator.ModeKeys.EVAL: 2088 return tpu_estimator.TPUEstimatorSpec( 2089 mode=mode, 2090 loss=loss, 2091 eval_metrics=(metric_fn, [labels, logits])) 2092 2093 # or specify the eval_metrics tensors as dict. 2094 def model_fn(features, labels, mode, config, params): 2095 ... 2096 final_layer_output = ... 2097 2098 if mode = tf.estimator.ModeKeys.EVAL: 2099 return tpu_estimator.TPUEstimatorSpec( 2100 mode=mode, 2101 loss=loss, 2102 eval_metrics=(metric_fn, { 2103 'labels': labels, 2104 'logits': final_layer_output, 2105 })) 2106 ``` 2107 2108 Prediction 2109 ========== 2110 2111 Prediction on TPU is an experimental feature to support large batch inference. 2112 It is not designed for latency-critical system. In addition, due to some 2113 usability issues, for prediction with small dataset, CPU `.predict`, i.e., 2114 creating a new `TPUEstimator` instance with `use_tpu=False`, might be more 2115 convenient. 2116 2117 Note: In contrast to TPU training/evaluation, the `input_fn` for prediction 2118 *should* raise an end-of-input exception (`OutOfRangeError` or 2119 `StopIteration`), which serves as the stopping signal to `TPUEstimator`. To be 2120 precise, the ops created by `input_fn` produce one batch of the data. 2121 The `predict()` API processes one batch at a time. When reaching the end of 2122 the data source, an end-of-input exception should be raised by one of these 2123 operations. The user usually does not need to do this manually. As long as the 2124 dataset is not repeated forever, the `tf.data` API will raise an end-of-input 2125 exception automatically after the last batch has been produced. 2126 2127 Note: Estimator.predict returns a Python generator. Please consume all the 2128 data from the generator so that TPUEstimator can shutdown the TPU system 2129 properly for user. 2130 2131 Current limitations: 2132 -------------------- 2133 1. TPU prediction only works on a single host (one TPU worker). 2134 2135 2. `input_fn` must return a `Dataset` instance rather than `features`. In 2136 fact, .train() and .evaluate() also support Dataset as return value. 2137 2138 Example (MNIST): 2139 ---------------- 2140 ``` 2141 height = 32 2142 width = 32 2143 total_examples = 100 2144 2145 def predict_input_fn(params): 2146 batch_size = params['batch_size'] 2147 2148 images = tf.random_uniform( 2149 [total_examples, height, width, 3], minval=-1, maxval=1) 2150 2151 dataset = tf.data.Dataset.from_tensor_slices(images) 2152 dataset = dataset.map(lambda images: {'image': images}) 2153 2154 dataset = dataset.batch(batch_size) 2155 return dataset 2156 2157 def model_fn(features, labels, params, mode): 2158 # Generate predictions, called 'output', from features['image'] 2159 2160 if mode == tf.estimator.ModeKeys.PREDICT: 2161 return tf.contrib.tpu.TPUEstimatorSpec( 2162 mode=mode, 2163 predictions={ 2164 'predictions': output, 2165 'is_padding': features['is_padding'] 2166 }) 2167 2168 tpu_est = TPUEstimator( 2169 model_fn=model_fn, 2170 ..., 2171 predict_batch_size=16) 2172 2173 # Fully consume the generator so that TPUEstimator can shutdown the TPU 2174 # system. 2175 for item in tpu_est.predict(input_fn=input_fn): 2176 # Filter out item if the `is_padding` is 1. 2177 # Process the 'predictions' 2178 ``` 2179 2180 Exporting 2181 ========= 2182 2183 `export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`, 2184 and another with `tag_constants.SERVING` and `tag_constants.TPU`. 2185 At serving time, these tags are used to select metagraph to load. 2186 2187 Before running the graph on TPU, TPU system needs to be initialized. If 2188 TensorFlow Serving model-server is used, this is done automatically. If 2189 not, please call `session.run(tpu.initialize_system())`. 2190 2191 `tpu.outside_compilation` can be used to wrap TPU incompatible ops in 2192 `model_fn`. 2193 2194 Example: 2195 ---------------- 2196 2197 ``` 2198 def model_fn(features, labels, mode, config, params): 2199 ... 2200 logits = ... 2201 export_outputs = { 2202 'logits': export_output_lib.PredictOutput( 2203 {'logits': logits}) 2204 } 2205 2206 def host_call(logits): 2207 class_ids = math_ops.argmax(logits) 2208 classes = string_ops.as_string(class_ids) 2209 export_outputs['classes'] = 2210 export_output_lib.ClassificationOutput(classes=classes) 2211 2212 tpu.outside_compilation(host_call, logits) 2213 2214 ... 2215 ``` 2216 2217 """ 2218 2219 def __init__(self, 2220 model_fn=None, 2221 model_dir=None, 2222 config=None, 2223 params=None, 2224 use_tpu=True, 2225 train_batch_size=None, 2226 eval_batch_size=None, 2227 predict_batch_size=None, 2228 batch_axis=None, 2229 eval_on_tpu=True, 2230 export_to_tpu=True, 2231 export_to_cpu=True, 2232 warm_start_from=None, 2233 experimental_exported_model_uses_all_cores=False, 2234 experimental_export_device_assignment=False, 2235 experimental_embedding_config_spec=None): 2236 """Constructs an `TPUEstimator` instance. 2237 2238 Args: 2239 model_fn: Model function as required by `Estimator` which returns 2240 EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks', 2241 and `prediction_hooks` must not capure any TPU Tensor inside the 2242 model_fn. 2243 model_dir: Directory to save model parameters, graph and etc. This can 2244 also be used to load checkpoints from the directory into a estimator to 2245 continue training a previously saved model. If `None`, the model_dir in 2246 `config` will be used if set. If both are set, they must be same. If 2247 both are `None`, a temporary directory will be used. 2248 config: An `tpu_config.RunConfig` configuration object. Cannot be `None`. 2249 params: An optional `dict` of hyper parameters that will be passed into 2250 `input_fn` and `model_fn`. Keys are names of parameters, values are 2251 basic python types. There are reserved keys for `TPUEstimator`, 2252 including 'batch_size'. 2253 use_tpu: A bool indicating whether TPU support is enabled. Currently, - 2254 TPU training and evaluation respect this bit, but eval_on_tpu can 2255 override execution of eval. See below. - Predict still happens on CPU. 2256 train_batch_size: An int representing the global training batch size. 2257 TPUEstimator transforms this global batch size to a per-shard batch 2258 size, as params['batch_size'], when calling `input_fn` and `model_fn`. 2259 Cannot be `None` if `use_tpu` is `True`. Must be divisible by total 2260 number of replicas. 2261 eval_batch_size: An int representing evaluation batch size. Must be 2262 divisible by total number of replicas. 2263 predict_batch_size: An int representing the prediction batch size. Must be 2264 divisible by total number of replicas. 2265 batch_axis: A python tuple of int values describing how each tensor 2266 produced by the Estimator `input_fn` should be split across the TPU 2267 compute shards. For example, if your input_fn produced (images, labels) 2268 where the images tensor is in `HWCN` format, your shard dimensions would 2269 be [3, 0], where 3 corresponds to the `N` dimension of your images 2270 Tensor, and 0 corresponds to the dimension along which to split the 2271 labels to match up with the corresponding images. If None is supplied, 2272 and per_host_input_for_training is True, batches will be sharded based 2273 on the major dimension. If tpu_config.per_host_input_for_training is 2274 False or `PER_HOST_V2`, batch_axis is ignored. 2275 eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the 2276 model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`. 2277 export_to_tpu: If True, `export_savedmodel()` exports a metagraph for 2278 serving on TPU. Note that unsupported export modes such as EVAL will be 2279 ignored. For those modes, only a CPU model will be exported. 2280 Currently, export_to_tpu only supports PREDICT. 2281 export_to_cpu: If True, `export_savedmodel()` exports a metagraph for 2282 serving on CPU. 2283 warm_start_from: Optional string filepath to a checkpoint or SavedModel to 2284 warm-start from, or a `tf.estimator.WarmStartSettings` object to fully 2285 configure warm-starting. If the string filepath is provided instead of 2286 a `WarmStartSettings`, then all variables are warm-started, and it is 2287 assumed that vocabularies and Tensor names are unchanged. 2288 experimental_exported_model_uses_all_cores: Whether to round-robin among 2289 all cores visible to the host which is serving the saved model, or to 2290 use a single core. This is a temporary flag to enable using all TPU 2291 cores for inference with TPUPartitionedCall(). Once outside compilation 2292 is supported in TPUPartitionedCall(), this flag will be enabled by 2293 default. 2294 experimental_export_device_assignment: Whether to include the device 2295 assignment in the exported model. Doing so is useful in case of model 2296 parallel inference but will tie the exported model to the TPU topology 2297 used to export the model. 2298 experimental_embedding_config_spec: Optional EmbeddingConfigSpec instance 2299 to support using TPU embedding. IT IS STILL WORK IN PROGRESS, SO PLEASE 2300 DO NOT USE. 2301 2302 Raises: 2303 ValueError: `params` has reserved keys already. 2304 """ 2305 if config is None or not isinstance(config, tpu_config.RunConfig): 2306 raise ValueError( 2307 '`config` must be provided with type `tpu_config.RunConfig`') 2308 2309 if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS): 2310 raise ValueError('{} are reserved keys but existed in params {}.'.format( 2311 _RESERVED_PARAMS_KEYS, params)) 2312 2313 if use_tpu: 2314 # Perform some very basic validations. More validations will be found in 2315 # _InternalTPUContext. 2316 if train_batch_size is None: 2317 raise ValueError('`train_batch_size` cannot be `None`') 2318 util_lib.check_positive_integer(train_batch_size, 'train_batch_size') 2319 2320 if (config.tpu_config.per_host_input_for_training is 2321 tpu_config.InputPipelineConfig.PER_SHARD_V1 and 2322 config.tpu_config.num_cores_per_replica): 2323 raise ValueError( 2324 'Model parallelism only supports per host input for training. ' 2325 'Please adjust TPURunconfig.per_host_input_for_training.') 2326 2327 if eval_batch_size is not None: 2328 util_lib.check_positive_integer(eval_batch_size, 'eval_batch_size') 2329 2330 if predict_batch_size is not None: 2331 util_lib.check_positive_integer(predict_batch_size, 2332 'predict_batch_size') 2333 2334 # Verifies the model_fn signature according to Estimator framework. 2335 estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access 2336 # We cannot store config and params in this constructor as parent 2337 # constructor might change them, such as assigning a temp dir for 2338 # config.model_dir. 2339 model_function = self._augment_model_fn(model_fn, batch_axis) 2340 2341 # Overwrite log_step_count_steps to disable TensorLoggingHook and 2342 # StepCounterHook from being created in Estimator. TPUEstimator already 2343 # added equivalent hooks in _augment_model_fn above. 2344 self._log_every_n_steps = config.log_step_count_steps 2345 config = config.replace(log_step_count_steps=None) 2346 2347 # Passing non-None params as wrapped model_fn has it. 2348 params = params or {} 2349 super(TPUEstimator, self).__init__( 2350 model_fn=model_function, 2351 model_dir=model_dir, 2352 config=config, 2353 params=params, 2354 warm_start_from=warm_start_from) 2355 self._iterations_per_training_loop = ( 2356 self._config.tpu_config.iterations_per_loop) 2357 2358 # All properties passed to _InternalTPUContext are immutable. 2359 # pylint: disable=protected-access 2360 self._ctx = tpu_context._get_tpu_context( 2361 self._config, train_batch_size, eval_batch_size, predict_batch_size, 2362 use_tpu, eval_on_tpu, experimental_embedding_config_spec) 2363 2364 self._export_to_cpu = export_to_cpu 2365 self._export_to_tpu = export_to_tpu 2366 self._experimental_exported_model_uses_all_cores = ( 2367 experimental_exported_model_uses_all_cores) 2368 self._experimental_export_device_assignment = ( 2369 experimental_export_device_assignment) 2370 if (experimental_exported_model_uses_all_cores and 2371 experimental_export_device_assignment): 2372 raise ValueError('experimental_exported_model_uses_all_cores and ' 2373 'experimental_export_device_assignment is not supported ' 2374 'at the same time.') 2375 2376 self._is_input_fn_invoked = None 2377 self._rendezvous = {} 2378 2379 def _add_meta_graph_for_mode(self, 2380 builder, 2381 input_receiver_fn_map, 2382 checkpoint_path, 2383 save_variables=True, 2384 mode=model_fn_lib.ModeKeys.PREDICT, 2385 export_tags=None, 2386 check_variables=True): 2387 if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT: 2388 logging.warning('TPUEstimator only handles mode PREDICT for exporting ' 2389 'when `export_to_tpu` is `True`; Mode {} will be ignored ' 2390 'for TPU.'.format(mode)) 2391 2392 if not self._export_to_cpu and not self._export_to_tpu: 2393 raise ValueError('One of export_to_cpu and export_to_tpu must be true.') 2394 2395 if self._export_to_cpu: 2396 (super(TPUEstimator, self)._add_meta_graph_for_mode( 2397 builder, 2398 input_receiver_fn_map, 2399 checkpoint_path, 2400 save_variables, 2401 mode=mode, 2402 export_tags=export_tags, 2403 check_variables=check_variables)) 2404 2405 if self._export_to_tpu and mode == model_fn_lib.ModeKeys.PREDICT: 2406 input_receiver_fn_map = { 2407 _REWRITE_FOR_INFERENCE_MODE: input_receiver_fn_map[mode] 2408 } 2409 export_tags = [tag_constants.SERVING, tag_constants.TPU] 2410 mode = _REWRITE_FOR_INFERENCE_MODE 2411 2412 # See b/110052256 for why `check_variables` is `False`. 2413 if not self._export_to_cpu: 2414 check_variables = save_variables = True 2415 else: 2416 check_variables = save_variables = False 2417 (super(TPUEstimator, self)._add_meta_graph_for_mode( 2418 builder, 2419 input_receiver_fn_map, 2420 checkpoint_path, 2421 save_variables=save_variables, 2422 mode=mode, 2423 export_tags=export_tags, 2424 check_variables=check_variables)) 2425 2426 def _call_model_fn(self, features, labels, mode, config): 2427 if mode == _REWRITE_FOR_INFERENCE_MODE: 2428 return self._call_model_fn_for_inference(features, labels, mode, config) 2429 else: 2430 return super(TPUEstimator, self)._call_model_fn(features, labels, mode, 2431 config) 2432 2433 def _call_model_fn_for_inference(self, features, labels, mode, config): 2434 """Wraps `_call_model_fn` for `export_savedmodel`.""" 2435 if mode != _REWRITE_FOR_INFERENCE_MODE: 2436 raise ValueError('mode must be {}; ' 2437 'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode)) 2438 2439 computation, capture = self._build_computation_for_inference( 2440 features, labels, mode, config) 2441 tensors = call_computation( 2442 computation, 2443 experimental_exported_model_uses_all_cores=self 2444 ._experimental_exported_model_uses_all_cores) 2445 estimator_spec, export_outputs_dict, predictions_dict, none_indices = ( 2446 capture.get()) 2447 predictions_list = tensors[:len(predictions_dict)] 2448 export_outputs_list_without_none = tensors[len(predictions_dict):] 2449 2450 # Reinsert `None`s which we've taken out in 2451 # `_build_computation_for_inference()`. 2452 export_outputs_list = [] 2453 while none_indices or export_outputs_list_without_none: 2454 if none_indices and none_indices[0] == len(export_outputs_list): 2455 export_outputs_list.append(None) 2456 none_indices.pop(0) 2457 else: 2458 export_outputs_list.append(export_outputs_list_without_none.pop(0)) 2459 2460 # Reconstruct `export_outputs` with updated tensors. 2461 new_export_outputs_dict = nest.pack_sequence_as(export_outputs_dict, 2462 export_outputs_list) 2463 export_outputs = estimator_spec.export_outputs 2464 new_export_outputs = collections.OrderedDict( 2465 (k, _clone_export_output_with_tensors(export_outputs[k], v)) 2466 for k, v in six.iteritems(new_export_outputs_dict)) 2467 # Reconstruct `predictions` with updated tensors. 2468 new_predictions = nest.pack_sequence_as(predictions_dict, predictions_list) 2469 if (len(new_predictions) == 1 and 2470 _KEY_WHEN_PREDICTIONS_IS_A_TENSOR in new_predictions): 2471 new_predictions = new_predictions[_KEY_WHEN_PREDICTIONS_IS_A_TENSOR] 2472 2473 return estimator_spec._replace( 2474 export_outputs=new_export_outputs, predictions=new_predictions) 2475 2476 def _build_computation_for_inference(self, features, labels, mode, config): 2477 capture = _CapturedObject() 2478 2479 def computation(): 2480 """Computation to be passed to `TPUPartitionedCall()`.""" 2481 tpu_computation, tpu_capture = self._build_tpu_computation_for_inference( 2482 features, labels, mode, config) 2483 2484 if self._experimental_export_device_assignment: 2485 # Export the device assignment as part of the model. This is useful for 2486 # model parallel usecases where the model relies on the mapping between 2487 # logical and physical devices. 2488 with self._ctx.with_mode(mode) as ctx: 2489 device_assignment = ctx.device_assignment 2490 else: 2491 device_assignment = None 2492 2493 if self._experimental_exported_model_uses_all_cores: 2494 tensors_on_cpu = tpu.rewrite( 2495 tpu_computation, device_assignment=device_assignment) 2496 tpu.prune_unconnected_ops_from_xla(ops.get_default_graph()) 2497 else: 2498 tensors_on_cpu = tpu.rewrite_for_inference( 2499 tpu_computation, device_assignment=device_assignment) 2500 2501 (estimator_spec, export_outputs_dict, export_outputs_list, 2502 predictions_dict) = ( 2503 tpu_capture.get()) 2504 predictions_list = tensors_on_cpu[:len(predictions_dict)] 2505 export_outputs_tpu_on_cpu_list = tensors_on_cpu[len(predictions_dict):] 2506 2507 # Reconstruct tensors used in export_outputs, with TPU tensors replaced 2508 # with their CPU counterpart returned from `rewrite_for_inference()`. 2509 # `function.Defun()` does not like `None`s in return values, so we leave 2510 # `None`s out but record their positions for later reconstruction. 2511 export_outputs_list_without_none = [] 2512 none_indices = [] 2513 for i, t in enumerate(export_outputs_list): 2514 if t is None: 2515 none_indices.append(i) 2516 else: 2517 export_outputs_list_without_none.append( 2518 export_outputs_tpu_on_cpu_list.pop(0)) 2519 2520 capture.capture((estimator_spec, export_outputs_dict, predictions_dict, 2521 none_indices)) 2522 return predictions_list + export_outputs_list_without_none 2523 2524 return computation, capture 2525 2526 def _build_tpu_computation_for_inference(self, features, labels, mode, 2527 config): 2528 capture = _CapturedObject() 2529 2530 def computation(): 2531 """Compute tpu tensors used in export_outputs. 2532 2533 Passed to rewrite_for_inference so that model_fn will be called under 2534 the rewriting contexts. Only tpu tensors are returned, but export_outputs 2535 and scaffold are captured. 2536 2537 Returns: 2538 A list of Tensors used in export_outputs and not marked for 2539 outside_compilation. 2540 """ 2541 # We should only call model fn once and it should be inside `computation` 2542 # so that building the graph will happen under `rewrite_for_inference`. 2543 estimator_spec = super(TPUEstimator, self)._call_model_fn( 2544 features, labels, mode, config) 2545 2546 # We pick the TPU tensors out from `export_output` and later return them 2547 # from `computation` for rewriting. 2548 export_outputs_dict = collections.OrderedDict( 2549 (k, _export_output_to_tensors(v)) 2550 for k, v in six.iteritems(estimator_spec.export_outputs)) 2551 export_outputs_list = nest.flatten(export_outputs_dict) 2552 export_outputs_tpu_list = [ 2553 t for t in export_outputs_list if t is not None 2554 ] 2555 2556 if isinstance(estimator_spec.predictions, dict): 2557 predictions_dict = collections.OrderedDict( 2558 (k, v) for k, v in six.iteritems(estimator_spec.predictions)) 2559 else: 2560 predictions_dict = { 2561 _KEY_WHEN_PREDICTIONS_IS_A_TENSOR: estimator_spec.predictions 2562 } 2563 predictions_list = nest.flatten(predictions_dict) 2564 2565 # We cannot return everything we want through the return values, so 2566 # capture the rest here for later use. 2567 capture.capture((estimator_spec, export_outputs_dict, export_outputs_list, 2568 predictions_dict)) 2569 return predictions_list + export_outputs_tpu_list 2570 2571 return computation, capture 2572 2573 def _create_global_step(self, graph): 2574 """Creates a global step suitable for TPUs. 2575 2576 Args: 2577 graph: The graph in which to create the global step. 2578 2579 Returns: 2580 A global step `Tensor`. 2581 2582 Raises: 2583 ValueError: if the global step tensor is already defined. 2584 """ 2585 return _create_global_step(graph) 2586 2587 def _convert_train_steps_to_hooks(self, steps, max_steps): 2588 with self._ctx.with_mode(model_fn_lib.ModeKeys.TRAIN) as ctx: 2589 if ctx.is_running_on_cpu(): 2590 return super(TPUEstimator, self)._convert_train_steps_to_hooks( 2591 steps, max_steps) 2592 2593 # On TPU. 2594 if steps is None and max_steps is None: 2595 raise ValueError( 2596 'For TPU training, one of `steps` or `max_steps` must be set. ' 2597 'Cannot be both `None`.') 2598 2599 # Estimator.train has explicit positiveness check. 2600 if steps is not None: 2601 util_lib.check_positive_integer(steps, 'Train steps') 2602 if max_steps is not None: 2603 util_lib.check_positive_integer(max_steps, 'Train max_steps') 2604 2605 return [ 2606 _TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps) 2607 ] 2608 2609 def _convert_eval_steps_to_hooks(self, steps): 2610 with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx: 2611 if ctx.is_running_on_cpu(): 2612 return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps) 2613 2614 if steps is None: 2615 raise ValueError('Evaluate `steps` must be set on TPU. Cannot be `None`.') 2616 2617 util_lib.check_positive_integer(steps, 'Eval steps') 2618 2619 return [ 2620 evaluation._StopAfterNEvalsHook( # pylint: disable=protected-access 2621 num_evals=steps), 2622 _SetEvalIterationsHook(steps) 2623 ] 2624 2625 def _call_input_fn(self, input_fn, mode): 2626 """Calls the input function. 2627 2628 Args: 2629 input_fn: The input function. 2630 mode: ModeKeys 2631 2632 Returns: 2633 In TPU mode, returns an input_fn to be called later in model_fn. 2634 Otherwise, calls the input_fn and returns either fatures or 2635 (features, labels). 2636 2637 Raises: 2638 ValueError: if input_fn takes invalid arguments or does not have `params`. 2639 """ 2640 input_fn_args = function_utils.fn_args(input_fn) 2641 config = self.config # a deep copy. 2642 kwargs = {} 2643 if 'params' in input_fn_args: 2644 kwargs['params'] = self.params # a deep copy. 2645 else: 2646 raise ValueError('input_fn ({}) does not include params argument, ' 2647 'required by TPUEstimator to pass batch size as ' 2648 'params["batch_size"]'.format(input_fn)) 2649 if 'config' in input_fn_args: 2650 kwargs['config'] = config 2651 2652 if 'mode' in input_fn_args: 2653 kwargs['mode'] = mode 2654 2655 # Records the fact input_fn has been invoked. 2656 self._is_input_fn_invoked = True 2657 2658 with self._ctx.with_mode(mode) as ctx: 2659 # Setting the batch size in params first. This helps user to have same 2660 # input_fn for use_tpu=True/False. 2661 batch_size_for_input_fn = ctx.batch_size_for_input_fn 2662 if batch_size_for_input_fn is not None: 2663 _add_item_to_params(kwargs['params'], _BATCH_SIZE_KEY, 2664 batch_size_for_input_fn) 2665 2666 # For export_savedmodel, input_fn is never passed to Estimator. So, 2667 # `is_export_mode` must be False. 2668 if ctx.is_running_on_cpu(is_export_mode=False): 2669 with ops.device('/device:CPU:0'): 2670 return input_fn(**kwargs) 2671 2672 # For TPU computation, input_fn should be invoked in a tf.while_loop for 2673 # performance. While constructing the tf.while_loop, the structure of 2674 # inputs returned by the `input_fn` needs to be recorded. The structure 2675 # includes whether features or labels is dict or single Tensor, dict keys, 2676 # tensor shapes, and dtypes. The recorded structure is used to create the 2677 # infeed dequeue ops, which must be wrapped and passed as a Fn, called 2678 # inside the TPU computation, as the TPU computation is wrapped inside a 2679 # tf.while_loop also. So, we either pass input_fn to model_fn or pass 2680 # dequeue_fn to model_fn. Here, `input_fn` is passed directly as 2681 # `features` in `model_fn` signature. 2682 def _input_fn(ctx): 2683 _add_item_to_params(kwargs['params'], _CTX_KEY, ctx) 2684 return input_fn(**kwargs) 2685 2686 return _input_fn 2687 2688 def _validate_features_in_predict_input(self, result): 2689 """Skip the validation. 2690 2691 For TPUEstimator, we do not need to check the result type. `_InputPipeline` 2692 has stronger check. Parent class's check generates confusing warning msg. 2693 2694 Args: 2695 result: `features` returned by input_fn. 2696 """ 2697 pass 2698 2699 def train(self, 2700 input_fn, 2701 hooks=None, 2702 steps=None, 2703 max_steps=None, 2704 saving_listeners=None): 2705 rendezvous = error_handling.ErrorRendezvous(num_sources=3) 2706 self._rendezvous[model_fn_lib.ModeKeys.TRAIN] = rendezvous 2707 try: 2708 return super(TPUEstimator, self).train( 2709 input_fn=input_fn, 2710 hooks=hooks, 2711 steps=steps, 2712 max_steps=max_steps, 2713 saving_listeners=saving_listeners) 2714 except Exception: # pylint: disable=broad-except 2715 rendezvous.record_error('training_loop', sys.exc_info()) 2716 finally: 2717 rendezvous.record_done('training_loop') 2718 rendezvous.raise_errors() 2719 2720 def evaluate(self, 2721 input_fn, 2722 steps=None, 2723 hooks=None, 2724 checkpoint_path=None, 2725 name=None): 2726 rendezvous = error_handling.ErrorRendezvous(num_sources=3) 2727 self._rendezvous[model_fn_lib.ModeKeys.EVAL] = rendezvous 2728 try: 2729 return super(TPUEstimator, self).evaluate( 2730 input_fn, 2731 steps=steps, 2732 hooks=hooks, 2733 checkpoint_path=checkpoint_path, 2734 name=name) 2735 except Exception: # pylint: disable=broad-except 2736 rendezvous.record_error('evaluation_loop', sys.exc_info()) 2737 finally: 2738 rendezvous.record_done('evaluation_loop') 2739 rendezvous.raise_errors() 2740 2741 def predict(self, 2742 input_fn, 2743 predict_keys=None, 2744 hooks=None, 2745 checkpoint_path=None, 2746 yield_single_examples=True): 2747 rendezvous = error_handling.ErrorRendezvous(num_sources=3) 2748 self._rendezvous[model_fn_lib.ModeKeys.PREDICT] = rendezvous 2749 try: 2750 for result in super(TPUEstimator, self).predict( 2751 input_fn=input_fn, 2752 predict_keys=predict_keys, 2753 hooks=hooks, 2754 checkpoint_path=checkpoint_path, 2755 yield_single_examples=yield_single_examples): 2756 yield result 2757 except Exception: # pylint: disable=broad-except 2758 rendezvous.record_error('prediction_loop', sys.exc_info()) 2759 finally: 2760 rendezvous.record_done('prediction_loop') 2761 rendezvous.raise_errors() 2762 2763 rendezvous.record_done('prediction_loop') 2764 rendezvous.raise_errors() 2765 2766 def _augment_model_fn(self, model_fn, batch_axis): 2767 """Returns a new model_fn, which wraps the TPU support.""" 2768 2769 def _model_fn(features, labels, mode, config, params): 2770 """A Estimator `model_fn` for TPUEstimator.""" 2771 2772 # `input_fn` is called in `train()`, `evaluate()`, and `predict()`, 2773 # but not in `export_savedmodel()`. 2774 if self._is_input_fn_invoked: 2775 is_export_mode = False 2776 else: 2777 is_export_mode = True 2778 2779 # Clear the bit. 2780 self._is_input_fn_invoked = None 2781 2782 if is_export_mode: 2783 if mode == _REWRITE_FOR_INFERENCE_MODE: 2784 _add_item_to_params(params, _USE_TPU_KEY, True) 2785 mode = model_fn_lib.ModeKeys.PREDICT 2786 else: 2787 _add_item_to_params(params, _USE_TPU_KEY, False) 2788 2789 with self._ctx.with_mode(mode) as ctx: 2790 model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx) 2791 2792 # examples_hook is added to training_hooks for both CPU and TPU 2793 # execution. 2794 if self._log_every_n_steps is not None: 2795 examples_hook = ExamplesPerSecondHook( 2796 ctx.global_batch_size, 2797 # pylint:disable=g-long-ternary 2798 output_dir=(self.model_dir 2799 if not config or config.save_summary_steps 2800 else None), 2801 # pylint:enable=g-long-ternary 2802 every_n_steps=self._log_every_n_steps) 2803 2804 if ctx.is_running_on_cpu(is_export_mode=is_export_mode): 2805 logging.info('Running %s on CPU', mode) 2806 estimator_spec = model_fn_wrapper.call_without_tpu( 2807 features, labels, is_export_mode=is_export_mode) 2808 if self._log_every_n_steps is not None: 2809 estimator_spec = estimator_spec._replace( 2810 training_hooks=estimator_spec.training_hooks + (examples_hook,)) 2811 return estimator_spec 2812 2813 assert labels is None, '`labels` passed to `model_fn` must be `None`.' 2814 # TPUEstimator._call_input_fn passes `input_fn` as features to here. 2815 assert callable(features), '`input_fn` is not callable.' 2816 input_fn = features 2817 2818 tpu_init_ops = [] 2819 if ctx.embedding_config and mode == model_fn_lib.ModeKeys.TRAIN: 2820 dummy_table_variables, dummy_table_variables_init = ( 2821 tpu_embedding_gradient.create_dummy_table_variables( 2822 ctx.embedding_config.tpu_embedding)) 2823 ctx.embedding_config.dummy_table_variables = dummy_table_variables 2824 tpu_init_ops.append(dummy_table_variables_init) 2825 2826 input_holders = _InputPipeline(input_fn, batch_axis, ctx) 2827 enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = ( 2828 input_holders.generate_infeed_enqueue_ops_and_dequeue_fn()) 2829 2830 graph = ops.get_default_graph() 2831 for enqueue_op in enqueue_ops: 2832 if isinstance(enqueue_op, list): 2833 graph.get_collection_ref(_TPU_ENQUEUE_OPS).extend(enqueue_op) 2834 else: 2835 graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op) 2836 2837 if mode == model_fn_lib.ModeKeys.TRAIN: 2838 compile_op, loss, host_call, scaffold, training_hooks = ( 2839 _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) 2840 if ctx.embedding_config: 2841 g = ops.get_default_graph() 2842 table_to_config_dict = ( 2843 ctx.embedding_config.tpu_embedding.table_to_config_dict) 2844 optimization_parameters = ( 2845 ctx.embedding_config.tpu_embedding.optimization_parameters) 2846 embedding_variable_name_by_table, slot_variable_names_by_table = ( 2847 _tpu_estimator_embedding.get_full_variable_names( 2848 g, table_to_config_dict, optimization_parameters 2849 ) 2850 ) 2851 embedding_variables_and_ops = ( 2852 ctx.embedding_config.tpu_embedding.create_variables_and_ops( 2853 embedding_variable_name_by_table, 2854 slot_variable_names_by_table 2855 )) 2856 tpu_init_ops.extend(embedding_variables_and_ops.load_ops()) 2857 2858 host_ops = host_call.create_tpu_hostcall() 2859 if host_ops is None: 2860 host_ops = [] 2861 2862 shutdown_hooks = [] 2863 shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE', 2864 'shutdown_worker') 2865 if shutdown_mode: 2866 if shutdown_mode == 'shutdown_worker': 2867 finalizer_hooks = [ 2868 session_support.ShutdownLameWorkers(timeout_ms=60 * 1000), 2869 ] 2870 elif shutdown_mode == 'shutdown_computation': 2871 finalizer_hooks = [ 2872 session_support.RestartComputation(timeout_ms=60 * 1000), 2873 ] 2874 else: 2875 raise ValueError( 2876 'Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' % shutdown_mode) 2877 2878 shutdown_hooks.append( 2879 session_support.GracefulShutdownHook( 2880 checkpoint_prefix=self.model_dir + '/model.ckpt', 2881 on_shutdown_hooks=finalizer_hooks)) 2882 2883 with ops.control_dependencies([loss]): 2884 global_step = array_ops.identity(training.get_global_step()) 2885 hooks = input_hooks + shutdown_hooks 2886 hooks.extend([ 2887 TPUInfeedOutfeedSessionHook( 2888 ctx, 2889 enqueue_ops, 2890 host_ops, 2891 tpu_compile_op=compile_op, 2892 run_infeed_loop_on_coordinator=( 2893 run_infeed_loop_on_coordinator), 2894 rendezvous=self._rendezvous[mode], 2895 master=self._config.master, 2896 session_config=self._session_config, 2897 tpu_init_ops=tpu_init_ops), 2898 InstallSignalHandlerHook() 2899 ]) 2900 if self._log_every_n_steps is not None: 2901 logging_hook_frequency = ( # Divide and round up 2902 (self._log_every_n_steps + 2903 self._config.tpu_config.iterations_per_loop - 1) // 2904 self._config.tpu_config.iterations_per_loop) 2905 hooks.append( 2906 training.LoggingTensorHook({ 2907 'loss': array_ops.identity(loss), 2908 'step': global_step, 2909 }, 2910 every_n_iter=logging_hook_frequency)) 2911 examples_hook._set_steps_per_run( # pylint: disable=protected-access 2912 self._config.tpu_config.iterations_per_loop) 2913 hooks.append(examples_hook) 2914 2915 if training_hooks: 2916 hooks.extend(training_hooks) 2917 2918 chief_hooks = [] 2919 if (self._config.save_checkpoints_secs or 2920 self._config.save_checkpoints_steps): 2921 checkpoint_hook = training.CheckpointSaverHook( 2922 self.model_dir, 2923 save_secs=self._config.save_checkpoints_secs, 2924 save_steps=self._config.save_checkpoints_steps, 2925 scaffold=scaffold) 2926 checkpoint_hook._set_steps_per_run( # pylint: disable=protected-access 2927 self._config.tpu_config.iterations_per_loop) 2928 chief_hooks.append(checkpoint_hook) 2929 2930 summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) 2931 with ops.control_dependencies([loss]): 2932 update_ops = _sync_variables_ops(ctx) 2933 if ctx.embedding_config: 2934 update_ops.extend(embedding_variables_and_ops.retrieve_ops()) 2935 2936 # Validate the TPU training graph to catch basic errors 2937 _validate_tpu_training_graph() 2938 2939 train_op = control_flow_ops.group(*update_ops) 2940 graph.add_to_collection(_TPU_TRAIN_OP, train_op) 2941 2942 return model_fn_lib.EstimatorSpec( 2943 mode, 2944 loss=loss, 2945 training_chief_hooks=chief_hooks, 2946 training_hooks=hooks, 2947 train_op=train_op, 2948 scaffold=scaffold) 2949 2950 if mode == model_fn_lib.ModeKeys.EVAL: 2951 compile_op, total_loss, host_calls, scaffold, eval_hooks = ( 2952 _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) 2953 if ctx.embedding_config: 2954 g = ops.get_default_graph() 2955 table_to_config_dict = ( 2956 ctx.embedding_config.tpu_embedding.table_to_config_dict) 2957 embedding_variable_name_by_table, _ = ( 2958 _tpu_estimator_embedding.get_full_variable_names( 2959 g, table_to_config_dict) 2960 ) 2961 embedding_variables_and_ops = ( 2962 ctx.embedding_config.tpu_embedding.create_variables_and_ops( 2963 embedding_variable_name_by_table 2964 )) 2965 tpu_init_ops.extend(embedding_variables_and_ops.load_ops()) 2966 iterations_per_loop_var = _create_or_get_iterations_per_loop() 2967 mean_loss = math_ops.div( 2968 total_loss, 2969 math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype)) 2970 2971 with ops.control_dependencies([mean_loss]): 2972 # After TPU evaluation computation is done (the mean_loss tensor), 2973 # reads all variables back from TPU and updates the eval step 2974 # counter properly 2975 internal_ops_to_run = _sync_variables_ops(ctx) 2976 internal_ops_to_run.append( 2977 _increase_eval_step_op(iterations_per_loop_var)) 2978 2979 host_call_ret = host_calls.create_tpu_hostcall() 2980 eval_metric_ops = {} 2981 eval_update_ops = [] 2982 2983 eval_metrics = host_call_ret.get('eval_metrics', {}) 2984 if eval_metrics: 2985 # Creates a dummy metric update_op for all metrics. Estimator 2986 # expects all metrics in `eval_metric_ops` have update_op and calls 2987 # them one by one. The real metric update_ops are invoked in a 2988 # separated thread. So, here give Estimator the dummy op for all 2989 # metrics. 2990 with ops.control_dependencies(internal_ops_to_run): 2991 dummy_update_op = control_flow_ops.no_op() 2992 2993 for k, v in eval_metrics.items(): 2994 eval_metric_ops[k] = (v[0], dummy_update_op) 2995 eval_update_ops.append(v[1]) 2996 else: 2997 # If no eval metrics are passed, create an identity node for the 2998 # loss and add `internal_ops_to_run` to its dependencies. So 2999 # `internal_ops_to_run` can be executed. 3000 with ops.control_dependencies(internal_ops_to_run): 3001 mean_loss = array_ops.identity(mean_loss) 3002 3003 if 'host_call' not in host_call_ret: 3004 host_ops = [] 3005 else: 3006 host_ops = host_call_ret['host_call'] 3007 hooks = [ 3008 TPUInfeedOutfeedSessionHook( 3009 ctx, 3010 enqueue_ops, 3011 eval_update_ops + host_ops, 3012 tpu_compile_op=compile_op, 3013 run_infeed_loop_on_coordinator=( 3014 run_infeed_loop_on_coordinator), 3015 rendezvous=self._rendezvous[mode], 3016 master=self._config.evaluation_master, 3017 session_config=self._session_config, 3018 tpu_init_ops=tpu_init_ops) 3019 ] + input_hooks 3020 3021 if eval_hooks: 3022 hooks.extend(eval_hooks) 3023 3024 return model_fn_lib.EstimatorSpec( 3025 mode, 3026 loss=mean_loss, 3027 evaluation_hooks=hooks, 3028 eval_metric_ops=eval_metric_ops, 3029 scaffold=scaffold) 3030 3031 # Predict 3032 assert mode == model_fn_lib.ModeKeys.PREDICT 3033 3034 (compile_op, dummy_predict_op, host_calls, 3035 scaffold, prediction_hooks) = _predict_on_tpu_system( 3036 ctx, model_fn_wrapper, dequeue_fn) 3037 with ops.control_dependencies([dummy_predict_op]): 3038 internal_ops_to_run = _sync_variables_ops(ctx) 3039 with ops.control_dependencies(internal_ops_to_run): 3040 dummy_predict_op = control_flow_ops.no_op() 3041 3042 # In train and evaluation, the main TPU program is passed to monitored 3043 # training session to run. Infeed enqueue and outfeed dequeue are 3044 # executed in side threads. This is not the configuration for 3045 # prediction mode. 3046 # 3047 # For prediction, the Estimator executes the EstimatorSpec.predictions 3048 # directly and yield the element (via generator) to call site. So, the 3049 # outfeed based prediction must be passed to MonitoredSession directly. 3050 # Other parts of the TPU execution are organized as follows. 3051 # 3052 # 1. All outfeed based Tensors must be grouped with predictions Tensors 3053 # to form a single invocation. This avoid the issue we might trigger 3054 # multiple outfeeds incorrectly. To achieve this, `host_call` is 3055 # placed in control_dependencies of `stopping_signals`, and 3056 # `stopping_signals` is passed into _StoppingPredictHook, which sets 3057 # the `stopping_signals` as SessionRunArgs. MonitoredSession merges 3058 # all SessionRunArgs with the fetch in session.run together. 3059 # 3060 # 2. The TPU program (dummy_predict_op) and enqueue_ops (infeed Enqueue) 3061 # are grouped together. They will be launched once and only once in 3062 # side threads and they quit naturally according to the SAME stopping 3063 # condition. 3064 enqueue_ops.append(dummy_predict_op) 3065 3066 host_call_ret = host_calls.create_tpu_hostcall() 3067 if 'host_call' not in host_call_ret: 3068 host_ops = [] 3069 else: 3070 host_ops = host_call_ret['host_call'] 3071 3072 predictions = host_call_ret['predictions'] 3073 _verify_cross_hosts_transfer_size( 3074 predictions, 3075 message=( 3076 'The estimated size for TPUEstimatorSpec.predictions is too ' 3077 'large.')) 3078 signals = host_call_ret['signals'] 3079 3080 with ops.control_dependencies(host_ops): 3081 host_ops = [] # Empty, we do do not need it anymore. 3082 scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal( 3083 signals) 3084 predictions = _PaddingSignals.slice_tensor_or_dict( 3085 predictions, signals) 3086 3087 hooks = [ 3088 _StoppingPredictHook(scalar_stopping_signal), 3089 TPUInfeedOutfeedSessionHookForPrediction( 3090 ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode], 3091 tpu_compile_op=compile_op, 3092 master=self._config.master, 3093 session_config=self._session_config), 3094 ] + input_hooks 3095 3096 if prediction_hooks: 3097 hooks.extend(prediction_hooks) 3098 3099 return model_fn_lib.EstimatorSpec( 3100 mode, 3101 prediction_hooks=hooks, 3102 predictions=predictions, 3103 scaffold=scaffold) 3104 3105 return _model_fn 3106 3107 3108 def _export_output_to_tensors(export_output): 3109 """Get a list of `Tensors` used in `export_output`. 3110 3111 Args: 3112 export_output: an `ExportOutput` object such as `ClassificationOutput`, 3113 `RegressionOutput`, or `PredictOutput`. 3114 3115 Returns: 3116 a list of tensors used in export_output. 3117 3118 Raises: 3119 ValueError: if `export_output` is not one of `ClassificationOutput`, 3120 `RegressionOutput`, or `PredictOutput`. 3121 """ 3122 if isinstance(export_output, export_output_lib.ClassificationOutput): 3123 return [export_output.scores, export_output.classes] 3124 elif isinstance(export_output, export_output_lib.RegressionOutput): 3125 return [export_output.value] 3126 elif isinstance(export_output, export_output_lib.PredictOutput): 3127 return list(export_output.outputs.values()) 3128 else: 3129 raise ValueError( 3130 '`export_output` must be have type `ClassificationOutput`, ' 3131 '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) 3132 3133 3134 def _clone_export_output_with_tensors(export_output, tensors): 3135 """Clones `export_output` but with new `tensors`. 3136 3137 Args: 3138 export_output: an `ExportOutput` object such as `ClassificationOutput`, 3139 `RegressionOutput`, or `PredictOutput`. 3140 tensors: a list of `Tensors` used to construct a new `export_output`. 3141 3142 Returns: 3143 A dict similar to `export_output` but with `tensors`. 3144 3145 Raises: 3146 ValueError: if `export_output` is not one of `ClassificationOutput`, 3147 `RegressionOutput`, or `PredictOutput`. 3148 """ 3149 if isinstance(export_output, export_output_lib.ClassificationOutput): 3150 if len(tensors) != 2: 3151 raise ValueError('tensors must be of length 2; ' 3152 'got {}.'.format(len(tensors))) 3153 return export_output_lib.ClassificationOutput(*tensors) 3154 elif isinstance(export_output, export_output_lib.RegressionOutput): 3155 if len(tensors) != 1: 3156 raise ValueError('tensors must be of length 1; ' 3157 'got {}'.format(len(tensors))) 3158 return export_output_lib.RegressionOutput(*tensors) 3159 elif isinstance(export_output, export_output_lib.PredictOutput): 3160 return export_output_lib.PredictOutput( 3161 dict(zip(export_output.outputs.keys(), tensors))) 3162 else: 3163 raise ValueError( 3164 '`export_output` must be have type `ClassificationOutput`, ' 3165 '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) 3166 3167 3168 def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): 3169 """Executes `model_fn_wrapper` multiple times on all TPU shards.""" 3170 iterations_per_loop_var = _create_or_get_iterations_per_loop() 3171 3172 (single_tpu_eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks 3173 ) = model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn) 3174 3175 @tpu_function.on_device_training_loop 3176 def multi_tpu_eval_steps_on_single_shard(): 3177 return training_loop.repeat(iterations_per_loop_var, single_tpu_eval_step, 3178 [_ZERO_LOSS]) 3179 3180 (compile_op, loss,) = tpu.split_compile_and_shard( 3181 multi_tpu_eval_steps_on_single_shard, 3182 inputs=[], 3183 num_shards=ctx.num_replicas, 3184 outputs_from_all_shards=False, 3185 device_assignment=ctx.device_assignment) 3186 3187 loss = loss[0] 3188 scaffold = _get_scaffold(captured_scaffold_fn) 3189 return compile_op, loss, host_calls, scaffold, captured_eval_hooks.get() 3190 3191 3192 def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): 3193 """Executes `model_fn_wrapper` multiple times on all TPU shards.""" 3194 iterations_per_loop_var = _create_or_get_iterations_per_loop() 3195 3196 (single_tpu_train_step, host_call, captured_scaffold_fn, 3197 captured_training_hooks) = ( 3198 model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn)) 3199 3200 @tpu_function.on_device_training_loop 3201 def multi_tpu_train_steps_on_single_shard(): 3202 return training_loop.repeat(iterations_per_loop_var, single_tpu_train_step, 3203 [_INITIAL_LOSS]) 3204 3205 (compile_op, loss,) = tpu.split_compile_and_shard( 3206 multi_tpu_train_steps_on_single_shard, 3207 inputs=[], 3208 num_shards=ctx.num_replicas, 3209 outputs_from_all_shards=False, 3210 device_assignment=ctx.device_assignment) 3211 3212 loss = loss[0] 3213 scaffold = _get_scaffold(captured_scaffold_fn) 3214 return compile_op, loss, host_call, scaffold, captured_training_hooks.get() 3215 3216 3217 def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): 3218 """Executes `model_fn_wrapper` multiple times on all TPU shards.""" 3219 (single_tpu_predict_step, host_calls, captured_scaffold_fn, 3220 captured_predict_hooks 3221 ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn) 3222 3223 @tpu_function.on_device_training_loop 3224 def multi_tpu_predict_steps_on_single_shard(): 3225 3226 def cond(scalar_stopping_signal): 3227 return math_ops.logical_not( 3228 _StopSignals.should_stop(scalar_stopping_signal)) 3229 3230 inputs = [_StopSignals.NON_STOPPING_SIGNAL] 3231 outputs = training_loop.while_loop( 3232 cond, single_tpu_predict_step, inputs=inputs, name=b'loop') 3233 return outputs 3234 3235 (compile_op, dummy_predict_op,) = tpu.split_compile_and_shard( 3236 multi_tpu_predict_steps_on_single_shard, 3237 inputs=[], 3238 num_shards=ctx.num_replicas, 3239 outputs_from_all_shards=False, 3240 device_assignment=ctx.device_assignment) 3241 3242 dummy_predict_op = dummy_predict_op[0] 3243 scaffold = _get_scaffold(captured_scaffold_fn) 3244 return (compile_op, dummy_predict_op, host_calls, scaffold, 3245 captured_predict_hooks.get()) 3246 3247 3248 def _wrap_computation_in_while_loop(device, op_fn): 3249 """Wraps the ops generated by `op_fn` in tf.while_loop.""" 3250 3251 def computation(i): 3252 with ops.control_dependencies(op_fn()): 3253 return i + 1 3254 3255 iterations_per_loop_var = _create_or_get_iterations_per_loop() 3256 # By setting parallel_iterations=1, the parallel execution in while_loop is 3257 # basically turned off. 3258 with ops.device(device): 3259 iterations = array_ops.identity(iterations_per_loop_var) 3260 return control_flow_ops.while_loop( 3261 lambda i: i < iterations, 3262 computation, [constant_op.constant(0)], 3263 parallel_iterations=1) 3264 3265 3266 def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn): 3267 """Wraps the ops generated by `op_fn` in tf.while_loop.""" 3268 3269 def cond(scalar_stopping_signal): 3270 return math_ops.logical_not( 3271 _StopSignals.should_stop(scalar_stopping_signal)) 3272 3273 def computation(unused_scalar_stopping_signal): 3274 return_value = op_fn() 3275 execute_ops = return_value['ops'] 3276 signals = return_value['signals'] 3277 with ops.control_dependencies(execute_ops): 3278 return _StopSignals.as_scalar_stopping_signal(signals) 3279 3280 # By setting parallel_iterations=1, the parallel execution in while_loop is 3281 # basically turned off. 3282 with ops.device(device): 3283 return control_flow_ops.while_loop( 3284 cond, 3285 computation, [_StopSignals.NON_STOPPING_SIGNAL], 3286 parallel_iterations=1) 3287 3288 3289 def _validate_tpu_training_graph(): 3290 """Validate graph before running distributed training. 3291 3292 Raises: 3293 ValueError: If the graph seems invalid for running on device 3294 """ 3295 operations = ops.get_default_graph().get_operations() 3296 3297 # Check if there is atleast one CrossReplicaSum operation in the graph 3298 # This should be introduced by using the CrossShardOptimizer wrapper 3299 cross_replica_sum_ops = [ 3300 o for o in operations if o.type == _CROSS_REPLICA_SUM_OP 3301 ] 3302 if not cross_replica_sum_ops: 3303 raise ValueError( 3304 'CrossShardOptimizer must be used for model training on TPUs.') 3305 3306 3307 class _CapturedObject(object): 3308 """A placeholder to capture an object. 3309 3310 This is useful when we need to capture a Python object in the Tensorflow 3311 control flow body function and use it outside the control flow. 3312 """ 3313 3314 def __init__(self): 3315 self._object = None 3316 self._captured = False 3317 3318 def capture(self, o): 3319 if self._captured: 3320 raise RuntimeError( 3321 'InternalError: Object can capture only once. Please file bug.') 3322 3323 self._captured = True 3324 self._object = o 3325 3326 def get(self): 3327 if not self._captured: 3328 raise RuntimeError( 3329 'InternalError: Object is not captured properly before `get`. ' 3330 'Please file bug.') 3331 return self._object 3332 3333 3334 def _get_scaffold(captured_scaffold_fn): 3335 """Retrieves the Scaffold from `captured_scaffold_fn`.""" 3336 with _CapturingContext(message='Inside scaffold_fn'): 3337 scaffold_fn = captured_scaffold_fn.get() 3338 if scaffold_fn: 3339 scaffold = scaffold_fn() 3340 if scaffold is None: 3341 raise ValueError( 3342 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed') 3343 else: 3344 scaffold = None 3345 3346 if scaffold: 3347 wrapped_finalize = scaffold.finalize 3348 3349 def _finalize(): 3350 with _CapturingContext('Inside Scaffold.finalize'): 3351 wrapped_finalize() 3352 3353 scaffold.finalize = _finalize 3354 return scaffold 3355 3356 3357 class _CapturingContext(control_flow_ops.ControlFlowContext): 3358 """Tracks references to Tensors defined in TPU replication.""" 3359 3360 def __init__(self, message): 3361 control_flow_ops.ControlFlowContext.__init__(self) 3362 self._message = message 3363 3364 def to_control_flow_context_def(self, context_def, export_scope=None): 3365 # pylint: disable=useless-super-delegation 3366 # NOTE(slebedev): the method is required by `ControlFlowContext`. 3367 super(_CapturingContext, self).to_control_flow_context_def( 3368 context_def, export_scope) 3369 3370 def AddOp(self, op): # pylint: disable=invalid-name 3371 for c in op.inputs: 3372 if tpu._TPU_REPLICATE_ATTR in c.op.node_def.attr: # pylint: disable=protected-access 3373 raise ValueError('{}: Op {} depends on TPU computation {}, ' 3374 'which is not allowed.'.format(self._message, op, c)) 3375 3376 def __enter__(self): 3377 # pylint: disable=protected-access 3378 self._g = ops.get_default_graph() 3379 self._old = self._g._get_control_flow_context() 3380 self._g._set_control_flow_context(self) 3381 # pylint: enable=protected-access 3382 3383 def __exit__(self, _, __, ___): # pylint: disable=invalid-name 3384 self._g._set_control_flow_context(self._old) # pylint: disable=protected-access 3385 3386 3387 class _Inputs(object): 3388 """A data structure representing the input_fn returned values. 3389 3390 This also supports the returned value from input_fn as `Dataset`. 3391 """ 3392 3393 def __init__(self, features=None, labels=None, dataset=None, signals=None): 3394 if dataset is not None and (features is not None or labels is not None or 3395 signals is not None): 3396 raise RuntimeError('Internal Error: Either (features and labels) or ' 3397 'dataset should be provided, not both. Please file ' 3398 'bug') 3399 3400 self._features = features 3401 self._labels = labels 3402 self._signals = signals 3403 3404 self._dataset = dataset 3405 self._iterator = None 3406 3407 @staticmethod 3408 def from_input_fn(return_values): 3409 """Returns an `_Inputs` instance according to `input_fn` return value.""" 3410 if isinstance(return_values, dataset_ops.DatasetV2): 3411 dataset = return_values 3412 return _Inputs(dataset=dataset) 3413 3414 features, labels = _Inputs._parse_inputs(return_values) 3415 return _Inputs(features, labels) 3416 3417 @staticmethod 3418 def _parse_inputs(return_values): 3419 if isinstance(return_values, tuple): 3420 features, labels = return_values 3421 else: 3422 features, labels = return_values, None 3423 return features, labels 3424 3425 @property 3426 def is_dataset(self): 3427 """Returns True if the return value from input_fn is Dataset.""" 3428 return self._dataset is not None 3429 3430 def dataset_initializer(self): 3431 """Returns the dataset's initializer. 3432 3433 The initializer must be run before calling `features_and_labels`. 3434 """ 3435 self._iterator = dataset_ops.make_initializable_iterator(self._dataset) 3436 return self._iterator.initializer 3437 3438 def features_and_labels(self): 3439 """Gets `features` and `labels`.""" 3440 if self.is_dataset: 3441 if self._iterator is None: 3442 raise RuntimeError('Internal error: Must run dataset_initializer ' 3443 'before calling features_and_labels(). Please file ' 3444 'a bug!') 3445 return _Inputs._parse_inputs(self._iterator.get_next()) 3446 3447 return (self._features, self._labels) 3448 3449 def signals(self): 3450 return self._signals 3451 3452 @property 3453 def dataset(self): 3454 return self._dataset 3455 3456 3457 class _InputsWithStoppingSignals(_Inputs): 3458 """Inputs with `_StopSignals` inserted into the dataset.""" 3459 3460 def __init__(self, 3461 dataset, 3462 batch_size, 3463 add_padding=False, 3464 num_invocations_per_step=1): 3465 3466 assert dataset is not None 3467 user_provided_dataset = dataset.map( 3468 _InputsWithStoppingSignals.insert_stopping_signal( 3469 stop=False, batch_size=batch_size, add_padding=add_padding)) 3470 if num_invocations_per_step == 1: 3471 final_batch_dataset = dataset.take(1).map( 3472 _InputsWithStoppingSignals.insert_stopping_signal( 3473 stop=True, batch_size=batch_size, add_padding=add_padding)) 3474 else: 3475 # We append (2 * num_invocations_per_step - 1) batches for exhausting the 3476 # user_provided_dataset and stop properly. 3477 # For example, if num_invocations_per_step is 2, we append 3 additional 3478 # padding batches: b1, b2, b3. 3479 # If user_provided_dataset contains two batches: a1, a2 3480 # Step 1: [a1, a2] 3481 # Step 2: [b1, b2] -> STOP 3482 # If user_provided_dataset contains three batches: a1, a2, a3. 3483 # The training loops: 3484 # Step 1: [a1, a2] 3485 # Step 2: [a3, b1] 3486 # Step 3: [b2, b3] -> STOP. 3487 final_batch_dataset = dataset.take(1).map( 3488 _InputsWithStoppingSignals.insert_stopping_signal( 3489 stop=True, batch_size=batch_size, add_padding=add_padding)) 3490 final_batch_dataset = final_batch_dataset.repeat( 3491 2 * num_invocations_per_step - 1) 3492 3493 def _set_mask(data_dict): 3494 signals = data_dict['signals'] 3495 signals['padding_mask'] = array_ops.ones_like(signals['padding_mask']) 3496 data_dict['signals'] = signals 3497 return data_dict 3498 3499 # Mask out the extra batch. 3500 final_batch_dataset = final_batch_dataset.map(_set_mask) 3501 3502 dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2) 3503 3504 super(_InputsWithStoppingSignals, self).__init__(dataset=dataset) 3505 self._current_inputs = None 3506 3507 def features_and_labels(self): 3508 if self._current_inputs is not None: 3509 raise RuntimeError( 3510 'Internal Error: The previous inputs have not been properly ' 3511 'consumed. First call features_and_labels, then call signals.') 3512 3513 inputs_with_signals = self._iterator.get_next() 3514 features = inputs_with_signals['features'] 3515 labels = inputs_with_signals.get('labels') 3516 3517 self._current_inputs = inputs_with_signals 3518 return features, labels 3519 3520 def signals(self): 3521 """Returns the `Signals` from `_Inputs`.""" 3522 if self._current_inputs is None: 3523 raise RuntimeError( 3524 'Internal Error: The current inputs have not been properly ' 3525 'generated. First call features_and_labels, then call signals.') 3526 signals = self._current_inputs['signals'] 3527 self._current_inputs = None 3528 return signals 3529 3530 @staticmethod 3531 def insert_stopping_signal(stop, batch_size, add_padding=False): 3532 """Inserts stopping_signal into dataset via _map_fn. 3533 3534 Here we change the data structure in the dataset, such that the return value 3535 is a dictionary now and `features`, `labels`, and `signals` are three 3536 distinguished keys in that dict. This provides a better structure, which 3537 eases the process to decompose the inputs (see `features_and_labels`). 3538 3539 Args: 3540 stop: bool, state of current stopping signals. 3541 batch_size: int, batch size. 3542 add_padding: bool, whether to pad the tensor to full batch size. 3543 3544 Returns: 3545 A map_fn passed to dataset.map API. 3546 """ 3547 3548 def _map_fn(*args): 3549 """The map fn to insert signals.""" 3550 if len(args) == 1: 3551 # Unpack the single Tensor/dict argument as features. This is required 3552 # for the input_fn returns no labels. 3553 args = args[0] 3554 features, labels = _Inputs._parse_inputs(args) 3555 new_input_dict = {} 3556 3557 if add_padding: 3558 padding_mask, features, labels = ( 3559 _PaddingSignals.pad_features_and_labels(features, labels, 3560 batch_size)) 3561 3562 new_input_dict['features'] = features 3563 if labels is not None: 3564 new_input_dict['labels'] = labels 3565 3566 else: 3567 new_input_dict['features'] = features 3568 if labels is not None: 3569 new_input_dict['labels'] = labels 3570 padding_mask = None 3571 3572 new_input_dict['signals'] = _StopSignals( 3573 stop=stop, batch_size=batch_size, 3574 padding_mask=padding_mask).as_dict() 3575 3576 return new_input_dict 3577 3578 return _map_fn 3579 3580 3581 class _StopSignals(object): 3582 """Signals class holding all logic to handle TPU stopping condition.""" 3583 3584 NON_STOPPING_SIGNAL = False 3585 STOPPING_SIGNAL = True 3586 3587 def __init__(self, stop, batch_size, padding_mask=None): 3588 self._stop = stop 3589 self._batch_size = batch_size 3590 self._padding_mask = padding_mask 3591 3592 def as_dict(self): 3593 """Returns the signals as Python dict.""" 3594 shape = [self._batch_size, 1] 3595 dtype = dtypes.bool 3596 3597 if self._stop: 3598 stopping = array_ops.ones(shape=shape, dtype=dtype) 3599 else: 3600 stopping = array_ops.zeros(shape=shape, dtype=dtype) 3601 3602 signals = {'stopping': stopping} 3603 if self._padding_mask is not None: 3604 signals['padding_mask'] = self._padding_mask 3605 return signals 3606 3607 @staticmethod 3608 def as_scalar_stopping_signal(signals): 3609 return array_ops.identity(signals['stopping'][0][0]) 3610 3611 @staticmethod 3612 def should_stop(scalar_stopping_signal): 3613 """Detects whether scalar_stopping_signal indicates stopping.""" 3614 if isinstance(scalar_stopping_signal, ops.Tensor): 3615 # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF 3616 # way to express the bool check whether scalar_stopping_signal is True. 3617 return math_ops.logical_and(scalar_stopping_signal, 3618 _StopSignals.STOPPING_SIGNAL) 3619 else: 3620 # For non Tensor case, it is used in SessionRunHook. So, we cannot modify 3621 # the graph anymore. Here, we use pure Python. 3622 return bool(scalar_stopping_signal) 3623 3624 3625 class _PaddingSignals(object): 3626 """Signals class holding all logic to handle padding.""" 3627 3628 @staticmethod 3629 def pad_features_and_labels(features, labels, batch_size): 3630 """Pads out the batch dimension of features and labels.""" 3631 real_batch_size = array_ops.shape( 3632 _PaddingSignals._find_any_tensor(features))[0] 3633 3634 batch_size_tensor = constant_op.constant(batch_size, dtypes.int32) 3635 3636 check_greater = check_ops.assert_greater_equal( 3637 batch_size_tensor, 3638 real_batch_size, 3639 data=(batch_size_tensor, real_batch_size), 3640 message='The real batch size should not be greater than batch_size.') 3641 3642 with ops.control_dependencies([check_greater]): 3643 missing_count = batch_size_tensor - real_batch_size 3644 3645 def pad_single_tensor(tensor): 3646 """Pads out the batch dimension of a tensor to the complete batch_size.""" 3647 rank = len(tensor.shape) 3648 assert rank > 0 3649 padding = array_ops.stack([[0, missing_count]] + [[0, 0]] * (rank - 1)) 3650 padded_shape = (batch_size,) + tuple(tensor.shape[1:]) 3651 padded_tensor = array_ops.pad(tensor, padding) 3652 padded_tensor.set_shape(padded_shape) 3653 return padded_tensor 3654 3655 def nest_pad(tensor_or_dict): 3656 return nest.map_structure(pad_single_tensor, tensor_or_dict) 3657 3658 features = nest_pad(features) 3659 if labels is not None: 3660 labels = nest_pad(labels) 3661 3662 padding_mask = _PaddingSignals._padding_mask(real_batch_size, missing_count, 3663 batch_size) 3664 3665 return padding_mask, features, labels 3666 3667 @staticmethod 3668 def slice_tensor_or_dict(tensor_or_dict, signals): 3669 """Slice the real Tensors according to padding mask in signals.""" 3670 3671 padding_mask = signals['padding_mask'] 3672 batch_size = array_ops.shape(padding_mask)[0] 3673 3674 def verify_batch_size(tensor): 3675 check_batch_size = math_ops.equal(batch_size, tensor.shape[0]) 3676 with ops.control_dependencies([check_batch_size]): 3677 return array_ops.identity(tensor) 3678 3679 def slice_single_tensor(tensor): 3680 rank = len(tensor.shape) 3681 assert rank > 0 3682 real_batch_size = batch_size - math_ops.reduce_sum(padding_mask) 3683 return verify_batch_size(tensor)[0:real_batch_size] 3684 3685 # As we split the Tensors to all TPU cores and concat them back, it is 3686 # important to ensure the real data is placed before padded ones, i.e., 3687 # order is preserved. By that, the sliced padding mask should have all 0's. 3688 # If this assertion failed, # the slice logic here would not hold. 3689 sliced_padding_mask = slice_single_tensor(padding_mask) 3690 assert_padding_mask = math_ops.equal( 3691 math_ops.reduce_sum(sliced_padding_mask), 0) 3692 3693 with ops.control_dependencies([assert_padding_mask]): 3694 should_stop = _StopSignals.should_stop( 3695 _StopSignals.as_scalar_stopping_signal(signals)) 3696 3697 is_full_batch = math_ops.equal(math_ops.reduce_sum(padding_mask), 0) 3698 3699 def slice_fn(tensor): 3700 # If the current batch is full batch or part of stopping signals, we do 3701 # not need to slice to save performance. 3702 return control_flow_ops.cond( 3703 math_ops.logical_or(should_stop, is_full_batch), 3704 (lambda: verify_batch_size(tensor)), 3705 (lambda: slice_single_tensor(tensor))) 3706 3707 return nest.map_structure(slice_fn, tensor_or_dict) 3708 3709 @staticmethod 3710 def _find_any_tensor(batch_features): 3711 tensors = [ 3712 x for x in nest.flatten(batch_features) if isinstance(x, ops.Tensor) 3713 ] 3714 if not tensors: 3715 raise ValueError('Cannot find any Tensor in features dict.') 3716 return tensors[0] 3717 3718 @staticmethod 3719 def _padding_mask(real_batch_size, missing_count, batch_size): 3720 padding_mask = array_ops.concat([ 3721 array_ops.zeros((real_batch_size,), dtype=dtypes.int32), 3722 array_ops.ones((missing_count,), dtype=dtypes.int32) 3723 ], 3724 axis=0) 3725 padding_mask.set_shape((batch_size,)) 3726 return padding_mask 3727 3728 3729 def _verify_cross_hosts_transfer_size(tensor_dict, message): 3730 total_size = 0 3731 tensor_structure = {} 3732 for key, tensor in tensor_dict.items(): 3733 shape = tensor.shape 3734 size = np.product(shape) * tensor.dtype.size 3735 tensor_structure[key] = shape 3736 total_size += size 3737 if total_size >= _ONE_GIGABYTE: 3738 raise ValueError( 3739 '{} The transfer size is larger than the protobuf limit. Please ' 3740 'consider to use Tensors with smaller shapes or reduce batch ' 3741 'size. Given:\n' 3742 '{}'.format( 3743 message, '\n'.join([ 3744 ' -- Key: {}, Shape: {}'.format(k, v) 3745 for k, v in tensor_structure.items() 3746 ]))) 3747 3748 3749 def _add_item_to_params(params, key, value): 3750 """Adds a new item into `params`.""" 3751 if hasattr(params, 'set_hparam'): 3752 # For HParams, we need to use special API. 3753 if key in params: 3754 params.set_hparam(key, value) 3755 else: 3756 params.add_hparam(key, value) 3757 else: 3758 # Now params is Python dict. 3759 params[key] = value 3760 3761 3762 def export_estimator_savedmodel(estimator, 3763 export_dir_base, 3764 serving_input_receiver_fn, 3765 assets_extra=None, 3766 as_text=False, 3767 checkpoint_path=None, 3768 strip_default_attrs=False): 3769 """Export `Estimator` trained model for TPU inference. 3770 3771 Args: 3772 estimator: `Estimator` with which model has been trained. 3773 export_dir_base: A string containing a directory in which to create 3774 timestamped subdirectories containing exported SavedModels. 3775 serving_input_receiver_fn: A function that takes no argument and returns a 3776 `ServingInputReceiver` or `TensorServingInputReceiver`. 3777 assets_extra: A dict specifying how to populate the assets.extra directory 3778 within the exported SavedModel, or `None` if no extra assets are needed. 3779 as_text: whether to write the SavedModel proto in text format. 3780 checkpoint_path: The checkpoint path to export. If `None` (the default), 3781 the most recent checkpoint found within the model directory is chosen. 3782 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 3783 removed from the NodeDefs. 3784 3785 Returns: 3786 The string path to the exported directory. 3787 """ 3788 # `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use 3789 # `estimator.config`. 3790 config = tpu_config.RunConfig(model_dir=estimator.model_dir) 3791 est = TPUEstimator( 3792 estimator._model_fn, # pylint: disable=protected-access 3793 config=config, 3794 params=estimator.params, 3795 use_tpu=True, 3796 train_batch_size=2048, # Does not matter. 3797 eval_batch_size=2048, # Does not matter. 3798 ) 3799 return est.export_savedmodel(export_dir_base, serving_input_receiver_fn, 3800 assets_extra, as_text, checkpoint_path, 3801 strip_default_attrs) 3802