Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Gradients for operators defined in math_ops.py."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.python.eager import context
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.framework import tensor_util
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import gen_array_ops
     29 from tensorflow.python.ops import gen_math_ops
     30 from tensorflow.python.ops import math_ops
     31 
     32 
     33 def _safe_shape_div(x, y):
     34   """Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`."""
     35   return x // math_ops.maximum(y, 1)
     36 
     37 
     38 @ops.RegisterGradient("Sum")
     39 def _SumGrad(op, grad):
     40   """Gradient for Sum."""
     41   # Fast path for when reducing to a scalar and ndims is known: adds only
     42   # Reshape and Tile ops (and possibly a Shape).
     43   input_0_shape = op.inputs[0]._shape_tuple()  # pylint: disable=protected-access
     44   if input_0_shape is not None:
     45     axes = tensor_util.constant_value(op.inputs[1])
     46     if axes is not None:
     47       rank = len(input_0_shape)
     48       if np.array_equal(axes, np.arange(rank)):  # Reduce all dims.
     49         grad = array_ops.reshape(grad, [1] * rank)
     50         # If shape is not fully defined (but rank is), we use Shape.
     51         if None not in input_0_shape:
     52           input_shape = input_0_shape
     53         else:
     54           input_shape = array_ops.shape(op.inputs[0])
     55         return [array_ops.tile(grad, input_shape), None]
     56 
     57   input_shape = array_ops.shape(op.inputs[0])
     58   # TODO(apassos) remove this once device placement for eager ops makes more
     59   # sense.
     60   with ops.colocate_with(input_shape):
     61     output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
     62     tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
     63   grad = array_ops.reshape(grad, output_shape_kept_dims)
     64   return [array_ops.tile(grad, tile_scaling), None]
     65 
     66 
     67 def _MinOrMaxGrad(op, grad):
     68   """Gradient for Min or Max. Amazingly it's precisely the same code."""
     69   input_shape = array_ops.shape(op.inputs[0])
     70   output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
     71   y = op.outputs[0]
     72   y = array_ops.reshape(y, output_shape_kept_dims)
     73   grad = array_ops.reshape(grad, output_shape_kept_dims)
     74 
     75   # Compute the number of selected (maximum or minimum) elements in each
     76   # reduction dimension. If there are multiple minimum or maximum elements
     77   # then the gradient will be divided between them.
     78   indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype)
     79   num_selected = array_ops.reshape(
     80       math_ops.reduce_sum(indicators, op.inputs[1]), output_shape_kept_dims)
     81 
     82   return [math_ops.div(indicators, num_selected) * grad, None]
     83 
     84 
     85 @ops.RegisterGradient("Max")
     86 def _MaxGrad(op, grad):
     87   """Gradient for Max."""
     88   return _MinOrMaxGrad(op, grad)
     89 
     90 
     91 @ops.RegisterGradient("Min")
     92 def _MinGrad(op, grad):
     93   return _MinOrMaxGrad(op, grad)
     94 
     95 
     96 @ops.RegisterGradient("Mean")
     97 def _MeanGrad(op, grad):
     98   """Gradient for Mean."""
     99   sum_grad = _SumGrad(op, grad)[0]
    100   input_shape = op.inputs[0]._shape_tuple()  # pylint: disable=protected-access
    101   output_shape = op.outputs[0]._shape_tuple()  # pylint: disable=protected-access
    102   if (input_shape is not None and output_shape is not None and
    103       None not in input_shape and None not in output_shape):
    104     input_size = np.prod(input_shape)
    105     output_size = np.prod(output_shape)
    106     factor = input_size // max(output_size, 1)
    107     factor = constant_op.constant(factor, dtype=sum_grad.dtype)
    108   else:
    109     input_shape = array_ops.shape(op.inputs[0])
    110     output_shape = array_ops.shape(op.outputs[0])
    111     factor = _safe_shape_div(
    112         math_ops.reduce_prod(input_shape), math_ops.reduce_prod(output_shape))
    113   return math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), None
    114 
    115 
    116 @ops.RegisterGradient("Prod")
    117 def _ProdGrad(op, grad):
    118   """Gradient for Prod."""
    119   # The gradient can be expressed by dividing the product by each entry of the
    120   # input tensor, but this approach can't deal with zeros in the input.
    121   # Here, we avoid this problem by composing the output as a product of two
    122   # cumprod operations.
    123 
    124   input_shape = array_ops.shape(op.inputs[0])
    125   # Reshape reduction indices for the case where the parameter is a scalar
    126   reduction_indices = array_ops.reshape(op.inputs[1], [-1])
    127 
    128   # Expand grad to full input shape
    129   output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
    130   tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
    131   grad = array_ops.reshape(grad, output_shape_kept_dims)
    132   grad = array_ops.tile(grad, tile_scaling)
    133 
    134   # Pack all reduced dimensions into a single one, so we can perform the
    135   # cumprod ops. If the reduction dims list is empty, it defaults to float32,
    136   # so we need to cast here.  We put all the shape-related ops on CPU to avoid
    137   # copying back and forth, and since listdiff is CPU only.
    138   with ops.device("/cpu:0"):
    139     rank = array_ops.rank(op.inputs[0])
    140     reduction_indices = (reduction_indices + rank) % rank
    141     reduced = math_ops.cast(reduction_indices, dtypes.int32)
    142     idx = math_ops.range(0, rank)
    143     other, _ = array_ops.setdiff1d(idx, reduced)
    144     perm = array_ops.concat([reduced, other], 0)
    145     reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced))
    146     other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other))
    147   permuted = array_ops.transpose(op.inputs[0], perm)
    148   permuted_shape = array_ops.shape(permuted)
    149   reshaped = array_ops.reshape(permuted, (reduced_num, other_num))
    150 
    151   # Calculate product, leaving out the current entry
    152   left = math_ops.cumprod(reshaped, axis=0, exclusive=True)
    153   right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True)
    154   y = array_ops.reshape(left * right, permuted_shape)
    155 
    156   # Invert the transpose and reshape operations.
    157   # Make sure to set the statically known shape information through a reshape.
    158   out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm))
    159   return array_ops.reshape(out, input_shape), None
    160 
    161 
    162 @ops.RegisterGradient("SegmentSum")
    163 def _SegmentSumGrad(op, grad):
    164   """Gradient for SegmentSum."""
    165   return array_ops.gather(grad, op.inputs[1]), None
    166 
    167 
    168 @ops.RegisterGradient("SegmentMean")
    169 def _SegmentMeanGrad(op, grad):
    170   """Gradient for SegmentMean."""
    171   input_rank = array_ops.rank(op.inputs[0])
    172   ones_shape = array_ops.concat([
    173       array_ops.shape(op.inputs[1]),
    174       array_ops.fill(array_ops.expand_dims(input_rank - 1, 0), 1)
    175   ], 0)
    176   ones = array_ops.fill(ones_shape, constant_op.constant(1, dtype=grad.dtype))
    177   scaled_grad = math_ops.div(grad, math_ops.segment_sum(ones, op.inputs[1]))
    178   return array_ops.gather(scaled_grad, op.inputs[1]), None
    179 
    180 
    181 @ops.RegisterGradient("SparseSegmentSum")
    182 def _SparseSegmentSumGrad(op, grad):
    183   """Gradient for SparseSegmentSum."""
    184   input_rows = array_ops.shape(op.inputs[0])[0]
    185   return (math_ops.unsorted_segment_sum(
    186       array_ops.gather(grad, op.inputs[2]), op.inputs[1], input_rows), None,
    187           None)
    188 
    189 
    190 @ops.RegisterGradient("SparseSegmentSumWithNumSegments")
    191 def _SparseSegmentSumWithNumSegmentsGrad(op, grad):
    192   """Gradient for SparseSegmentSumWithNumSegments."""
    193   input_rows = array_ops.shape(op.inputs[0])[0]
    194   return (math_ops.unsorted_segment_sum(
    195       array_ops.gather(grad, op.inputs[2]), op.inputs[1], input_rows), None,
    196           None, None)
    197 
    198 
    199 @ops.RegisterGradient("SparseSegmentMean")
    200 def _SparseSegmentMeanGrad(op, grad):
    201   """Gradient for SparseSegmentMean."""
    202   dim0 = array_ops.shape(op.inputs[0])[0]
    203   return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2],
    204                                             dim0), None, None)
    205 
    206 
    207 @ops.RegisterGradient("SparseSegmentMeanWithNumSegments")
    208 def _SparseSegmentMeanWithNumSegmentsGrad(op, grad):
    209   """Gradient for SparseSegmentMeanWithNumSegments."""
    210   dim0 = array_ops.shape(op.inputs[0])[0]
    211   return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2],
    212                                             dim0), None, None, None)
    213 
    214 
    215 @ops.RegisterGradient("SparseSegmentSqrtN")
    216 def _SparseSegmentSqrtNGrad(op, grad):
    217   """Gradient for SparseSegmentSqrtN."""
    218   dim0 = array_ops.shape(op.inputs[0])[0]
    219   return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2],
    220                                               dim0), None, None)
    221 
    222 
    223 @ops.RegisterGradient("SparseSegmentSqrtNWithNumSegments")
    224 def _SparseSegmentSqrtNWithNumSegmentsGrad(op, grad):
    225   """Gradient for SparseSegmentSqrtNWithNumSegments."""
    226   dim0 = array_ops.shape(op.inputs[0])[0]
    227   return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2],
    228                                               dim0), None, None, None)
    229 
    230 
    231 def _SegmentMinOrMaxGrad(op, grad):
    232   """ Gradient for SegmentMin and SegmentMax. """
    233   zeros = array_ops.zeros_like(op.inputs[0], dtype=op.inputs[0].dtype)
    234   # Get the number of selected (minimum or maximum) elements in each segment.
    235   gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
    236   is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
    237   num_selected = math_ops.segment_sum(math_ops.cast(is_selected, grad.dtype),
    238                                       op.inputs[1])
    239   # Compute the gradient for each segment. The gradient for the ith segment is
    240   # divided evenly among the selected elements in that segment.
    241   weighted_grads = math_ops.div(grad, num_selected)
    242   gathered_grads = array_ops.gather(weighted_grads, op.inputs[1])
    243   return array_ops.where(is_selected, gathered_grads, zeros), None
    244 
    245 
    246 @ops.RegisterGradient("SegmentMin")
    247 def _SegmentMinGrad(op, grad):
    248   """Gradient for SegmentMin."""
    249   return _SegmentMinOrMaxGrad(op, grad)
    250 
    251 
    252 @ops.RegisterGradient("SegmentMax")
    253 def _SegmentMaxGrad(op, grad):
    254   """Gradient for SegmentMax."""
    255   return _SegmentMinOrMaxGrad(op, grad)
    256 
    257 
    258 def _GatherDropNegatives(params, ids, zero_clipped_indices=None,
    259                          is_positive=None):
    260   """ Helper function for unsorted segment ops. Gathers params for
    261       positive segment ids and gathers 0 for inputs with negative segment id.
    262       Also returns the clipped indices and a boolean mask with the same shape
    263       as ids where a positive id is masked as true. With this, the latter two
    264       can be passed as arguments to this function to reuse them.
    265   """
    266   if zero_clipped_indices is None:
    267     zero_clipped_indices = math_ops.maximum(ids, array_ops.zeros_like(ids))
    268   gathered = array_ops.gather(params, zero_clipped_indices)
    269   if is_positive is None:
    270     is_positive = math_ops.greater_equal(ids, 0)
    271     # tf.where(condition, x, y) requires condition to have the same shape as x
    272     # and y.
    273     # todo(philjd): remove this if tf.where supports broadcasting (#9284)
    274     for _ in range(gathered.shape.ndims - is_positive.shape.ndims):
    275       is_positive = array_ops.expand_dims(is_positive, -1)
    276     is_positive = (is_positive &
    277                    array_ops.ones_like(gathered, dtype=dtypes.bool))
    278   # replace gathered params of negative indices with 0
    279   zero_slice = array_ops.zeros_like(gathered)
    280   return (array_ops.where(is_positive, gathered, zero_slice),
    281           zero_clipped_indices, is_positive)
    282 
    283 
    284 def _UnsortedSegmentMinOrMaxGrad(op, grad):
    285   """ Gradient for UnsortedSegmentMin and UnsortedSegmentMax. """
    286   # Get the number of selected (minimum or maximum) elements in each segment.
    287   gathered_outputs, zero_clipped_indices, is_positive = \
    288       _GatherDropNegatives(op.outputs[0], op.inputs[1])
    289   is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
    290   is_selected = math_ops.logical_and(is_selected, is_positive)
    291   num_selected = math_ops.unsorted_segment_sum(
    292       math_ops.cast(is_selected, grad.dtype), op.inputs[1], op.inputs[2])
    293   # Compute the gradient for each segment. The gradient for the ith segment is
    294   # divided evenly among the selected elements in that segment.
    295   weighted_grads = math_ops.div(grad, num_selected)
    296   gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
    297                                               zero_clipped_indices,
    298                                               is_positive)
    299   zeros = array_ops.zeros_like(gathered_grads)
    300   return array_ops.where(is_selected, gathered_grads, zeros), None, None
    301 
    302 
    303 @ops.RegisterGradient("UnsortedSegmentSum")
    304 def _UnsortedSegmentSumGrad(op, grad):
    305   """Gradient for UnsortedSegmentSum."""
    306   return _GatherDropNegatives(grad, op.inputs[1])[0], None, None
    307 
    308 
    309 @ops.RegisterGradient("UnsortedSegmentMax")
    310 def _UnsortedSegmentMaxGrad(op, grad):
    311   """ Gradient for UnsortedSegmentMax. """
    312   return _UnsortedSegmentMinOrMaxGrad(op, grad)
    313 
    314 
    315 @ops.RegisterGradient("UnsortedSegmentMin")
    316 def _UnsortedSegmentMinGrad(op, grad):
    317   """ Gradient for UnsortedSegmentMin. """
    318   return _UnsortedSegmentMinOrMaxGrad(op, grad)
    319 
    320 
    321 @ops.RegisterGradient("UnsortedSegmentProd")
    322 def _UnsortedSegmentProdGrad(op, grad):
    323   """ Gradient for UnsortedSegmentProd.
    324   The gradient can be expressed for each segment by dividing the segment's
    325   product by each element of the segment input tensor, but this approach can't
    326   deal with zeros in the input.
    327   Unlike reduce_prod we can't use cumsum here as individual segments may have
    328   a different number of elements. Therefore we consider three cases:
    329   1) A segment input contains no zeros and we can safely divide by the input
    330      tensor.
    331   2) A segment contains exactly one zero. Then the gradient of each input of
    332      the segment is zero except for the 0-input, there the gradient is
    333      the product of the remaining segment entries.
    334   3) A segment contains at least two zeros. The gradient is zero for all
    335      segment inputs.
    336   """
    337   # Note that unsorted_segment_sum will filter out the negative indices,
    338   # so we don't need to do a logical_and with is_positive here
    339   is_zero = math_ops.equal(op.inputs[0], 0)
    340   num_zeros = gen_math_ops.unsorted_segment_sum(
    341       math_ops.cast(is_zero, dtype=dtypes.int32), op.inputs[1], op.inputs[2])
    342   # handle case 3 and set the gradient to 0 for segments with more than one
    343   # 0 as input
    344   grad = array_ops.where(math_ops.greater(num_zeros, 1),
    345                          array_ops.zeros_like(grad), grad)
    346   # replace all zeros with ones and compute the unsorted_segment_prod
    347   non_zero_data = array_ops.where(is_zero, array_ops.ones_like(op.inputs[0]),
    348                                   op.inputs[0])
    349   non_zero_prod = gen_math_ops.unsorted_segment_prod(
    350       non_zero_data, op.inputs[1], op.inputs[2])
    351   # clip the indices for gather to be positive
    352   zero_clipped_indices = math_ops.maximum(op.inputs[1],
    353                                           array_ops.zeros_like(op.inputs[1]))
    354   gathered_prod = array_ops.gather(op.outputs[0], zero_clipped_indices)
    355   gathered_non_zero_prod = array_ops.gather(non_zero_prod,
    356                                             zero_clipped_indices)
    357   prod_divided_by_el = gathered_prod / op.inputs[0]  # May contain nan/inf.
    358   # Now fetch the individual results for segments containing 0 and those that
    359   # don't. is_zero will also fetch results for entries with negative index
    360   # but the following gather_drop_negatives sets the corresponding entry in
    361   # grad to 0 for these
    362   partial_derivative = array_ops.where(is_zero, gathered_non_zero_prod,
    363                                        prod_divided_by_el)
    364   gathered_grad = _GatherDropNegatives(grad, op.inputs[1],
    365                                        zero_clipped_indices)[0]
    366   return gathered_grad * partial_derivative, None, None
    367 
    368 
    369 @ops.RegisterGradient("Abs")
    370 def _AbsGrad(op, grad):
    371   x = op.inputs[0]
    372   return grad * math_ops.sign(x)
    373 
    374 
    375 @ops.RegisterGradient("Neg")
    376 def _NegGrad(_, grad):
    377   """Returns -grad."""
    378   return -grad
    379 
    380 
    381 @ops.RegisterGradient("Inv")
    382 def _InvGrad(op, grad):
    383   """Returns -grad * (1 / x^2)."""
    384   y = op.outputs[0]  # y = 1 / x
    385   # pylint: disable=protected-access
    386   return gen_math_ops._reciprocal_grad(y, grad)
    387 
    388 
    389 @ops.RegisterGradient("Reciprocal")
    390 def _ReciprocalGrad(op, grad):
    391   """Returns -grad * (1 / x^2)."""
    392   y = op.outputs[0]  # y = 1 / x
    393   # pylint: disable=protected-access
    394   return gen_math_ops._reciprocal_grad(y, grad)
    395 
    396 
    397 @ops.RegisterGradient("InvGrad")
    398 def _InvGradGrad(op, grad):
    399   b = op.inputs[1]
    400   # op.output[0]: y = -b * conj(a)^2
    401   with ops.control_dependencies([grad]):
    402     ca = math_ops.conj(op.inputs[0])
    403     cg = math_ops.conj(grad)
    404     # pylint: disable=protected-access
    405     return cg * -2.0 * b * ca, gen_math_ops._reciprocal_grad(ca, grad)
    406 
    407 
    408 @ops.RegisterGradient("ReciprocalGrad")
    409 def _ReciprocalGradGrad(op, grad):
    410   b = op.inputs[1]
    411   # op.output[0]: y = -b * conj(a)^2
    412   with ops.control_dependencies([grad]):
    413     ca = math_ops.conj(op.inputs[0])
    414     cg = math_ops.conj(grad)
    415     # pylint: disable=protected-access
    416     return cg * -2.0 * b * ca, gen_math_ops._reciprocal_grad(ca, grad)
    417 
    418 
    419 @ops.RegisterGradient("Square")
    420 def _SquareGrad(op, grad):
    421   x = op.inputs[0]
    422   # Added control dependencies to prevent 2*x from being computed too early.
    423   with ops.control_dependencies([grad]):
    424     x = math_ops.conj(x)
    425     return math_ops.multiply(grad, math_ops.multiply(x, 2.0))
    426 
    427 
    428 @ops.RegisterGradient("Sqrt")
    429 def _SqrtGrad(op, grad):
    430   y = op.outputs[0]  # y = x^(1/2)
    431   # pylint: disable=protected-access
    432   return gen_math_ops._sqrt_grad(y, grad)
    433   # pylint: enable=protected-access
    434 
    435 
    436 @ops.RegisterGradient("SqrtGrad")
    437 def _SqrtGradGrad(op, grad):
    438   a = op.inputs[0]
    439   y = op.outputs[0]  # y = 0.5 * b / conj(a)
    440   with ops.control_dependencies([grad]):
    441     ga = grad / a
    442     return -math_ops.conj(ga) * y, 0.5 * ga
    443 
    444 
    445 @ops.RegisterGradient("Rsqrt")
    446 def _RsqrtGrad(op, grad):
    447   """Returns -0.5 * grad * conj(y)^3."""
    448   y = op.outputs[0]  # y = x^(-1/2)
    449   # pylint: disable=protected-access
    450   return gen_math_ops._rsqrt_grad(y, grad)
    451   # pylint: enable=protected-access
    452 
    453 
    454 @ops.RegisterGradient("RsqrtGrad")
    455 def _RsqrtGradGrad(op, grad):
    456   """Returns backprop gradient for f(a,b) = -0.5 * b * conj(a)^3."""
    457   a = op.inputs[0]  # a = x^{-1/2}
    458   b = op.inputs[1]  # backprop gradient for a
    459   with ops.control_dependencies([grad]):
    460     ca = math_ops.conj(a)
    461     cg = math_ops.conj(grad)
    462     grad_a = -1.5 * cg * b * math_ops.square(ca)
    463     # pylint: disable=protected-access
    464     grad_b = gen_math_ops._rsqrt_grad(ca, grad)
    465     return grad_a, grad_b
    466 
    467 
    468 @ops.RegisterGradient("Exp")
    469 def _ExpGrad(op, grad):
    470   """Returns grad * exp(x)."""
    471   y = op.outputs[0]  # y = e^x
    472   with ops.control_dependencies([grad]):
    473     y = math_ops.conj(y)
    474     return grad * y
    475 
    476 
    477 @ops.RegisterGradient("Expm1")
    478 def _Expm1Grad(op, grad):
    479   """Returns grad * exp(x)."""
    480   x = op.inputs[0]
    481   with ops.control_dependencies([grad]):
    482     x = math_ops.conj(x)
    483     y = math_ops.exp(x)
    484     return grad * y
    485 
    486 
    487 @ops.RegisterGradient("Log")
    488 def _LogGrad(op, grad):
    489   """Returns grad * (1/x)."""
    490   x = op.inputs[0]
    491   with ops.control_dependencies([grad]):
    492     x = math_ops.conj(x)
    493     return grad * math_ops.reciprocal(x)
    494 
    495 
    496 @ops.RegisterGradient("Log1p")
    497 def _Log1pGrad(op, grad):
    498   """Returns grad * (1/(1 + x))."""
    499   x = op.inputs[0]
    500   with ops.control_dependencies([grad]):
    501     x = math_ops.conj(x)
    502     return grad * math_ops.reciprocal(1 + x)
    503 
    504 
    505 @ops.RegisterGradient("Sinh")
    506 def _SinhGrad(op, grad):
    507   """Returns grad * cosh(x)."""
    508   x = op.inputs[0]
    509   with ops.control_dependencies([grad]):
    510     x = math_ops.conj(x)
    511     return grad * math_ops.cosh(x)
    512 
    513 
    514 @ops.RegisterGradient("Cosh")
    515 def _CoshGrad(op, grad):
    516   """Returns grad * sinh(x)."""
    517   x = op.inputs[0]
    518   with ops.control_dependencies([grad]):
    519     x = math_ops.conj(x)
    520     return grad * math_ops.sinh(x)
    521 
    522 
    523 @ops.RegisterGradient("Tanh")
    524 def _TanhGrad(op, grad):
    525   """Returns grad * (1 - tanh(x) * tanh(x))."""
    526   y = op.outputs[0]  # y = tanh(x)
    527   with ops.control_dependencies([grad]):
    528     y = math_ops.conj(y)
    529     # pylint: disable=protected-access
    530     return gen_math_ops._tanh_grad(y, grad)
    531 
    532 
    533 @ops.RegisterGradient("Asinh")
    534 def _AsinhGrad(op, grad):
    535   """Returns grad * 1/cosh(y)."""
    536   y = op.outputs[0]
    537   with ops.control_dependencies([grad]):
    538     y = math_ops.conj(y)
    539     return grad / math_ops.cosh(y)
    540 
    541 
    542 @ops.RegisterGradient("Acosh")
    543 def _AcoshGrad(op, grad):
    544   """Returns grad * 1/sinh(y)."""
    545   y = op.outputs[0]
    546   with ops.control_dependencies([grad]):
    547     y = math_ops.conj(y)
    548     return grad / math_ops.sinh(y)
    549 
    550 
    551 @ops.RegisterGradient("Atanh")
    552 def _AtanhGrad(op, grad):
    553   """Returns grad * 1/ (1 - x^2)."""
    554   x = op.inputs[0]
    555   with ops.control_dependencies([grad]):
    556     x = math_ops.conj(x)
    557     x2 = math_ops.square(x)
    558     one = constant_op.constant(1, dtype=grad.dtype)
    559     inv = math_ops.reciprocal(math_ops.subtract(one, x2))
    560     return grad * inv
    561 
    562 
    563 @ops.RegisterGradient("TanhGrad")
    564 def _TanhGradGrad(op, grad):
    565   with ops.control_dependencies([grad]):
    566     a = math_ops.conj(op.inputs[0])
    567     b = math_ops.conj(op.inputs[1])
    568     # pylint: disable=protected-access
    569     return grad * -2.0 * b * a, gen_math_ops._tanh_grad(a, grad)
    570 
    571 
    572 @ops.RegisterGradient("Erf")
    573 def _ErfGrad(op, grad):
    574   """Returns grad * 2/sqrt(pi) * exp(-x**2)."""
    575   x = op.inputs[0]
    576   two_over_root_pi = constant_op.constant(2 / np.sqrt(np.pi), dtype=grad.dtype)
    577   with ops.control_dependencies([grad]):
    578     x = math_ops.conj(x)
    579     return grad * two_over_root_pi * math_ops.exp(-math_ops.square(x))
    580 
    581 
    582 @ops.RegisterGradient("Erfc")
    583 def _ErfcGrad(op, grad):
    584   """Returns -grad * 2/sqrt(pi) * exp(-x**2)."""
    585   x = op.inputs[0]
    586   minus_two_over_root_pi = constant_op.constant(
    587       -2 / np.sqrt(np.pi), dtype=grad.dtype)
    588   with ops.control_dependencies([grad]):
    589     x = math_ops.conj(x)
    590     return grad * minus_two_over_root_pi * math_ops.exp(-math_ops.square(x))
    591 
    592 
    593 @ops.RegisterGradient("Lgamma")
    594 def _LgammaGrad(op, grad):
    595   """Returns grad * digamma(x)."""
    596   x = op.inputs[0]
    597   with ops.control_dependencies([grad]):
    598     x = math_ops.conj(x)
    599     return grad * math_ops.digamma(x)
    600 
    601 
    602 @ops.RegisterGradient("Digamma")
    603 def _DigammaGrad(op, grad):
    604   """Compute gradient of the digamma function with respect to its argument."""
    605   x = op.inputs[0]
    606   with ops.control_dependencies([grad]):
    607     x = math_ops.conj(x)
    608     return grad * math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x)
    609 
    610 
    611 @ops.RegisterGradient("Igamma")
    612 def _IgammaGrad(op, grad):
    613   """Returns gradient of igamma(a, x) with respect to x."""
    614   # TODO(ebrevdo): Perhaps add the derivative w.r.t. a
    615   a = op.inputs[0]
    616   x = op.inputs[1]
    617   sa = array_ops.shape(a)
    618   sx = array_ops.shape(x)
    619   # pylint: disable=protected-access
    620   unused_ra, rx = gen_array_ops._broadcast_gradient_args(sa, sx)
    621   # pylint: enable=protected-access
    622 
    623   # Perform operations in log space before summing, because Gamma(a)
    624   # and Gamma'(a) can grow large.
    625   partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) - math_ops.lgamma(a))
    626   # TODO(b/36815900): Mark None return values as NotImplemented
    627   return (None, array_ops.reshape(
    628       math_ops.reduce_sum(partial_x * grad, rx), sx))
    629 
    630 
    631 @ops.RegisterGradient("Igammac")
    632 def _IgammacGrad(op, grad):
    633   """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. x."""
    634   _, igamma_grad_x = _IgammaGrad(op, grad)
    635   return None, -igamma_grad_x
    636 
    637 
    638 @ops.RegisterGradient("Betainc")
    639 def _BetaincGrad(op, grad):
    640   """Returns gradient of betainc(a, b, x) with respect to x."""
    641   # TODO(ebrevdo): Perhaps add the derivative w.r.t. a, b
    642   a, b, x = op.inputs
    643 
    644   # two cases: x is a scalar and a/b are same-shaped tensors, or vice
    645   # versa; so its sufficient to check against shape(a).
    646   sa = array_ops.shape(a)
    647   sx = array_ops.shape(x)
    648   # pylint: disable=protected-access
    649   _, rx = gen_array_ops._broadcast_gradient_args(sa, sx)
    650   # pylint: enable=protected-access
    651 
    652   # Perform operations in log space before summing, because terms
    653   # can grow large.
    654   log_beta = (
    655       gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) -
    656       gen_math_ops.lgamma(a + b))
    657   partial_x = math_ops.exp((b - 1) * math_ops.log(1 - x) +
    658                            (a - 1) * math_ops.log(x) - log_beta)
    659 
    660   # TODO(b/36815900): Mark None return values as NotImplemented
    661   return (
    662       None,  # da
    663       None,  # db
    664       array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
    665 
    666 
    667 @ops.RegisterGradient("Zeta")
    668 def _ZetaGrad(op, grad):
    669   """Returns gradient of zeta(x, q) with respect to x and q."""
    670   # TODO(tillahoffmann): Add derivative with respect to x
    671   x = op.inputs[0]
    672   q = op.inputs[1]
    673   # Broadcast gradients
    674   sx = array_ops.shape(x)
    675   sq = array_ops.shape(q)
    676   # pylint: disable=protected-access
    677   unused_rx, rq = gen_array_ops._broadcast_gradient_args(sx, sq)
    678   # pylint: enable=protected-access
    679   # Evaluate gradient
    680   with ops.control_dependencies([grad]):
    681     x = math_ops.conj(x)
    682     q = math_ops.conj(q)
    683     partial_q = -x * math_ops.zeta(x + 1, q)
    684     # TODO(b/36815900): Mark None return values as NotImplemented
    685     return (None,
    686             array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq))
    687 
    688 
    689 @ops.RegisterGradient("Polygamma")
    690 def _PolygammaGrad(op, grad):
    691   """Returns gradient of psi(n, x) with respect to n and x."""
    692   # TODO(tillahoffmann): Add derivative with respect to n
    693   n = op.inputs[0]
    694   x = op.inputs[1]
    695   # Broadcast gradients
    696   sn = array_ops.shape(n)
    697   sx = array_ops.shape(x)
    698   # pylint: disable=protected-access
    699   unused_rn, rx = gen_array_ops._broadcast_gradient_args(sn, sx)
    700   # pylint: enable=protected-access
    701   # Evaluate gradient
    702   with ops.control_dependencies([grad]):
    703     n = math_ops.conj(n)
    704     x = math_ops.conj(x)
    705     partial_x = math_ops.polygamma(n + 1, x)
    706     # TODO(b/36815900): Mark None return values as NotImplemented
    707     return (None,
    708             array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
    709 
    710 
    711 @ops.RegisterGradient("Sigmoid")
    712 def _SigmoidGrad(op, grad):
    713   """Returns grad * sigmoid(x) * (1 - sigmoid(x))."""
    714   y = op.outputs[0]  # y = sigmoid(x)
    715   with ops.control_dependencies([grad]):
    716     y = math_ops.conj(y)
    717     # pylint: disable=protected-access
    718     return gen_math_ops._sigmoid_grad(y, grad)
    719 
    720 
    721 @ops.RegisterGradient("SigmoidGrad")
    722 def _SigmoidGradGrad(op, grad):
    723   with ops.control_dependencies([grad]):
    724     a = math_ops.conj(op.inputs[0])
    725     b = math_ops.conj(op.inputs[1])
    726     gb = grad * b
    727     # pylint: disable=protected-access
    728     return gb - 2.0 * gb * a, gen_math_ops._sigmoid_grad(a, grad)
    729 
    730 
    731 @ops.RegisterGradient("Sign")
    732 def _SignGrad(op, _):
    733   """Returns 0."""
    734   x = op.inputs[0]
    735   return array_ops.zeros(array_ops.shape(x), dtype=x.dtype)
    736 
    737 
    738 @ops.RegisterGradient("Sin")
    739 def _SinGrad(op, grad):
    740   """Returns grad * cos(x)."""
    741   x = op.inputs[0]
    742   with ops.control_dependencies([grad]):
    743     x = math_ops.conj(x)
    744     return grad * math_ops.cos(x)
    745 
    746 
    747 @ops.RegisterGradient("Cos")
    748 def _CosGrad(op, grad):
    749   """Returns grad * -sin(x)."""
    750   x = op.inputs[0]
    751   with ops.control_dependencies([grad]):
    752     x = math_ops.conj(x)
    753     return -grad * math_ops.sin(x)
    754 
    755 
    756 @ops.RegisterGradient("Tan")
    757 def _TanGrad(op, grad):
    758   """Returns grad * 1/sec^2(x)."""
    759   x = op.inputs[0]
    760   with ops.control_dependencies([grad]):
    761     x = math_ops.conj(x)
    762     secx = math_ops.reciprocal(math_ops.cos(x))
    763     secx2 = math_ops.square(secx)
    764     return grad * secx2
    765 
    766 
    767 @ops.RegisterGradient("Asin")
    768 def _AsinGrad(op, grad):
    769   """Returns grad * 1/sqrt(1-x^2)."""
    770   x = op.inputs[0]
    771   with ops.control_dependencies([grad]):
    772     x = math_ops.conj(x)
    773     x2 = math_ops.square(x)
    774     one = constant_op.constant(1, dtype=grad.dtype)
    775     den = math_ops.sqrt(math_ops.subtract(one, x2))
    776     inv = math_ops.reciprocal(den)
    777     return grad * inv
    778 
    779 
    780 @ops.RegisterGradient("Acos")
    781 def _AcosGrad(op, grad):
    782   """Returns grad * -1/sqrt(1-x^2)."""
    783   x = op.inputs[0]
    784   with ops.control_dependencies([grad]):
    785     x = math_ops.conj(x)
    786     x2 = math_ops.square(x)
    787     one = constant_op.constant(1, dtype=grad.dtype)
    788     den = math_ops.sqrt(math_ops.subtract(one, x2))
    789     inv = math_ops.reciprocal(den)
    790     return -grad * inv
    791 
    792 
    793 @ops.RegisterGradient("Atan")
    794 def _AtanGrad(op, grad):
    795   """Returns grad * 1/ (1 + x^2)."""
    796   x = op.inputs[0]
    797   with ops.control_dependencies([grad]):
    798     x = math_ops.conj(x)
    799     x2 = math_ops.square(x)
    800     one = constant_op.constant(1, dtype=grad.dtype)
    801     inv = math_ops.reciprocal(math_ops.add(one, x2))
    802     return grad * inv
    803 
    804 
    805 @ops.RegisterGradient("Atan2")
    806 def _Atan2Grad(op, grad):
    807   """Returns grad * x / (x^2 + y^2), grad * -y / (x^2 + y^2)."""
    808   y = op.inputs[0]
    809   x = op.inputs[1]
    810   with ops.control_dependencies([grad]):
    811     grad_inv = grad / (math_ops.square(x) + math_ops.square(y))
    812     return x * grad_inv, -y * grad_inv
    813 
    814 
    815 @ops.RegisterGradient("AddN")
    816 def _AddNGrad(op, grad):
    817   """Copies the gradient to all inputs."""
    818   # Not broadcasting.
    819   return [grad] * len(op.inputs)
    820 
    821 
    822 def _ShapesFullySpecifiedAndEqual(x, y, grad):
    823   # pylint: disable=protected-access
    824   x_shape = x._shape_tuple()
    825   y_shape = y._shape_tuple()
    826   grad_shape = grad._shape_tuple()
    827   # pylint: enable=protected-access
    828   return (x_shape == y_shape and x_shape == grad_shape and
    829           x_shape is not None and None not in x_shape)
    830 
    831 
    832 @ops.RegisterGradient("Add")
    833 def _AddGrad(op, grad):
    834   """Gradient for Add."""
    835   x = op.inputs[0]
    836   y = op.inputs[1]
    837   if (isinstance(grad, ops.Tensor) and
    838       _ShapesFullySpecifiedAndEqual(x, y, grad)):
    839     return grad, grad
    840   sx = array_ops.shape(x)
    841   sy = array_ops.shape(y)
    842   # pylint: disable=protected-access
    843   rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
    844   # pylint: enable=protected-access
    845   return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx),
    846           array_ops.reshape(math_ops.reduce_sum(grad, ry), sy))
    847 
    848 
    849 @ops.RegisterGradient("Sub")
    850 def _SubGrad(op, grad):
    851   """Gradient for Sub."""
    852   x = op.inputs[0]
    853   y = op.inputs[1]
    854   if (isinstance(grad, ops.Tensor) and
    855       _ShapesFullySpecifiedAndEqual(x, y, grad)):
    856     return grad, -grad
    857   sx = array_ops.shape(x)
    858   sy = array_ops.shape(y)
    859   # pylint: disable=protected-access
    860   rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
    861   # pylint: enable=protected-access
    862   return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx),
    863           array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy))
    864 
    865 
    866 @ops.RegisterGradient("Mul")
    867 def _MulGrad(op, grad):
    868   """The gradient of scalar multiplication."""
    869   x = op.inputs[0]
    870   y = op.inputs[1]
    871   # pylint: disable=protected-access
    872   if (isinstance(grad, ops.Tensor) and
    873       _ShapesFullySpecifiedAndEqual(x, y, grad) and
    874       grad.dtype in (dtypes.int32, dtypes.float32)):
    875     return gen_math_ops._mul(grad, y), gen_math_ops._mul(grad, x)
    876   assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
    877   sx = array_ops.shape(x)
    878   sy = array_ops.shape(y)
    879   rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
    880   x = math_ops.conj(x)
    881   y = math_ops.conj(y)
    882   return (array_ops.reshape(
    883       math_ops.reduce_sum(gen_math_ops._mul(grad, y), rx), sx),
    884           array_ops.reshape(
    885               math_ops.reduce_sum(gen_math_ops._mul(x, grad), ry), sy))
    886   # pylint: enable=protected-access
    887 
    888 
    889 @ops.RegisterGradient("Div")
    890 def _DivGrad(op, grad):
    891   """The gradient for the Div operator."""
    892   x = op.inputs[0]
    893   y = op.inputs[1]
    894   sx = array_ops.shape(x)
    895   sy = array_ops.shape(y)
    896   # pylint: disable=protected-access
    897   rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
    898   # pylint: enable=protected-access
    899   x = math_ops.conj(x)
    900   y = math_ops.conj(y)
    901   return (array_ops.reshape(math_ops.reduce_sum(math_ops.div(grad, y), rx), sx),
    902           array_ops.reshape(
    903               math_ops.reduce_sum(grad * math_ops.div(math_ops.div(-x, y), y),
    904                                   ry), sy))
    905 
    906 
    907 @ops.RegisterGradient("FloorDiv")
    908 def _FloorDivGrad(_, unused_grad):
    909   """The gradient for the FloorDiv operator."""
    910   return None, None
    911 
    912 
    913 @ops.RegisterGradient("FloorMod")
    914 def _FloorModGrad(op, grad):
    915   """Returns grad * (1, -floor(x/y))."""
    916   x = math_ops.conj(op.inputs[0])
    917   y = math_ops.conj(op.inputs[1])
    918 
    919   sx = array_ops.shape(x)
    920   sy = array_ops.shape(y)
    921   # pylint: disable=protected-access
    922   rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
    923   # pylint: enable=protected-access
    924   floor_xy = math_ops.floor_div(x, y)
    925   gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
    926   gy = array_ops.reshape(
    927       math_ops.reduce_sum(grad * math_ops.negative(floor_xy), ry), sy)
    928   return gx, gy
    929 
    930 
    931 @ops.RegisterGradient("TruncateDiv")
    932 def _TruncateDivGrad(_, unused_grad):
    933   return None, None
    934 
    935 
    936 @ops.RegisterGradient("RealDiv")
    937 def _RealDivGrad(op, grad):
    938   """RealDiv op gradient."""
    939   x = op.inputs[0]
    940   y = op.inputs[1]
    941   sx = array_ops.shape(x)
    942   sy = array_ops.shape(y)
    943   # pylint: disable=protected-access
    944   rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
    945   # pylint: enable=protected-access
    946   x = math_ops.conj(x)
    947   y = math_ops.conj(y)
    948   return (array_ops.reshape(
    949       math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx),
    950           array_ops.reshape(
    951               math_ops.reduce_sum(
    952                   grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy))
    953 
    954 
    955 @ops.RegisterGradient("Pow")
    956 def _PowGrad(op, grad):
    957   """Returns grad * (y*x^(y-1), z*log(x))."""
    958   x = op.inputs[0]
    959   y = op.inputs[1]
    960   z = op.outputs[0]
    961   sx = array_ops.shape(x)
    962   sy = array_ops.shape(y)
    963   rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
    964   x = math_ops.conj(x)
    965   y = math_ops.conj(y)
    966   z = math_ops.conj(z)
    967   gx = array_ops.reshape(
    968       math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx)
    969   # Avoid false singularity at x = 0
    970   if x.dtype.is_complex:
    971     # real(x) < 0 is fine for the complex case
    972     log_x = array_ops.where(
    973         math_ops.not_equal(x, 0), math_ops.log(x), array_ops.zeros_like(x))
    974   else:
    975     # There's no sensible real value to return if x < 0, so return 0
    976     log_x = array_ops.where(x > 0, math_ops.log(x), array_ops.zeros_like(x))
    977   gy = array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy)
    978   return gx, gy
    979 
    980 
    981 def _MaximumMinimumGrad(op, grad, selector_op):
    982   """Factor out the code for the gradient of Maximum or Minimum."""
    983   x = op.inputs[0]
    984   y = op.inputs[1]
    985   gdtype = grad.dtype
    986   sx = array_ops.shape(x)
    987   sy = array_ops.shape(y)
    988   gradshape = array_ops.shape(grad)
    989   zeros = array_ops.zeros(gradshape, gdtype)
    990   xmask = selector_op(x, y)
    991   rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
    992   xgrad = array_ops.where(xmask, grad, zeros)
    993   ygrad = array_ops.where(xmask, zeros, grad)
    994   gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx)
    995   gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy)
    996   return (gx, gy)
    997 
    998 
    999 @ops.RegisterGradient("Maximum")
   1000 def _MaximumGrad(op, grad):
   1001   """Returns grad*(x > y, x <= y) with type of grad."""
   1002   return _MaximumMinimumGrad(op, grad, math_ops.greater_equal)
   1003 
   1004 
   1005 @ops.RegisterGradient("Minimum")
   1006 def _MinimumGrad(op, grad):
   1007   """Returns grad*(x < y, x >= y) with type of grad."""
   1008   return _MaximumMinimumGrad(op, grad, math_ops.less_equal)
   1009 
   1010 
   1011 @ops.RegisterGradient("SquaredDifference")
   1012 def _SquaredDifferenceGrad(op, grad):
   1013   """Returns the gradient for (x-y)^2."""
   1014   x = op.inputs[0]
   1015   y = op.inputs[1]
   1016   sx = array_ops.shape(x)
   1017   sy = array_ops.shape(y)
   1018   # pylint: disable=protected-access
   1019   rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
   1020   # pylint: enable=protected-access
   1021   with ops.control_dependencies([grad]):
   1022     # The parens ensure that if grad is IndexedSlices, it'll get multiplied by
   1023     # Tensor (not a number like 2.0) which causes it to convert to Tensor.
   1024     x_grad = math_ops.scalar_mul(2.0, grad) * (x - y)
   1025   return (array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx),
   1026           -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy))
   1027 
   1028 
   1029 # Logical operations have no gradients.
   1030 ops.NotDifferentiable("Less")
   1031 ops.NotDifferentiable("LessEqual")
   1032 ops.NotDifferentiable("Greater")
   1033 ops.NotDifferentiable("GreaterEqual")
   1034 ops.NotDifferentiable("Equal")
   1035 ops.NotDifferentiable("ApproximateEqual")
   1036 ops.NotDifferentiable("NotEqual")
   1037 ops.NotDifferentiable("LogicalAnd")
   1038 ops.NotDifferentiable("LogicalOr")
   1039 ops.NotDifferentiable("LogicalNot")
   1040 
   1041 
   1042 @ops.RegisterGradient("Select")
   1043 def _SelectGrad(op, grad):
   1044   c = op.inputs[0]
   1045   x = op.inputs[1]
   1046   zeros = array_ops.zeros_like(x)
   1047   return (None, array_ops.where(c, grad, zeros), array_ops.where(
   1048       c, zeros, grad))
   1049 
   1050 
   1051 @ops.RegisterGradient("MatMul")
   1052 def _MatMulGrad(op, grad):
   1053   """Gradient for MatMul."""
   1054 
   1055   t_a = op.get_attr("transpose_a")
   1056   t_b = op.get_attr("transpose_b")
   1057   a = math_ops.conj(op.inputs[0])
   1058   b = math_ops.conj(op.inputs[1])
   1059   # pylint: disable=protected-access
   1060   if not t_a and not t_b:
   1061     grad_a = gen_math_ops._mat_mul(grad, b, transpose_b=True)
   1062     grad_b = gen_math_ops._mat_mul(a, grad, transpose_a=True)
   1063   elif not t_a and t_b:
   1064     grad_a = gen_math_ops._mat_mul(grad, b)
   1065     grad_b = gen_math_ops._mat_mul(grad, a, transpose_a=True)
   1066   elif t_a and not t_b:
   1067     grad_a = gen_math_ops._mat_mul(b, grad, transpose_b=True)
   1068     grad_b = gen_math_ops._mat_mul(a, grad)
   1069   elif t_a and t_b:
   1070     grad_a = gen_math_ops._mat_mul(b, grad, transpose_a=True, transpose_b=True)
   1071     grad_b = gen_math_ops._mat_mul(grad, a, transpose_a=True, transpose_b=True)
   1072   # pylint: enable=protected-access
   1073   return grad_a, grad_b
   1074 
   1075 
   1076 @ops.RegisterGradient("SparseMatMul")
   1077 def _SparseMatMulGrad(op, grad):
   1078   """Gradient for SparseMatMul."""
   1079 
   1080   t_a = op.get_attr("transpose_a")
   1081   t_b = op.get_attr("transpose_b")
   1082   is_sparse = {
   1083       op.inputs[0]: op.get_attr("a_is_sparse"),
   1084       op.inputs[1]: op.get_attr("b_is_sparse"),
   1085       # Use heuristic to figure out if grad might be sparse
   1086       grad: context.in_graph_mode() and (grad.op.type == "ReluGrad")
   1087   }
   1088 
   1089   def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False):
   1090     """Helper function to create SparseMatMul op."""
   1091 
   1092     assert t1 in is_sparse and t2 in is_sparse
   1093     t1_sparse = is_sparse[t1]
   1094     t2_sparse = is_sparse[t2]
   1095     if transpose_b:
   1096       t2 = array_ops.transpose(t2)
   1097       transpose_b = False
   1098     prod = math_ops.matmul(
   1099         t1,
   1100         t2,
   1101         transpose_a=transpose_a,
   1102         transpose_b=transpose_b,
   1103         a_is_sparse=t1_sparse,
   1104         b_is_sparse=t2_sparse)
   1105     if prod.dtype != out_dtype:
   1106       prod = math_ops.cast(prod, out_dtype)
   1107     return prod
   1108 
   1109   dtype_a = op.inputs[0].dtype
   1110   dtype_b = op.inputs[1].dtype
   1111   if not t_a and not t_b:
   1112     return (_SparseMatMul(grad, op.inputs[1], dtype_a, transpose_b=True),
   1113             _SparseMatMul(op.inputs[0], grad, dtype_b, transpose_a=True))
   1114   elif not t_a and t_b:
   1115     return (_SparseMatMul(grad, op.inputs[1], dtype_a),
   1116             _SparseMatMul(grad, op.inputs[0], dtype_b, transpose_a=True))
   1117   elif t_a and not t_b:
   1118     return (_SparseMatMul(op.inputs[1], grad, dtype_a, transpose_b=True),
   1119             _SparseMatMul(op.inputs[0], grad, dtype_b))
   1120   elif t_a and t_b:
   1121     return (_SparseMatMul(
   1122         op.inputs[1], grad, dtype_a, transpose_a=True, transpose_b=True),
   1123             _SparseMatMul(
   1124                 grad, op.inputs[0], dtype_b, transpose_a=True,
   1125                 transpose_b=True))
   1126 
   1127 
   1128 @ops.RegisterGradient("Floor")
   1129 def _FloorGrad(_, unused_grad):
   1130   return [None]
   1131 
   1132 
   1133 @ops.RegisterGradient("Ceil")
   1134 def _CeilGrad(_, unused_grad):
   1135   return [None]
   1136 
   1137 
   1138 @ops.RegisterGradient("Round")
   1139 def _RoundGrad(_, unused_grad):
   1140   return [None]
   1141 
   1142 
   1143 @ops.RegisterGradient("Rint")
   1144 def _RintGrad(_, unused_grad):
   1145   # the gradient of Rint is zero
   1146   return [None]
   1147 
   1148 
   1149 @ops.RegisterGradient("BatchMatMul")
   1150 def _BatchMatMul(op, grad):
   1151   """Returns the gradient of x and y given the gradient of x * y."""
   1152   x = op.inputs[0]
   1153   y = op.inputs[1]
   1154   adj_x = op.get_attr("adj_x")
   1155   adj_y = op.get_attr("adj_y")
   1156 
   1157   if not adj_x:
   1158     if not adj_y:
   1159       grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True)
   1160       grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False)
   1161     else:
   1162       grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False)
   1163       grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False)
   1164   else:
   1165     if not adj_y:
   1166       grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True)
   1167       grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False)
   1168     else:
   1169       grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True)
   1170       grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True)
   1171 
   1172   return grad_x, grad_y
   1173 
   1174 
   1175 ops.NotDifferentiable("Range")
   1176 ops.NotDifferentiable("LinSpace")
   1177 
   1178 
   1179 @ops.RegisterGradient("Complex")
   1180 def _ComplexGrad(op, grad):
   1181   """Returns the real and imaginary components of 'grad', respectively."""
   1182   x = op.inputs[0]
   1183   y = op.inputs[1]
   1184   sx = array_ops.shape(x)
   1185   sy = array_ops.shape(y)
   1186   rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
   1187   return (array_ops.reshape(math_ops.reduce_sum(math_ops.real(grad), rx), sx),
   1188           array_ops.reshape(math_ops.reduce_sum(math_ops.imag(grad), ry), sy))
   1189 
   1190 
   1191 @ops.RegisterGradient("Real")
   1192 def _RealGrad(_, grad):
   1193   """Returns 'grad' as the real part and set the imaginary part 0."""
   1194   zero = constant_op.constant(0, dtype=grad.dtype)
   1195   return math_ops.complex(grad, zero)
   1196 
   1197 
   1198 @ops.RegisterGradient("Imag")
   1199 def _ImagGrad(_, grad):
   1200   """Returns 'grad' as the imaginary part and set the real part 0."""
   1201   zero = constant_op.constant(0, dtype=grad.dtype)
   1202   return math_ops.complex(zero, grad)
   1203 
   1204 
   1205 @ops.RegisterGradient("Angle")
   1206 def _AngleGrad(op, grad):
   1207   """Returns -grad / (Im(x) + iRe(x))"""
   1208   x = op.inputs[0]
   1209   with ops.control_dependencies([grad]):
   1210     re = math_ops.real(x)
   1211     im = math_ops.imag(x)
   1212     z = math_ops.reciprocal(math_ops.complex(im, re))
   1213     zero = constant_op.constant(0, dtype=grad.dtype)
   1214     complex_grad = math_ops.complex(grad, zero)
   1215     return -complex_grad * z
   1216 
   1217 
   1218 @ops.RegisterGradient("Conj")
   1219 def _ConjGrad(_, grad):
   1220   """Returns the complex conjugate of grad."""
   1221   return math_ops.conj(grad)
   1222 
   1223 
   1224 @ops.RegisterGradient("ComplexAbs")
   1225 def _ComplexAbsGrad(op, grad):
   1226   """Returns the gradient of ComplexAbs."""
   1227   # TODO(b/27786104): The cast to complex could be removed once arithmetic
   1228   # supports mixtures of complex64 and real values.
   1229   return (math_ops.complex(grad, array_ops.zeros_like(grad)) * math_ops.sign(
   1230       op.inputs[0]))
   1231 
   1232 
   1233 @ops.RegisterGradient("Cast")
   1234 def _CastGrad(op, grad):
   1235   t = [
   1236       dtypes.float16, dtypes.float32, dtypes.float64, dtypes.bfloat16,
   1237       dtypes.complex64, dtypes.complex128
   1238   ]
   1239   src_type = op.inputs[0].dtype.base_dtype
   1240   dst_type = grad.dtype.base_dtype
   1241   if src_type in t and dst_type in t:
   1242     return math_ops.cast(grad, src_type)
   1243   else:
   1244     return None
   1245 
   1246 
   1247 @ops.RegisterGradient("Cross")
   1248 def _CrossGrad(op, grad):
   1249   u = op.inputs[0]
   1250   v = op.inputs[1]
   1251   return (math_ops.cross(v, grad), math_ops.cross(grad, u))
   1252 
   1253 
   1254 @ops.RegisterGradient("Cumsum")
   1255 def _CumsumGrad(op, grad):
   1256   axis = op.inputs[1]
   1257   exclusive = op.get_attr("exclusive")
   1258   reverse = op.get_attr("reverse")
   1259   return [
   1260       math_ops.cumsum(grad, axis, exclusive=exclusive, reverse=not reverse),
   1261       None
   1262   ]
   1263 
   1264 
   1265 @ops.RegisterGradient("Cumprod")
   1266 def _CumprodGrad(op, grad):
   1267   x = op.inputs[0]
   1268   axis = op.inputs[1]
   1269   exclusive = op.get_attr("exclusive")
   1270   reverse = op.get_attr("reverse")
   1271 
   1272   # TODO This fails when x contains 0 and should be fixed
   1273   prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse)
   1274   out = math_ops.cumsum(
   1275       prod * grad, axis, exclusive=exclusive, reverse=not reverse)
   1276   return [out / x, None]
   1277