Home | History | Annotate | Download | only in eager
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 # pylint: disable=unidiomatic-typecheck
     16 """Defun decorator for defining graph-mode functions."""
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import collections
     23 import contextlib
     24 import threading
     25 
     26 import numpy as np
     27 
     28 from tensorflow.core.framework import function_pb2
     29 from tensorflow.python import pywrap_tensorflow
     30 from tensorflow.python.eager import context
     31 from tensorflow.python.eager import execute
     32 from tensorflow.python.eager import tape
     33 from tensorflow.python.eager.graph_only_ops import graph_placeholder
     34 from tensorflow.python.framework import c_api_util
     35 from tensorflow.python.framework import constant_op
     36 from tensorflow.python.framework import dtypes as dtypes_module
     37 from tensorflow.python.framework import errors
     38 from tensorflow.python.framework import ops
     39 from tensorflow.python.ops import control_flow_ops
     40 from tensorflow.python.ops import gradients_impl
     41 from tensorflow.python.util import compat
     42 from tensorflow.python.util import nest
     43 from tensorflow.python.util import tf_decorator
     44 
     45 # Thread-local storage for tfe Tensors which are referenced while evaluating a
     46 # graph-mode function.
     47 _scoped_captures = threading.local()
     48 # _scoped_captures.tensors is either None or a map from Tensor id to a pair
     49 # of a tfe tensor and its corresponding placeholder to pass as a function
     50 # argument. The value should be None unless we're in function definition
     51 # context.
     52 _scoped_captures.tensors = None
     53 
     54 
     55 @contextlib.contextmanager
     56 def capture_tensors(captures):
     57   old = _scoped_captures.__dict__.get("tensors", None)
     58   try:
     59     _scoped_captures.tensors = captures
     60     yield
     61   finally:
     62     _scoped_captures.tensors = old
     63 
     64 
     65 def capture_value(tensor_map, value, dtype, name):
     66   """Capture a value from outside the function, to pass in as an extra arg."""
     67   captured_value = tensor_map.get(ops.tensor_id(value), None)
     68   if captured_value is None:
     69     captured_value = graph_placeholder(
     70         dtype=dtype or value.dtype, shape=value.shape, name=name)
     71     if captured_value.dtype == dtypes_module.resource:
     72       handle_data = value._handle_data  # pylint: disable=protected-access
     73       captured_value._handle_data = handle_data  # pylint: disable=protected-access
     74       if handle_data is not None and handle_data.is_set:
     75         # Ensure that shapes and dtypes are propagated.
     76         shapes, types = zip(*[(pair.shape, pair.dtype)
     77                               for pair in handle_data.shape_and_type])
     78         ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
     79         shapes = [[d.size for d in s.dim]
     80                   if not s.unknown_rank else None for s in shapes]
     81         with errors.raise_exception_on_not_ok_status() as status:
     82           pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
     83               captured_value._op._graph._c_graph,  # pylint: disable=protected-access
     84               captured_value._as_tf_output(),  # pylint: disable=protected-access
     85               shapes,
     86               ranks,
     87               types,
     88               status)
     89 
     90     tensor_map[ops.tensor_id(value)] = (value, captured_value)
     91   else:
     92     captured_value = captured_value[1]
     93   tape.record_operation("captured_value", [captured_value], [value],
     94                         lambda x: [x])
     95   return captured_value
     96 
     97 
     98 def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
     99   """Captures a Tensor while building a graph mode function.
    100 
    101   Arguments:
    102     value: A Tensor object.
    103     dtype: The datatype of the value produced by the node in the graph.
    104     name:  str, Name of the node in the graph.
    105     as_ref: Ignored (required by register_tensor_conversion_function).
    106 
    107   Returns:
    108     Returns a constant (the current value of the tensor) if capturing
    109     is not enabled. A placeholder which will have the value of the
    110     tensor at runtime otherwise.
    111   """
    112   del as_ref  # Unused.
    113 
    114   if context.in_eager_mode():
    115     return value
    116 
    117   default_graph = ops.get_default_graph()
    118   if not default_graph.building_function:
    119     return value
    120 
    121   tensor_map = _scoped_captures.tensors
    122   if tensor_map is None:
    123     # Capturing is not enabled.
    124     if value.dtype == dtypes_module.resource:
    125       return value
    126     return constant_op.constant(value.numpy())
    127   if type(value) == ops.Tensor and value.graph is default_graph:
    128     # The tensor has already been converted and captured. The type check
    129     # is intentional: we are checking that value is a Tensor and not an
    130     # EagerTensor.
    131     return value
    132   return capture_value(tensor_map, value, dtype, name)
    133 
    134 
    135 class CapturingGraph(ops.Graph):
    136   """Graph used when constructing eager functions."""
    137 
    138   def __init__(self, captures):
    139     super(CapturingGraph, self).__init__()
    140     self._building_function = True
    141     self.captures = captures
    142     # Map from resource tensor name to last op (in program order) which uses
    143     # this tensor. Used to enforce that execution order matches program order
    144     # for resource tensors.
    145     self._last_op_using_resource_tensor = {}
    146 
    147   # TODO(apassos) remove once the C API is used by default.
    148   def _use_c_api_hack(self):
    149     return True
    150 
    151   def clear_resource_control_flow_state(self):
    152     self._last_op_using_resource_tensor = {}
    153 
    154   def create_op(
    155       self,
    156       op_type,
    157       inputs,
    158       dtypes,  # pylint: disable=redefined-outer-name
    159       input_types=None,
    160       name=None,
    161       attrs=None,
    162       op_def=None,
    163       compute_shapes=True,
    164       compute_device=True):
    165     # TODO(apassos) probably control flow has to be handled delicately here as
    166     # in if a resource is accessed inside a control flow context we need the
    167     # control dependency to point to something outside the context which is
    168     # guaranteed to happen after the access.
    169     #
    170     # TODO(apassos) this should do some form of alias analysis as ops which
    171     # forward the resources such as Identity and Switch can cause serialization
    172     # to fail.
    173     resource_inputs = set()
    174     control_inputs = set()
    175     for i, inp in enumerate(inputs):
    176       if inp.graph is not self:
    177         inputs[i] = capture_value(self.captures, inp, inp.dtype, inp.op.name)
    178       inp = inputs[i]
    179       if inp.dtype == dtypes_module.resource:
    180         if inp.name in self._last_op_using_resource_tensor:
    181           control_inputs.add(self._last_op_using_resource_tensor[inp.name])
    182         resource_inputs.add(inp.name)
    183     with self.control_dependencies(list(control_inputs)):
    184       op = super(CapturingGraph, self).create_op(
    185           op_type, inputs, dtypes, input_types, name, attrs, op_def,
    186           compute_shapes, compute_device)
    187     for name in resource_inputs:
    188       self._last_op_using_resource_tensor[name] = op
    189     return op
    190 
    191 
    192 # TODO(apassos): it'd be really nice if we could scope this registration.
    193 # Note that we register this at a higher priority than ops.Tensor since we want
    194 # to handle subclass specific conversion before a superclass conversion.
    195 ops.register_tensor_conversion_function(
    196     ops.EagerTensor, _convert_to_graph_tensor, priority=-1)
    197 
    198 
    199 class _CapturingContext(object):
    200   """Tracks references to Tensors outside this context while it is active."""
    201 
    202   def __init__(self):
    203     # known_ops are ops which are created while this context is active
    204     self.known_ops = set()
    205 
    206     # captured_tensors are all tensors referenced to by ops in this context but
    207     # not produced in it
    208     self.captured_tensors = set()
    209 
    210   def AddOp(self, op):  # pylint: disable=invalid-name
    211     if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
    212       raise ValueError("tfe.defun cannot capture variables created without "
    213                        "using tf.get_variable. Op: %s" % op)
    214     self.known_ops.add(op)
    215     for i in op.inputs:
    216       if i.op not in self.known_ops:
    217         self.captured_tensors.add(i)
    218 
    219   def __enter__(self):
    220     self._g = ops.get_default_graph()
    221     self._old = self._g._get_control_flow_context()  # pylint: disable=protected-access
    222     self._g._set_control_flow_context(self)  # pylint: disable=protected-access
    223 
    224   def __exit__(self, _, __, ___):  # pylint: disable=invalid-name
    225     self._g._set_control_flow_context(self._old)  # pylint: disable=protected-access
    226 
    227 
    228 def _forward_name(n):
    229   """The name of a generated forward defun named n."""
    230   return "__forward_%s_%s" % (n, ops.uid())
    231 
    232 
    233 def _backward_name(n):
    234   """The name of a generated backward defun named n."""
    235   return "__backward_%s_%s" % (n, ops.uid())
    236 
    237 
    238 def _inference_name(n):
    239   """The name of a forward-but-no-gradient defun named n."""
    240   return "__inference_%s_%s" % (n, ops.uid())
    241 
    242 
    243 # TODO(apassos) get rid of this by splitting framework.function._DefinedFunction
    244 # so it doesn't have the definition-generating logic and is just a container for
    245 # an already-defined function.
    246 class _EagerDefinedFunction(object):
    247   """Function object with the interface of tf _DefinedFunction."""
    248 
    249   def __init__(self, name, graph, operations, inputs, outputs):
    250     """Initializes an eager defined function.
    251 
    252     Args:
    253       name: str, the name for the created function.
    254       graph: Graph, the graph containing the operations in the function
    255       operations: list of Operation; the subset of operations in the graph
    256         which will be in the function
    257       inputs: the tensors in the graph to be used as inputs to the function
    258       outputs: the tensors in the graph which will be outputs to the function
    259     """
    260     with errors.raise_exception_on_not_ok_status() as status:
    261       fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
    262           graph._c_graph,  # pylint: disable=protected-access
    263           compat.as_str(name),
    264           False,
    265           [o._c_op for o in operations],  # pylint: disable=protected-access
    266           [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
    267           [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
    268           [],
    269           None,
    270           compat.as_str(""),
    271           status)
    272     # TODO(apassos) avoid creating a FunctionDef (specially to grab the
    273     # signature, but also in general it's nice not to depend on it.
    274     with c_api_util.tf_buffer() as buffer_:
    275       with errors.raise_exception_on_not_ok_status() as status:
    276         pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
    277       proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
    278     function_def = function_pb2.FunctionDef()
    279     function_def.ParseFromString(compat.as_bytes(proto_data))
    280     if context.in_eager_mode():
    281       _register(fn)
    282     self.definition = function_def
    283     self.name = function_def.signature.name
    284     self.signature = function_def.signature
    285     self.grad_func_name = None
    286     self.python_grad_func = None
    287     self._c_func = fn
    288     self._grad_func = None
    289 
    290 
    291 def _map_sequence_obj_to_idx(sequence):
    292   """Maps objs in the sequence from id(obj) to sequence index."""
    293   return {id(x): i for i, x in enumerate(sequence)}
    294 
    295 
    296 def _flatten(sequence):
    297   """A wrapper around `nest.flatten` that also unpacks `IndexedSlices`."""
    298   # TODO(akshayka): Support `SparseTensor` in a similar fashion.
    299   flat_sequence = nest.flatten(sequence)
    300   outputs = []
    301   for item in flat_sequence:
    302     if isinstance(item, ops.IndexedSlices):
    303       if item.dense_shape is not None:
    304         outputs.extend([item.values, item.indices, item.dense_shape])
    305       else:
    306         outputs.extend([item.values, item.indices])
    307     else:
    308       outputs.append(item)
    309   return outputs
    310 
    311 
    312 class GraphModeFunction(object):
    313   """Callable object representing a graph-mode function.
    314 
    315   Args:
    316     name: str the name of the created function
    317     input_placeholders: list of placeholder values (tensors) to feed when
    318       calling the wrapped function.
    319     extra_inputs: Tensor inputs this function definition closed over which
    320       are passed as arguments. Need to track so gradients are supported
    321       correctly.
    322     graph: the Graph from which the operations will be pulled. Used as
    323       a context when computing gradients.
    324     operations: the subset of Operations in the graph used in the function
    325       definition.
    326     outputs: a flat list of the Tensors in the graph used as outputs to the
    327       function
    328     func_outputs: a possibly nested python object which will be returned by
    329       this function. The Tensors in this structure will be replaced by their
    330       corresponding values in outputs.
    331     output_shapes: List of shapes of all tensors in outputs
    332     variables: (optional) List of variables to watch during function execution.
    333   """
    334 
    335   def __init__(self,
    336                name,
    337                input_placeholders,
    338                extra_inputs,
    339                graph,
    340                operations,
    341                outputs,
    342                func_outputs,
    343                output_shapes,
    344                variables=None):
    345     defined_function = _EagerDefinedFunction(
    346         name, graph, operations, input_placeholders, outputs)
    347     if len(input_placeholders) != len(defined_function.signature.input_arg):
    348       raise ValueError("Internal error: invalid lengths. %s %s" % (
    349           len(input_placeholders), len(defined_function.signature.input_arg)))
    350     self._input_placeholders = input_placeholders
    351     self._extra_inputs = list(extra_inputs)
    352     self._graph = graph
    353     self._backward_function = None
    354     self._func_name = name
    355     self._function_def = defined_function
    356     self._num_outputs = len(defined_function.signature.output_arg)
    357     self._ops = operations
    358     self._func_outputs = func_outputs
    359     self._returns = [func_outputs] if isinstance(
    360         func_outputs, (ops.Tensor, type(None))) else _flatten(func_outputs)
    361     self._output_shapes = output_shapes
    362     self._variables = variables if variables is not None else []
    363 
    364   @property
    365   def variables(self):
    366     return self._variables
    367 
    368   def _construct_backprop_function(self):
    369     """Constructs the backprop function object for this function."""
    370     with self._graph.as_default(), context.graph_mode():
    371       c = _CapturingContext()
    372       with c:
    373         filtered_outputs = [x for x in self._returns if x is not None]
    374         self._out_grad_placeholders = [
    375             graph_placeholder(x.dtype, x.shape) for x in filtered_outputs]
    376         in_gradients = gradients_impl.gradients(
    377             filtered_outputs,
    378             self._input_placeholders,
    379             grad_ys=self._out_grad_placeholders)
    380 
    381     backward_outputs = tuple(
    382         grad for grad in _flatten(in_gradients) if grad is not None)
    383     output_shapes = tuple(grad.shape for grad in backward_outputs)
    384 
    385     captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
    386     forward_name = _forward_name(self._func_name)
    387     self._forward_fdef = _EagerDefinedFunction(
    388         forward_name, self._graph, self._ops, self._input_placeholders,
    389         filtered_outputs + captures)
    390     all_inputs = self._out_grad_placeholders + captures
    391     # Excluding input ops from the body as we do not intend to execute these
    392     # operations when the function is executed.
    393     all_ignored_ops = frozenset(x.op for x in all_inputs)
    394     # Enforce a deterministic order of operations in the generated graph. This
    395     # means rerunning the function-defining code will always define the same
    396     # function, which is useful if we serialize this etc.
    397     function_def_ops = tuple(x
    398                              for x in sorted(c.known_ops, key=lambda x: x.name)
    399                              if x not in all_ignored_ops)
    400     bname = _backward_name(self._func_name)
    401     self._backward_function = GraphModeFunction(
    402         bname, all_inputs, [], self._graph, function_def_ops,
    403         backward_outputs, in_gradients, output_shapes)
    404 
    405   def _backprop_call(self, args):
    406     """Calls the wrapped function and records the result on a tape."""
    407     all_args = args + self._extra_inputs
    408     signature = self._forward_fdef.signature
    409     ctx = context.context()
    410     if ctx.in_graph_mode():
    411       g = ops.get_default_graph()
    412       g._add_function(self._forward_fdef)  # pylint: disable=protected-access
    413       op = g.create_op(
    414           signature.name,
    415           [ops.internal_convert_to_tensor(x, ctx=ctx) for x in all_args],
    416           tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
    417           op_def=signature,
    418           name="FunctionCall",
    419           compute_shapes=False)
    420       outputs = op.outputs
    421       outputs = [outputs] if isinstance(
    422           outputs, (ops.Tensor, type(None))) else list(outputs)
    423       for i, s in enumerate(self._output_shapes):
    424         outputs[i].set_shape(s)
    425     else:
    426       outputs = execute.execute(
    427           str(signature.name),
    428           num_outputs=len(signature.output_arg),
    429           inputs=all_args,
    430           attrs=None,
    431           ctx=ctx)
    432     real_outputs = outputs[:len(self._returns)]
    433     side_outputs = outputs[len(self._returns):]
    434 
    435     def backward_function(*args):
    436       return self._backward_function(*(list(args) + side_outputs))  # pylint: disable=not-callable
    437 
    438     tape.record_operation(
    439         signature.name,
    440         real_outputs,
    441         (args + self._extra_inputs),
    442         backward_function)
    443 
    444     return self._build_call_outputs(real_outputs)
    445 
    446   @property
    447   def output_shapes(self):
    448     """The function's output shapes."""
    449     # TODO(ebrevdo): Should we only keep the output shapes associated
    450     # with len(self._returns) outputs?
    451     outputs_list = nest.flatten(self._func_outputs)
    452     j = 0
    453     for i, o in enumerate(outputs_list):
    454       if o is not None:
    455         if isinstance(o, ops.IndexedSlices):
    456           # Extract the shape of the `IndexedSlices` object's `values` field.
    457           outputs_list[i] = self._output_shapes[j]  # the `values` shape
    458           if o.dense_shape is not None:
    459             j += 3  # skip over shapes for `values`, `indices`, `dense_shape`
    460           else:
    461             j += 2  # skip over shapes for `values`, `indices`
    462         else:
    463           outputs_list[i] = self._output_shapes[j]
    464           j += 1
    465     return nest.pack_sequence_as(self._func_outputs, outputs_list)
    466 
    467   @property
    468   def output_dtypes(self):
    469     return nest.map_structure(
    470         lambda x: x.dtype if x is not None else None, self._func_outputs)
    471 
    472   @property
    473   def captured_inputs(self):
    474     return self._extra_inputs
    475 
    476   @property
    477   def name(self):
    478     """Returns the name of the function in Eager-compatible format."""
    479     return self._function_def.name.encode("utf-8")
    480 
    481   def add_to_graph(self, g):
    482     if self._function_def.name not in g._functions:  # pylint: disable=protected-access
    483       g._add_function(self._function_def)  # pylint: disable=protected-access
    484     for f in self._graph._functions.values():  # pylint: disable=protected-access
    485       if f.name not in g._functions:  # pylint: disable=protected-access
    486         g._add_function(f)  # pylint: disable=protected-access
    487 
    488   def __call__(self, *args):
    489     """Executes the passed function in eager mode."""
    490     for v in self._variables:
    491       if v._trainable:  # pylint: disable=protected-access
    492         tape.watch_variable(v)
    493 
    494     tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
    495     if tape.should_record(tensor_inputs) or tape.should_record(
    496         self._extra_inputs):
    497       if self._backward_function is None:
    498         self._construct_backprop_function()
    499       return self._backprop_call(tensor_inputs)
    500 
    501     ctx = context.context()
    502     if ctx.in_graph_mode():
    503       g = ops.get_default_graph()
    504       self.add_to_graph(g)
    505       signature = self._function_def.definition.signature
    506       args = list(tensor_inputs) + self._extra_inputs
    507       op = g.create_op(
    508           signature.name,
    509           [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args],
    510           tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
    511           op_def=signature,
    512           name="FunctionCall",
    513           compute_shapes=False)
    514       result = op.outputs
    515       if not result:
    516         return op
    517       for i, s in enumerate(self._output_shapes):
    518         result[i].set_shape(s)
    519     else:
    520       result = execute.execute(
    521           str(self._func_name),
    522           num_outputs=self._num_outputs,
    523           inputs=tensor_inputs + self._extra_inputs,
    524           attrs=None,
    525           ctx=ctx)
    526 
    527     return self._build_call_outputs(result)
    528 
    529   def _build_call_outputs(self, result):
    530     """Maps the fdef output list to actual output structure.
    531 
    532     Args:
    533       result: Output lists defined by FunctionDef.
    534     Returns:
    535       The actual call output.
    536     """
    537     if self._func_outputs is None:
    538       return None
    539     # Use `nest.flatten` instead of `_flatten` in order to preserve any
    540     # IndexedSlices in `self._func_outputs`.
    541     outputs_list = nest.flatten(self._func_outputs)
    542     j = 0
    543     for i, o in enumerate(outputs_list):
    544       if o is not None:
    545         if isinstance(o, ops.IndexedSlices):
    546           # Repack Tensors for IndexedSlices.
    547           if o.dense_shape is not None:
    548             outputs_list[i] = ops.IndexedSlices(
    549                 values=result[j],
    550                 indices=result[j + 1],
    551                 dense_shape=result[j + 2])
    552             j += 3
    553           else:
    554             outputs_list[i] = ops.IndexedSlices(
    555                 values=result[j],
    556                 indices=result[j + 1])
    557             j += 2
    558         else:
    559           outputs_list[i] = result[j]
    560           j += 1
    561     ret = nest.pack_sequence_as(self._func_outputs, outputs_list)
    562     return ret
    563 
    564 
    565 def _get_defun_inputs(args):
    566   """Maps the inputs args to graph inputs."""
    567   ret = []
    568   flat_args = nest.flatten(args)
    569   for a in flat_args:
    570     if isinstance(a, ops.Tensor):
    571       ret.append(graph_placeholder(a.dtype, a.shape))
    572     else:
    573       ret.append(a)
    574   return nest.pack_sequence_as(args, ret)
    575 
    576 
    577 def _defun_internal(name, func, args, kwds):
    578   """Defines and returns graph-mode version of func."""
    579   graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
    580   with context.graph_mode():
    581     captures = {}
    582     tmp_graph = CapturingGraph(captures)
    583     # Inherit the graph key, since this is used for matching variables in
    584     # optimizers.
    585     tmp_graph._graph_key = graph_key  # pylint: disable=protected-access
    586     # Copy the graph collections to ensure summaries and other things work. This
    587     # lets the function access (but not mutate) collections of the containing
    588     # graph, such as the global step and the summary writer collections.
    589     curr_graph = ops.get_default_graph()
    590     for collection in curr_graph.collections:
    591       tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
    592           collection)
    593     with tmp_graph.as_default():
    594       func_inputs = _get_defun_inputs(args)
    595 
    596       def convert(x):
    597         if x is None:
    598           return None
    599         return ops.convert_to_tensor_or_indexed_slices(x)
    600 
    601       with capture_tensors(captures):
    602         this_tape = tape.push_new_tape()
    603         try:
    604           func_outputs = func(*func_inputs, **kwds)
    605           func_outputs = nest.map_structure(convert, func_outputs)
    606         finally:
    607           tape.pop_tape(this_tape)
    608         variables = this_tape.watched_variables()
    609 
    610         # Returning a closed-over tensor as an output does not trigger a
    611         # call to convert_to_tensor, so we manually capture all such tensors.
    612         outputs_list = _flatten(func_outputs)
    613         func_def_outputs = [
    614             _convert_to_graph_tensor(x) for x in outputs_list if x is not None
    615         ]
    616 
    617       ids = list(sorted(captures.keys()))
    618       if ids:
    619         extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
    620       else:
    621         extra_inputs = []
    622         extra_placeholders = []
    623       output_shapes = tuple(
    624           x.shape if isinstance(x, ops.Tensor) else None
    625           for x in outputs_list)
    626 
    627   flat_inputs = [x for x in nest.flatten(func_inputs)
    628                  if isinstance(x, ops.Tensor)]
    629   all_inputs = flat_inputs + list(extra_placeholders)
    630   all_ignored_ops = frozenset(x.op for x in all_inputs)
    631   fname = _inference_name(name)
    632   operations = tuple(x for x in tmp_graph.get_operations()
    633                      if x not in all_ignored_ops)
    634   # Register any other functions defined in the graph
    635   # TODO(ashankar): Oh lord, forgive me for this lint travesty.
    636   if context.in_eager_mode():
    637     for f in tmp_graph._functions.values():  # pylint: disable=protected-access
    638       # TODO(ashankar): What about the gradient registry?
    639       _register(f._c_func)  # pylint: disable=protected-access
    640   return GraphModeFunction(
    641       fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs,
    642       func_outputs, output_shapes, variables)
    643 
    644 
    645 # Defun uses this instead of Tensor as a cache key. Using dtype because
    646 # TensorFlow graphs are not parametric wrt dtypes, and using shapes for
    647 # performance reasons, as much TensorFlow code specializes on known shapes to
    648 # produce slimmer graphs.
    649 _TensorDtype = collections.namedtuple("_TensorDtype", ["dtype", "shape"])
    650 _ZeroDtype = collections.namedtuple("_ZeroDtype", ["dtype", "shape"])
    651 
    652 
    653 def _cache_key(x):
    654   """Cache key for tfe functions."""
    655   if isinstance(x, ops.Tensor):
    656     return _TensorDtype(x.dtype, x._shape_tuple())  # pylint: disable=protected-access
    657   if isinstance(x, ops.IndexedSlices):
    658     if x.dense_shape is not None:
    659       return tuple([
    660           _TensorDtype(x.values.dtype, x.values._shape_tuple()),  # pylint: disable=protected-access
    661           _TensorDtype(x.indices.dtype, x.indices._shape_tuple()),  # pylint: disable=protected-access
    662           _TensorDtype(x.dense_shape.dtype, x.dense_shape._shape_tuple())  # pylint: disable=protected-access
    663       ])
    664     else:
    665       return tuple([
    666           _TensorDtype(x.values.dtype, x.values._shape_tuple()),  # pylint: disable=protected-access
    667           _TensorDtype(x.indices.dtype, x.indices._shape_tuple())  # pylint: disable=protected-access
    668       ])
    669   if isinstance(x, np.ndarray):
    670     return ("array", x.shape, tuple(x.reshape(-1)))
    671   if isinstance(x, (list, tuple)):
    672     return tuple([_cache_key(a) for a in x])
    673   if isinstance(x, dict):
    674     return tuple(tuple([_cache_key(k), _cache_key(v)]) for k, v in x.items())
    675   return x
    676 
    677 
    678 def _register(fn):
    679   """Registers the function `fn`."""
    680   context.context().add_function(fn)
    681 
    682 
    683 # TODO(apassos): better error messages for non-hashable arguments.
    684 def named_defun(func, name):
    685   """Defines a function with a given name.
    686 
    687   See the documentation for `defun` for more information on the semantics of the
    688   function.
    689 
    690   Args:
    691     func: the function to be wrapped.
    692     name: the name given to it.
    693 
    694   Returns:
    695     the wrapped function.
    696   """
    697   arguments_to_functions = {}
    698 
    699   def decorated(*args, **kwds):
    700     """Decorated version of func."""
    701     # Macroexpand on non-Tensor arguments
    702     cache_key = tuple(_cache_key(x) for x in args)
    703     if any(isinstance(x, ops.EagerTensor) for x in kwds.values()):
    704       raise ValueError("Tensor keyword arguments are not supported.")
    705     cache_key = (cache_key, tuple(kwds.items()))
    706 
    707     if cache_key not in arguments_to_functions:
    708       arguments_to_functions[cache_key] = _defun_internal(
    709           name, func, args, kwds)
    710     return arguments_to_functions[cache_key](*args)
    711 
    712   return decorated
    713 
    714 
    715 def defun(func):
    716   """Decorator to compile func into graph_mode.
    717 
    718   `defun` converts a function that constructs a TensorFlow graph into a function
    719   that executes the graph. TensorFlow graphs typically execute faster and with a
    720   lower memory-footprint than executing each of the operations that make up the
    721   function individually as the TensorFlow runtime can optimize the graph and
    722   execute sub-operations in parallel.
    723 
    724   func must be a Python function that constructs a TensorFlow graph,
    725   typically using functions in the tensorflow module.
    726 
    727   Arguments to func can be either Tensor objects or Python
    728   objects. Non-Tensor python objects are treated as constants, and new function
    729   definitions are created internally based on their values.
    730 
    731   func must return a tf.Tensor (NOT a Tensor) or a list of tf.Tensor (NOT a
    732   Tensor).
    733 
    734   Control flow constructs (e.g., `if`, `while`) are not yet compatible with
    735   `defun`.
    736 
    737   Example:
    738   ```python
    739   def f(x, y):
    740     return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)
    741 
    742   @tfe.defun
    743   def g(x, y):
    744     return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)
    745 
    746   x = tf.constant([[2.0, 3.0]])
    747   y = tf.constant([[3.0, -2.0]])
    748   # The plain function and defun-compiled function should return the same value.
    749   assert f(x, y).numpy() == g(x, y).numpy()
    750 
    751   # After the first invocation, the defun-compiled (graph) function runs faster
    752   # than the plain function because the defun-compiled function does not involve
    753   # Python interpreter overhead during the execution.
    754   %time print(f(x, y))
    755   %time print(g(x, y))
    756   ```
    757 
    758   Args:
    759     func: function to be compiled.
    760 
    761   Returns:
    762      A callable that will execute the compiled function (and return zero
    763      or more Tensor objects).
    764   """
    765   # TODO(apassos): deal with captured global state. Deal with control flow.
    766   try:
    767     name = func.__name__
    768   except AttributeError:
    769     name = "function"
    770   return tf_decorator.make_decorator(func, named_defun(func, name))
    771 
    772 
    773 def make_defun_op(func, *args, **kwds):
    774   """Compile func into graph_mode, assuming func arguments are *args, **kwargs.
    775 
    776   `make_defun_op` converts a function that constructs a TensorFlow graph into
    777   a function object and attaches it to the graph.  The resulting function
    778   object can be queried for its properties, and called directly with different
    779   inputs to execute.
    780 
    781   More details on use cases and limitations are available in the
    782   documentation for `defun`.
    783 
    784   Example:
    785   ```python
    786   def f(x, y):
    787     return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)
    788 
    789   def g(x, y):
    790     return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)
    791 
    792   z = tf.constant([[0.0, 0.0]])
    793   g_op = make_defun_op(g, z, z)
    794 
    795   assert g_op.output_shapes == tf.TensorShape([])
    796   assert g_op.output_types == tf.float32
    797 
    798   x = tf.constant([[2.0, 3.0]])
    799   y = tf.constant([[3.0, -2.0]])
    800 
    801   # The plain function and defun-compiled function should return the same value.
    802   assert f(x, y).numpy() == g_op(x, y).numpy()
    803   ```
    804 
    805   Args:
    806     func: function to be compiled.
    807     *args: List arguments to pass to `func` when attaching to the graph.
    808     **kwds: Keyword arguments to pass to `func` when attaching to the graph.
    809 
    810   Returns:
    811      A wrapper object which can be queried for its output properties,
    812      and which can be called directly the way a `@defun` wrapped function
    813      can.
    814 
    815   Raises:
    816     ValueError: if any of the keyword arguments to `func` are `EagerTensor`
    817       objects (not yet supported).
    818   """
    819   name = func.__name__
    820   if any(isinstance(x, ops.EagerTensor) for x in kwds.values()):
    821     raise ValueError("Tensor keyword arguments are not supported.")
    822   return _defun_internal(name, func, args, kwds)
    823 
    824 
    825 class AutomaticControlDependencies(object):
    826   """Context manager to automatically add control dependencies.
    827 
    828   Code under this context manager will act as if a sensible set of control
    829   dependencies were present. More specifically:
    830     1. All stateful ops in the scope will execute
    831     2. Stateful ops which modify the same resource will execute in program order
    832 
    833   Note: creating variables in an automatic control dependencies context is not
    834   supported (the value of the variables will never change as they will keep
    835   getting reinitialized).
    836 
    837   NOT THREAD SAFE
    838   """
    839 
    840   def __init__(self):
    841     self._returned_tensors = set()
    842 
    843   def mark_as_return(self, tensor):
    844     self._returned_tensors.add(tensor)
    845 
    846   def __enter__(self):
    847     if context.in_eager_mode():
    848       return self
    849     # This code assumes no other thread is adding ops to the graph while
    850     # we're adding ops to the graph.
    851     # TODO(apassos): Fix this by locking the graph or using a temporary
    852     # graph (but that would mess up devices and collections at least,
    853     # probably other things as well).
    854     self._graph = ops.get_default_graph()
    855     self._n_operations = len(self._graph.get_operations())
    856     return self
    857 
    858   def _process_switch(self, switch_op, ops_which_must_run,
    859                       last_op_using_resource_tensor, merge_for_resource):
    860     """Processes a switch node for a resource input.
    861 
    862     When tensorflow creates a cond, it creates a control flow context for each
    863     branch of the cond. Each external tensor accessed by that branch is routed
    864     through a switch op, which gets created in the graph _after_ the op which
    865     uses that tensor get created.
    866 
    867     If the resource comes from another switch op we process that one first.
    868 
    869     _process_switch creates a corresponding merge node for the switch node. This
    870     merge node is added to the outer control flow context of the switch
    871     node. We also ensure that:
    872 
    873       1. The switch node executes after the previous op which used the resource
    874          tensor
    875 
    876       2. Any op which uses a resource output of the switch node executes before
    877          the merge for the switch node.
    878 
    879       3. The next op which uses the input resource to the switch node (which
    880          might be another switch node for the other branch of the conditional)
    881          will execute after the merge node is done.
    882 
    883       4. The merge node is marked as must_run so it will run even if no
    884          subsequent operation uses the resource.
    885 
    886     Args:
    887       switch_op: the switch op to be processed
    888       ops_which_must_run: the set of ops which must run
    889       last_op_using_resource_tensor: map from resource tensor to last op using
    890         it
    891       merge_for_resource: map from resource tensor to merge which must follow
    892         all usages of it.
    893     """
    894     inp = switch_op.inputs[0]
    895     if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
    896       self._process_switch(inp.op, ops_which_must_run,
    897                            last_op_using_resource_tensor, merge_for_resource)
    898     if switch_op.outputs[0] in merge_for_resource:
    899       return
    900     new_merge = control_flow_ops.merge(switch_op.outputs,
    901                                        name="artificial_merge")
    902     new_merge[0].op._control_flow_context = (  # pylint: disable=protected-access
    903         switch_op._control_flow_context.outer_context)  # pylint: disable=protected-access
    904     # Ensures the merge always runs
    905     ops_which_must_run.add(new_merge[0].op)
    906     if inp in last_op_using_resource_tensor:
    907       # Ensures the switch exectutes after the previous op using the resource.
    908       switch_op._add_control_input(last_op_using_resource_tensor[inp])  # pylint: disable=protected-access
    909     # Ensure the next op outside the cond happens after the merge.
    910     last_op_using_resource_tensor[inp] = new_merge[0].op
    911     if inp in merge_for_resource:
    912       merge_for_resource[inp]._add_control_input(new_merge[0].op)  # pylint: disable=protected-access
    913     for o in switch_op.outputs:
    914       # Ensures the merge will execute after all ops inside the cond
    915       merge_for_resource[o] = new_merge[0].op
    916 
    917   def __exit__(self, unused_type, unused_value, unused_traceback):
    918     if context.in_eager_mode():
    919       return
    920 
    921     if self._graph is not ops.get_default_graph():
    922       raise RuntimeError(
    923           "Graph changed while trying to add control dependencies.")
    924 
    925     # map from resource tensor to the last op which used it
    926     last_op_using_resource_tensor = {}
    927     # set of conditional and loop exits
    928     ops_which_must_run = set()
    929     # merge which must depend on ops which use this resource
    930     merge_for_resource = {}
    931 
    932     new_operations = self._graph.get_operations()[self._n_operations:]
    933 
    934     # Ensures that uses of resource tensors get serialized properly and all
    935     # execute. This is done by keeping a map from resource tensor to the last op
    936     # in graph-construction order which used it (last_op_using_resource_tensor).
    937     #
    938     # Conditionals are written in TensorFlow such that every external tensor
    939     # accessed in the conditional goes through a switch op and every return
    940     # tensor (it's guaranteed that there will be at least one) goes through a
    941     # merge op.
    942     #
    943     # To handle conditionals, switches are handled in a special way (see
    944     # comments for _process_switch). Merge nodes created by TF's conditional
    945     # logic (as opposed to by _process_switch) are forced to run and also get a
    946     # control dependency added to them to ensure all stateful ops inside their
    947     # control flow context run.
    948     #
    949     # We also ensure that if an op is using a resource output by a switch node
    950     # (that is, a resource tensor for which there's a value in
    951     # merge_for_resource) this op will run before the merge for that resource.
    952     #
    953     # We try to add control inputs to nodes respecting their control flow
    954     # contexts to avoid dead nodes propagating everywhere and leading to
    955     # "retval[0] doesn't have value" errors. If a node gets a control dependency
    956     # on a dead node (i.e. a note from an untaken control flow branch) that node
    957     # will be marked as dead unless it's a merge node.
    958     #
    959     # TODO(apassos): serialize non-resource-taking stateful ops as well, and
    960     # test that it works. Support while loops. Support init_scope escaping from
    961     # this.
    962     for op in new_operations:
    963       control_inputs = set()
    964       # Ensure stateful ops run
    965       if self._graph._registered_ops[op.type].is_stateful:  # pylint: disable=protected-access
    966         ops_which_must_run.add(op)
    967       # Ignore switches (they're handled separately)
    968       if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
    969         continue
    970       # Make merges trigger all other computation which must run
    971       if op.type == "Merge":
    972         for o in ops_which_must_run:
    973           op._add_control_input(o)  # pylint: disable=protected-access
    974           for inp in o.inputs:
    975             if inp in last_op_using_resource_tensor:
    976               last_op_using_resource_tensor[inp] = op
    977         ops_which_must_run = set([op])
    978         continue
    979       for inp in op.inputs:
    980         if inp.dtype == dtypes_module.resource:
    981           # Deal with switches, finally.
    982           if inp.op.type == "Switch":
    983             self._process_switch(inp.op, ops_which_must_run,
    984                                  last_op_using_resource_tensor,
    985                                  merge_for_resource)
    986           # Ensure uses of resources are serialized
    987           if inp in last_op_using_resource_tensor:
    988             if (last_op_using_resource_tensor[inp]._control_flow_context  # pylint: disable=protected-access
    989                 is op._control_flow_context):  # pylint: disable=protected-access
    990               control_inputs.add(last_op_using_resource_tensor[inp])
    991           # Ensure merges happen after the closing of a cond block
    992           if inp in merge_for_resource:
    993             merge_for_resource[inp]._add_control_input(op)  # pylint: disable=protected-access
    994           last_op_using_resource_tensor[inp] = op
    995       control_inputs = [c for c in control_inputs
    996                         if c._control_flow_context is op._control_flow_context]  # pylint: disable=protected-access
    997       op._add_control_inputs(control_inputs)  # pylint: disable=protected-access
    998 
    999     # Ensure all ops which must run do run
   1000     for r in self._returned_tensors:
   1001       r.op._add_control_inputs(  # pylint: disable=protected-access
   1002           [o for o in ops_which_must_run
   1003            if o._control_flow_context is r.op._control_flow_context])  # pylint: disable=protected-access
   1004 
   1005 
   1006 def automatic_control_dependencies(f):
   1007   """Wraps f to automatically insert control dependencies.
   1008 
   1009   The inserted dependencies ensure that:
   1010     1. All stateful ops in f run when the result of f runs
   1011     2. Updates to the same resources happen in order.
   1012 
   1013   Args:
   1014     f: the function to be wrapped.
   1015 
   1016   Returns:
   1017     The wrapped function.
   1018   """
   1019 
   1020   def wrapper(*args, **kwds):
   1021     with AutomaticControlDependencies() as a:
   1022       result = f(*args, **kwds)
   1023       for t in nest.flatten(result):
   1024         a.mark_as_return(t)
   1025       return result
   1026 
   1027   return tf_decorator.make_decorator(f, wrapper)
   1028