Home | History | Annotate | Download | only in eager
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Code for backpropagation using the tape utilities."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import functools
     23 import operator
     24 import threading
     25 
     26 import six
     27 
     28 from tensorflow.python import pywrap_tensorflow
     29 from tensorflow.python.eager import context
     30 from tensorflow.python.eager import execute
     31 from tensorflow.python.eager import imperative_grad
     32 from tensorflow.python.eager import tape
     33 from tensorflow.python.framework import constant_op
     34 from tensorflow.python.framework import dtypes
     35 from tensorflow.python.framework import errors
     36 from tensorflow.python.framework import ops
     37 from tensorflow.python.framework import tensor_shape
     38 from tensorflow.python.ops import array_ops
     39 from tensorflow.python.ops import gen_array_ops
     40 from tensorflow.python.ops import math_ops
     41 from tensorflow.python.ops import resource_variable_ops
     42 from tensorflow.python.util import nest
     43 from tensorflow.python.util import tf_inspect
     44 
     45 
     46 class _TensorCache(object):
     47   """Simple cache which evicts items based on length in a FIFO manner."""
     48 
     49   def __init__(self, max_items=256):
     50     self._data = collections.OrderedDict()
     51     self._max_items = max_items if max_items else 256
     52 
     53   def put(self, key, value):
     54     self._data[key] = value
     55 
     56     if len(self._data) > self._max_items:
     57       self._data.popitem(last=False)
     58 
     59   def get(self, key):
     60     return self._data.get(key, None)
     61 
     62   def flush(self):
     63     self._data = {}
     64 
     65 
     66 _op_attr_type_cache = {}
     67 
     68 
     69 def op_attr_type(op_type, attr_name):
     70   try:
     71     return _op_attr_type_cache[(op_type, attr_name)]
     72   except KeyError:
     73     with errors.raise_exception_on_not_ok_status() as status:
     74       h = context.context()._handle  # pylint: disable=protected-access
     75       attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType(
     76           h, op_type, attr_name, status)
     77     _op_attr_type_cache[(op_type, attr_name)] = attr_type
     78     return attr_type
     79 
     80 
     81 def make_attr(attr_type, value):
     82   if attr_type == pywrap_tensorflow.TF_ATTR_TYPE:
     83     return dtypes.as_dtype(value)
     84   elif attr_type == [pywrap_tensorflow.TF_ATTR_TYPE]:
     85     return [dtypes.as_dtype(v) for v in value]
     86   elif attr_type == pywrap_tensorflow.TF_ATTR_SHAPE:
     87     return tensor_shape.as_shape(value).as_proto()
     88   elif attr_type == [pywrap_tensorflow.TF_ATTR_SHAPE]:
     89     return [tensor_shape.as_shape(v).as_proto() for v in value]
     90   return value
     91 
     92 
     93 class _MockOp(object):
     94   """Pretends to be a tf.Operation for the gradient functions."""
     95 
     96   def __init__(self, attrs, inputs, outputs, typ):
     97     self.attrs = attrs
     98     self.inputs = inputs
     99     self.outputs = outputs
    100     self.type = typ
    101 
    102   def get_attr(self, attr):
    103     typ = op_attr_type(self.type, attr)
    104     for i in range(0, len(self.attrs), 2):
    105       if self.attrs[i] == attr:
    106         return make_attr(typ, self.attrs[i + 1])
    107     raise KeyError(attr)
    108 
    109 
    110 def _magic_gradient_function(op_name, attr_tuple, num_inputs,
    111                              inputs, outputs, out_grads):
    112   """Calls the gradient function of the op.
    113 
    114   Args:
    115     op_name: the name of the op to be differentiated.
    116     attr_tuple: the attrs, as a tuple.
    117     num_inputs: the number of inputs to the op.
    118     inputs: inputs to the original operation.
    119     outputs: outputs to the original operation.
    120     out_grads: gradients of the operation wrt its outputs.
    121 
    122   Returns:
    123     The gradients with respect to the inputs of the function, as a list.
    124   """
    125   mock_op = _MockOp(attr_tuple, inputs, outputs, op_name)
    126   grad_fn = ops._gradient_registry.lookup(op_name)  # pylint: disable=protected-access
    127   if grad_fn is None:
    128     return [None] * num_inputs
    129 
    130   return grad_fn(mock_op, *out_grads)
    131 
    132 
    133 _gradient_functions = {}
    134 _gradient_functions_lock = threading.Lock()
    135 
    136 
    137 _tracing = False
    138 
    139 
    140 # TODO(apassos) replace this with a mechanism which can happen at the op
    141 # gradient function registration site, to be less error-prone
    142 # TODO(apassos) add ops other than those in nn_grad and math_grad
    143 _ops_which_dont_need_outputs = set([
    144     "Identity",
    145     "MatMul",
    146     "Conv2DBackpropInput",
    147     "Conv2DBackpropFilter",
    148     "Conv3D",
    149     "Conv3DBackpropInputV2",
    150     "AvgPool3D",
    151     "AvgPool3DGrad",
    152     "MaxPool3D",
    153     "MaxPool3DGrad",
    154     "MaxPool3DGradGrad",
    155     "BiasAdd",
    156     "BiasAddV1",
    157     "BiasAddGrad",
    158     "Relu6",
    159     "Softplus",
    160     "SoftplusGrad",
    161     "Softsign",
    162     "ReluGrad",
    163     "Conv2D",
    164     "DepthwiseConv2dNative",
    165     "Dilation2D",
    166     "AvgPool",
    167     "AvgPoolGrad",
    168     "BatchNormWithGlobalNormalization",
    169     "L2Loss",
    170     "Sum",
    171     "Prod",
    172     "SegmentSum",
    173     "SegmentMean",
    174     "SparseSegmentSum",
    175     "SparseSegmentMean",
    176     "SparseSegmentSqrtN",
    177     "SegmentMin",
    178     "SegmentMax",
    179     "UnsortedSegmentSum",
    180     "UnsortedSegmentMax",
    181     "UnsortedSegmentMin",
    182     "UnsortedSegmentProd",
    183     "Abs",
    184     "Neg",
    185     "ReciprocalGrad",
    186     "Square",
    187     "Expm1",
    188     "Log",
    189     "Log1p",
    190     "TanhGrad",
    191     "SigmoidGrad",
    192     "Sign",
    193     "Sin",
    194     "Cos",
    195     "Tan",
    196     "Add",
    197     "Sub",
    198     "Mul",
    199     "Div",
    200     "RealDiv",
    201     "Maximum",
    202     "Minimum",
    203     "SquaredDifference",
    204     "Select",
    205     "SparseMatMul",
    206     "BatchMatMul",
    207     "Complex",
    208     "Real",
    209     "Imag",
    210     "Angle",
    211     "Conj",
    212     "Cast",
    213     "Cross",
    214     "Cumsum",
    215     "Cumprod",
    216     "ReadVariableOp",
    217     "VarHandleOp",
    218     "Shape",
    219 ])
    220 
    221 _ops_which_dont_need_inputs = set([
    222     "Identity",
    223     "Softmax",
    224     "LogSoftmax",
    225     "BiasAdd",
    226     "Relu",
    227     "Elu",
    228     "Selu",
    229     "SparseSoftmaxCrossEntropyWithLogits",
    230     "Neg",
    231     "Inv",
    232     "Reciprocal",
    233     "Sqrt",
    234     "Exp",
    235     "Tanh",
    236     "Sigmoid",
    237     "Real",
    238     "Imag",
    239     "Conj",
    240     "ReadVariableOp",
    241     "VarHandleOp",
    242     "Shape",
    243 ])
    244 
    245 
    246 # TODO(agarwal): use an automatic mechanism for handling None arguments to
    247 # gradient functions.
    248 # Some gradient functions can accept None arguments for gradients. The following
    249 # maps the operation name to the indices at which the corresponding gradient
    250 # function can accept None values.
    251 # e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values
    252 # during backprop. However the gradient function uses only the first of those
    253 # values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4],
    254 # indicates that only the gradient corresponding to index 0 is used, and the
    255 # gradient values at indices 1-4 are ignored (and hence can be None). The
    256 # backprop algorithm can then leverage this by not constructing zeros to
    257 # pass for those indices.
    258 _grad_fn_accepts_none_for_indices = {
    259     "SoftmaxCrossEntropyWithLogits": [1],
    260     "FusedBatchNorm": [1, 2, 3, 4]
    261 }
    262 
    263 
    264 def _record_gradient(op_name, inputs, attrs, results, name):
    265   """Records gradients for a TensorFlow operation.
    266 
    267   Args:
    268     op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
    269       execute.
    270     inputs: A flat list of Tensor object inputs to the operation.
    271     attrs: A tuple with alternating string attr names and attr values for this
    272       operation.
    273     results: The results of the operation (as a flat list).
    274     name: Customized name for the operation.
    275 
    276   Returns:
    277     A list of maybe-wrapped results. Either Tensors or TensorNodes.
    278 
    279   Raises:
    280     An exception on error.
    281   """
    282   if not tape.could_possibly_record():
    283     return
    284 
    285   if op_name in _ops_which_dont_need_outputs:
    286     op_outputs = None
    287   else:
    288     # TODO(apassos) this line creates a weak circular reference where the
    289     # backprop function keeps an output alive which in turn keeps the tape entry
    290     # alive which keeps the backprop function alive. Figure out how to break
    291     # this up without breaking second derivatives of ops like Exp whose
    292     # gradients depend only on the outputs.
    293     op_outputs = results
    294 
    295   if op_name in _ops_which_dont_need_inputs:
    296     op_inputs = None
    297   else:
    298     op_inputs = inputs
    299 
    300   num_inputs = len(inputs)
    301 
    302   def grad_fn(*orig_outputs):
    303     """Generated gradient function."""
    304     result = _magic_gradient_function(op_name, attrs, num_inputs,
    305                                       op_inputs, op_outputs, orig_outputs)
    306     if _tracing:
    307       print("Gradient for", (name if name else op_name), "inputs", op_inputs,
    308             "output_grads", orig_outputs, "gradients", result)
    309     return nest.flatten(result)
    310 
    311   tape.record_operation(op_name, results, inputs, grad_fn)
    312   if _tracing:
    313     print("Computed op", (name if name else op_name), "inputs", inputs,
    314           "outputs", results)
    315 
    316 
    317 execute.record_gradient = _record_gradient
    318 
    319 
    320 def implicit_val_and_grad(f):
    321   """Returns a function which differentiates f with respect to variables.
    322 
    323   The wrapped function returns the value and the gradient of f when called with
    324   the same arguments. The gradient is with respect to all TFE variables which
    325   have `variable.watch()` called on them by f.
    326 
    327   This function is useful when the exact set of variables to differentiate with
    328   is not known ahead of time.
    329 
    330   Example:
    331 
    332   ```python
    333   dense_layer = tf.layers.Dense(1)
    334   def loss(x, y):
    335     return tf.reduce_sum(tf.square(dense_layer(x) - y))
    336 
    337   # Obtain the gradient function.
    338   val_grad_fn = tfe.implicit_value_and_gradients(loss)
    339 
    340   # Invoke the gradient function with concrete values of x and y.
    341   x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
    342   y = tf.constant([[10.0], [20.0]])
    343   value, grads_and_vars = val_grad_fn(x, y)
    344   print('Value of loss: %s' % value)
    345 
    346   # Apply the gradients to Variables.
    347   optimizer = tf.train.GradientDescentOptimizer(0.1)
    348   optimizer.apply_gradients(grads_and_vars)
    349   ```
    350 
    351   Args:
    352    f: function to be differentiated. If `f` returns a scalar, this scalar will
    353      be differentiated. If `f` returns a tensor or list of tensors, by default
    354      a scalar will be computed by adding all their values to produce a single
    355      scalar.
    356 
    357   Returns:
    358     A function which, when called, returns a tuple pair.
    359     Its first element is the value to which the function evaluates.
    360     Its second element is list of (gradient, variable) pairs.
    361 
    362   Raises:
    363     ValueError: if `f` returns None.
    364   """
    365   # TODO(cais): Remove calls to tf.constant() once the gradients functions
    366   # accept lists and np.ndarrays.
    367 
    368   def grad_fn(*args):
    369     """Computes the gradient of the wrapped function."""
    370     this_tape = tape.push_new_tape()
    371     try:
    372       end_node = f(*args)
    373       if end_node is None:
    374         raise ValueError("Cannot differentiate a function that returns None; "
    375                          "did you forget to return a value from {}?".format(
    376                              f.__name__))
    377     finally:
    378       tape.pop_tape(this_tape)
    379     # Sorting variables by id, which is monotonically increasing in construction
    380     # order. This ensures unique order across executions.
    381     variables = list(sorted(this_tape.watched_variables(),
    382                             key=lambda v: v.handle._id))  # pylint: disable=protected-access
    383     sources = [x.handle for x in variables]
    384 
    385     if not sources:
    386       raise ValueError("No trainable variables were accessed while the "
    387                        "function was being computed.")
    388     grad = imperative_grad.imperative_grad(_default_vspace,
    389                                            this_tape,
    390                                            nest.flatten(end_node),
    391                                            sources)
    392     return end_node, list(zip(grad, variables))
    393 
    394   return grad_fn
    395 
    396 
    397 def implicit_grad(f):
    398   """Returns a function which differentiates f with respect to variables.
    399 
    400   The wrapped function returns the gradient of f when called with the same
    401   arguments. The gradient is with respect to all TFE variables which have
    402   `variable.watch()` called on them by f.
    403 
    404   This function is useful when the exact set of variables to differentiate with
    405   is not known ahead of time.
    406 
    407   Example:
    408 
    409   ```python
    410   dense_layer = tf.layers.Dense(1)
    411   def loss(x, y):
    412     return tf.reduce_sum(tf.square(dense_layer(x) - y))
    413 
    414   # Obtain the gradient function.
    415   grad_fn = tfe.implicit_gradients(loss)
    416 
    417   # Invoke the gradient function with concrete values of x and y.
    418   x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
    419   y = tf.constant([[10.0], [20.0]])
    420   grads_and_vars = grad_fn(x, y)
    421 
    422   # Apply the gradients to Variables.
    423   optimizer = tf.train.GradientDescentOptimizer(0.1)
    424   optimizer.apply_gradients(grads_and_vars)
    425   ```
    426 
    427   Args:
    428    f: function to be differentiated. If `f` returns a scalar, this scalar will
    429      be differentiated. If `f` returns a tensor or list of tensors, by default
    430      a scalar will be computed by adding all their values to produce a single
    431      scalar.
    432 
    433   Returns:
    434     A function which, when called, returns a list of (gradient, variable) pairs.
    435   """
    436   # TODO(cais): Remove calls to tf.constant() once the gradients functions
    437   # accept lists and np.ndarrays.
    438 
    439   def grad_fn(*args, **kwds):
    440     """Computes the gradient of the wrapped function."""
    441     return implicit_val_and_grad(f)(*args, **kwds)[1]
    442 
    443   return grad_fn
    444 
    445 
    446 def _get_arg_spec(f, params, param_args):
    447   """The positions of the parameters of f to be differentiated in param_args."""
    448   try:
    449     args = tf_inspect.getargspec(f).args
    450   except TypeError as e:
    451     # TypeError can happen when f is a callable object.
    452     if params is None:
    453       return range(len(param_args))
    454     elif all(isinstance(x, int) for x in params):
    455       return params
    456     raise ValueError("Either callable provided is not a function or could not "
    457                      "inspect its arguments by name: %s. Original error: %s"
    458                      % (f, e))
    459   if params is None:
    460     if not args:
    461       return range(len(param_args))
    462     return range(len(args))
    463   elif all(isinstance(x, six.string_types) for x in params):
    464     return [args.index(n) for n in params]
    465   elif all(isinstance(x, int) for x in params):
    466     return params
    467   else:
    468     raise ValueError(
    469         "params must be all strings or all integers; got %s." % params)
    470 
    471 
    472 def gradients_function(f, params=None):
    473   """Returns a function which differentiates f with respect to params.
    474 
    475   Example:
    476   ```python
    477   # f(x, y) = (x ^ 3) * y - x * (y ^ 2)
    478   # Therefore, the 1st order derivatives are:
    479   #   df / dx = 3 * (x ^ 2) * y - y ^ 2
    480   #   df / dy = x ^ 3 - 2 * x * y
    481   # The 2nd order derivatives with respect to x is:
    482   #   d^2 f / (dx)^2 = 6 * x * y
    483   def f(x, y):
    484     return x * x * x * y - x * y * y
    485 
    486   # Obtain a function that returns 1st order gradients.
    487   grad_fn = tfe.gradients_function(f)
    488 
    489   x = 2.0
    490   y = 3.0
    491 
    492   # Invoke the 1st order gradient function.
    493   x_grad, y_grad = grad_fn(x, y)
    494   assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2
    495   assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
    496 
    497   # Obtain a function that returns the 2nd order gradient with respect to x.
    498   gradgrad_fn = tfe.gradients_function(lambda x, y: grad_fn(x, y)[0])
    499 
    500   # Invoke the 2nd order gradient function.
    501   x_gradgrad = gradgrad_fn(x, y)[0]
    502   assert x_gradgrad.numpy() == 6 * 2 * 3
    503 
    504   # To obtain a callable that returns the gradient(s) of `f` with respect to a
    505   # subset of its inputs, use the `params` keyword argument with
    506   # `gradients_function()`.
    507   ygrad_fn = tfe.gradients_function(f, params=[1])
    508 
    509   (y_grad,) = ygrad_fn(x, y)
    510   assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
    511   ```
    512 
    513   Args:
    514    f: function to be differentiated. If `f` returns a scalar, this scalar will
    515      be differentiated. If `f` returns a tensor or list of tensors, by default
    516      a scalar will be computed by adding all their values to produce a single
    517      scalar. If desired, the tensors can be elementwise multiplied by the
    518      tensors passed as the `dy` keyword argument to the returned gradient
    519      function.
    520    params: list of parameter names of f or list of integers indexing the
    521      parameters with respect to which we'll differentiate. Passing None
    522      differentiates with respect to all parameters.
    523 
    524   Returns:
    525     function which, when called, returns the value of f and the gradient
    526     of f with respect to all of `params`. The function takes an extra optional
    527     keyword argument "dy". Setting it allows computation of vector jacobian
    528     products for vectors other than the vector of ones.
    529 
    530   Raises:
    531    ValueError: if the params are not all strings or all integers.
    532   """
    533 
    534   def decorated(*args, **kwds):
    535     """Computes the gradient of the decorated function."""
    536 
    537     _, grad = val_and_grad_function(f, params=params)(*args, **kwds)
    538     return grad
    539 
    540   return decorated
    541 
    542 
    543 def _ensure_unique_tensor_objects(parameter_positions, args):
    544   """Make each of the parameter_positions in args a unique ops.Tensor object.
    545 
    546   Ensure that each parameter is treated independently.
    547   For example:
    548 
    549   def f(x, y): return x * y
    550   g = gradients_function(f)
    551   one = tf.constant(1.)
    552 
    553   g(one, one) should return [1., 1.]
    554   (even though the two arguments are the same Tensor object).
    555 
    556   Args:
    557     parameter_positions: List of indices into args defining the arguments to
    558       differentiate against.
    559     args: A list of arguments to the function to be differentiated.
    560 
    561   Returns:
    562     args, possibly edited in-place.
    563   """
    564   s = set()
    565   for (i, t) in enumerate(args):
    566     if i in parameter_positions:
    567       tid = ops.tensor_id(t)
    568       if tid in s:
    569         args[i] = gen_array_ops.identity(args[i])
    570       else:
    571         s.add(tid)
    572   return args
    573 
    574 
    575 def val_and_grad_function(f, params=None):
    576   """Returns a function that computes f and its derivative w.r.t. params.
    577 
    578   Example:
    579   ```python
    580   # f(x, y) = (x ^ 3) * y - x * (y ^ 2)
    581   # Therefore, the 1st order derivatives are:
    582   #   df / dx = 3 * (x ^ 2) * y - y ^ 2
    583   #   df / dy = x ^ 3 - 2 * x * y
    584   def f(x, y):
    585     return x * x * x * y - x * y * y
    586 
    587   # Obtain a function that returns the function value and the 1st order
    588   # gradients.
    589   val_grads_fn = tfe.value_and_gradients_function(f)
    590 
    591   x = 2.0
    592   y = 3.0
    593 
    594   # Invoke the value-and-gradients function.
    595   f_val, (x_grad, y_grad) = val_grads_fn(x, y)
    596   assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2)
    597   assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2
    598   assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
    599 
    600   # To obtain a callable that returns the value of `f` and the gradient(s) of
    601   # `f` with respect to a subset of its inputs, use the `params` keyword
    602   # argument with `value_and_gradients_function()`.
    603   val_ygrad_fn = tfe.value_and_gradients_function(f, params=[1])
    604 
    605   f_val, (y_grad,) = val_ygrad_fn(x, y)
    606   assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2)
    607   assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
    608   ```
    609 
    610   Args:
    611    f: function to be differentiated. If `f` returns a scalar, this scalar will
    612      be differentiated. If `f` returns a tensor or list of tensors, by default
    613      a scalar will be computed by adding all their values to produce a single
    614      scalar. If desired, the tensors can be elementwise multiplied by the
    615      tensors passed as the `dy` keyword argument to the returned gradient
    616      function.
    617    params: list of parameter names of f or list of integers indexing the
    618      parameters with respect to which we'll differentiate. Passing `None`
    619      differentiates with respect to all parameters.
    620 
    621   Returns: function which, when called, returns the value of f and the gradient
    622    of f with respect to all of `params`. The function takes an extra optional
    623    keyword argument "dy". Setting it allows computation of vector jacobian
    624    products for vectors other than the vector of ones.
    625 
    626   Raises:
    627    ValueError: if the params are not all strings or all integers.
    628   """
    629 
    630   def decorated(*args, **kwds):
    631     """Computes the value and gradient of the decorated function."""
    632     dy = kwds.pop("dy", None)
    633     if kwds:
    634       raise ValueError("Functions to be differentiated cannot "
    635                        "receive keyword arguments.")
    636     val, vjp = make_vjp(f, params)(*args, **kwds)
    637     return val, vjp(dy=dy)
    638 
    639   return decorated
    640 
    641 
    642 def make_vjp(f, params=None):
    643   """Returns a function that computes f and is vjp w.r.t. params.
    644 
    645   The term "vjp" here is an abbreviation for vector-jacobian product.
    646 
    647   Args:
    648     f: the function to be differentiated.
    649     params: the parameters (numbers or names) to differentiate with respect to.
    650        A value of None will differentiate with respect to all parameters.
    651 
    652   Returns:
    653     A function, which when called, returns a tuple (value, vjp), where:
    654     - value is the result of calling f.
    655     - vjp is a function, which takes a vector as an argument and
    656       returns the product of that vector with the Jacobian of f.
    657       Providing no argument to vjp is equivalent to providing a
    658       vector of ones.
    659 
    660     For example,
    661     ```python
    662     def f(x):
    663       return x * x
    664 
    665     wrapped_fn = tfe.make_vjp(f)
    666     result, vjp = wrapped_fn(tf.constant(3.0))
    667     # result is 9.0
    668     vjp()  # the vjp function rturns 6.0
    669 
    670   Raises:
    671     ValueError: if `f` returns None.
    672   """
    673 
    674   def decorated(*args, **kwds):
    675     """Computes the value and gradient of the decorated function."""
    676     parameter_positions = _get_arg_spec(f, params, args)
    677     assert not kwds, "The gradient function can't take keyword arguments."
    678     this_tape = tape.push_new_tape()
    679     try:
    680       sources = []
    681       args = [
    682           ops.convert_to_tensor(args[i])
    683           if i in parameter_positions else args[i]
    684           for i in range(len(args))
    685       ]
    686       args = _ensure_unique_tensor_objects(parameter_positions, args)
    687       for i in parameter_positions:
    688         sources.append(args[i])
    689         tape.watch(args[i])
    690       result = f(*args)
    691       if result is None:
    692         raise ValueError("Cannot differentiate a function that returns None; "
    693                          "did you forget to return a value from {}?".format(
    694                              f.__name__))
    695       flat_result = nest.flatten(result)
    696       flat_result = [gen_array_ops.identity(x) for x in flat_result]
    697       result = nest.pack_sequence_as(result, flat_result)
    698     finally:
    699       tape.pop_tape(this_tape)
    700     def vjp(dy=None):
    701       if dy is not None:
    702         dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
    703       return imperative_grad.imperative_grad(
    704           _default_vspace, this_tape, nest.flatten(result), sources,
    705           output_gradients=dy)
    706     return result, vjp
    707 
    708   return decorated
    709 
    710 
    711 def _aggregate_grads(gradients):
    712   """Aggregate gradients from multiple sources.
    713 
    714   Args:
    715     gradients: A list of 'Tensor' or 'IndexedSlices' gradients.
    716 
    717   Returns:
    718     If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'.
    719     Otherwise returns an aggregated 'IndexedSlices'.
    720   """
    721   assert gradients, "No gradients to aggregate"
    722 
    723   if len(gradients) == 1:
    724     return gradients[0]
    725   if all([isinstance(g, ops.Tensor) for g in gradients]):
    726     return math_ops.add_n(gradients)
    727   else:
    728     assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices))
    729                 for g in gradients])
    730     indexed_slices_list = []
    731     for grad in gradients:
    732       # TODO(xpan): Support nested IndexedSlices and core IndexedSlices
    733       if isinstance(grad, ops.Tensor):
    734         indexed_slices = ops.IndexedSlices(
    735             grad,
    736             math_ops.range(grad.shape[0]),
    737             constant_op.constant(grad.shape.as_list()))
    738         indexed_slices_list.append(indexed_slices)
    739       else:
    740         indexed_slices_list.append(grad)
    741 
    742     # Dense shapes from all gradients should be the same.
    743     dense_shape = indexed_slices_list[0].dense_shape
    744     # For simplicity now, always cast to int64.
    745     indices = array_ops.concat([math_ops.cast(x.indices, dtypes.int64)
    746                                 for x in indexed_slices_list], 0)
    747     values = array_ops.concat([x.values for x in indexed_slices_list], 0)
    748     return ops.IndexedSlices(values, indices, dense_shape)
    749 
    750 
    751 def _num_elements(grad):
    752   """The number of elements in the `grad` tensor."""
    753   if isinstance(grad, ops.Tensor):
    754     return functools.reduce(operator.mul, grad._shape_tuple(), 1)  # pylint: disable=protected-access
    755   if isinstance(grad, ops.IndexedSlices):
    756     return functools.reduce(operator.mul, grad.values._shape_tuple(), 1)  # pylint: disable=protected-access
    757   raise ValueError("`grad` not a Tensor or IndexedSlices.")
    758 
    759 
    760 _zeros_cache = _TensorCache()
    761 
    762 
    763 def _fast_fill(value, shape, dtype):
    764   return array_ops.fill(shape, constant_op.constant(value, dtype=dtype))
    765 
    766 
    767 def _zeros(shape, dtype):
    768   """Wraps array_ops.zeros to cache last zero for a given shape and dtype."""
    769   device = context.context().device_name
    770   if dtype == dtypes.variant:
    771     # TODO(apassos): need to save enough information about variant tensors to do
    772     # a zeros
    773     return None
    774   cache_key = shape, dtype, device
    775   cached = _zeros_cache.get(cache_key)
    776   if cached is None:
    777     cached = _fast_fill(0, shape, dtype)
    778     _zeros_cache.put(cache_key, cached)
    779   return cached
    780 
    781 
    782 def _ones(shape, dtype):
    783   if shape == ():  # pylint: disable=g-explicit-bool-comparison
    784     return constant_op.constant(1, dtype=dtype)
    785   return _fast_fill(1, shape, dtype)
    786 
    787 
    788 _default_vspace = imperative_grad.VSpace(
    789     num_elements_fn=_num_elements,
    790     aggregate_fn=_aggregate_grads,
    791     tensor_id=ops.tensor_id,
    792     zeros=_zeros,
    793     ones=_ones)
    794 
    795 
    796 class GradientTape(object):
    797   """Records operations to use to compute gradients.
    798 
    799   Operations are recorded if:
    800     - they happen in code marked by this context manager
    801     - at least one of their inputs is being watched
    802 
    803   Outputs of recorded operations are watched. Variables are automatically
    804   watched and tensors can be manually watched by calling the watch method on the
    805   context manager.
    806 
    807   Example usage:
    808 
    809   ```python
    810   with tfe.GradientTape() as g:
    811     x = tf.constant(3.0)
    812     g.watch(x)
    813     y = x * x
    814   grad = g.gradient(y, [x])[0]
    815   assert grad.numpy() == 6.0
    816   ```
    817 
    818   It is possible to use GradientTapes to compute higher-order derivatives as
    819   follows:
    820 
    821   ```python
    822   with tfe.GradientTape() as g:
    823     x = tf.constant(3.0)
    824     g.watch(x)
    825     y = x * x
    826     with tfe.GradientTape() as gg:
    827       gg.watch(y)
    828       z = 2 * y
    829     inner_grad = gg.gradient(z, [y])[0]
    830     assert inner_grad.numpy() == 2
    831     y = y + inner_grad
    832   grad = g.gradient(y, [x])[0]
    833   assert grad.numpy() == 6.0
    834   ```
    835 
    836   By default, the resources held by a GradientTape are released as soon as
    837   GradientTape.gradient() method is called. However, if one need to compute
    838   multiple gradients over the same computation, she can create a persistent
    839   GradientTape. Persistent tapes allow multiple calls to the gradient() method
    840   and release resources when the tape object is destructed.
    841 
    842   Example usage:
    843 
    844   ```python
    845   with tfe.GradientTape(persistent=True) as g:
    846     x = tf.constant(3.0)
    847     g.watch(x)
    848     y = x * x
    849     z = y * y
    850   dz_dx = g.gradient(z, [x])[0]
    851   assert dz_dx.numpy() == 108.0   # 4*x^3 at x = 3
    852   dy_dx = g.gradient(y, [x])[0]
    853   assert dy_dx.numpy() == 6.0
    854   del g  # Drop the reference to the tape
    855   """
    856 
    857   def __init__(self, persistent=False):
    858     """Creates a new GradientTape.
    859 
    860     Args:
    861       persistent: Boolean controlling whether a persistent gradient tape
    862         is created. Must be True or False.
    863 
    864     """
    865     self._tape = None
    866     self._persistent = persistent
    867 
    868   def __enter__(self):
    869     self._tape = tape.push_new_tape(persistent=self._persistent)
    870     return self
    871 
    872   def __exit__(self, typ, value, traceback):
    873     tape.pop_tape(self._tape)
    874 
    875   def watch(self, tensor):
    876     """Ensures that `tensor` is being traced by this tape.
    877 
    878     Args:
    879       tensor: a Tensor or Variable a list of Tensors or Variables.
    880     """
    881     for t in nest.flatten(tensor):
    882       if isinstance(t, resource_variable_ops.ResourceVariable):
    883         t = t.handle
    884       tape.watch(t)
    885 
    886   def watched_variables(self):
    887     return self._tape.watched_variables()
    888 
    889   def gradient(self, target, sources, output_gradients=None):
    890     """Computes the gradient using information traced by the tape.
    891 
    892     Args:
    893       target: the tensor to be differentiated.
    894       sources: a list of Tensors or Variables, the target will be
    895        differentiated with respect to the sources.
    896       output_gradients: a list of gradients, one for each element of
    897        target. Defaults to None.
    898 
    899     Returns:
    900       a list of Tensors (or IndexedSlices, or None), one for each element in
    901       `sources`.
    902 
    903     Raises:
    904       RuntimeError: if called inside the context of the tape, or if called more
    905        than once.
    906     """
    907     if self._tape is None:
    908       raise RuntimeError("GradientTape.gradient can only be called once "
    909                          "on non-persistent tapes, and "
    910                          "only when the context manager has exited.")
    911     sources = [x.handle if isinstance(x, resource_variable_ops.ResourceVariable)
    912                else x
    913                for x in sources]
    914     grad = imperative_grad.imperative_grad(
    915         _default_vspace, self._tape, [target], sources,
    916         output_gradients=output_gradients)
    917     if not self._persistent:
    918       self._tape = None
    919     return grad
    920