Home | History | Annotate | Download | only in compiler
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # =============================================================================
     15 """xla is an experimental library that provides XLA support APIs."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import contextlib
     23 from six.moves import xrange  # pylint: disable=redefined-builtin
     24 
     25 from tensorflow.compiler.jit.ops import xla_ops
     26 from tensorflow.compiler.jit.ops import xla_ops_grad  # pylint: disable=unused-import
     27 from tensorflow.core.framework import attr_value_pb2
     28 from tensorflow.python.distribute import summary_op_util
     29 from tensorflow.python.estimator import model_fn as model_fn_lib
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.ops import array_ops
     32 from tensorflow.python.ops import control_flow_ops
     33 from tensorflow.python.ops import variable_scope
     34 from tensorflow.python.platform import tf_logging as logging
     35 from tensorflow.python.util import compat
     36 from tensorflow.python.util import function_utils
     37 from tensorflow.python.util import nest
     38 from tensorflow.python.util import tf_decorator
     39 from tensorflow.python.util import tf_inspect
     40 
     41 _XLA_COMPILE_ATTR = '_xla_compile_id'
     42 _MAX_WARNING_LINES = 5
     43 
     44 # Operations that indicate some error in the users graph. For example, XLA
     45 # computation should not have any Placeholder op.
     46 _BLACKLISTED_OPS = set([
     47     'Placeholder',
     48 ])
     49 
     50 # XLA doesn't currently support reading of intermediate tensors, thus some ops
     51 # are not supported.
     52 _UNSUPPORTED_OPS = set([
     53     'AudioSummary',
     54     'AudioSummaryV2',
     55     'HistogramSummary',
     56     'ImageSummary',
     57     'MergeSummary',
     58     'Print',
     59     'ScalarSummary',
     60     'TensorSummary',
     61     'TensorSummaryV2',
     62 ])
     63 
     64 
     65 def compile(computation, inputs=None):  # pylint: disable=redefined-builtin
     66   """Builds an operator that compiles and runs `computation` with XLA.
     67 
     68   Args:
     69     computation: A Python function that builds a computation to apply to the
     70       input. If the function takes n inputs, 'inputs' should be a list of n
     71       tensors.
     72 
     73       `computation` may return a list of operations and tensors.  Tensors must
     74       come before operations in the returned list.  The return value of
     75       `compile` is a list of tensors corresponding to the tensors from the
     76       output of `computation`.
     77 
     78       All `Operation`s returned from `computation` will be executed when
     79       evaluating any of the returned output tensors.
     80     inputs: A list of inputs or `None` (equivalent to an empty list). Each input
     81       can be a nested structure containing values that are convertible to
     82       tensors. Note that passing an N-dimension list of compatible values will
     83       result in a N-dimention list of scalar tensors rather than a single Rank-N
     84       tensors. If you need different behavior, convert part of inputs to tensors
     85       with `tf.convert_to_tensor`.
     86 
     87   Returns:
     88     Same data structure as if computation(*inputs) is called directly with some
     89     exceptions for correctness. Exceptions include:
     90       1) None output: a NoOp would be returned which control-depends on
     91          computation.
     92       2) Single value output: A tuple containing the value would be returned.
     93       3) Operation-only outputs: a NoOp would be returned which
     94          control-depends on computation.
     95       TODO(b/121383831): Investigate into removing these special cases.
     96   """
     97   # pylint: disable=protected-access
     98   return _compile_internal(computation, inputs)
     99 
    100 
    101 class XLACompileContext(control_flow_ops.XLAControlFlowContext):
    102   """A `ControlFlowContext` for nodes inside an XLA computation cluster.
    103 
    104   THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY.
    105 
    106   The primary role of `XLACompileContext` is to mark operators inside a
    107   xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is
    108   a unique name.
    109 
    110   `ControlFlowContext` is used to perform the annotation since it integrates
    111   with Tensorflow constructs like ResourceVariables. For example, if a
    112   `ResourceVariable` is constructed inside a xla.compile() block, the
    113   `ResourceVariable` implementation can use
    114   `with ops.control_dependencies(None)` to build the variable's definition
    115   outside the compiled computation.
    116   """
    117 
    118   def __init__(self, name, pivot):
    119     """Builds a new XLACompileContext.
    120 
    121     Args:
    122       name: a unique name for the context, used to populate the
    123         `_xla_compile_id` attribute.
    124       pivot: a pivot node. Nodes in the XLACompileContext that do not have any
    125         inputs will have a control dependency on the pivot node. This ensures
    126         that nodes are correctly included in any enclosing control flow
    127         contexts.
    128     """
    129     super(XLACompileContext, self).__init__()
    130     self._name = name
    131     self._name_as_bytes = compat.as_bytes(name)
    132     self._unsupported_ops = []
    133     self._pivot = pivot
    134 
    135   def report_unsupported_operations(self):
    136     if self._unsupported_ops:
    137       op_str = '\n'.join([
    138           '  %s (%s)' % (op.type, op.name)
    139           for op in self._unsupported_ops[:_MAX_WARNING_LINES]
    140       ])
    141       logging.warning('%d unsupported operations found: \n%s',
    142                       len(self._unsupported_ops), op_str)
    143       if len(self._unsupported_ops) > _MAX_WARNING_LINES:
    144         logging.warning('... and %d more',
    145                         len(self._unsupported_ops) - _MAX_WARNING_LINES)
    146 
    147   def _RemoveExternalControlEdges(self, op):
    148     """Remove any external control dependency on this op."""
    149     internal_control_inputs = []
    150     external_control_inputs = []
    151     for x in op.control_inputs:
    152       # pylint: disable=protected-access
    153       is_internal_op = False
    154       ctxt = x._get_control_flow_context()
    155       while ctxt is not None:
    156         if ctxt == self:
    157           is_internal_op = True
    158           break
    159         ctxt = ctxt._outer_context
    160       if is_internal_op:
    161         internal_control_inputs.append(x)
    162       else:
    163         external_control_inputs.append(x)
    164       # pylint: enable=protected-access
    165     # pylint: disable=protected-access
    166     op._remove_all_control_inputs()
    167     op._add_control_inputs(internal_control_inputs)
    168     # pylint: enable=protected-access
    169     return internal_control_inputs, external_control_inputs
    170 
    171   def AddOp(self, op):
    172     """Create op in XLACompileContext and notifies outer context recursively."""
    173     # pylint: disable=protected-access
    174     if op.type in _BLACKLISTED_OPS:
    175       logging.error(
    176           'Operation of type %s (%s) is not supported in XLA. Execution will '
    177           'fail if this op is used in the graph. ', op.type, op.name)
    178 
    179     # TODO(ycao): Automatically disable summaries instead of reporting them.
    180     if op.type in _UNSUPPORTED_OPS:
    181       self._unsupported_ops.append(op)
    182 
    183     if any(x.dtype._is_ref_dtype for x in op.inputs):
    184       raise NotImplementedError(
    185           'Non-resource Variables are not supported inside XLA computations '
    186           '(operator name: %s)' % op.name)
    187 
    188     if _XLA_COMPILE_ATTR in op.node_def.attr:
    189       raise ValueError('XLA compiled computations cannot be nested, (operator '
    190                        'name: %s)' % op.name)
    191 
    192     op._set_attr(
    193         _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes))
    194 
    195     op.graph.prevent_feeding(op)
    196     op.graph.prevent_fetching(op)
    197 
    198     # Remove any control edges from outer control flow contexts. These may cause
    199     # mismatched frame errors. An example is when one of op's inputs is
    200     # generated in a different While control flow context.
    201     (internal_control_inputs,
    202      external_control_inputs) = self._RemoveExternalControlEdges(op)
    203 
    204     if not op.inputs:
    205       # Add a control edge from the control pivot to this op.
    206       if not internal_control_inputs:
    207         # pylint: disable=protected-access
    208         op._add_control_input(self._pivot)
    209         # pylint: enable=protected-access
    210     else:
    211       for index in xrange(len(op.inputs)):
    212         x = op.inputs[index]
    213         real_x = self.AddValue(x)
    214         if real_x != x:
    215           op._update_input(index, real_x)  # pylint: disable=protected-access
    216 
    217     if external_control_inputs:
    218       # Use an identity to pull control inputs as data inputs. Note that we
    219       # ignore ops which don't have outputs. TODO(phawkins): fix that.
    220       with ops.control_dependencies(None):
    221         self.Enter()
    222         external_control_inputs = [
    223             array_ops.identity(x.outputs[0]).op
    224             for x in external_control_inputs
    225             if x.outputs
    226         ]
    227         self.Exit()
    228       # pylint: disable=protected-access
    229       op._add_control_inputs(external_control_inputs)
    230       # pylint: enable=protected-access
    231 
    232     # Mark op's outputs as seen by this context and any outer contexts.
    233     output_names = [x.name for x in op.outputs]
    234     context = self
    235     while context is not None:
    236       # pylint: disable=protected-access
    237       context._values.update(output_names)
    238       context = context._outer_context
    239       # pylint: enable=protected-access
    240 
    241     if self._outer_context:
    242       self._outer_context.AddInnerOp(op)
    243 
    244   def AddValue(self, val):
    245     """Add `val` to the current context and its outer context recursively."""
    246     if val.name in self._values:
    247       # Use the real value if it comes from outer context.
    248       result = self._external_values.get(val.name)
    249       return val if result is None else result
    250 
    251     result = val
    252     self._values.add(val.name)
    253     if self._outer_context:
    254       result = self._outer_context.AddValue(val)
    255       self._values.add(result.name)
    256 
    257     self._external_values[val.name] = result
    258 
    259     return result
    260 
    261   def AddInnerOp(self, op):
    262     self.AddOp(op)
    263     if self._outer_context:
    264       self._outer_context.AddInnerOp(op)
    265 
    266   @property
    267   def grad_state(self):
    268     # Define the gradient loop state associated with the XLACompileContext to
    269     # be None as the XLACompileContext does not get nested nor does the
    270     # grad_state outside the XLACompileContext affect the graph inside so the
    271     # grad_state should be as if this is the top-level gradient state.
    272     return None
    273 
    274   @property
    275   def back_prop(self):
    276     """Forwards to the enclosing while context, if any."""
    277     if self.GetWhileContext():
    278       return self.GetWhileContext().back_prop
    279     return False
    280 
    281 
    282 def _compile_internal(computation, inputs=None):
    283   """Builds graph operators that compiles and symbolically executes computation.
    284 
    285   Args:
    286     computation: A Python function that builds the computation to compile and
    287       execute.
    288     inputs: A list of inputs or `None` (equivalent to an empty list). Each input
    289       can be a nested structure containing values that are convertible to
    290       tensors. Note that passing an N-dimension list of compatible values will
    291       result in a N-dimension list of scalar tensors rather than a single Rank-N
    292       tensors. If you need different behavior, convert part of inputs to tensors
    293       with `tf.convert_to_tensor`.
    294 
    295   Returns:
    296     Same data structure as if computation(*inputs) is called directly with some
    297     exceptions for correctness. Exceptions include: 1) None output 2) Single
    298     value output 3) Operation-only outputs
    299   Raises:
    300     ValueError: If any element in computation outputs is neither an operations
    301       or a value that can be converted to tensor.
    302     ValueError: If computation outputs is non-flat and contains any Operations.
    303     TypeError: If `inputs` is not a list or tuple.
    304   """
    305   if inputs is None:
    306     inputs = []
    307 
    308   if not isinstance(inputs, collections.Sequence):
    309     raise TypeError('inputs must be a list')
    310 
    311   # Flatten inputs.
    312   flat_inputs = nest.flatten(inputs)
    313   # Converts inputs to Tensors.
    314   flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs]
    315 
    316   cluster_name = ops.get_default_graph().unique_name('cluster')
    317   pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
    318   context = XLACompileContext(name=cluster_name, pivot=pivot)
    319   try:
    320     context.Enter()
    321 
    322     # Add identity ops so even unused inputs are 'consumed' by the
    323     # computation.
    324     flat_inputs = [
    325         array_ops.identity(x, name='input_{}'.format(i))
    326         for i, x in enumerate(flat_inputs)
    327     ]
    328 
    329     # Re-pack flat_inputs in same structure as 'inputs'.
    330     computation_inputs = nest.pack_sequence_as(
    331         structure=inputs, flat_sequence=flat_inputs)
    332 
    333     # Only resource variables work inside an XLA computation, so turn on
    334     # resource variables for the computation.
    335     vscope = variable_scope.get_variable_scope()
    336     saved_use_resource = vscope.use_resource
    337     vscope.set_use_resource(True)
    338 
    339     with _disable_summary_context():
    340       outputs = computation(*computation_inputs)
    341 
    342     # Restore variable scope after computation.
    343     vscope.set_use_resource(saved_use_resource)
    344 
    345     outputs_is_flat = is_flat(outputs)
    346     if outputs_is_flat:
    347       output_tensors, control_deps = _postprocess_flat_outputs(outputs)
    348     else:
    349       output_tensors, control_deps = _postprocess_non_flat_outputs(outputs)
    350 
    351     context.ExitResult(output_tensors)
    352   finally:
    353     context.report_unsupported_operations()
    354     context.Exit()
    355 
    356   # When XLA computation returns only operations and no tensors, a NoOp
    357   # dependent on the operations in outputs is returned. Otherwise final
    358   # outputs would be empty and there is no way to trigger returned
    359   # operations.
    360   if not output_tensors:
    361     return control_flow_ops.group(control_deps, name='output_0')
    362 
    363   output_tensors = [
    364       xla_ops.xla_cluster_output(o, name='output{}'.format(i))
    365       for i, o in enumerate(output_tensors)
    366   ]
    367 
    368   with ops.control_dependencies(control_deps):
    369     # Wraps the outputs in identity operators that carries control
    370     # dependencies.
    371     output_tensors = [
    372         array_ops.identity(o, name='output_%d' % i)
    373         for i, o in enumerate(output_tensors)
    374     ]
    375 
    376   # If `computation` returned non-flat output structure, pack output tensors
    377   # back into same structure.
    378   if not outputs_is_flat:
    379     output_tensors = nest.pack_sequence_as(
    380         structure=outputs, flat_sequence=output_tensors)
    381 
    382   return output_tensors
    383 
    384 
    385 def is_flat(outputs):
    386   """Checks if outputs is a flat structure.
    387 
    388     Following structures and values are considered flat:
    389     1) None
    390     2) A single object
    391     3) A list or tuple of Tensors/Operations
    392 
    393     The only structures that this function understands are sequences and
    394     dictionaries.  E.g. this means that if outputs contains a single
    395     user-defined Object, it is considered to be flat. Errors are raised later on
    396     if that Object cannot be converted to a Tensor.
    397 
    398   Args:
    399     outputs: Output from `computation` inside `xla.compile`.
    400 
    401   Returns:
    402     A boolean indicates whether outputs is flat.
    403   """
    404   # If outputs is a list or tuple, check if it has any nested structure. If
    405   # there is, then outputs is non-flat.
    406   if isinstance(outputs, collections.Sequence):
    407     for o in outputs:
    408       if isinstance(o, collections.Sequence) or isinstance(o, dict):
    409         return False
    410 
    411   # If outputs is a dict, it is non-flat.
    412   if isinstance(outputs, dict):
    413     return False
    414 
    415   # Getting here means either outputs itself is a single non-structured value
    416   # or it is a flat list of single non-structured values.
    417   return True
    418 
    419 
    420 def _postprocess_flat_outputs(outputs):
    421   """Validates flat outputs and adds back device assignments.
    422 
    423   Args:
    424     outputs: Output from `computation` inside `xla.compile`.
    425 
    426   Returns:
    427     Tensors and Operations extracted from outputs.
    428   """
    429   # Following code segment is to preserve legacy behavior. Previously we only
    430   # supported flat outputs and thus for consistency it was nice to convert even
    431   # single element into a tuple. But now that we support arbitrary output
    432   # structure, this is no longer necessary.
    433   # TODO(b/121383831): Migrate all legacy use cases and delete this special
    434   # case.
    435   # If the computation returns `None`, make it an empty tuple.
    436   if outputs is None:
    437     outputs = tuple()
    438   # If the computation only returned one value, make it a tuple.
    439   if not isinstance(outputs, collections.Sequence):
    440     outputs = (outputs,)
    441 
    442   # Append `no_op` here so that return value of this function always contains
    443   # at least one op that can trigger XlaLaunch node.
    444   outputs += (control_flow_ops.no_op(),)
    445   try:
    446     outputs = [
    447         o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
    448         for o in outputs
    449     ]
    450   except Exception as e:
    451     raise ValueError(
    452         'XLA computation function return values must all either be Operations'
    453         ' or convertible to Tensors. Got error: "%s"' % str(e))
    454 
    455   # Separates the returned Operations and Tensors.
    456   output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
    457   output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
    458 
    459   if outputs != output_tensors + output_operations:
    460     raise ValueError(
    461         'XLA computation function must return zero or more Tensor values '
    462         'followed by zero or more Operations.')
    463 
    464   new_output_tensors = []
    465   for t in output_tensors:
    466     with ops.device(t.device if t.device else ''):
    467       new_output_tensors.append(array_ops.identity(t))
    468 
    469   return new_output_tensors, output_operations
    470 
    471 
    472 def _postprocess_non_flat_outputs(outputs):
    473   """Validates non-flat outputs and adds back device assignments.
    474 
    475   Args:
    476     outputs: Output from `computation` inside `xla.compile`.
    477 
    478   Returns:
    479     Tensors extracted from outputs and an empty list because Operations are not
    480     allowed in non-flat outputs..
    481   """
    482   # Convert all non-Operation outputs to Tensors.
    483   new_output_tensors = []
    484   for o in nest.flatten(outputs):
    485     if isinstance(o, ops.Operation):
    486       raise ValueError(
    487           'xla.compile does not support Operation as return value in non-flat '
    488           'output structure. You can set returned Operations as control '
    489           'dependencies of returned Tensors so Operations are triggered when '
    490           'Tensors are evaluated. Operation found: "%s"' % o.name)
    491 
    492     try:
    493       o = ops.convert_to_tensor(o)
    494     except Exception as e:
    495       raise ValueError(
    496           'XLA computation function return values must all either be '
    497           'Operations or convertible to Tensors. Got error: "%s"' % str(e))
    498 
    499     # Makes sure even pass-through inputs/outputs are touched in compile
    500     # context by creating an Identity node inside compile context.
    501     with ops.device(o.device if o.device else ''):
    502       new_output_tensors.append(array_ops.identity(o))
    503 
    504   return new_output_tensors, []
    505 
    506 
    507 @contextlib.contextmanager
    508 def _disable_summary_context():
    509   """Enters a context where all summary ops are skipped.
    510 
    511   Summaries are not yet supported in xla.compile(). So we provide this context
    512   manager that can skip creating summary ops. This is a temporary workaround due
    513   to XLA not supporting summary ops.
    514 
    515   Yields:
    516     None.
    517   """
    518   original_skip_summary_func = summary_op_util.skip_summary
    519   summary_op_util.skip_summary = lambda: True
    520 
    521   try:
    522     yield
    523   finally:
    524     summary_op_util.skip_summary = original_skip_summary_func
    525 
    526 
    527 class _CapturedObject(object):
    528   """A placeholder to capture an object."""
    529 
    530   def __init__(self):
    531     self._object = None
    532 
    533   def capture(self, o):
    534     if self._object:
    535       raise RuntimeError(
    536           'InternalError: _CapturedObject can capture only once. Please file '
    537           'bug.')
    538 
    539     self._object = o
    540 
    541   def get(self):
    542     return self._object
    543 
    544 
    545 def _get_scaffold(captured_scaffold_fn):
    546   """Retrieves the Scaffold from `captured_scaffold_fn`."""
    547   scaffold_fn = captured_scaffold_fn.get()
    548 
    549   if not scaffold_fn:
    550     return None
    551 
    552   scaffold = scaffold_fn()
    553   if scaffold is None:
    554     raise ValueError(
    555         'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
    556 
    557   return scaffold
    558 
    559 
    560 class _ModelFnWrapper(object):
    561   """_ModelFnWrapper supports executing model_fn with XLA."""
    562 
    563   def __init__(self, function):
    564     self._model_fn = function
    565 
    566   def __call__(self, features, labels, mode, params):
    567 
    568     # TPUEstimator compiles model_fn when use_tpu=True. To avoid double
    569     # compilation, we use this params['use_tpu'] as a hint. When it is set to
    570     # True, model_fn is called without compilation.
    571     # Note that this condition isn't accurate for the case of exporting a model.
    572     # In that case we should ideally not compile so that user can see detailed
    573     # graph. However, we don't have enough information to tell whether model_fn
    574     # is being called for export mode or not.
    575     # TODO(ycao): Make this condition more accurate when implementing PREDICT
    576     # mode.
    577     if params.get('use_tpu'):
    578       return self._call_model_fn(features, labels, mode, params)
    579 
    580     if mode == model_fn_lib.ModeKeys.TRAIN:
    581       train_step, captured_scaffold_fn = self._make_train_step(
    582           features, labels, params)
    583       (loss,) = compile(train_step)
    584       return model_fn_lib.EstimatorSpec(
    585           mode=mode,
    586           loss=loss,
    587           train_op=array_ops.identity(loss),
    588           scaffold=_get_scaffold(captured_scaffold_fn))
    589     elif mode == model_fn_lib.ModeKeys.EVAL:
    590       eval_step, captured_eval_metric_fn, captured_scaffold_fn = (
    591           self._make_eval_step(features, labels, params))
    592       outputs = compile(eval_step)
    593       loss = outputs[0]
    594 
    595       # Calculate eval_metric_ops if eval_metric_fn is set and captured.
    596       eval_metric_fn = captured_eval_metric_fn.get()
    597       if eval_metric_fn:
    598         eval_metric_fn_tensors = outputs[1:]
    599         eval_metric_ops = eval_metric_fn(*eval_metric_fn_tensors)
    600       else:
    601         eval_metric_ops = None
    602 
    603       return model_fn_lib.EstimatorSpec(
    604           mode=mode,
    605           loss=loss,
    606           eval_metric_ops=eval_metric_ops,
    607           scaffold=_get_scaffold(captured_scaffold_fn))
    608     else:
    609       raise NotImplementedError('%s is not implemented, only TRAIN and EVAL are'
    610                                 ' supported' % mode)
    611 
    612   def _make_train_step(self, features, labels, params):
    613     """Creates a single step of training for xla.compile()."""
    614     captured_scaffold_fn = _CapturedObject()
    615 
    616     def train_step():
    617       """A single step of training."""
    618       estimator_spec = self._call_model_fn(features, labels,
    619                                            model_fn_lib.ModeKeys.TRAIN, params)
    620 
    621       try:
    622         captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
    623       except AttributeError:
    624         captured_scaffold_fn.capture(None)
    625 
    626       # train_step will be run by xla.compile(). xla.compile() only supports
    627       # tensor output while train_op can be either an operation or a tensor.
    628       # Even though xla.compile() automatically adds operation-typed train_op as
    629       # control dependency of other tensor outputs, it doesn't do so for
    630       # tensor-typed train_op. Thus, we need to set it explicitly here.
    631       with ops.control_dependencies([estimator_spec.train_op]):
    632         return array_ops.identity(estimator_spec.loss)
    633 
    634     return train_step, captured_scaffold_fn
    635 
    636   def _make_eval_step(self, features, labels, params):
    637     """Creates a single step of evaluation for xla.compile()."""
    638     captured_eval_metric_fn = _CapturedObject()
    639     captured_scaffold_fn = _CapturedObject()
    640 
    641     def eval_step():
    642       """A single step of evaluation."""
    643       estimator_spec = self._call_model_fn(features, labels,
    644                                            model_fn_lib.ModeKeys.EVAL, params)
    645 
    646       try:
    647         captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
    648       except AttributeError:
    649         captured_scaffold_fn.capture(None)
    650 
    651       eval_metric_fn = None
    652       eval_metric_fn_tensors = []
    653       try:
    654         if estimator_spec.eval_metrics:
    655           (eval_metric_fn, eval_metric_fn_tensors) = estimator_spec.eval_metrics
    656       except AttributeError:
    657         pass
    658 
    659       # If a dictionary is provided, we need to convert it into a list sorted
    660       # according to order of eval_metric_fn positional arguments.
    661       if isinstance(eval_metric_fn_tensors, dict):
    662         eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
    663         eval_metric_fn_tensors = [
    664             eval_metric_fn_tensors[i] for i in eval_metric_fn_args
    665         ]
    666 
    667       captured_eval_metric_fn.capture(eval_metric_fn)
    668 
    669       return tuple([estimator_spec.loss] + eval_metric_fn_tensors)
    670 
    671     return eval_step, captured_eval_metric_fn, captured_scaffold_fn
    672 
    673   def _call_model_fn(self, features, labels, mode, params):
    674     """Calls the model_fn with required parameters."""
    675     model_fn_args = function_utils.fn_args(self._model_fn)
    676     kwargs = {}
    677 
    678     if 'labels' in model_fn_args:
    679       kwargs['labels'] = labels
    680     elif labels is not None:
    681       raise ValueError(
    682           'model_fn does not take labels, but input_fn returns labels.')
    683     if 'mode' in model_fn_args:
    684       kwargs['mode'] = mode
    685 
    686     if 'params' in model_fn_args:
    687       kwargs['params'] = params
    688 
    689     return self._verify_estimator_spec(
    690         self._model_fn(features=features, **kwargs))
    691 
    692   def _verify_estimator_spec(self, estimator_spec):
    693     """Verifies estimator spec contains correct data."""
    694     # TODO(ycao): Implement estimator spec verification for other modes.
    695 
    696     try:
    697       if estimator_spec.scaffold:
    698         logging.warning('EstimatorSpec.scaffold is ignored with XLA compilation'
    699                         '. Please use TPUEstimatorSpec.scaffold_fn instead.')
    700     except AttributeError:
    701       pass
    702 
    703     try:
    704       if estimator_spec.eval_metric_ops:
    705         raise ValueError('EstimatorSpec.eval_metric_ops is not supported with '
    706                          'XLA compilation. Please use '
    707                          'TPUEstimatorSpec.eval_metrics instead.')
    708     except AttributeError:
    709       pass
    710 
    711     if estimator_spec.mode == model_fn_lib.ModeKeys.EVAL:
    712       # If estimator_spec is of type TPUEstimatorSpec and contains eval_metrics,
    713       # check that eval_metrics contains eval_metric_fn and
    714       # eval_metric_fn_tensors with matching arguments.
    715       try:
    716         eval_metrics = estimator_spec.eval_metrics
    717       except AttributeError:
    718         eval_metrics = None
    719 
    720       if eval_metrics:
    721         (eval_metric_fn, eval_metric_fn_tensors) = eval_metrics
    722         eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
    723 
    724         if isinstance(eval_metric_fn_tensors, dict):
    725           missing_tensors = [
    726               i for i in eval_metric_fn_args if i not in eval_metric_fn_tensors
    727           ]
    728           additional_tensors = [
    729               i for i in eval_metric_fn_tensors if i not in eval_metric_fn_args
    730           ]
    731 
    732           if missing_tensors:
    733             raise ValueError('Arguments %s are needed by metric_fn (first '
    734                              'element of TPUEstimatorSpec.eval_metrics) but '
    735                              'they are not provided by evaluation tensors '
    736                              '(second element of TPUEstimatorSpec.eval_metrics)'
    737                              '.' % missing_tensors)
    738 
    739           if additional_tensors:
    740             raise ValueError('Arguments %s are provided by evaluation tensors '
    741                              '(second element of TPUEstimatorSpec.eval_metrics)'
    742                              ' but they are not needed by metric_fn (first '
    743                              'element of TPUEstimatorSpec.eval_metrics).' %
    744                              additional_tensors)
    745 
    746     return estimator_spec
    747 
    748 
    749 def estimator_model_fn(target_model_fn=None):
    750   """estimator_model_fn decorates a model_fn to be compiled for execution.
    751 
    752   Currently it only works with `TPUEstimator`. If you need to use it with base
    753   `Estimator`, please add `tf.enable_resource_variables()` at the beginning of
    754   your program.
    755 
    756   Example 1, decorating model_fn:
    757   ```
    758   @xla.estimator_model_fn()
    759   def model_fn(features, labels, mode, params):
    760     ...
    761     return EstimatorSpec(...)
    762 
    763 
    764   est = Estimator(model_fn=model_fn, ...)
    765   est.train(...)
    766 
    767   ```
    768 
    769   Example 2, decorator as function:
    770   ```
    771   def model_fn(features, labels, mode, params):
    772     ...
    773     return EstimatorSpec(...)
    774 
    775   est = Estimator(model_fn=xla.estimator_model_fn(model_fn), ...)
    776   est.train(...)
    777   ```
    778 
    779   Args:
    780     target_model_fn: model_fn to be decorated. This is only needed when
    781       decorator is used in function call form (example 2).
    782 
    783   Returns:
    784     Decorated target_model_fn.
    785   """
    786 
    787   def decorated(function):
    788     return tf_decorator.make_decorator(function, _ModelFnWrapper(function))
    789 
    790   return decorated(target_model_fn) if target_model_fn else decorated
    791 
    792 
    793 def check_function_argument_count(func, input_arity, infeed_queue):
    794   """Validate the number of input arguments to an XLA function.
    795 
    796   Args:
    797     func: the Python function that will be called to generate the body of an XLA
    798       computation graph.
    799     input_arity: the number of explicit arguments supplied by the caller.
    800     infeed_queue: if not None, the infeed queue that will supply
    801       additional arguments to the function.
    802 
    803   Returns:
    804     None if function can be called with the supplied number of
    805       arguments, or an error string if it cannot.
    806   """
    807   def format_error(complaint, quantity):
    808     return '%s %d argument%s' % (complaint, quantity, ''
    809                                  if quantity == 1 else 's')
    810 
    811   num_args_supplied = input_arity
    812   if infeed_queue is not None:
    813     num_args_supplied += infeed_queue.number_of_tuple_elements
    814   arg_spec = tf_inspect.getargspec(func)
    815   num_func_args = len(arg_spec.args)
    816   if arg_spec.defaults is None:
    817     num_func_defaults = 0
    818   else:
    819     num_func_defaults = len(arg_spec.defaults)
    820   min_func_args = num_func_args - num_func_defaults
    821   if num_args_supplied < min_func_args:
    822     # The required number of arguments is not enough to call the function.
    823     if num_func_defaults == 0 and arg_spec.varargs is None:
    824       return format_error('exactly', num_func_args)
    825     else:
    826       return format_error('at least', min_func_args)
    827   if arg_spec.varargs is None and num_args_supplied > num_func_args:
    828     # The required number of arguments is too many to call the function.
    829     if num_func_defaults == 0:
    830       return format_error('exactly', num_func_args)
    831     else:
    832       return format_error('at most', num_func_args)
    833   # Reaching here means either
    834   # 1) There are varargs, func can accept any number of arguments greater than
    835   # the minimum.
    836   # 2) Number of supplied arguments falls in range of acceptable argument count
    837   # of func.
    838   return None
    839