1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================= 15 """xla is an experimental library that provides XLA support APIs.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import collections 22 import contextlib 23 from six.moves import xrange # pylint: disable=redefined-builtin 24 25 from tensorflow.compiler.jit.ops import xla_ops 26 from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import 27 from tensorflow.core.framework import attr_value_pb2 28 from tensorflow.python.distribute import summary_op_util 29 from tensorflow.python.estimator import model_fn as model_fn_lib 30 from tensorflow.python.framework import ops 31 from tensorflow.python.ops import array_ops 32 from tensorflow.python.ops import control_flow_ops 33 from tensorflow.python.ops import variable_scope 34 from tensorflow.python.platform import tf_logging as logging 35 from tensorflow.python.util import compat 36 from tensorflow.python.util import function_utils 37 from tensorflow.python.util import nest 38 from tensorflow.python.util import tf_decorator 39 from tensorflow.python.util import tf_inspect 40 41 _XLA_COMPILE_ATTR = '_xla_compile_id' 42 _MAX_WARNING_LINES = 5 43 44 # Operations that indicate some error in the users graph. For example, XLA 45 # computation should not have any Placeholder op. 46 _BLACKLISTED_OPS = set([ 47 'Placeholder', 48 ]) 49 50 # XLA doesn't currently support reading of intermediate tensors, thus some ops 51 # are not supported. 52 _UNSUPPORTED_OPS = set([ 53 'AudioSummary', 54 'AudioSummaryV2', 55 'HistogramSummary', 56 'ImageSummary', 57 'MergeSummary', 58 'Print', 59 'ScalarSummary', 60 'TensorSummary', 61 'TensorSummaryV2', 62 ]) 63 64 65 def compile(computation, inputs=None): # pylint: disable=redefined-builtin 66 """Builds an operator that compiles and runs `computation` with XLA. 67 68 Args: 69 computation: A Python function that builds a computation to apply to the 70 input. If the function takes n inputs, 'inputs' should be a list of n 71 tensors. 72 73 `computation` may return a list of operations and tensors. Tensors must 74 come before operations in the returned list. The return value of 75 `compile` is a list of tensors corresponding to the tensors from the 76 output of `computation`. 77 78 All `Operation`s returned from `computation` will be executed when 79 evaluating any of the returned output tensors. 80 inputs: A list of inputs or `None` (equivalent to an empty list). Each input 81 can be a nested structure containing values that are convertible to 82 tensors. Note that passing an N-dimension list of compatible values will 83 result in a N-dimention list of scalar tensors rather than a single Rank-N 84 tensors. If you need different behavior, convert part of inputs to tensors 85 with `tf.convert_to_tensor`. 86 87 Returns: 88 Same data structure as if computation(*inputs) is called directly with some 89 exceptions for correctness. Exceptions include: 90 1) None output: a NoOp would be returned which control-depends on 91 computation. 92 2) Single value output: A tuple containing the value would be returned. 93 3) Operation-only outputs: a NoOp would be returned which 94 control-depends on computation. 95 TODO(b/121383831): Investigate into removing these special cases. 96 """ 97 # pylint: disable=protected-access 98 return _compile_internal(computation, inputs) 99 100 101 class XLACompileContext(control_flow_ops.XLAControlFlowContext): 102 """A `ControlFlowContext` for nodes inside an XLA computation cluster. 103 104 THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY. 105 106 The primary role of `XLACompileContext` is to mark operators inside a 107 xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is 108 a unique name. 109 110 `ControlFlowContext` is used to perform the annotation since it integrates 111 with Tensorflow constructs like ResourceVariables. For example, if a 112 `ResourceVariable` is constructed inside a xla.compile() block, the 113 `ResourceVariable` implementation can use 114 `with ops.control_dependencies(None)` to build the variable's definition 115 outside the compiled computation. 116 """ 117 118 def __init__(self, name, pivot): 119 """Builds a new XLACompileContext. 120 121 Args: 122 name: a unique name for the context, used to populate the 123 `_xla_compile_id` attribute. 124 pivot: a pivot node. Nodes in the XLACompileContext that do not have any 125 inputs will have a control dependency on the pivot node. This ensures 126 that nodes are correctly included in any enclosing control flow 127 contexts. 128 """ 129 super(XLACompileContext, self).__init__() 130 self._name = name 131 self._name_as_bytes = compat.as_bytes(name) 132 self._unsupported_ops = [] 133 self._pivot = pivot 134 135 def report_unsupported_operations(self): 136 if self._unsupported_ops: 137 op_str = '\n'.join([ 138 ' %s (%s)' % (op.type, op.name) 139 for op in self._unsupported_ops[:_MAX_WARNING_LINES] 140 ]) 141 logging.warning('%d unsupported operations found: \n%s', 142 len(self._unsupported_ops), op_str) 143 if len(self._unsupported_ops) > _MAX_WARNING_LINES: 144 logging.warning('... and %d more', 145 len(self._unsupported_ops) - _MAX_WARNING_LINES) 146 147 def _RemoveExternalControlEdges(self, op): 148 """Remove any external control dependency on this op.""" 149 internal_control_inputs = [] 150 external_control_inputs = [] 151 for x in op.control_inputs: 152 # pylint: disable=protected-access 153 is_internal_op = False 154 ctxt = x._get_control_flow_context() 155 while ctxt is not None: 156 if ctxt == self: 157 is_internal_op = True 158 break 159 ctxt = ctxt._outer_context 160 if is_internal_op: 161 internal_control_inputs.append(x) 162 else: 163 external_control_inputs.append(x) 164 # pylint: enable=protected-access 165 # pylint: disable=protected-access 166 op._remove_all_control_inputs() 167 op._add_control_inputs(internal_control_inputs) 168 # pylint: enable=protected-access 169 return internal_control_inputs, external_control_inputs 170 171 def AddOp(self, op): 172 """Create op in XLACompileContext and notifies outer context recursively.""" 173 # pylint: disable=protected-access 174 if op.type in _BLACKLISTED_OPS: 175 logging.error( 176 'Operation of type %s (%s) is not supported in XLA. Execution will ' 177 'fail if this op is used in the graph. ', op.type, op.name) 178 179 # TODO(ycao): Automatically disable summaries instead of reporting them. 180 if op.type in _UNSUPPORTED_OPS: 181 self._unsupported_ops.append(op) 182 183 if any(x.dtype._is_ref_dtype for x in op.inputs): 184 raise NotImplementedError( 185 'Non-resource Variables are not supported inside XLA computations ' 186 '(operator name: %s)' % op.name) 187 188 if _XLA_COMPILE_ATTR in op.node_def.attr: 189 raise ValueError('XLA compiled computations cannot be nested, (operator ' 190 'name: %s)' % op.name) 191 192 op._set_attr( 193 _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes)) 194 195 op.graph.prevent_feeding(op) 196 op.graph.prevent_fetching(op) 197 198 # Remove any control edges from outer control flow contexts. These may cause 199 # mismatched frame errors. An example is when one of op's inputs is 200 # generated in a different While control flow context. 201 (internal_control_inputs, 202 external_control_inputs) = self._RemoveExternalControlEdges(op) 203 204 if not op.inputs: 205 # Add a control edge from the control pivot to this op. 206 if not internal_control_inputs: 207 # pylint: disable=protected-access 208 op._add_control_input(self._pivot) 209 # pylint: enable=protected-access 210 else: 211 for index in xrange(len(op.inputs)): 212 x = op.inputs[index] 213 real_x = self.AddValue(x) 214 if real_x != x: 215 op._update_input(index, real_x) # pylint: disable=protected-access 216 217 if external_control_inputs: 218 # Use an identity to pull control inputs as data inputs. Note that we 219 # ignore ops which don't have outputs. TODO(phawkins): fix that. 220 with ops.control_dependencies(None): 221 self.Enter() 222 external_control_inputs = [ 223 array_ops.identity(x.outputs[0]).op 224 for x in external_control_inputs 225 if x.outputs 226 ] 227 self.Exit() 228 # pylint: disable=protected-access 229 op._add_control_inputs(external_control_inputs) 230 # pylint: enable=protected-access 231 232 # Mark op's outputs as seen by this context and any outer contexts. 233 output_names = [x.name for x in op.outputs] 234 context = self 235 while context is not None: 236 # pylint: disable=protected-access 237 context._values.update(output_names) 238 context = context._outer_context 239 # pylint: enable=protected-access 240 241 if self._outer_context: 242 self._outer_context.AddInnerOp(op) 243 244 def AddValue(self, val): 245 """Add `val` to the current context and its outer context recursively.""" 246 if val.name in self._values: 247 # Use the real value if it comes from outer context. 248 result = self._external_values.get(val.name) 249 return val if result is None else result 250 251 result = val 252 self._values.add(val.name) 253 if self._outer_context: 254 result = self._outer_context.AddValue(val) 255 self._values.add(result.name) 256 257 self._external_values[val.name] = result 258 259 return result 260 261 def AddInnerOp(self, op): 262 self.AddOp(op) 263 if self._outer_context: 264 self._outer_context.AddInnerOp(op) 265 266 @property 267 def grad_state(self): 268 # Define the gradient loop state associated with the XLACompileContext to 269 # be None as the XLACompileContext does not get nested nor does the 270 # grad_state outside the XLACompileContext affect the graph inside so the 271 # grad_state should be as if this is the top-level gradient state. 272 return None 273 274 @property 275 def back_prop(self): 276 """Forwards to the enclosing while context, if any.""" 277 if self.GetWhileContext(): 278 return self.GetWhileContext().back_prop 279 return False 280 281 282 def _compile_internal(computation, inputs=None): 283 """Builds graph operators that compiles and symbolically executes computation. 284 285 Args: 286 computation: A Python function that builds the computation to compile and 287 execute. 288 inputs: A list of inputs or `None` (equivalent to an empty list). Each input 289 can be a nested structure containing values that are convertible to 290 tensors. Note that passing an N-dimension list of compatible values will 291 result in a N-dimension list of scalar tensors rather than a single Rank-N 292 tensors. If you need different behavior, convert part of inputs to tensors 293 with `tf.convert_to_tensor`. 294 295 Returns: 296 Same data structure as if computation(*inputs) is called directly with some 297 exceptions for correctness. Exceptions include: 1) None output 2) Single 298 value output 3) Operation-only outputs 299 Raises: 300 ValueError: If any element in computation outputs is neither an operations 301 or a value that can be converted to tensor. 302 ValueError: If computation outputs is non-flat and contains any Operations. 303 TypeError: If `inputs` is not a list or tuple. 304 """ 305 if inputs is None: 306 inputs = [] 307 308 if not isinstance(inputs, collections.Sequence): 309 raise TypeError('inputs must be a list') 310 311 # Flatten inputs. 312 flat_inputs = nest.flatten(inputs) 313 # Converts inputs to Tensors. 314 flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs] 315 316 cluster_name = ops.get_default_graph().unique_name('cluster') 317 pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') 318 context = XLACompileContext(name=cluster_name, pivot=pivot) 319 try: 320 context.Enter() 321 322 # Add identity ops so even unused inputs are 'consumed' by the 323 # computation. 324 flat_inputs = [ 325 array_ops.identity(x, name='input_{}'.format(i)) 326 for i, x in enumerate(flat_inputs) 327 ] 328 329 # Re-pack flat_inputs in same structure as 'inputs'. 330 computation_inputs = nest.pack_sequence_as( 331 structure=inputs, flat_sequence=flat_inputs) 332 333 # Only resource variables work inside an XLA computation, so turn on 334 # resource variables for the computation. 335 vscope = variable_scope.get_variable_scope() 336 saved_use_resource = vscope.use_resource 337 vscope.set_use_resource(True) 338 339 with _disable_summary_context(): 340 outputs = computation(*computation_inputs) 341 342 # Restore variable scope after computation. 343 vscope.set_use_resource(saved_use_resource) 344 345 outputs_is_flat = is_flat(outputs) 346 if outputs_is_flat: 347 output_tensors, control_deps = _postprocess_flat_outputs(outputs) 348 else: 349 output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) 350 351 context.ExitResult(output_tensors) 352 finally: 353 context.report_unsupported_operations() 354 context.Exit() 355 356 # When XLA computation returns only operations and no tensors, a NoOp 357 # dependent on the operations in outputs is returned. Otherwise final 358 # outputs would be empty and there is no way to trigger returned 359 # operations. 360 if not output_tensors: 361 return control_flow_ops.group(control_deps, name='output_0') 362 363 output_tensors = [ 364 xla_ops.xla_cluster_output(o, name='output{}'.format(i)) 365 for i, o in enumerate(output_tensors) 366 ] 367 368 with ops.control_dependencies(control_deps): 369 # Wraps the outputs in identity operators that carries control 370 # dependencies. 371 output_tensors = [ 372 array_ops.identity(o, name='output_%d' % i) 373 for i, o in enumerate(output_tensors) 374 ] 375 376 # If `computation` returned non-flat output structure, pack output tensors 377 # back into same structure. 378 if not outputs_is_flat: 379 output_tensors = nest.pack_sequence_as( 380 structure=outputs, flat_sequence=output_tensors) 381 382 return output_tensors 383 384 385 def is_flat(outputs): 386 """Checks if outputs is a flat structure. 387 388 Following structures and values are considered flat: 389 1) None 390 2) A single object 391 3) A list or tuple of Tensors/Operations 392 393 The only structures that this function understands are sequences and 394 dictionaries. E.g. this means that if outputs contains a single 395 user-defined Object, it is considered to be flat. Errors are raised later on 396 if that Object cannot be converted to a Tensor. 397 398 Args: 399 outputs: Output from `computation` inside `xla.compile`. 400 401 Returns: 402 A boolean indicates whether outputs is flat. 403 """ 404 # If outputs is a list or tuple, check if it has any nested structure. If 405 # there is, then outputs is non-flat. 406 if isinstance(outputs, collections.Sequence): 407 for o in outputs: 408 if isinstance(o, collections.Sequence) or isinstance(o, dict): 409 return False 410 411 # If outputs is a dict, it is non-flat. 412 if isinstance(outputs, dict): 413 return False 414 415 # Getting here means either outputs itself is a single non-structured value 416 # or it is a flat list of single non-structured values. 417 return True 418 419 420 def _postprocess_flat_outputs(outputs): 421 """Validates flat outputs and adds back device assignments. 422 423 Args: 424 outputs: Output from `computation` inside `xla.compile`. 425 426 Returns: 427 Tensors and Operations extracted from outputs. 428 """ 429 # Following code segment is to preserve legacy behavior. Previously we only 430 # supported flat outputs and thus for consistency it was nice to convert even 431 # single element into a tuple. But now that we support arbitrary output 432 # structure, this is no longer necessary. 433 # TODO(b/121383831): Migrate all legacy use cases and delete this special 434 # case. 435 # If the computation returns `None`, make it an empty tuple. 436 if outputs is None: 437 outputs = tuple() 438 # If the computation only returned one value, make it a tuple. 439 if not isinstance(outputs, collections.Sequence): 440 outputs = (outputs,) 441 442 # Append `no_op` here so that return value of this function always contains 443 # at least one op that can trigger XlaLaunch node. 444 outputs += (control_flow_ops.no_op(),) 445 try: 446 outputs = [ 447 o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) 448 for o in outputs 449 ] 450 except Exception as e: 451 raise ValueError( 452 'XLA computation function return values must all either be Operations' 453 ' or convertible to Tensors. Got error: "%s"' % str(e)) 454 455 # Separates the returned Operations and Tensors. 456 output_operations = [o for o in outputs if isinstance(o, ops.Operation)] 457 output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] 458 459 if outputs != output_tensors + output_operations: 460 raise ValueError( 461 'XLA computation function must return zero or more Tensor values ' 462 'followed by zero or more Operations.') 463 464 new_output_tensors = [] 465 for t in output_tensors: 466 with ops.device(t.device if t.device else ''): 467 new_output_tensors.append(array_ops.identity(t)) 468 469 return new_output_tensors, output_operations 470 471 472 def _postprocess_non_flat_outputs(outputs): 473 """Validates non-flat outputs and adds back device assignments. 474 475 Args: 476 outputs: Output from `computation` inside `xla.compile`. 477 478 Returns: 479 Tensors extracted from outputs and an empty list because Operations are not 480 allowed in non-flat outputs.. 481 """ 482 # Convert all non-Operation outputs to Tensors. 483 new_output_tensors = [] 484 for o in nest.flatten(outputs): 485 if isinstance(o, ops.Operation): 486 raise ValueError( 487 'xla.compile does not support Operation as return value in non-flat ' 488 'output structure. You can set returned Operations as control ' 489 'dependencies of returned Tensors so Operations are triggered when ' 490 'Tensors are evaluated. Operation found: "%s"' % o.name) 491 492 try: 493 o = ops.convert_to_tensor(o) 494 except Exception as e: 495 raise ValueError( 496 'XLA computation function return values must all either be ' 497 'Operations or convertible to Tensors. Got error: "%s"' % str(e)) 498 499 # Makes sure even pass-through inputs/outputs are touched in compile 500 # context by creating an Identity node inside compile context. 501 with ops.device(o.device if o.device else ''): 502 new_output_tensors.append(array_ops.identity(o)) 503 504 return new_output_tensors, [] 505 506 507 @contextlib.contextmanager 508 def _disable_summary_context(): 509 """Enters a context where all summary ops are skipped. 510 511 Summaries are not yet supported in xla.compile(). So we provide this context 512 manager that can skip creating summary ops. This is a temporary workaround due 513 to XLA not supporting summary ops. 514 515 Yields: 516 None. 517 """ 518 original_skip_summary_func = summary_op_util.skip_summary 519 summary_op_util.skip_summary = lambda: True 520 521 try: 522 yield 523 finally: 524 summary_op_util.skip_summary = original_skip_summary_func 525 526 527 class _CapturedObject(object): 528 """A placeholder to capture an object.""" 529 530 def __init__(self): 531 self._object = None 532 533 def capture(self, o): 534 if self._object: 535 raise RuntimeError( 536 'InternalError: _CapturedObject can capture only once. Please file ' 537 'bug.') 538 539 self._object = o 540 541 def get(self): 542 return self._object 543 544 545 def _get_scaffold(captured_scaffold_fn): 546 """Retrieves the Scaffold from `captured_scaffold_fn`.""" 547 scaffold_fn = captured_scaffold_fn.get() 548 549 if not scaffold_fn: 550 return None 551 552 scaffold = scaffold_fn() 553 if scaffold is None: 554 raise ValueError( 555 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed') 556 557 return scaffold 558 559 560 class _ModelFnWrapper(object): 561 """_ModelFnWrapper supports executing model_fn with XLA.""" 562 563 def __init__(self, function): 564 self._model_fn = function 565 566 def __call__(self, features, labels, mode, params): 567 568 # TPUEstimator compiles model_fn when use_tpu=True. To avoid double 569 # compilation, we use this params['use_tpu'] as a hint. When it is set to 570 # True, model_fn is called without compilation. 571 # Note that this condition isn't accurate for the case of exporting a model. 572 # In that case we should ideally not compile so that user can see detailed 573 # graph. However, we don't have enough information to tell whether model_fn 574 # is being called for export mode or not. 575 # TODO(ycao): Make this condition more accurate when implementing PREDICT 576 # mode. 577 if params.get('use_tpu'): 578 return self._call_model_fn(features, labels, mode, params) 579 580 if mode == model_fn_lib.ModeKeys.TRAIN: 581 train_step, captured_scaffold_fn = self._make_train_step( 582 features, labels, params) 583 (loss,) = compile(train_step) 584 return model_fn_lib.EstimatorSpec( 585 mode=mode, 586 loss=loss, 587 train_op=array_ops.identity(loss), 588 scaffold=_get_scaffold(captured_scaffold_fn)) 589 elif mode == model_fn_lib.ModeKeys.EVAL: 590 eval_step, captured_eval_metric_fn, captured_scaffold_fn = ( 591 self._make_eval_step(features, labels, params)) 592 outputs = compile(eval_step) 593 loss = outputs[0] 594 595 # Calculate eval_metric_ops if eval_metric_fn is set and captured. 596 eval_metric_fn = captured_eval_metric_fn.get() 597 if eval_metric_fn: 598 eval_metric_fn_tensors = outputs[1:] 599 eval_metric_ops = eval_metric_fn(*eval_metric_fn_tensors) 600 else: 601 eval_metric_ops = None 602 603 return model_fn_lib.EstimatorSpec( 604 mode=mode, 605 loss=loss, 606 eval_metric_ops=eval_metric_ops, 607 scaffold=_get_scaffold(captured_scaffold_fn)) 608 else: 609 raise NotImplementedError('%s is not implemented, only TRAIN and EVAL are' 610 ' supported' % mode) 611 612 def _make_train_step(self, features, labels, params): 613 """Creates a single step of training for xla.compile().""" 614 captured_scaffold_fn = _CapturedObject() 615 616 def train_step(): 617 """A single step of training.""" 618 estimator_spec = self._call_model_fn(features, labels, 619 model_fn_lib.ModeKeys.TRAIN, params) 620 621 try: 622 captured_scaffold_fn.capture(estimator_spec.scaffold_fn) 623 except AttributeError: 624 captured_scaffold_fn.capture(None) 625 626 # train_step will be run by xla.compile(). xla.compile() only supports 627 # tensor output while train_op can be either an operation or a tensor. 628 # Even though xla.compile() automatically adds operation-typed train_op as 629 # control dependency of other tensor outputs, it doesn't do so for 630 # tensor-typed train_op. Thus, we need to set it explicitly here. 631 with ops.control_dependencies([estimator_spec.train_op]): 632 return array_ops.identity(estimator_spec.loss) 633 634 return train_step, captured_scaffold_fn 635 636 def _make_eval_step(self, features, labels, params): 637 """Creates a single step of evaluation for xla.compile().""" 638 captured_eval_metric_fn = _CapturedObject() 639 captured_scaffold_fn = _CapturedObject() 640 641 def eval_step(): 642 """A single step of evaluation.""" 643 estimator_spec = self._call_model_fn(features, labels, 644 model_fn_lib.ModeKeys.EVAL, params) 645 646 try: 647 captured_scaffold_fn.capture(estimator_spec.scaffold_fn) 648 except AttributeError: 649 captured_scaffold_fn.capture(None) 650 651 eval_metric_fn = None 652 eval_metric_fn_tensors = [] 653 try: 654 if estimator_spec.eval_metrics: 655 (eval_metric_fn, eval_metric_fn_tensors) = estimator_spec.eval_metrics 656 except AttributeError: 657 pass 658 659 # If a dictionary is provided, we need to convert it into a list sorted 660 # according to order of eval_metric_fn positional arguments. 661 if isinstance(eval_metric_fn_tensors, dict): 662 eval_metric_fn_args = function_utils.fn_args(eval_metric_fn) 663 eval_metric_fn_tensors = [ 664 eval_metric_fn_tensors[i] for i in eval_metric_fn_args 665 ] 666 667 captured_eval_metric_fn.capture(eval_metric_fn) 668 669 return tuple([estimator_spec.loss] + eval_metric_fn_tensors) 670 671 return eval_step, captured_eval_metric_fn, captured_scaffold_fn 672 673 def _call_model_fn(self, features, labels, mode, params): 674 """Calls the model_fn with required parameters.""" 675 model_fn_args = function_utils.fn_args(self._model_fn) 676 kwargs = {} 677 678 if 'labels' in model_fn_args: 679 kwargs['labels'] = labels 680 elif labels is not None: 681 raise ValueError( 682 'model_fn does not take labels, but input_fn returns labels.') 683 if 'mode' in model_fn_args: 684 kwargs['mode'] = mode 685 686 if 'params' in model_fn_args: 687 kwargs['params'] = params 688 689 return self._verify_estimator_spec( 690 self._model_fn(features=features, **kwargs)) 691 692 def _verify_estimator_spec(self, estimator_spec): 693 """Verifies estimator spec contains correct data.""" 694 # TODO(ycao): Implement estimator spec verification for other modes. 695 696 try: 697 if estimator_spec.scaffold: 698 logging.warning('EstimatorSpec.scaffold is ignored with XLA compilation' 699 '. Please use TPUEstimatorSpec.scaffold_fn instead.') 700 except AttributeError: 701 pass 702 703 try: 704 if estimator_spec.eval_metric_ops: 705 raise ValueError('EstimatorSpec.eval_metric_ops is not supported with ' 706 'XLA compilation. Please use ' 707 'TPUEstimatorSpec.eval_metrics instead.') 708 except AttributeError: 709 pass 710 711 if estimator_spec.mode == model_fn_lib.ModeKeys.EVAL: 712 # If estimator_spec is of type TPUEstimatorSpec and contains eval_metrics, 713 # check that eval_metrics contains eval_metric_fn and 714 # eval_metric_fn_tensors with matching arguments. 715 try: 716 eval_metrics = estimator_spec.eval_metrics 717 except AttributeError: 718 eval_metrics = None 719 720 if eval_metrics: 721 (eval_metric_fn, eval_metric_fn_tensors) = eval_metrics 722 eval_metric_fn_args = function_utils.fn_args(eval_metric_fn) 723 724 if isinstance(eval_metric_fn_tensors, dict): 725 missing_tensors = [ 726 i for i in eval_metric_fn_args if i not in eval_metric_fn_tensors 727 ] 728 additional_tensors = [ 729 i for i in eval_metric_fn_tensors if i not in eval_metric_fn_args 730 ] 731 732 if missing_tensors: 733 raise ValueError('Arguments %s are needed by metric_fn (first ' 734 'element of TPUEstimatorSpec.eval_metrics) but ' 735 'they are not provided by evaluation tensors ' 736 '(second element of TPUEstimatorSpec.eval_metrics)' 737 '.' % missing_tensors) 738 739 if additional_tensors: 740 raise ValueError('Arguments %s are provided by evaluation tensors ' 741 '(second element of TPUEstimatorSpec.eval_metrics)' 742 ' but they are not needed by metric_fn (first ' 743 'element of TPUEstimatorSpec.eval_metrics).' % 744 additional_tensors) 745 746 return estimator_spec 747 748 749 def estimator_model_fn(target_model_fn=None): 750 """estimator_model_fn decorates a model_fn to be compiled for execution. 751 752 Currently it only works with `TPUEstimator`. If you need to use it with base 753 `Estimator`, please add `tf.enable_resource_variables()` at the beginning of 754 your program. 755 756 Example 1, decorating model_fn: 757 ``` 758 @xla.estimator_model_fn() 759 def model_fn(features, labels, mode, params): 760 ... 761 return EstimatorSpec(...) 762 763 764 est = Estimator(model_fn=model_fn, ...) 765 est.train(...) 766 767 ``` 768 769 Example 2, decorator as function: 770 ``` 771 def model_fn(features, labels, mode, params): 772 ... 773 return EstimatorSpec(...) 774 775 est = Estimator(model_fn=xla.estimator_model_fn(model_fn), ...) 776 est.train(...) 777 ``` 778 779 Args: 780 target_model_fn: model_fn to be decorated. This is only needed when 781 decorator is used in function call form (example 2). 782 783 Returns: 784 Decorated target_model_fn. 785 """ 786 787 def decorated(function): 788 return tf_decorator.make_decorator(function, _ModelFnWrapper(function)) 789 790 return decorated(target_model_fn) if target_model_fn else decorated 791 792 793 def check_function_argument_count(func, input_arity, infeed_queue): 794 """Validate the number of input arguments to an XLA function. 795 796 Args: 797 func: the Python function that will be called to generate the body of an XLA 798 computation graph. 799 input_arity: the number of explicit arguments supplied by the caller. 800 infeed_queue: if not None, the infeed queue that will supply 801 additional arguments to the function. 802 803 Returns: 804 None if function can be called with the supplied number of 805 arguments, or an error string if it cannot. 806 """ 807 def format_error(complaint, quantity): 808 return '%s %d argument%s' % (complaint, quantity, '' 809 if quantity == 1 else 's') 810 811 num_args_supplied = input_arity 812 if infeed_queue is not None: 813 num_args_supplied += infeed_queue.number_of_tuple_elements 814 arg_spec = tf_inspect.getargspec(func) 815 num_func_args = len(arg_spec.args) 816 if arg_spec.defaults is None: 817 num_func_defaults = 0 818 else: 819 num_func_defaults = len(arg_spec.defaults) 820 min_func_args = num_func_args - num_func_defaults 821 if num_args_supplied < min_func_args: 822 # The required number of arguments is not enough to call the function. 823 if num_func_defaults == 0 and arg_spec.varargs is None: 824 return format_error('exactly', num_func_args) 825 else: 826 return format_error('at least', min_func_args) 827 if arg_spec.varargs is None and num_args_supplied > num_func_args: 828 # The required number of arguments is too many to call the function. 829 if num_func_defaults == 0: 830 return format_error('exactly', num_func_args) 831 else: 832 return format_error('at most', num_func_args) 833 # Reaching here means either 834 # 1) There are varargs, func can accept any number of arguments greater than 835 # the minimum. 836 # 2) Number of supplied arguments falls in range of acceptable argument count 837 # of func. 838 return None 839