Home | History | Annotate | Download | only in framework
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # =============================================================================
     15 """Python front-end supports for functions.
     16 
     17 NOTE: functions are currently experimental and subject to change!
     18 """
     19 
     20 from __future__ import absolute_import
     21 from __future__ import division
     22 from __future__ import print_function
     23 
     24 import collections
     25 import hashlib
     26 
     27 from tensorflow.core.framework import attr_value_pb2
     28 from tensorflow.core.framework import function_pb2
     29 from tensorflow.python import pywrap_tensorflow as c_api
     30 from tensorflow.python.eager import context
     31 from tensorflow.python.framework import c_api_util
     32 from tensorflow.python.framework import dtypes
     33 from tensorflow.python.framework import errors
     34 from tensorflow.python.framework import graph_to_function_def
     35 from tensorflow.python.framework import ops
     36 from tensorflow.python.ops import array_ops
     37 from tensorflow.python.ops import resource_variable_ops
     38 from tensorflow.python.ops import variable_scope as vs
     39 from tensorflow.python.util import compat
     40 from tensorflow.python.util import tf_decorator
     41 from tensorflow.python.util import tf_inspect
     42 
     43 
     44 class Defun(object):
     45   """Decorator used to define TensorFlow functions.
     46 
     47   Use this decorator to make a Python function usable directly as a TensorFlow
     48   function.
     49 
     50   The decorated function must add ops to the default graph and return zero or
     51   more `Tensor` objects.  Call the decorator with named arguments, one for each
     52   argument of the function to decorate, with the expected type of the argument
     53   as value.
     54 
     55   For example if the function to decorate accepts two `tf.float32` arguments
     56   named `x` and `y`, call the decorator with:
     57 
     58       @Defun(tf.float32, tf.float32)
     59       def foo(x, y):
     60         ...
     61 
     62   When you call the decorated function it will add `call` ops to the
     63   default graph and adds the definition of the function into the
     64   default graph. Because the addition of the function into the graph
     65   is deferred, the decorator can be used anywhere in the program.
     66 
     67   Any variables created inside of the function are hoisted into the outer graph.
     68   Note that the variables are created in the variable scope that was active
     69   during the first call to the function. Subsequent function calls will refer to
     70   the same set of variables.
     71 
     72   Definitions of functions are frozen in a graph as soon as the graph is used to
     73   create a session. Therefore, nodes using the function must be created in the
     74   graph before the corresponding session is created.
     75 
     76   Example, but also see the [How To on functions](link_needed).
     77 
     78   ```python
     79   # Defining the function.
     80   @tf.Defun(tf.float32, tf.float32)
     81   def MyFunc(x, y):
     82     return x + y, x - y
     83 
     84   # Building the graph.
     85   a = tf.constant([1.0])
     86   b = tf.constant([2.0])
     87   c, d = MyFunc(a, b, name='mycall')
     88   ```
     89   """
     90 
     91   def __init__(self, *input_types, **kwargs):
     92     """Create a `Defun` decorator.
     93 
     94     Args:
     95       *input_types: A list of `tf.DType`
     96       **kwargs: Optional keyword arguments, including
     97          func_name - (optional).  A python string, the name to use to
     98            declare this `Function` in the graph.
     99 
    100          grad_func - (optional).  A function implementing the gradient
    101            of the function-to-register.  This is must be a
    102            `_DefinedFunction` object. The gradient
    103            function must satisfy the criterion defined in
    104            function.proto:GradientDef.
    105 
    106          python_grad_func - (optional).  A function implementing the
    107            gradient of the function python-side. This function must
    108            take the current op and the gradients w.r.t. its outputs,
    109            and return the gradients w.r.t. the inputs. That is it must
    110            implement the interface expected by `tf.RegisterGradient`).
    111            This will be called by tf.gradients to add the gradient ops
    112            to the graph. At most one of grad_func and python_grad_func
    113            can be specified.
    114 
    115          out_names = (optional). A list of strings, one per output
    116            tensor.
    117 
    118          shape_func - (optional). A function taking the op and returning a list
    119            of static shapes to set for the function's outputs.
    120     """
    121     self._input_types = input_types
    122     self._func_name = kwargs.pop("func_name", None)
    123     self._grad_func = kwargs.pop("grad_func", None)
    124     self._python_grad_func = kwargs.pop("python_grad_func", None)
    125     self._out_names = kwargs.pop("out_names", None)
    126     self._extra_kwargs = kwargs
    127 
    128   def __call__(self, func):
    129     # Various sanity checks on the callable func.
    130     if not callable(func):
    131       raise ValueError("func %s must be callable" % func)
    132 
    133     # Func should not use kwargs and defaults.
    134     argspec = tf_inspect.getargspec(func)
    135     if argspec.keywords or argspec.defaults:
    136       raise ValueError("Functions with argument defaults or keyword "
    137                        "arguments are not supported.")
    138 
    139     # Computes how many arguments 'func' has.
    140     min_args = len(argspec.args)
    141     max_args = min_args
    142     if argspec.varargs:
    143       max_args = 1000000
    144     argnames = argspec.args
    145     if tf_inspect.ismethod(func):
    146       # 1st argument is the "class" type.
    147       min_args -= 1
    148       argnames = argnames[1:]
    149 
    150     if self._input_types:
    151       # If Defun is given a list of types for the inputs, the number
    152       # of input types should be compatible with 'func'.
    153       num = len(self._input_types)
    154       if num < min_args or num > max_args:
    155         raise ValueError(
    156             "The function has fewer arguments than the number of specified "
    157             "input types.")
    158       return _DefinedFunction(
    159           func,
    160           argnames,
    161           self._input_types,
    162           self._func_name,
    163           self._grad_func,
    164           self._python_grad_func,
    165           out_names=self._out_names,
    166           **self._extra_kwargs)
    167 
    168     # 'func' expects no arguments and input types is an empty list.
    169     if min_args == 0 and max_args == 0:
    170       return _DefinedFunction(
    171           func, [], [],
    172           self._func_name,
    173           self._grad_func,
    174           self._python_grad_func,
    175           out_names=self._out_names,
    176           **self._extra_kwargs)
    177 
    178     # Input types are unknown. It's an overloaded function and hence
    179     # its definition needs to be deferred until it's called.
    180     return _OverloadedFunction(
    181         func,
    182         argnames,
    183         self._func_name,
    184         self._grad_func,
    185         self._python_grad_func,
    186         out_names=self._out_names,
    187         **self._extra_kwargs)
    188 
    189 
    190 class _DefinedFunction(object):
    191   """_DefinedFunction encapsulates a function definition and its properties.
    192 
    193   Attributes:
    194     name: The function name.
    195     definition: The definition of this function. A FunctionDef proto.
    196     grad_func_name: If not None, the name of this function's gradient function.
    197     python_grad_func: A python callable implementing the gradient of
    198       the function python-side.
    199   """
    200 
    201   def __init__(self,
    202                func,
    203                argnames,
    204                input_types,
    205                func_name=None,
    206                grad_func=None,
    207                python_grad_func=None,
    208                out_names=None,
    209                shape_func=None,
    210                capture_by_value=False,
    211                **kwargs):
    212     """Creates _DefinedFunction.
    213 
    214     Args:
    215       func:  A python callable which constructs a tf function body.
    216       argnames: A list of strings for function argument names.
    217       input_types: The function's argument types. Can be a tuple, list of
    218         tf data types.
    219       func_name: The function name. Defaults to None, in which derives from
    220         'func'.
    221       grad_func: This function's gradient function, if not None. Defaults
    222         to None.
    223       python_grad_func: A python callable implementing the gradient of
    224         the function python-side.
    225       out_names: An optional list of strings for the function return value
    226         names.
    227       shape_func: An optional function mapping an op to a list of static
    228         output shapes.
    229       capture_by_value: Boolean (defaults to False). If True, captured values
    230         will be copied into the function body.
    231       **kwargs: The keyword arguments. **kwargs is passed to every call
    232         site of this function.
    233 
    234     Raises:
    235       ValueError: The function definition is invalid.
    236 
    237     """
    238     self._func = func
    239     self._input_types = input_types
    240     self._func_name = func_name
    241     self._grad_func = grad_func
    242     self._python_grad_func = python_grad_func
    243     self._out_names = out_names
    244     self._shape_func = shape_func
    245     self._capture_by_value = capture_by_value
    246     self._extra_kwargs = kwargs
    247     # Constructed only when C API is disabled, lazily
    248     self._definition = None
    249     # Constructed only when C API is enabled, lazily
    250     self._c_func = None
    251     self._sub_functions = dict()  # Constructed with _definition or _c_func
    252 
    253     # Cached OpDef for this function. When C API is enabled, this is
    254     # the only part of FunctionDef that we cache in Python. When C API
    255     # is disabled the whole _definition is available and this is simply
    256     # another reference to _definition.signature
    257     self._op_def = None
    258 
    259     self._args = []
    260     assert isinstance(input_types, (list, tuple))
    261     for i in range(len(input_types)):
    262       argname = argnames[i] if i < len(argnames) else ("arg%d" % i)
    263       argtype = input_types[i]
    264       self._args.append((argname, argtype))
    265 
    266   @property
    267   def name(self):
    268     """Function name."""
    269     self._create_definition_if_needed()
    270     return self._func_name
    271 
    272   @property
    273   def definition(self):
    274     """Function definition proto."""
    275     self._create_definition_if_needed()
    276     if self._c_func:
    277       with c_api_util.tf_buffer() as buf:
    278         with errors.raise_exception_on_not_ok_status() as status:
    279           c_api.TF_FunctionToFunctionDef(self._c_func, buf, status)
    280         fdef = function_pb2.FunctionDef()
    281         proto_data = c_api.TF_GetBuffer(buf)
    282         fdef.ParseFromString(compat.as_bytes(proto_data))
    283       return fdef
    284     return self._definition
    285 
    286   @property
    287   def _signature(self):
    288     self._create_definition_if_needed()
    289     return self._op_def
    290 
    291   def set_grad_func(self, grad_func):
    292     """Specifies the gradient function of this function."""
    293     assert not self._grad_func
    294     assert isinstance(grad_func, _DefinedFunction)
    295     self._grad_func = grad_func
    296 
    297   @property
    298   def grad_func_name(self):
    299     """Its gradient function's name."""
    300     return self._grad_func.name if self._grad_func else None
    301 
    302   @property
    303   def python_grad_func(self):
    304     """Python gradient function callable."""
    305     return self._python_grad_func
    306 
    307   @property
    308   def declared_input_types(self):
    309     """Returns the list of data types of explicit declared inputs."""
    310     return self._input_types
    311 
    312   @property
    313   def captured_inputs(self):
    314     """Returns the list of implicitly captured inputs."""
    315     self._create_definition_if_needed()
    316     return self._extra_inputs
    317 
    318   def _create_definition_if_needed(self):
    319     """Creates the function definition if it's not created yet."""
    320     with context.graph_mode():
    321       self._create_definition_if_needed_impl()
    322 
    323   def _create_definition_if_needed_impl(self):
    324     """This is not what you want, see _create_definition_if_needed."""
    325     if self._definition is not None or self._c_func is not None:
    326       return
    327 
    328     # Create the func_def object.
    329     temp_graph = _FuncGraph(capture_by_value=self._capture_by_value)
    330     with temp_graph.as_default():
    331       # List of placeholders for the function_def.
    332       inputs = []
    333       for (argname, argtype) in self._args:
    334         argholder = array_ops.placeholder(argtype, name=argname)
    335         inputs.append(argholder)
    336       # Call func and gather the output tensors.
    337       with vs.variable_scope("", custom_getter=temp_graph.getvar):
    338         outputs = self._func(*inputs)
    339 
    340       # There is no way of distinguishing between a function not returning
    341       # anything and a function returning None in Python.
    342       # We need to allow the former and ideally want to forbid the latter as
    343       # it is most likely user error.
    344       # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to
    345       # allow users to explicitly mark the function as not returning anything.
    346       # For now, we allow a single None return and interpret it as a function
    347       # with no output.
    348       if outputs is None:
    349         outputs = []
    350       else:
    351         # If func only returned one value, make it a tuple.
    352         if not isinstance(outputs, (list, tuple)):
    353           outputs = (outputs,)
    354         if any([_ is None for _ in outputs]):
    355           raise ValueError("Function can not return None.")
    356       # Ensures each output is a Tensor.
    357       outputs = [ops.convert_to_tensor(_) for _ in outputs]
    358     self._extra_inputs = temp_graph.extra_inputs
    359     inputs.extend(temp_graph.extra_args)
    360     # pylint: disable=protected-access
    361     self._sub_functions = temp_graph._functions
    362     # pylint: enable=protected-access
    363 
    364     # Extra kwargs are treated as attrs on the function def.
    365     base_func_name = self._func_name or _get_func_name(self._func)
    366     kwargs_attr = _parse_kwargs_as_attrs(base_func_name,
    367                                          **self._extra_kwargs)
    368 
    369     if not temp_graph._c_graph:  # pylint: disable=protected-access
    370       # Build the FunctionDef
    371       self._definition = graph_to_function_def.graph_to_function_def(
    372           temp_graph,
    373           temp_graph.get_operations(),
    374           inputs,
    375           outputs,
    376           out_names=self._out_names)
    377 
    378       for k in kwargs_attr:
    379         self._definition.attr[k].CopyFrom(kwargs_attr[k])
    380 
    381       # Hash the definition and its dependencies.
    382       self._hash_str = self._create_hash_str(
    383           self._definition.signature.input_arg,
    384           self._definition.signature.output_arg, self._definition.node_def)
    385 
    386       # Finally, we decide the function name to use.  If not specified,
    387       # make up something which is almost certainly unique (but deterministic).
    388       if not self._func_name:
    389         self._func_name = "_".join([base_func_name, self._hash_str])
    390       self._definition.signature.name = self._func_name
    391       if self._func.__doc__:
    392         self._definition.signature.description = self._func.__doc__
    393 
    394       self._op_def = self._definition.signature
    395     else:  # C API is enabled
    396       output_names = ([compat.as_bytes(x) for x in self._out_names]
    397                       if self._out_names else [])
    398       description = self._func.__doc__ or None
    399       # pylint: disable=protected-access
    400       with errors.raise_exception_on_not_ok_status() as status:
    401         self._c_func = c_api.TF_GraphToFunction_wrapper(
    402             temp_graph._c_graph,
    403             base_func_name,
    404             self._func_name is None,  # append_hash_to_fn_name
    405             None,  # opers
    406             [t._as_tf_output() for t in inputs],
    407             [t._as_tf_output() for t in outputs],
    408             output_names,
    409             None,  # opts
    410             description,
    411             status)
    412       # pylint: enable=protected-access
    413       self._set_c_attrs(kwargs_attr)
    414 
    415       # Set cached fields: _op_def and _func_name (if not already set)
    416       self._op_def = self.definition.signature
    417       if self._func_name:
    418         assert self._func_name == self._op_def.name
    419       else:
    420         self._func_name = compat.as_str(self._op_def.name)
    421 
    422   def _set_c_attrs(self, attrs):
    423     """Sets `attrs` as attributes of self._c_func.
    424 
    425     Requires that self._c_func is not None.
    426 
    427     Args:
    428       attrs: a dictionary from attribute name to attribute proto value
    429     """
    430     for name, attr_value in attrs.items():
    431       serialized = attr_value.SerializeToString()
    432       # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
    433       # It might be worth creating a convenient way to re-use the same status.
    434       with errors.raise_exception_on_not_ok_status() as status:
    435         c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name),
    436                                            serialized, status)
    437 
    438   def _create_hash_str(self, input_arg, output_arg, node_def):
    439     """Creates an 8-character string unique to this input.
    440 
    441     Args:
    442       input_arg: the input_arg field of an OpDef
    443                  (e.g. self._definition.signature.input_arg)
    444       output_arg: the output_arg field of an OpDef
    445                  (e.g. self._definition.signature.output_arg)
    446       node_def: the node_def field of a FunctionDef
    447                 (e.g. self._definition.node_def)
    448 
    449     Returns:
    450       The unique string for this input
    451     """
    452     hasher = hashlib.sha1()
    453 
    454     def update_num(n):
    455       hasher.update(compat.as_bytes("%x" % n))
    456 
    457     def update_str(s):
    458       update_num(len(s))
    459       hasher.update(compat.as_bytes(s))
    460 
    461     def update_strs(slist):
    462       update_num(len(slist))
    463       for s in slist:
    464         update_str(s)
    465 
    466     for adef in input_arg:
    467       update_str(adef.SerializeToString())
    468 
    469     for adef in output_arg:
    470       update_str(adef.SerializeToString())
    471 
    472     for n in sorted(node_def, key=lambda n: n.name):
    473       update_str(n.name)
    474       update_str(n.op)
    475       update_strs(n.input)
    476       update_num(len(n.attr))
    477       # NOTE: protobuf map serialization does not guarantee ordering.
    478       for k in sorted(n.attr):
    479         update_str(k)
    480         update_str(n.attr[k].SerializeToString())
    481 
    482     return hasher.hexdigest()[:8]
    483 
    484   def add_to_graph(self, g):
    485     """Adds this function into the graph g."""
    486     self._create_definition_if_needed()
    487 
    488     # Adds this function into 'g'.
    489     # pylint: disable=protected-access
    490     if context.in_graph_mode():
    491       g._add_function(self)
    492     else:
    493       context.context().add_function_def(self.definition)
    494     # pylint: enable=protected-access
    495 
    496     # Ensures related sub-routines are defined in 'g', too.
    497     for f in self._sub_functions.values():
    498       f.add_to_graph(g)
    499 
    500     # Adds its gradient function, too.
    501     if self._grad_func:
    502       self._grad_func.add_to_graph(g)
    503 
    504   def __call__(self, *args, **kwargs):
    505     self.add_to_graph(ops.get_default_graph())
    506     args = [ops.convert_to_tensor(_) for _ in args] + self._extra_inputs
    507     ret, op = _call(self._signature, *args, **kwargs)
    508     if self._shape_func is not None:
    509       shapes = self._shape_func(op)
    510       if len(shapes) != len(op.outputs):
    511         raise ValueError("shape_func produced %d shapes for %d outputs" %
    512                          (len(shapes), len(op.outputs)))
    513       for (t, shape) in zip(op.outputs, shapes):
    514         t.set_shape(shape)
    515     return ret
    516 
    517 
    518 class _OverloadedFunction(object):
    519   """_OverloadedFunction encapsulates an overloaded function.
    520 
    521   _OverloadedFunction maintains a mapping from input types to
    522   instantiated _DefinedFunction in self._overload.
    523 
    524   """
    525 
    526   def __init__(self,
    527                func,
    528                argnames,
    529                func_name=None,
    530                grad_func=None,
    531                python_grad_func=None,
    532                out_names=None,
    533                **kwargs):
    534     """Creates _DefinedFunction.
    535 
    536     Args:
    537       func:  A python callable which constructs a tf function body.
    538       argnames: A list of strings for function argument names.
    539       func_name: The function name. Defaults to None, in which derives from
    540         'func'.
    541       grad_func: This function's gradient function, if not None. Defaults
    542         to None.
    543       python_grad_func: A python callable implementing the gradient of
    544         the function python-side.
    545       out_names: A list of strings for the function return value names.
    546       **kwargs: The keyword arguments. **kwargs is passed to every call
    547         site of this function.
    548 
    549     Raises:
    550       ValueError: The function definition is invalid.
    551 
    552     """
    553     self._func = func
    554     self._argnames = argnames
    555     self._func_name = func_name
    556     assert grad_func is None or isinstance(grad_func, _OverloadedFunction)
    557     self._grad_func = grad_func
    558     self._python_grad_func = python_grad_func
    559     self._out_names = out_names
    560     self._extra_kwargs = kwargs
    561     self._overload = {}
    562 
    563   def instantiate(self, input_types):
    564     """Instantiate this function given input argument types.
    565 
    566     Args:
    567       input_types: A list of data types for the inputs.
    568 
    569     Returns:
    570       _DefinedFunction for the given input types.
    571 
    572     """
    573     # Stringify the type list.
    574     key = _type_list_to_str(input_types)
    575     defined = self._overload.get(key)
    576     if not defined:
    577       # If not defined yet, define the function given the input types.
    578       name = self._func_name
    579       if name is not None:
    580         name = "_".join([name, key])
    581       defined = _DefinedFunction(
    582           self._func,
    583           self._argnames,
    584           input_types,
    585           name,
    586           None,
    587           self._python_grad_func,
    588           out_names=self._out_names,
    589           **self._extra_kwargs)
    590       _ = defined.name  # Fully instantiate the function definition.
    591       if self._grad_func:
    592         # If _grad_func is given, it is another
    593         # _OverloadedFunction. We need to instantiate it with the
    594         # right input types.
    595         output_types = [
    596             dtypes.DType(_.type)
    597             for _ in defined._signature.output_arg  # pylint: disable=protected-access
    598         ]
    599         # pylint: disable=protected-access
    600         defined._grad_func = self._grad_func.instantiate(
    601             input_types + output_types)
    602         # pylint: enable=protected-access
    603       self._overload[key] = defined
    604     return defined
    605 
    606   def __call__(self, *args, **kwargs):
    607     input_types = []
    608     args = list(args)
    609     for (i, x) in enumerate(args):
    610       x = ops.convert_to_tensor(x)
    611       if not isinstance(x, ops.Tensor):
    612         raise ValueError("Expect a Tensor but get ", x)
    613       input_types.append(x.dtype)
    614       args[i] = x
    615     return self.instantiate(input_types)(*args, **kwargs)
    616 
    617 
    618 class _FuncGraph(ops.Graph):
    619   """A helper for constructing a function.
    620 
    621   _FuncGraph overrides ops.Graph's create_op() so that we can keep
    622   track of all inputs into every op created inside the function.  If
    623   any input is from other graphs, we keep track of it in self.capture
    624   and substitute the input with a place holder.
    625 
    626   Each captured input's corresponding place holder is converted into a
    627   function argument and the caller passes in the captured tensor.
    628   """
    629 
    630   def __init__(self, capture_by_value, *args, **kwargs):
    631     super(_FuncGraph, self).__init__(*args, **kwargs)
    632     self._capture_by_value = capture_by_value
    633     self._building_function = True
    634     self._outer_graph = ops.get_default_graph()
    635     self._vscope = vs.get_variable_scope()
    636     self._old_custom_getter = self._vscope.custom_getter
    637     self._captured = {}
    638     self.extra_inputs = []
    639     self.extra_args = []
    640     self.extra_vars = []
    641 
    642   def getvar(
    643       self,
    644       getter,
    645       name,
    646       shape=None,
    647       dtype=None,
    648       initializer=None,
    649       reuse=None,
    650       trainable=True,
    651       collections=None,  # pylint: disable=redefined-outer-name
    652       use_resource=None,
    653       **kwargs):
    654     """A custom variable getter."""
    655     # Here, we switch the default graph to the outer graph and ask the
    656     # variable scope in which the function is defined to give us the
    657     # variable. The variable is stashed in extra_vars and returned to
    658     # the caller.
    659     #
    660     # We capture these variables so that the variable definition is
    661     # hoisted upward to the outer most graph.
    662     with self._outer_graph.as_default():
    663       # pylint: disable=protected-access
    664       var = self._vscope.get_variable(
    665           vs._get_default_variable_store(),
    666           name,
    667           shape=shape,
    668           dtype=dtype,
    669           initializer=initializer,
    670           reuse=reuse,
    671           trainable=trainable,
    672           collections=collections,
    673           use_resource=use_resource)
    674       self.extra_vars.append(var)
    675       if isinstance(var, resource_variable_ops.ResourceVariable):
    676         # For resource-based variables read the variable outside the function
    677         # and pass in the value. This ensures that the function is pure and
    678         # differentiable. TODO(apassos) this may have performance problems if
    679         # the function will only do embedding lookups on the variable.
    680         return var.value()
    681       return var
    682 
    683   def create_op(self, op_type, inputs, data_types, **kwargs):
    684     for i, x in enumerate(inputs):
    685       if isinstance(x, ops.EagerTensor) or x.graph is not self:
    686         # Referring to a tensor from other graph.
    687         if x in self._captured:
    688           # Captured already.
    689           inputs[i] = self._captured[x]
    690         elif self._capture_by_value:
    691           inputs[i] = self._add_tensor_and_parents(x)
    692         else:
    693           # Substitute with a placeholder.
    694           self.extra_inputs.append(x)
    695           # Hoist the new input placeholder out of any control flow context
    696           # we're currently in.
    697           with ops.control_dependencies(None):
    698             ph = array_ops.placeholder(x.dtype, shape=x.get_shape())
    699           # pylint: disable=protected-access
    700           ph._handle_data = x._handle_data
    701           # pylint: enable=protected-access
    702           inputs[i] = ph
    703           self._captured[x] = ph
    704           self.extra_args.append(ph)
    705     return super(_FuncGraph, self).create_op(op_type, inputs, data_types,
    706                                              **kwargs)
    707 
    708   def _add_tensor_and_parents(self, tensor):
    709     op = self._add_op_and_parents(tensor.op)
    710     return op.outputs[tensor.value_index]
    711 
    712   def _add_op_and_parents(self, op):
    713     # pylint: disable=protected-access
    714     op_def = graph_to_function_def._get_op_def(op)
    715     # pylint: enable=protected-access
    716     if op_def.is_stateful:
    717       raise ValueError("Cannot capture a stateful node (name:%s, type:%s) "
    718                        "by value." % (op.name, op.type))
    719     elif op.type in ("Placeholder", "PlaceholderV2"):
    720       raise ValueError("Cannot capture a placeholder (name:%s, type:%s) "
    721                        "by value." % (op.name, op.type))
    722 
    723     captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs]
    724 
    725     captured_op = self.create_op(
    726         op.type,
    727         captured_inputs, [o.dtype for o in op.outputs],
    728         name=op.name,
    729         attrs=op.node_def.attr,
    730         op_def=op_def)
    731 
    732     for t, captured_t in zip(op.outputs, captured_op.outputs):
    733       self._captured[t] = captured_t
    734 
    735     return captured_op
    736 
    737 
    738 def _call(sig, *inputs, **kwargs):
    739   """Adds a node calling a function.
    740 
    741   This adds a `call` op to the default graph that calls the function
    742   of signature `sig`, passing the tensors in `inputs` as arguments.
    743   It returns the outputs of the call, which are one or more tensors.
    744 
    745   `sig` is OpDefArg.a `_DefinedFunction` object.
    746 
    747   You can pass an optional keyword parameter `name=string` to name the
    748   added operation.
    749 
    750   You can pass an optional keyword parameter `noinline=True|False` to
    751   instruct the runtime not to inline the function body into the call
    752   site.
    753 
    754   Args:
    755     sig: OpDefArg. The signature of the function.
    756     *inputs: arguments to the function.
    757     **kwargs: Optional keyword arguments.  Can only contain 'name' or
    758         'noinline'.
    759 
    760   Returns:
    761      A 2-element tuple. First element: a Tensor if the function returns a single
    762      value; a list of Tensors if the function returns multiple value; the
    763      Operation if the function returns no values. Second element: the Operation.
    764 
    765   Raises:
    766     ValueError: if the arguments are invalid.
    767   """
    768   if len(inputs) != len(sig.input_arg):
    769     raise ValueError("Expected number of arguments: %d, received: %d" %
    770                      (len(sig.input_arg), len(inputs)))
    771   name = kwargs.pop("name", None)
    772   g = ops.get_default_graph()
    773   func_name = sig.name
    774   attrs = _parse_kwargs_as_attrs(func_name, **kwargs)
    775   output_types = [dtypes.DType(x.type) for x in sig.output_arg]
    776   with ops.name_scope(name, func_name, inputs) as name:
    777     op = g.create_op(
    778         func_name,
    779         list(inputs),
    780         output_types,
    781         name=name,
    782         attrs=attrs,
    783         op_def=sig,
    784         compute_shapes=False)
    785   if op.outputs:
    786     if len(op.outputs) == 1:
    787       ret = op.outputs[0]
    788     else:
    789       ret = tuple(op.outputs)
    790   else:
    791     ret = op
    792   return ret, op
    793 
    794 
    795 def _from_definition(fdef, grad_func=None):
    796   """Creates a _DefinedFunction initialized from a FunctionDef proto.
    797 
    798   Args:
    799     fdef: a FunctionDef
    800     grad_func: a _DefinedFunction or None
    801 
    802   Returns:
    803     A _DefinedFunction representing fdef
    804   """
    805   # TODO(iga): This method does major surgery on _DefinedFunction.
    806   # Make it a named constructor using @classmethod of _DefinedFunction.
    807 
    808   # The Python callable is only needed to create a FunctionDef. Since we have
    809   # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we
    810   # have access to such a callable here).
    811   func = None
    812   argnames = [arg.name for arg in fdef.signature.input_arg]
    813   input_types = tuple(
    814       dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg)
    815   func_name = fdef.signature.name
    816   # Note: FunctionDefs do not include python gradient functions, so if the
    817   # original _DefinedFunction included one it will not be reflected here.
    818   python_grad_func = None
    819   out_names = [arg.name for arg in fdef.signature.output_arg]
    820   result = _DefinedFunction(func, argnames, input_types, func_name, grad_func,
    821                             python_grad_func, out_names)
    822   # pylint: disable=protected-access
    823   if ops._USE_C_API:
    824     serialized = fdef.SerializeToString()
    825     with errors.raise_exception_on_not_ok_status() as status:
    826       result._c_func = c_api.TF_FunctionImportFunctionDef(serialized, status)
    827     result._extra_inputs = []
    828   else:
    829     result._definition = fdef
    830     # Captured inputs are added as regular inputs to a function when it's
    831     # serialized, i.e. any extra inputs from the original function are now
    832     # included in `result`._args
    833     result._extra_inputs = []
    834     result._hash_str = result._create_hash_str(
    835         result._definition.signature.input_arg,
    836         result._definition.signature.output_arg, result._definition.node_def)
    837   # pylint: enable=protected-access
    838 
    839   return result
    840 
    841 
    842 def _from_library(lib):
    843   """Creates _DefinedFunctions initialized from a FunctionDefLibrary proto.
    844 
    845   This method handles assigning the correct gradient functions to each
    846   function.
    847 
    848   Args:
    849     lib: a FunctionDefLibrary
    850 
    851   Returns:
    852     A list of _DefinedFunctions
    853 
    854   Raises:
    855     ValueError: `lib` is invalid
    856   """
    857   if not lib.function and not lib.gradient:
    858     return []
    859 
    860   # function name -> FunctionDef proto
    861   funcs = {fdef.signature.name: fdef for fdef in lib.function}
    862 
    863   # Validate that all references function names have function defs
    864   for g in lib.gradient:
    865     if g.function_name not in funcs:
    866       raise ValueError("FunctionDefLibrary missing '%s' FunctionDef\n%s" %
    867                        (g.function_name, str(lib)))
    868     if g.gradient_func not in funcs:
    869       raise ValueError("FunctionDefLibrary missing '%s' FunctionDef\n%s" %
    870                        (g.gradient_func, str(lib)))
    871 
    872   # function name -> gradient function name
    873   func_to_grad = collections.defaultdict(lambda: None)
    874   # gradient function name -> names of functions having that grad function
    875   grad_to_funcs = collections.defaultdict(list)
    876 
    877   for gdef in lib.gradient:
    878     func_to_grad[gdef.function_name] = gdef.gradient_func
    879     grad_to_funcs[gdef.gradient_func].append(gdef.function_name)
    880 
    881   # Start with functions without gradients
    882   ready = [
    883       fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None
    884   ]
    885   if not ready:
    886     raise ValueError("FunctionDefLibrary contains cyclic gradient functions!\n"
    887                      + str(lib))
    888   # function name -> _DefinedFunction
    889   initialized = {}
    890 
    891   while ready:
    892     fdef = ready.pop()
    893     name = fdef.signature.name
    894 
    895     grad = initialized.get(func_to_grad[name])
    896     if func_to_grad[name]:
    897       assert grad
    898     defined_func = _from_definition(fdef, grad_func=grad)
    899     initialized[name] = defined_func
    900 
    901     ready.extend(funcs[f] for f in grad_to_funcs[name])
    902 
    903   return initialized.values()
    904 
    905 
    906 def _parse_kwargs_as_attrs(func_name, **kwargs):
    907   """Parses **kwargs into a node's attributes."""
    908   attrs = {}
    909 
    910   noinline = kwargs.pop("noinline", None)
    911   if noinline is not None:
    912     attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline))
    913 
    914   compiled = kwargs.pop("compiled", None)
    915   separate_compiled_gradients = kwargs.pop("separate_compiled_gradients", None)
    916   if compiled is not None:
    917     attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=bool(compiled))
    918     attrs["_XlaSeparateCompiledGradients"] = attr_value_pb2.AttrValue(
    919         b=bool(separate_compiled_gradients))
    920     # Forward _XlaScope from enclosing context (if set), otherwise create new.
    921     # pylint: disable=protected-access
    922     if "_XlaScope" in ops.get_default_graph()._attr_scope_map:
    923       attrs["_XlaScope"] = ops.get_default_graph()._attr_scope_map["_XlaScope"]
    924     else:
    925       attrs["_XlaScope"] = attr_value_pb2.AttrValue(
    926           s=("function_%s" % func_name).encode())
    927     # pylint: enable=protected-access
    928 
    929   if kwargs:
    930     raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
    931   return attrs
    932 
    933 
    934 def _get_func_name(func):
    935   _, func = tf_decorator.unwrap(func)
    936   if callable(func):
    937     if tf_inspect.isfunction(func):
    938       return func.__name__
    939     elif tf_inspect.ismethod(func):
    940       return "%s.%s" % (func.__self__.__name__, func.__name__)
    941     else:  # Probably a class instance with __call__
    942       return type(func)
    943   else:
    944     raise ValueError("Argument must be callable")
    945 
    946 
    947 def get_extra_vars():
    948   """Returns the captured variables by the function.
    949 
    950   Returns:
    951     If the default graph is being used to define a function, the
    952     returned list of variables are those created inside the function
    953     body so far. Otherwise, returns an empty list.
    954   """
    955   g = ops.get_default_graph()
    956   if isinstance(g, _FuncGraph):
    957     return g.extra_vars
    958   else:
    959     return []
    960 
    961 
    962 def get_extra_inputs():
    963   """Returns the captured input tensors by the function.
    964 
    965   Returns:
    966     If the default graph is being used to define a function, the
    967     returned list of tensors are those accessed inside the function body
    968     but defined outside the function body so far. Otherwise, returns an
    969     empty list.
    970   """
    971   g = ops.get_default_graph()
    972   if isinstance(g, _FuncGraph):
    973     return g.extra_inputs
    974   else:
    975     return []
    976 
    977 
    978 def get_extra_args():
    979   """Returns the corresponding function arguments for the captured inputs.
    980 
    981   Returns:
    982     If the default graph is being used to define a function, the
    983     returned list of place holders are those used inside the function
    984     body corresponding those returned by get_extra_inputs(). Otherwise,
    985     returns an empty list.
    986   """
    987   g = ops.get_default_graph()
    988   if isinstance(g, _FuncGraph):
    989     return g.extra_args
    990   else:
    991     return []
    992 
    993 
    994 def _type_list_to_str(types):
    995   if any([_ not in _DTYPE_TO_STR for _ in types]):
    996     raise ValueError("Unsupported dtypes: %s" % types)
    997   return "".join([_DTYPE_TO_STR[_] for _ in types])
    998 
    999 
   1000 # NOTE: The list needs to be extended when more data types are added.
   1001 _DTYPE_TO_STR = {
   1002     dtypes.float16: "f16",
   1003     dtypes.float32: "f32",
   1004     dtypes.float64: "f64",
   1005     dtypes.int32: "i32",
   1006     dtypes.uint8: "i8",
   1007     dtypes.uint16: "u16",
   1008     dtypes.uint32: "u32",
   1009     dtypes.uint64: "u64",
   1010     dtypes.int16: "i16",
   1011     dtypes.int8: "i8",
   1012     dtypes.string: "s",
   1013     dtypes.complex64: "c64",
   1014     dtypes.complex128: "c128",
   1015     dtypes.int64: "i64",
   1016     dtypes.bool: "b",
   1017     dtypes.qint8: "qi8",
   1018     dtypes.quint8: "qu8",
   1019     dtypes.qint16: "qi16",
   1020     dtypes.quint16: "qu16",
   1021     dtypes.qint32: "qi32",
   1022     dtypes.bfloat16: "b16"
   1023 }
   1024