Home | History | Annotate | Download | only in compiler
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # =============================================================================
     15 """xla is an experimental library that provides XLA support APIs."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     21 import collections
     22 import contextlib
     23 from six.moves import xrange  # pylint: disable=redefined-builtin
     25 from tensorflow.compiler.jit.ops import xla_ops
     26 from tensorflow.compiler.jit.ops import xla_ops_grad  # pylint: disable=unused-import
     27 from tensorflow.core.framework import attr_value_pb2
     28 from tensorflow.python.distribute import summary_op_util
     29 from tensorflow.python.estimator import model_fn as model_fn_lib
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.ops import array_ops
     32 from tensorflow.python.ops import control_flow_ops
     33 from tensorflow.python.ops import variable_scope
     34 from tensorflow.python.platform import tf_logging as logging
     35 from tensorflow.python.util import compat
     36 from tensorflow.python.util import function_utils
     37 from tensorflow.python.util import nest
     38 from tensorflow.python.util import tf_decorator
     39 from tensorflow.python.util import tf_inspect
     41 _XLA_COMPILE_ATTR = '_xla_compile_id'
     42 _MAX_WARNING_LINES = 5
     44 # Operations that indicate some error in the users graph. For example, XLA
     45 # computation should not have any Placeholder op.
     46 _BLACKLISTED_OPS = set([
     47     'Placeholder',
     48 ])
     50 # XLA doesn't currently support reading of intermediate tensors, thus some ops
     51 # are not supported.
     52 _UNSUPPORTED_OPS = set([
     53     'AudioSummary',
     54     'AudioSummaryV2',
     55     'HistogramSummary',
     56     'ImageSummary',
     57     'MergeSummary',
     58     'Print',
     59     'ScalarSummary',
     60     'TensorSummary',
     61     'TensorSummaryV2',
     62 ])
     65 def compile(computation, inputs=None):  # pylint: disable=redefined-builtin
     66   """Builds an operator that compiles and runs `computation` with XLA.
     68   Args:
     69     computation: A Python function that builds a computation to apply to the
     70       input. If the function takes n inputs, 'inputs' should be a list of n
     71       tensors.
     73       `computation` may return a list of operations and tensors.  Tensors must
     74       come before operations in the returned list.  The return value of
     75       `compile` is a list of tensors corresponding to the tensors from the
     76       output of `computation`.
     78       All `Operation`s returned from `computation` will be executed when
     79       evaluating any of the returned output tensors.
     80     inputs: A list of inputs or `None` (equivalent to an empty list). Each input
     81       can be a nested structure containing values that are convertible to
     82       tensors. Note that passing an N-dimension list of compatible values will
     83       result in a N-dimention list of scalar tensors rather than a single Rank-N
     84       tensors. If you need different behavior, convert part of inputs to tensors
     85       with `tf.convert_to_tensor`.
     87   Returns:
     88     Same data structure as if computation(*inputs) is called directly with some
     89     exceptions for correctness. Exceptions include:
     90       1) None output: a NoOp would be returned which control-depends on
     91          computation.
     92       2) Single value output: A tuple containing the value would be returned.
     93       3) Operation-only outputs: a NoOp would be returned which
     94          control-depends on computation.
     95       TODO(b/121383831): Investigate into removing these special cases.
     96   """
     97   # pylint: disable=protected-access
     98   return _compile_internal(computation, inputs)
    101 class XLACompileContext(control_flow_ops.XLAControlFlowContext):
    102   """A `ControlFlowContext` for nodes inside an XLA computation cluster.
    106   The primary role of `XLACompileContext` is to mark operators inside a
    107   xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is
    108   a unique name.
    110   `ControlFlowContext` is used to perform the annotation since it integrates
    111   with Tensorflow constructs like ResourceVariables. For example, if a
    112   `ResourceVariable` is constructed inside a xla.compile() block, the
    113   `ResourceVariable` implementation can use
    114   `with ops.control_dependencies(None)` to build the variable's definition
    115   outside the compiled computation.
    116   """
    118   def __init__(self, name, pivot):
    119     """Builds a new XLACompileContext.
    121     Args:
    122       name: a unique name for the context, used to populate the
    123         `_xla_compile_id` attribute.
    124       pivot: a pivot node. Nodes in the XLACompileContext that do not have any
    125         inputs will have a control dependency on the pivot node. This ensures
    126         that nodes are correctly included in any enclosing control flow
    127         contexts.
    128     """
    129     super(XLACompileContext, self).__init__()
    130     self._name = name
    131     self._name_as_bytes = compat.as_bytes(name)
    132     self._unsupported_ops = []
    133     self._pivot = pivot
    135   def report_unsupported_operations(self):
    136     if self._unsupported_ops:
    137       op_str = '\n'.join([
    138           '  %s (%s)' % (op.type, op.name)
    139           for op in self._unsupported_ops[:_MAX_WARNING_LINES]
    140       ])
    141       logging.warning('%d unsupported operations found: \n%s',
    142                       len(self._unsupported_ops), op_str)
    143       if len(self._unsupported_ops) > _MAX_WARNING_LINES:
    144         logging.warning('... and %d more',
    145                         len(self._unsupported_ops) - _MAX_WARNING_LINES)
    147   def _RemoveExternalControlEdges(self, op):
    148     """Remove any external control dependency on this op."""
    149     internal_control_inputs = []
    150     external_control_inputs = []
    151     for x in op.control_inputs:
    152       # pylint: disable=protected-access
    153       is_internal_op = False
    154       ctxt = x._get_control_flow_context()
    155       while ctxt is not None:
    156         if ctxt == self:
    157           is_internal_op = True
    158           break
    159         ctxt = ctxt._outer_context
    160       if is_internal_op:
    161         internal_control_inputs.append(x)
    162       else:
    163         external_control_inputs.append(x)
    164       # pylint: enable=protected-access
    165     # pylint: disable=protected-access
    166     op._remove_all_control_inputs()
    167     op._add_control_inputs(internal_control_inputs)
    168     # pylint: enable=protected-access
    169     return internal_control_inputs, external_control_inputs
    171   def AddOp(self, op):
    172     """Create op in XLACompileContext and notifies outer context recursively."""
    173     # pylint: disable=protected-access
    174     if op.type in _BLACKLISTED_OPS:
    175       logging.error(
    176           'Operation of type %s (%s) is not supported in XLA. Execution will '
    177           'fail if this op is used in the graph. ', op.type, op.name)
    179     # TODO(ycao): Automatically disable summaries instead of reporting them.
    180     if op.type in _UNSUPPORTED_OPS:
    181       self._unsupported_ops.append(op)
    183     if any(x.dtype._is_ref_dtype for x in op.inputs):
    184       raise NotImplementedError(
    185           'Non-resource Variables are not supported inside XLA computations '
    186           '(operator name: %s)' % op.name)
    188     if _XLA_COMPILE_ATTR in op.node_def.attr:
    189       raise ValueError('XLA compiled computations cannot be nested, (operator '
    190                        'name: %s)' % op.name)
    192     op._set_attr(
    193         _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes))
    195     op.graph.prevent_feeding(op)
    196     op.graph.prevent_fetching(op)
    198     # Remove any control edges from outer control flow contexts. These may cause
    199     # mismatched frame errors. An example is when one of op's inputs is
    200     # generated in a different While control flow context.
    201     (internal_control_inputs,
    202      external_control_inputs) = self._RemoveExternalControlEdges(op)
    204     if not op.inputs:
    205       # Add a control edge from the control pivot to this op.
    206       if not internal_control_inputs:
    207         # pylint: disable=protected-access
    208         op._add_control_input(self._pivot)
    209         # pylint: enable=protected-access
    210     else:
    211       for index in xrange(len(op.inputs)):
    212         x = op.inputs[index]
    213         real_x = self.AddValue(x)
    214         if real_x != x:
    215           op._update_input(index, real_x)  # pylint: disable=protected-access
    217     if external_control_inputs:
    218       # Use an identity to pull control inputs as data inputs. Note that we
    219       # ignore ops which don't have outputs. TODO(phawkins): fix that.
    220       with ops.control_dependencies(None):
    221         self.Enter()
    222         external_control_inputs = [
    223             array_ops.identity(x.outputs[0]).op
    224             for x in external_control_inputs
    225             if x.outputs
    226         ]
    227         self.Exit()
    228       # pylint: disable=protected-access
    229       op._add_control_inputs(external_control_inputs)
    230       # pylint: enable=protected-access
    232     # Mark op's outputs as seen by this context and any outer contexts.
    233     output_names = [x.name for x in op.outputs]
    234     context = self
    235     while context is not None:
    236       # pylint: disable=protected-access
    237       context._values.update(output_names)
    238       context = context._outer_context
    239       # pylint: enable=protected-access
    241     if self._outer_context:
    242       self._outer_context.AddInnerOp(op)
    244   def AddValue(self, val):
    245     """Add `val` to the current context and its outer context recursively."""
    246     if val.name in self._values:
    247       # Use the real value if it comes from outer context.
    248       result = self._external_values.get(val.name)
    249       return val if result is None else result
    251     result = val
    252     self._values.add(val.name)
    253     if self._outer_context:
    254       result = self._outer_context.AddValue(val)
    255       self._values.add(result.name)
    257     self._external_values[val.name] = result
    259     return result
    261   def AddInnerOp(self, op):
    262     self.AddOp(op)
    263     if self._outer_context:
    264       self._outer_context.AddInnerOp(op)
    266   @property
    267   def grad_state(self):
    268     # Define the gradient loop state associated with the XLACompileContext to
    269     # be None as the XLACompileContext does not get nested nor does the
    270     # grad_state outside the XLACompileContext affect the graph inside so the
    271     # grad_state should be as if this is the top-level gradient state.
    272     return None
    274   @property
    275   def back_prop(self):
    276     """Forwards to the enclosing while context, if any."""
    277     if self.GetWhileContext():
    278       return self.GetWhileContext().back_prop
    279     return False
    282 def _compile_internal(computation, inputs=None):
    283   """Builds graph operators that compiles and symbolically executes computation.
    285   Args:
    286     computation: A Python function that builds the computation to compile and
    287       execute.
    288     inputs: A list of inputs or `None` (equivalent to an empty list). Each input
    289       can be a nested structure containing values that are convertible to
    290       tensors. Note that passing an N-dimension list of compatible values will
    291       result in a N-dimension list of scalar tensors rather than a single Rank-N
    292       tensors. If you need different behavior, convert part of inputs to tensors
    293       with `tf.convert_to_tensor`.
    295   Returns:
    296     Same data structure as if computation(*inputs) is called directly with some
    297     exceptions for correctness. Exceptions include: 1) None output 2) Single
    298     value output 3) Operation-only outputs
    299   Raises:
    300     ValueError: If any element in computation outputs is neither an operations
    301       or a value that can be converted to tensor.
    302     ValueError: If computation outputs is non-flat and contains any Operations.
    303     TypeError: If `inputs` is not a list or tuple.
    304   """
    305   if inputs is None:
    306     inputs = []
    308   if not isinstance(inputs, collections.Sequence):
    309     raise TypeError('inputs must be a list')
    311   # Flatten inputs.
    312   flat_inputs = nest.flatten(inputs)
    313   # Converts inputs to Tensors.
    314   flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs]
    316   cluster_name = ops.get_default_graph().unique_name('cluster')
    317   pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
    318   context = XLACompileContext(name=cluster_name, pivot=pivot)
    319   try:
    320     context.Enter()
    322     # Add identity ops so even unused inputs are 'consumed' by the
    323     # computation.
    324     flat_inputs = [
    325         array_ops.identity(x, name='input_{}'.format(i))
    326         for i, x in enumerate(flat_inputs)
    327     ]
    329     # Re-pack flat_inputs in same structure as 'inputs'.
    330     computation_inputs = nest.pack_sequence_as(
    331         structure=inputs, flat_sequence=flat_inputs)
    333     # Only resource variables work inside an XLA computation, so turn on
    334     # resource variables for the computation.
    335     vscope = variable_scope.get_variable_scope()
    336     saved_use_resource = vscope.use_resource
    337     vscope.set_use_resource(True)
    339     with _disable_summary_context():
    340       outputs = computation(*computation_inputs)
    342     # Restore variable scope after computation.
    343     vscope.set_use_resource(saved_use_resource)
    345     outputs_is_flat = is_flat(outputs)
    346     if outputs_is_flat:
    347       output_tensors, control_deps = _postprocess_flat_outputs(outputs)
    348     else:
    349       output_tensors, control_deps = _postprocess_non_flat_outputs(outputs)
    351     context.ExitResult(output_tensors)
    352   finally:
    353     context.report_unsupported_operations()
    354     context.Exit()
    356   # When XLA computation returns only operations and no tensors, a NoOp
    357   # dependent on the operations in outputs is returned. Otherwise final
    358   # outputs would be empty and there is no way to trigger returned
    359   # operations.
    360   if not output_tensors:
    361     return control_flow_ops.group(control_deps, name='output_0')
    363   output_tensors = [
    364       xla_ops.xla_cluster_output(o, name='output{}'.format(i))
    365       for i, o in enumerate(output_tensors)
    366   ]
    368   with ops.control_dependencies(control_deps):
    369     # Wraps the outputs in identity operators that carries control
    370     # dependencies.
    371     output_tensors = [
    372         array_ops.identity(o, name='output_%d' % i)
    373         for i, o in enumerate(output_tensors)
    374     ]
    376   # If `computation` returned non-flat output structure, pack output tensors
    377   # back into same structure.
    378   if not outputs_is_flat:
    379     output_tensors = nest.pack_sequence_as(
    380         structure=outputs, flat_sequence=output_tensors)
    382   return output_tensors
    385 def is_flat(outputs):
    386   """Checks if outputs is a flat structure.
    388     Following structures and values are considered flat:
    389     1) None
    390     2) A single object
    391     3) A list or tuple of Tensors/Operations
    393     The only structures that this function understands are sequences and
    394     dictionaries.  E.g. this means that if outputs contains a single
    395     user-defined Object, it is considered to be flat. Errors are raised later on
    396     if that Object cannot be converted to a Tensor.
    398   Args:
    399     outputs: Output from `computation` inside `xla.compile`.
    401   Returns:
    402     A boolean indicates whether outputs is flat.
    403   """
    404   # If outputs is a list or tuple, check if it has any nested structure. If
    405   # there is, then outputs is non-flat.
    406   if isinstance(outputs, collections.Sequence):
    407     for o in outputs:
    408       if isinstance(o, collections.Sequence) or isinstance(o, dict):
    409         return False
    411   # If outputs is a dict, it is non-flat.
    412   if isinstance(outputs, dict):
    413     return False
    415   # Getting here means either outputs itself is a single non-structured value
    416   # or it is a flat list of single non-structured values.
    417   return True
    420 def _postprocess_flat_outputs(outputs):
    421   """Validates flat outputs and adds back device assignments.
    423   Args:
    424     outputs: Output from `computation` inside `xla.compile`.
    426   Returns:
    427     Tensors and Operations extracted from outputs.
    428   """
    429   # Following code segment is to preserve legacy behavior. Previously we only
    430   # supported flat outputs and thus for consistency it was nice to convert even
    431   # single element into a tuple. But now that we support arbitrary output
    432   # structure, this is no longer necessary.
    433   # TODO(b/121383831): Migrate all legacy use cases and delete this special
    434   # case.
    435   # If the computation returns `None`, make it an empty tuple.
    436   if outputs is None:
    437     outputs = tuple()
    438   # If the computation only returned one value, make it a tuple.
    439   if not isinstance(outputs, collections.Sequence):
    440     outputs = (outputs,)
    442   # Append `no_op` here so that return value of this function always contains
    443   # at least one op that can trigger XlaLaunch node.
    444   outputs += (control_flow_ops.no_op(),)
    445   try:
    446     outputs = [
    447         o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
    448         for o in outputs
    449     ]
    450   except Exception as e:
    451     raise ValueError(
    452         'XLA computation function return values must all either be Operations'
    453         ' or convertible to Tensors. Got error: "%s"' % str(e))
    455   # Separates the returned Operations and Tensors.
    456   output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
    457   output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
    459   if outputs != output_tensors + output_operations:
    460     raise ValueError(
    461         'XLA computation function must return zero or more Tensor values '
    462         'followed by zero or more Operations.')
    464   new_output_tensors = []
    465   for t in output_tensors:
    466     with ops.device(t.device if t.device else ''):
    467       new_output_tensors.append(array_ops.identity(t))
    469   return new_output_tensors, output_operations
    472 def _postprocess_non_flat_outputs(outputs):
    473   """Validates non-flat outputs and adds back device assignments.
    475   Args:
    476     outputs: Output from `computation` inside `xla.compile`.
    478   Returns:
    479     Tensors extracted from outputs and an empty list because Operations are not
    480     allowed in non-flat outputs..
    481   """
    482   # Convert all non-Operation outputs to Tensors.
    483   new_output_tensors = []
    484   for o in nest.flatten(outputs):
    485     if isinstance(o, ops.Operation):
    486       raise ValueError(
    487           'xla.compile does not support Operation as return value in non-flat '
    488           'output structure. You can set returned Operations as control '
    489           'dependencies of returned Tensors so Operations are triggered when '
    490           'Tensors are evaluated. Operation found: "%s"' % o.name)
    492     try:
    493       o = ops.convert_to_tensor(o)
    494     except Exception as e:
    495       raise ValueError(
    496           'XLA computation function return values must all either be '
    497           'Operations or convertible to Tensors. Got error: "%s"' % str(e))
    499     # Makes sure even pass-through inputs/outputs are touched in compile
    500     # context by creating an Identity node inside compile context.
    501     with ops.device(o.device if o.device else ''):
    502       new_output_tensors.append(array_ops.identity(o))
    504   return new_output_tensors, []
    507 @contextlib.contextmanager
    508 def _disable_summary_context():
    509   """Enters a context where all summary ops are skipped.
    511   Summaries are not yet supported in xla.compile(). So we provide this context
    512   manager that can skip creating summary ops. This is a temporary workaround due
    513   to XLA not supporting summary ops.
    515   Yields:
    516     None.
    517   """
    518   original_skip_summary_func = summary_op_util.skip_summary
    519   summary_op_util.skip_summary = lambda: True
    521   try:
    522     yield
    523   finally:
    524     summary_op_util.skip_summary = original_skip_summary_func
    527 class _CapturedObject(object):
    528   """A placeholder to capture an object."""
    530   def __init__(self):
    531     self._object = None
    533   def capture(self, o):
    534     if self._object:
    535       raise RuntimeError(
    536           'InternalError: _CapturedObject can capture only once. Please file '
    537           'bug.')
    539     self._object = o
    541   def get(self):
    542     return self._object
    545 def _get_scaffold(captured_scaffold_fn):
    546   """Retrieves the Scaffold from `captured_scaffold_fn`."""
    547   scaffold_fn = captured_scaffold_fn.get()
    549   if not scaffold_fn:
    550     return None
    552   scaffold = scaffold_fn()
    553   if scaffold is None:
    554     raise ValueError(
    555         'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
    557   return scaffold
    560 class _ModelFnWrapper(object):
    561   """_ModelFnWrapper supports executing model_fn with XLA."""
    563   def __init__(self, function):
    564     self._model_fn = function
    566   def __call__(self, features, labels, mode, params):
    568     # TPUEstimator compiles model_fn when use_tpu=True. To avoid double
    569     # compilation, we use this params['use_tpu'] as a hint. When it is set to
    570     # True, model_fn is called without compilation.
    571     # Note that this condition isn't accurate for the case of exporting a model.
    572     # In that case we should ideally not compile so that user can see detailed
    573     # graph. However, we don't have enough information to tell whether model_fn
    574     # is being called for export mode or not.
    575     # TODO(ycao): Make this condition more accurate when implementing PREDICT
    576     # mode.
    577     if params.get('use_tpu'):
    578       return self._call_model_fn(features, labels, mode, params)
    580     if mode == model_fn_lib.ModeKeys.TRAIN:
    581       train_step, captured_scaffold_fn = self._make_train_step(
    582           features, labels, params)
    583       (loss,) = compile(train_step)
    584       return model_fn_lib.EstimatorSpec(
    585           mode=mode,
    586           loss=loss,
    587           train_op=array_ops.identity(loss),
    588           scaffold=_get_scaffold(captured_scaffold_fn))
    589     elif mode == model_fn_lib.ModeKeys.EVAL:
    590       eval_step, captured_eval_metric_fn, captured_scaffold_fn = (
    591           self._make_eval_step(features, labels, params))
    592       outputs = compile(eval_step)
    593       loss = outputs[0]
    595       # Calculate eval_metric_ops if eval_metric_fn is set and captured.
    596       eval_metric_fn = captured_eval_metric_fn.get()
    597       if eval_metric_fn:
    598         eval_metric_fn_tensors = outputs[1:]
    599         eval_metric_ops = eval_metric_fn(*eval_metric_fn_tensors)
    600       else:
    601         eval_metric_ops = None
    603       return model_fn_lib.EstimatorSpec(
    604           mode=mode,
    605           loss=loss,
    606           eval_metric_ops=eval_metric_ops,
    607           scaffold=_get_scaffold(captured_scaffold_fn))
    608     else:
    609       raise NotImplementedError('%s is not implemented, only TRAIN and EVAL are'
    610                                 ' supported' % mode)
    612   def _make_train_step(self, features, labels, params):
    613     """Creates a single step of training for xla.compile()."""
    614     captured_scaffold_fn = _CapturedObject()
    616     def train_step():
    617       """A single step of training."""
    618       estimator_spec = self._call_model_fn(features, labels,
    619                                            model_fn_lib.ModeKeys.TRAIN, params)
    621       try:
    622         captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
    623       except AttributeError:
    624         captured_scaffold_fn.capture(None)
    626       # train_step will be run by xla.compile(). xla.compile() only supports
    627       # tensor output while train_op can be either an operation or a tensor.
    628       # Even though xla.compile() automatically adds operation-typed train_op as
    629       # control dependency of other tensor outputs, it doesn't do so for
    630       # tensor-typed train_op. Thus, we need to set it explicitly here.
    631       with ops.control_dependencies([estimator_spec.train_op]):
    632         return array_ops.identity(estimator_spec.loss)
    634     return train_step, captured_scaffold_fn
    636   def _make_eval_step(self, features, labels, params):
    637     """Creates a single step of evaluation for xla.compile()."""
    638     captured_eval_metric_fn = _CapturedObject()
    639     captured_scaffold_fn = _CapturedObject()
    641     def eval_step():
    642       """A single step of evaluation."""
    643       estimator_spec = self._call_model_fn(features, labels,
    644                                            model_fn_lib.ModeKeys.EVAL, params)
    646       try:
    647         captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
    648       except AttributeError:
    649         captured_scaffold_fn.capture(None)
    651       eval_metric_fn = None
    652       eval_metric_fn_tensors = []
    653       try:
    654         if estimator_spec.eval_metrics:
    655           (eval_metric_fn, eval_metric_fn_tensors) = estimator_spec.eval_metrics
    656       except AttributeError:
    657         pass
    659       # If a dictionary is provided, we need to convert it into a list sorted
    660       # according to order of eval_metric_fn positional arguments.
    661       if isinstance(eval_metric_fn_tensors, dict):
    662         eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
    663         eval_metric_fn_tensors = [
    664             eval_metric_fn_tensors[i] for i in eval_metric_fn_args
    665         ]
    667       captured_eval_metric_fn.capture(eval_metric_fn)
    669       return tuple([estimator_spec.loss] + eval_metric_fn_tensors)
    671     return eval_step, captured_eval_metric_fn, captured_scaffold_fn
    673   def _call_model_fn(self, features, labels, mode, params):
    674     """Calls the model_fn with required parameters."""
    675     model_fn_args = function_utils.fn_args(self._model_fn)
    676     kwargs = {}
    678     if 'labels' in model_fn_args:
    679       kwargs['labels'] = labels
    680     elif labels is not None:
    681       raise ValueError(
    682           'model_fn does not take labels, but input_fn returns labels.')
    683     if 'mode' in model_fn_args:
    684       kwargs['mode'] = mode
    686     if 'params' in model_fn_args:
    687       kwargs['params'] = params
    689     return self._verify_estimator_spec(
    690         self._model_fn(features=features, **kwargs))
    692   def _verify_estimator_spec(self, estimator_spec):
    693     """Verifies estimator spec contains correct data."""
    694     # TODO(ycao): Implement estimator spec verification for other modes.
    696     try:
    697       if estimator_spec.scaffold:
    698         logging.warning('EstimatorSpec.scaffold is ignored with XLA compilation'
    699                         '. Please use TPUEstimatorSpec.scaffold_fn instead.')
    700     except AttributeError:
    701       pass
    703     try:
    704       if estimator_spec.eval_metric_ops:
    705         raise ValueError('EstimatorSpec.eval_metric_ops is not supported with '
    706                          'XLA compilation. Please use '
    707                          'TPUEstimatorSpec.eval_metrics instead.')
    708     except AttributeError:
    709       pass
    711     if estimator_spec.mode == model_fn_lib.ModeKeys.EVAL:
    712       # If estimator_spec is of type TPUEstimatorSpec and contains eval_metrics,
    713       # check that eval_metrics contains eval_metric_fn and
    714       # eval_metric_fn_tensors with matching arguments.
    715       try:
    716         eval_metrics = estimator_spec.eval_metrics
    717       except AttributeError:
    718         eval_metrics = None
    720       if eval_metrics:
    721         (eval_metric_fn, eval_metric_fn_tensors) = eval_metrics
    722         eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
    724         if isinstance(eval_metric_fn_tensors, dict):
    725           missing_tensors = [
    726               i for i in eval_metric_fn_args if i not in eval_metric_fn_tensors
    727           ]
    728           additional_tensors = [
    729               i for i in eval_metric_fn_tensors if i not in eval_metric_fn_args
    730           ]
    732           if missing_tensors:
    733             raise ValueError('Arguments %s are needed by metric_fn (first '
    734                              'element of TPUEstimatorSpec.eval_metrics) but '
    735                              'they are not provided by evaluation tensors '
    736                              '(second element of TPUEstimatorSpec.eval_metrics)'
    737                              '.' % missing_tensors)
    739           if additional_tensors:
    740             raise ValueError('Arguments %s are provided by evaluation tensors '
    741                              '(second element of TPUEstimatorSpec.eval_metrics)'
    742                              ' but they are not needed by metric_fn (first '
    743                              'element of TPUEstimatorSpec.eval_metrics).' %
    744                              additional_tensors)
    746     return estimator_spec
    749 def estimator_model_fn(target_model_fn=None):
    750   """estimator_model_fn decorates a model_fn to be compiled for execution.
    752   Currently it only works with `TPUEstimator`. If you need to use it with base
    753   `Estimator`, please add `tf.enable_resource_variables()` at the beginning of
    754   your program.
    756   Example 1, decorating model_fn:
    757   ```
    758   @xla.estimator_model_fn()
    759   def model_fn(features, labels, mode, params):
    760     ...
    761     return EstimatorSpec(...)
    764   est = Estimator(model_fn=model_fn, ...)
    765   est.train(...)
    767   ```
    769   Example 2, decorator as function:
    770   ```
    771   def model_fn(features, labels, mode, params):
    772     ...
    773     return EstimatorSpec(...)
    775   est = Estimator(model_fn=xla.estimator_model_fn(model_fn), ...)
    776   est.train(...)
    777   ```
    779   Args:
    780     target_model_fn: model_fn to be decorated. This is only needed when
    781       decorator is used in function call form (example 2).
    783   Returns:
    784     Decorated target_model_fn.
    785   """
    787   def decorated(function):
    788     return tf_decorator.make_decorator(function, _ModelFnWrapper(function))
    790   return decorated(target_model_fn) if target_model_fn else decorated
    793 def check_function_argument_count(func, input_arity, infeed_queue):
    794   """Validate the number of input arguments to an XLA function.
    796   Args:
    797     func: the Python function that will be called to generate the body of an XLA
    798       computation graph.
    799     input_arity: the number of explicit arguments supplied by the caller.
    800     infeed_queue: if not None, the infeed queue that will supply
    801       additional arguments to the function.
    803   Returns:
    804     None if function can be called with the supplied number of
    805       arguments, or an error string if it cannot.
    806   """
    807   def format_error(complaint, quantity):
    808     return '%s %d argument%s' % (complaint, quantity, ''
    809                                  if quantity == 1 else 's')
    811   num_args_supplied = input_arity
    812   if infeed_queue is not None:
    813     num_args_supplied += infeed_queue.number_of_tuple_elements
    814   arg_spec = tf_inspect.getargspec(func)
    815   num_func_args = len(arg_spec.args)
    816   if arg_spec.defaults is None:
    817     num_func_defaults = 0
    818   else:
    819     num_func_defaults = len(arg_spec.defaults)
    820   min_func_args = num_func_args - num_func_defaults
    821   if num_args_supplied < min_func_args:
    822     # The required number of arguments is not enough to call the function.
    823     if num_func_defaults == 0 and arg_spec.varargs is None:
    824       return format_error('exactly', num_func_args)
    825     else:
    826       return format_error('at least', min_func_args)
    827   if arg_spec.varargs is None and num_args_supplied > num_func_args:
    828     # The required number of arguments is too many to call the function.
    829     if num_func_defaults == 0:
    830       return format_error('exactly', num_func_args)
    831     else:
    832       return format_error('at most', num_func_args)
    833   # Reaching here means either
    834   # 1) There are varargs, func can accept any number of arguments greater than
    835   # the minimum.
    836   # 2) Number of supplied arguments falls in range of acceptable argument count
    837   # of func.
    838   return None