Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Gradients for operators defined in nn_ops.py."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.eager import backprop
     22 from tensorflow.python.eager import context
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.framework import tensor_util
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import gen_nn_ops
     28 from tensorflow.python.ops import math_ops
     29 from tensorflow.python.ops import nn_ops
     30 
     31 
     32 @ops.RegisterGradient("Conv2DBackpropInput")
     33 def _Conv2DBackpropInputGrad(op, grad):
     34   """The derivatives for deconvolution.
     35 
     36   Args:
     37     op: the Deconvolution op.
     38     grad: the tensor representing the gradient w.r.t. the output
     39 
     40   Returns:
     41     the gradients w.r.t. the input and the filter
     42   """
     43   return [
     44       None,
     45       nn_ops.conv2d_backprop_filter(
     46           grad,
     47           array_ops.shape(op.inputs[1]),
     48           op.inputs[2],
     49           dilations=op.get_attr("dilations"),
     50           strides=op.get_attr("strides"),
     51           padding=op.get_attr("padding"),
     52           use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
     53           data_format=op.get_attr("data_format").decode()),
     54       nn_ops.conv2d(
     55           grad,
     56           op.inputs[1],
     57           dilations=op.get_attr("dilations"),
     58           strides=op.get_attr("strides"),
     59           padding=op.get_attr("padding"),
     60           use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
     61           data_format=op.get_attr("data_format").decode())
     62   ]
     63 
     64 
     65 @ops.RegisterGradient("Conv2DBackpropFilter")
     66 def _Conv2DBackpropFilterGrad(op, grad):
     67   return [
     68       nn_ops.conv2d_backprop_input(
     69           array_ops.shape(op.inputs[0]),
     70           grad,
     71           op.inputs[2],
     72           dilations=op.get_attr("dilations"),
     73           strides=op.get_attr("strides"),
     74           padding=op.get_attr("padding"),
     75           use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
     76           data_format=op.get_attr("data_format").decode()), None,
     77       nn_ops.conv2d(
     78           op.inputs[0],
     79           grad,
     80           dilations=op.get_attr("dilations"),
     81           strides=op.get_attr("strides"),
     82           padding=op.get_attr("padding"),
     83           use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
     84           data_format=op.get_attr("data_format").decode())
     85   ]
     86 
     87 
     88 @ops.RegisterGradient("DepthwiseConv2dNativeBackpropInput")
     89 def _DepthwiseConv2dNativeBackpropInputGrad(op, grad):
     90   """The derivatives for deconvolution.
     91 
     92   Args:
     93     op: the Deconvolution op.
     94     grad: the tensor representing the gradient w.r.t. the output
     95 
     96   Returns:
     97     the gradients w.r.t. the input and the filter
     98   """
     99   return [
    100       None,
    101       nn_ops.depthwise_conv2d_native_backprop_filter(
    102           grad,
    103           array_ops.shape(op.inputs[1]),
    104           op.inputs[2],
    105           dilations=op.get_attr("dilations"),
    106           strides=op.get_attr("strides"),
    107           padding=op.get_attr("padding"),
    108           data_format=op.get_attr("data_format")),
    109       nn_ops.depthwise_conv2d_native(
    110           grad,
    111           op.inputs[1],
    112           dilations=op.get_attr("dilations"),
    113           strides=op.get_attr("strides"),
    114           padding=op.get_attr("padding"),
    115           data_format=op.get_attr("data_format"))
    116   ]
    117 
    118 
    119 @ops.RegisterGradient("DepthwiseConv2dNativeBackpropFilter")
    120 def _DepthwiseConv2dNativeBackpropFilterGrad(op, grad):
    121   return [
    122       nn_ops.depthwise_conv2d_native_backprop_input(
    123           array_ops.shape(op.inputs[0]),
    124           grad,
    125           op.inputs[2],
    126           dilations=op.get_attr("dilations"),
    127           strides=op.get_attr("strides"),
    128           padding=op.get_attr("padding"),
    129           data_format=op.get_attr("data_format")), None,
    130       nn_ops.depthwise_conv2d_native(
    131           op.inputs[0],
    132           grad,
    133           dilations=op.get_attr("dilations"),
    134           strides=op.get_attr("strides"),
    135           padding=op.get_attr("padding"),
    136           data_format=op.get_attr("data_format"))
    137   ]
    138 
    139 
    140 @ops.RegisterGradient("Conv3D")
    141 def _Conv3DGrad(op, grad):
    142   data_format = op.get_attr("data_format").decode()
    143   return [
    144       nn_ops.conv3d_backprop_input_v2(
    145           array_ops.shape(op.inputs[0]),
    146           op.inputs[1],
    147           grad,
    148           dilations=op.get_attr("dilations"),
    149           strides=op.get_attr("strides"),
    150           padding=op.get_attr("padding"),
    151           data_format=data_format),
    152       nn_ops.conv3d_backprop_filter_v2(
    153           op.inputs[0],
    154           array_ops.shape(op.inputs[1]),
    155           grad,
    156           dilations=op.get_attr("dilations"),
    157           strides=op.get_attr("strides"),
    158           padding=op.get_attr("padding"),
    159           data_format=data_format)
    160   ]
    161 
    162 
    163 @ops.RegisterGradient("Conv3DBackpropInputV2")
    164 def _Conv3DBackpropInputGrad(op, grad):
    165   data_format = op.get_attr("data_format").decode()
    166   return [
    167       None,
    168       nn_ops.conv3d_backprop_filter_v2(
    169           grad,
    170           array_ops.shape(op.inputs[1]),
    171           op.inputs[2],
    172           dilations=op.get_attr("dilations"),
    173           strides=op.get_attr("strides"),
    174           padding=op.get_attr("padding"),
    175           data_format=data_format),
    176       nn_ops.conv3d(
    177           grad,
    178           op.inputs[1],
    179           dilations=op.get_attr("dilations"),
    180           strides=op.get_attr("strides"),
    181           padding=op.get_attr("padding"),
    182           data_format=data_format)
    183   ]
    184 
    185 
    186 @ops.RegisterGradient("Conv3DBackpropFilterV2")
    187 def _Conv3DBackpropFilterGrad(op, grad):
    188   data_format = op.get_attr("data_format").decode()
    189   return [
    190       nn_ops.conv3d_backprop_input_v2(
    191           array_ops.shape(op.inputs[0]),
    192           grad,
    193           op.inputs[2],
    194           dilations=op.get_attr("dilations"),
    195           strides=op.get_attr("strides"),
    196           padding=op.get_attr("padding"),
    197           data_format=data_format), None,
    198       nn_ops.conv3d(
    199           op.inputs[0],
    200           grad,
    201           dilations=op.get_attr("dilations"),
    202           strides=op.get_attr("strides"),
    203           padding=op.get_attr("padding"),
    204           data_format=data_format)
    205   ]
    206 
    207 
    208 @ops.RegisterGradient("AvgPool3D")
    209 def _AvgPool3DGrad(op, grad):
    210   return gen_nn_ops.avg_pool3d_grad(
    211       array_ops.shape(op.inputs[0]),
    212       grad,
    213       ksize=op.get_attr("ksize"),
    214       strides=op.get_attr("strides"),
    215       padding=op.get_attr("padding"),
    216       data_format=op.get_attr("data_format").decode())
    217 
    218 
    219 @ops.RegisterGradient("AvgPool3DGrad")
    220 def _AvgPool3DGradGrad(op, grad):
    221   return (array_ops.stop_gradient(op.inputs[0]),
    222           gen_nn_ops.avg_pool3d(
    223               grad,
    224               op.get_attr("ksize"),
    225               op.get_attr("strides"),
    226               op.get_attr("padding"),
    227               data_format=op.get_attr("data_format").decode()))
    228 
    229 
    230 @ops.RegisterGradient("MaxPool3D")
    231 def _MaxPool3DGrad(op, grad):
    232   return gen_nn_ops.max_pool3d_grad(
    233       op.inputs[0],
    234       op.outputs[0],
    235       grad,
    236       ksize=op.get_attr("ksize"),
    237       strides=op.get_attr("strides"),
    238       padding=op.get_attr("padding"),
    239       data_format=op.get_attr("data_format").decode())
    240 
    241 
    242 @ops.RegisterGradient("MaxPool3DGrad")
    243 def _MaxPool3DGradGrad(op, grad):
    244   return (array_ops.zeros(
    245       shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
    246           array_ops.zeros(
    247               shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
    248           gen_nn_ops.max_pool3d_grad_grad(
    249               op.inputs[0],
    250               op.inputs[1],
    251               grad,
    252               op.get_attr("ksize"),
    253               op.get_attr("strides"),
    254               padding=op.get_attr("padding"),
    255               data_format=op.get_attr("data_format").decode()))
    256 
    257 
    258 @ops.RegisterGradient("MaxPool3DGradGrad")
    259 def _MaxPool3DGradGradGrad(op, grad):
    260   return (array_ops.zeros(
    261       shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
    262           array_ops.zeros(
    263               shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
    264           gen_nn_ops.max_pool3d_grad(
    265               op.inputs[0],
    266               op.inputs[1],
    267               grad,
    268               op.get_attr("ksize"),
    269               op.get_attr("strides"),
    270               padding=op.get_attr("padding"),
    271               data_format=op.get_attr("data_format").decode()))
    272 
    273 
    274 @ops.RegisterGradient("Softmax")
    275 def _SoftmaxGrad(op, grad_softmax):
    276   """The derivative of the softmax nonlinearity.
    277 
    278   We assume that probs is of shape [batch_size * dim]
    279   The formula for dsoftmax / dx = (diag(softmax) - softmax * softmax').
    280   This matrix is diagonal minus a rank one matrix, so it is easy to implement
    281   as follows:
    282 
    283     grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax
    284 
    285   Args:
    286      op: the Softmax op.
    287      grad_softmax:  the tensor representing the gradient w.r.t. the softmax
    288        output.
    289 
    290   Returns:
    291      gradient w.r.t the input to the softmax
    292 
    293   """
    294   softmax = op.outputs[0]
    295   sum_channels = math_ops.reduce_sum(grad_softmax * softmax, -1, keepdims=True)
    296   return (grad_softmax - sum_channels) * softmax
    297 
    298 
    299 @ops.RegisterGradient("LogSoftmax")
    300 def _LogSoftmaxGrad(op, grad):
    301   """The gradient for log_softmax.
    302 
    303       log_softmax = input - log(sum(exp(input))
    304       dlog_softmax/dinput = diag - softmax(input)
    305 
    306   Args:
    307     op: The log softmax op.
    308     grad: The tensor representing the gradient w.r.t. the output.
    309 
    310   Returns:
    311     The gradients w.r.t. the input.
    312   """
    313   softmax = math_ops.exp(op.outputs[0])
    314   return grad - math_ops.reduce_sum(grad, -1, keepdims=True) * softmax
    315 
    316 
    317 @ops.RegisterGradient("BiasAdd")
    318 def _BiasAddGrad(op, received_grad):
    319   """Return the gradients for the 2 inputs of bias_op.
    320 
    321   The first input of unused_bias_op is the tensor t, and its gradient is
    322   just the gradient the unused_bias_op received.
    323 
    324   The second input of unused_bias_op is the bias vector which has one fewer
    325   dimension than "received_grad" (the batch dimension.)  Its gradient is the
    326   received gradient Summed on the batch dimension, which is the first dimension.
    327 
    328   Args:
    329     op: The BiasOp for which we need to generate gradients.
    330     received_grad: Tensor.  The gradients passed to the BiasOp.
    331 
    332   Returns:
    333     Two tensors, the first one for the "tensor" input of the BiasOp,
    334     the second one for the "bias" input of the BiasOp.
    335   """
    336   try:
    337     data_format = op.get_attr("data_format")
    338   except ValueError:
    339     data_format = None
    340   return (received_grad,
    341           gen_nn_ops.bias_add_grad(
    342               out_backprop=received_grad, data_format=data_format))
    343 
    344 
    345 @ops.RegisterGradient("BiasAddGrad")
    346 def _BiasAddGradGrad(op, received_grad):
    347   """Gradient for the BiasAddGrad op.
    348 
    349   Args:
    350     op: BiasAddGrad op for which we are calculating gradients.
    351     received_grad: The gradients passed to the BiasAddGrad op.
    352 
    353   Returns:
    354     A single gradient Tensor for the input to BiasAddGrad (which
    355     is the gradient of the bias term in BiasAdd)
    356   """
    357 
    358   try:
    359     data_format = op.get_attr("data_format")
    360   except ValueError:
    361     data_format = None
    362 
    363   shape = array_ops.shape(op.inputs[0])
    364   bias_shape = array_ops.shape(received_grad)
    365 
    366   if data_format == b"NCHW":
    367     expanded_shape = array_ops.concat([
    368         array_ops.ones_like(shape[:1]), bias_shape,
    369         array_ops.ones_like(shape[2:])
    370     ], 0)
    371     tile_mults = array_ops.concat([shape[:1], [1], shape[2:]], 0)
    372   else:
    373     expanded_shape = array_ops.concat(
    374         [array_ops.ones_like(shape[:-1]), bias_shape], 0)
    375     tile_mults = array_ops.concat([shape[:-1], [1]], 0)
    376 
    377   expanded_grad = array_ops.reshape(received_grad, expanded_shape)
    378   return array_ops.tile(expanded_grad, tile_mults)
    379 
    380 
    381 @ops.RegisterGradient("BiasAddV1")
    382 def _BiasAddGradV1(unused_bias_op, received_grad):
    383   """Return the gradients for the 2 inputs of bias_op.
    384 
    385   The first input of unused_bias_op is the tensor t, and its gradient is
    386   just the gradient the unused_bias_op received.
    387 
    388   The second input of unused_bias_op is the bias vector which has one fewer
    389   dimension than "received_grad" (the batch dimension.)  Its gradient is the
    390   received gradient Summed on the batch dimension, which is the first dimension.
    391 
    392   Args:
    393     unused_bias_op: The BiasOp for which we need to generate gradients.
    394     received_grad: Tensor.  The gradients passed to the BiasOp.
    395 
    396   Returns:
    397     Two tensors, the first one for the "tensor" input of the BiasOp,
    398     the second one for the "bias" input of the BiasOp.
    399   """
    400   reduction_dim_tensor = math_ops.range(array_ops.rank(received_grad) - 1)
    401   return (received_grad, math_ops.reduce_sum(received_grad,
    402                                              reduction_dim_tensor))
    403 
    404 
    405 @ops.RegisterGradient("Relu")
    406 def _ReluGrad(op, grad):
    407   return gen_nn_ops.relu_grad(grad, op.outputs[0])
    408 
    409 
    410 @ops.RegisterGradient("EluGrad")
    411 def _EluGradGrad(op, grad):
    412   elu_x = op.inputs[1]
    413   return (gen_nn_ops.elu_grad(grad, op.outputs[0]),
    414           array_ops.where(
    415               elu_x < 0, grad * op.inputs[0],
    416               array_ops.zeros(shape=array_ops.shape(elu_x), dtype=elu_x.dtype)))
    417 
    418 
    419 @ops.RegisterGradient("SeluGrad")
    420 def _SeluGradGrad(op, grad):
    421   x = op.inputs[1]
    422   scale_alpha = 1.7580993408473768599402175208123
    423   return (gen_nn_ops.elu_grad(grad, op.outputs[0]),
    424           array_ops.where(
    425               x < 0., gen_nn_ops.elu_grad(grad, op.outputs[0] + scale_alpha),
    426               array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)))
    427 
    428 
    429 @ops.RegisterGradient("Relu6")
    430 def _Relu6Grad(op, grad):
    431   return gen_nn_ops.relu6_grad(grad, op.outputs[0])
    432 
    433 
    434 @ops.RegisterGradient("Relu6Grad")
    435 def _Relu6GradGrad(op, grad):
    436   x = op.inputs[1]
    437   return (gen_nn_ops.relu6_grad(grad, x),
    438           array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
    439 
    440 
    441 @ops.RegisterGradient("LeakyRelu")
    442 def _LeakyReluGrad(op, grad):
    443   x = op.inputs[0]
    444   alpha = op.get_attr("alpha")
    445   return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha)
    446 
    447 
    448 @ops.RegisterGradient("LeakyReluGrad")
    449 def _LeakyReluGradGrad(op, grad):
    450   x = op.inputs[1]
    451   alpha = op.get_attr("alpha")
    452   return (gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha),
    453           array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
    454 
    455 
    456 @ops.RegisterGradient("Elu")
    457 def _EluGrad(op, grad):
    458   return gen_nn_ops.elu_grad(grad, op.outputs[0])
    459 
    460 
    461 @ops.RegisterGradient("Selu")
    462 def _SeluGrad(op, grad):
    463   return gen_nn_ops.selu_grad(grad, op.outputs[0])
    464 
    465 
    466 @ops.RegisterGradient("Softplus")
    467 def _SoftplusGrad(op, grad):
    468   return gen_nn_ops.softplus_grad(grad, op.inputs[0])
    469 
    470 
    471 @ops.RegisterGradient("SoftplusGrad")
    472 def _SoftplusGradGrad(op, grad):
    473   # Let:
    474   #   y = tf.nn.softplus(x)
    475   #   dx = gen_nn_ops.softplus_grad(dy, x) = dy / (1 + exp(-x))
    476   # This op computes (ddy, d2x) from op.inputs == [dy, x] and grad == ddx.
    477   dy, x = op.inputs
    478   with ops.control_dependencies([grad]):
    479     ddy = gen_nn_ops.softplus_grad(grad, x)
    480     d2x = grad * dy / (math_ops.exp(-x) + 2.0 + math_ops.exp(x))
    481     return (ddy, d2x)
    482 
    483 
    484 @ops.RegisterGradient("Softsign")
    485 def _SoftsignGrad(op, grad):
    486   return gen_nn_ops.softsign_grad(grad, op.inputs[0])
    487 
    488 
    489 @ops.RegisterGradient("ReluGrad")
    490 def _ReluGradGrad(op, grad):
    491   x = op.inputs[1]
    492   return (gen_nn_ops.relu_grad(grad, x),
    493           array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
    494 
    495 
    496 def _BroadcastMul(vec, mat):
    497   """Multiply after broadcasting vec to match dimensions of mat.
    498 
    499   Args:
    500     vec: A 1-D tensor of dimension [D0]
    501     mat: A 2-D tensor of dimension [D0, D1]
    502 
    503   Returns:
    504     A tensor of dimension [D0, D1], the result of vec * mat
    505   """
    506   # Reshape vec to [D0, 1]
    507   vec = array_ops.expand_dims(vec, -1)
    508   return vec * mat
    509 
    510 
    511 @ops.RegisterGradient("SoftmaxCrossEntropyWithLogits")
    512 def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
    513   """Gradient function for SoftmaxCrossEntropyWithLogits."""
    514   # grad_loss is the backprop for cost, and we multiply it with the gradients
    515   # (which is output[1])
    516   # grad_grad is the backprop for softmax gradient.
    517   #
    518   # Second derivative is just softmax derivative w.r.t. logits.
    519   softmax_grad = op.outputs[1]
    520   grad = _BroadcastMul(grad_loss, softmax_grad)
    521 
    522   def IsZero(g):
    523     # Some introspection to check if the gradient is feeding zeros
    524     if context.executing_eagerly():
    525       # TODO(apassos) add an efficient way to detect eager zeros here.
    526       return False
    527     if g.op.type in ("ZerosLike", "Zeros"):
    528       return True
    529     const_fill_value = tensor_util.constant_value(g)
    530     return const_fill_value is not None and (const_fill_value == 0).all()
    531 
    532   logits = op.inputs[0]
    533   if grad_grad is not None and not IsZero(grad_grad):
    534     softmax = nn_ops.softmax(logits)
    535 
    536     grad += ((grad_grad - array_ops.squeeze(
    537         math_ops.matmul(
    538             array_ops.expand_dims(grad_grad, 1),
    539             array_ops.expand_dims(softmax, 2)),
    540         axis=1)) * softmax)
    541 
    542   return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits))
    543 
    544 
    545 @ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")
    546 def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
    547   """Gradient function for SparseSoftmaxCrossEntropyWithLogits."""
    548   # grad_0 is the backprop for cost, and we multiply it with the gradients
    549   # (which is output[1])
    550   # There is no gradient for the labels
    551   #
    552   # Currently there is no way to take the second derivative of this op
    553   # due to the fused implementation's interaction with tf.gradients(),
    554   # so we make sure we prevent silently incorrect results by raising
    555   # an error if the second derivative is requested via prevent_gradient.
    556   sparse_softmax_grad_without_gradient = array_ops.prevent_gradient(
    557       op.outputs[1],
    558       message="Currently there is no way to take the second "
    559       "derivative of sparse_softmax_cross_entropy_with_logits due to the fused "
    560       "implementation's interaction with tf.gradients()")
    561   return _BroadcastMul(grad_0, sparse_softmax_grad_without_gradient), None
    562 
    563 
    564 @ops.RegisterGradient("Conv2D")
    565 def _Conv2DGrad(op, grad):
    566   """Gradient function for Conv2D."""
    567   dilations = op.get_attr("dilations")
    568   strides = op.get_attr("strides")
    569   padding = op.get_attr("padding")
    570   explicit_paddings = op.get_attr("explicit_paddings")
    571   use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu")
    572   data_format = op.get_attr("data_format")
    573   shape_0, shape_1 = array_ops.shape_n([op.inputs[0], op.inputs[1]])
    574 
    575   # We call the gen_nn_ops backprop functions instead of nn_ops backprop
    576   # functions for performance reasons in Eager mode. gen_nn_ops functions take a
    577   # `explicit_paddings` parameter, but nn_ops functions do not. So if were were
    578   # to use the nn_ops functions, we would have to convert `padding` and
    579   # `explicit_paddings` into a single `padding` parameter, increasing overhead
    580   # in Eager mode.
    581   return [
    582       gen_nn_ops.conv2d_backprop_input(
    583           shape_0,
    584           op.inputs[1],
    585           grad,
    586           dilations=dilations,
    587           strides=strides,
    588           padding=padding,
    589           explicit_paddings=explicit_paddings,
    590           use_cudnn_on_gpu=use_cudnn_on_gpu,
    591           data_format=data_format),
    592       gen_nn_ops.conv2d_backprop_filter(
    593           op.inputs[0],
    594           shape_1,
    595           grad,
    596           dilations=dilations,
    597           strides=strides,
    598           padding=padding,
    599           explicit_paddings=explicit_paddings,
    600           use_cudnn_on_gpu=use_cudnn_on_gpu,
    601           data_format=data_format)
    602   ]
    603 
    604 
    605 @ops.RegisterGradient("DepthwiseConv2dNative")
    606 def _DepthwiseConv2dNativeGrad(op, grad):
    607   return [
    608       nn_ops.depthwise_conv2d_native_backprop_input(
    609           array_ops.shape(op.inputs[0]),
    610           op.inputs[1],
    611           grad,
    612           op.get_attr("strides"),
    613           op.get_attr("padding"),
    614           data_format=op.get_attr("data_format")),
    615       nn_ops.depthwise_conv2d_native_backprop_filter(
    616           op.inputs[0],
    617           array_ops.shape(op.inputs[1]),
    618           grad,
    619           op.get_attr("strides"),
    620           op.get_attr("padding"),
    621           data_format=op.get_attr("data_format"))
    622   ]
    623 
    624 
    625 @ops.RegisterGradient("Dilation2D")
    626 def _Dilation2DGrad(op, grad):
    627   return [
    628       nn_ops.dilation2d_backprop_input(op.inputs[0], op.inputs[1], grad,
    629                                        op.get_attr("strides"),
    630                                        op.get_attr("rates"),
    631                                        op.get_attr("padding")),
    632       nn_ops.dilation2d_backprop_filter(op.inputs[0], op.inputs[1], grad,
    633                                         op.get_attr("strides"),
    634                                         op.get_attr("rates"),
    635                                         op.get_attr("padding"))
    636   ]
    637 
    638 
    639 @ops.RegisterGradient("LRN")
    640 def _LRNGrad(op, grad):
    641   depth_radius = op.get_attr("depth_radius")
    642   bias = op.get_attr("bias")
    643   alpha = op.get_attr("alpha")
    644   beta = op.get_attr("beta")
    645   return [
    646       gen_nn_ops.lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius, bias,
    647                           alpha, beta)
    648   ]
    649 
    650 
    651 @ops.RegisterGradient("AvgPool")
    652 def _AvgPoolGrad(op, grad):
    653   return gen_nn_ops.avg_pool_grad(
    654       array_ops.shape(op.inputs[0]),
    655       grad,
    656       op.get_attr("ksize"),
    657       op.get_attr("strides"),
    658       op.get_attr("padding"),
    659       data_format=op.get_attr("data_format"))
    660 
    661 
    662 @ops.RegisterGradient("AvgPoolGrad")
    663 def _AvgPoolGradGrad(op, grad):
    664   return (array_ops.stop_gradient(op.inputs[0]),
    665           gen_nn_ops.avg_pool(
    666               grad,
    667               op.get_attr("ksize"),
    668               op.get_attr("strides"),
    669               op.get_attr("padding"),
    670               data_format=op.get_attr("data_format")))
    671 
    672 
    673 @ops.RegisterGradient("MaxPool")
    674 def _MaxPoolGrad(op, grad):
    675   return gen_nn_ops.max_pool_grad(
    676       op.inputs[0],
    677       op.outputs[0],
    678       grad,
    679       op.get_attr("ksize"),
    680       op.get_attr("strides"),
    681       padding=op.get_attr("padding"),
    682       data_format=op.get_attr("data_format"))
    683 
    684 
    685 @ops.RegisterGradient("MaxPoolV2")
    686 def _MaxPoolGradV2(op, grad):
    687   ksize = op.inputs[1]
    688   strides = op.inputs[2]
    689   return gen_nn_ops.max_pool_grad_v2(
    690       op.inputs[0],
    691       op.outputs[0],
    692       grad,
    693       ksize,
    694       strides,
    695       padding=op.get_attr("padding"),
    696       data_format=op.get_attr("data_format")), None, None
    697 
    698 
    699 @ops.RegisterGradient("MaxPoolWithArgmax")
    700 def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad):
    701   del unused_argmax_grad
    702   return gen_nn_ops.max_pool_grad_with_argmax(
    703       op.inputs[0],
    704       grad,
    705       op.outputs[1],
    706       op.get_attr("ksize"),
    707       op.get_attr("strides"),
    708       padding=op.get_attr("padding"),
    709       include_batch_in_index=op.get_attr("include_batch_in_index"))
    710 
    711 
    712 @ops.RegisterGradient("MaxPoolGrad")
    713 def _MaxPoolGradGrad(op, grad):
    714   return (array_ops.zeros(
    715       shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
    716           array_ops.zeros(
    717               shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
    718           gen_nn_ops.max_pool_grad_grad(
    719               op.inputs[0],
    720               op.inputs[1],
    721               grad,
    722               op.get_attr("ksize"),
    723               op.get_attr("strides"),
    724               padding=op.get_attr("padding"),
    725               data_format=op.get_attr("data_format")))
    726 
    727 
    728 @ops.RegisterGradient("MaxPoolGradV2")
    729 def _MaxPoolGradGradV2(op, grad):
    730   ksize = op.inputs[3]
    731   strides = op.inputs[4]
    732   return (array_ops.zeros(
    733       shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
    734           array_ops.zeros(
    735               shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
    736           gen_nn_ops.max_pool_grad_grad_v2(
    737               op.inputs[0],
    738               op.inputs[1],
    739               grad,
    740               ksize,
    741               strides,
    742               padding=op.get_attr("padding"),
    743               data_format=op.get_attr("data_format")), None, None)
    744 
    745 
    746 @ops.RegisterGradient("MaxPoolGradGrad")
    747 def _MaxPoolGradGradGrad(op, grad):
    748   return (array_ops.zeros(
    749       shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
    750           array_ops.zeros(
    751               shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
    752           gen_nn_ops.max_pool_grad(
    753               op.inputs[0],
    754               op.inputs[1],
    755               grad,
    756               op.get_attr("ksize"),
    757               op.get_attr("strides"),
    758               padding=op.get_attr("padding"),
    759               data_format=op.get_attr("data_format")))
    760 
    761 
    762 @ops.RegisterGradient("FractionalMaxPool")
    763 def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
    764   """Returns gradient for FractionalMaxPool.
    765 
    766   Since FractionalMaxPool has three outputs, there are three gradients passed in
    767   for each of the outputs. Only the first one is useful, the other two gradients
    768   are empty.
    769 
    770   Args:
    771     op: The FractionalMaxPoolOp.
    772     grad_0: Gradient with respect to op.outputs[0]
    773     unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty.
    774     unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty.
    775 
    776   Returns:
    777     Input backprop for FractionalMaxPool op.
    778   """
    779   return gen_nn_ops.fractional_max_pool_grad(
    780       op.inputs[0], op.outputs[0], grad_0, op.outputs[1], op.outputs[2],
    781       op.get_attr("overlapping"))
    782 
    783 
    784 @ops.RegisterGradient("FractionalAvgPool")
    785 def _FractionalAvgPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
    786   """Returns gradient for FractionalAvgPool.
    787 
    788   Since FractionalAvgPool has three outputs, there are three gradients passed in
    789   for each of the outputs. Only the first one is useful, the other two gradients
    790   are empty.
    791 
    792   Args:
    793     op: The FractionalAvgPoolOp.
    794     grad_0: Gradient with respect to op.outputs[0]
    795     unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty.
    796     unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty.
    797 
    798   Returns:
    799     Input backprop for FractionalAvgPool op.
    800   """
    801   return gen_nn_ops.fractional_avg_pool_grad(op.inputs[0].get_shape(), grad_0,
    802                                              op.outputs[1], op.outputs[2],
    803                                              op.get_attr("overlapping"))
    804 
    805 
    806 @ops.RegisterGradient("BatchNormWithGlobalNormalization")
    807 def _BatchNormWithGlobalNormalizationGrad(op, grad):
    808   """Return the gradients for the 5 inputs of BatchNormWithGlobalNormalization.
    809 
    810   We do not backprop anything for the mean and var intentionally as they are
    811   not being trained with backprop in the operation.
    812 
    813   Args:
    814     op: The BatchNormOp for which we need to generate gradients.
    815     grad: Tensor.  The gradients passed to the BatchNormOp.
    816 
    817   Returns:
    818     dx: Backprop for input, which is (grad * (g * rsqrt(v + epsilon)))
    819     dm: Backprop for mean, which is
    820         sum_over_rest(grad * g) * (-1 / rsqrt(v + epsilon))
    821     dv: Backprop for variance, which is
    822         sum_over_rest(grad * g * (x - m)) * (-1/2) * (v + epsilon) ^ (-3/2)
    823     db: Backprop for beta, which is grad reduced in all except the
    824         last dimension.
    825     dg: Backprop for gamma, which is (grad * ((x - m) * rsqrt(v + epsilon)))
    826   """
    827   dx, dm, dv, db, dg = gen_nn_ops.batch_norm_with_global_normalization_grad(
    828       op.inputs[0], op.inputs[1], op.inputs[2], op.inputs[4], grad,
    829       op.get_attr("variance_epsilon"), op.get_attr("scale_after_normalization"))
    830   return dx, dm, dv, db, dg
    831 
    832 
    833 def _BaseFusedBatchNormGrad(op, use_v2, *grad):
    834   """Return the gradients for the 3 inputs of BatchNorm.
    835 
    836   Args:
    837     op: The BatchNormOp for which we need to compute gradients.
    838     use_v2: Boolean indicating whether to use the V2 version of the fused batch
    839       norm gradient.
    840     *grad: An argument list for tensors of gradients wrt the outputs with
    841       grad[0] as grad_y.
    842 
    843   Returns:
    844     grad_x: gradient for x, which is scale * rsqrt(variance + epsilon) *
    845             [grad_y - mean(grad_y) - (x - mean(x)) *
    846             mean(grad_y * (x - mean(x))) / (variance + epsilon)]
    847             in training mode; grad_y * scale * rsqrt(pop_variance + epsilon)
    848             in freeze mode.
    849 
    850     grad_scale: gradient for scale, which is sum(grad_y * (x - mean(x)) *
    851                 rsqrt(variance + epsilon)) in training mode;
    852                 sum(grad_y * (x - pop_mean) * rsqrt(pop_variance + epsilon))
    853                 in freeze mode.
    854 
    855     grad_offset: gradient for offset, which is sum(grad_y) in training mode;
    856                  sum(grad_y) in freeze mode.
    857   """
    858   x = op.inputs[0]
    859   grad_y = grad[0]
    860   scale = op.inputs[1]
    861   epsilon = op.get_attr("epsilon")
    862   data_format = op.get_attr("data_format")
    863   is_training = op.get_attr("is_training")
    864   grad_fun = (
    865       gen_nn_ops.fused_batch_norm_grad_v2
    866       if use_v2 else gen_nn_ops.fused_batch_norm_grad)
    867   if is_training:
    868     return grad_fun(
    869         grad_y,
    870         x,
    871         scale,
    872         op.outputs[3],
    873         op.outputs[4],
    874         epsilon=epsilon,
    875         data_format=data_format,
    876         is_training=is_training)
    877   else:
    878     pop_mean = op.inputs[3]
    879     pop_var = op.inputs[4]
    880     if data_format == b"NCHW":
    881       x = array_ops.transpose(x, [0, 2, 3, 1])
    882       grad_y = array_ops.transpose(grad_y, [0, 2, 3, 1])
    883     dx, dscale, doffset, _, _ = grad_fun(
    884         grad_y,
    885         x,
    886         scale,
    887         pop_mean,
    888         pop_var,
    889         epsilon=epsilon,
    890         data_format="NHWC",
    891         is_training=is_training)
    892     if data_format == b"NCHW":
    893       dx = array_ops.transpose(dx, [0, 3, 1, 2])
    894     return dx, dscale, doffset, None, None
    895 
    896 
    897 @ops.RegisterGradient("FusedBatchNorm")
    898 def _FusedBatchNormGrad(op, *grad):
    899   return _BaseFusedBatchNormGrad(op, False, *grad)
    900 
    901 
    902 @ops.RegisterGradient("FusedBatchNormV2")
    903 def _FusedBatchNormV2Grad(op, *grad):
    904   return _BaseFusedBatchNormGrad(op, True, *grad)
    905 
    906 
    907 def _BatchNormGrad(grad_y,
    908                    x,
    909                    scale,
    910                    pop_mean,
    911                    pop_var,
    912                    epsilon,
    913                    data_format,
    914                    is_training=True):
    915   """Returns the gradients for the 3 inputs of BatchNorm.
    916 
    917   Args:
    918     grad_y: A `Tensor` of 4 dimensions for gradient for y.
    919     x: A `Tensor` of 4 dimensions for x.
    920     scale: A `Tensor` of 1 dimension for scaling.
    921     pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when
    922       is_training=False.
    923     pop_var: A `Tensor` of 1 dimension for the population variance. Only used
    924       when is_training=False.
    925     epsilon: A small float number added to the variance of x.
    926     data_format: The data format for input. Either b"NHWC" or b"NCHW".
    927     is_training: A bool value to indicate the operation is for training
    928       (default) or inference.
    929 
    930   Returns:
    931     A tuple (grad_x, grad_scale, grad_offset), where grad_x is the gradient
    932     for x, grad_scale the gradient for scale, and grad_offset the gradient
    933     for offset.
    934   """
    935   x_dtype = x.dtype.base_dtype
    936   if x_dtype == dtypes.float16:
    937     # float16 math is too imprecise, so we do the batch norm gradient
    938     # computations in float32.
    939     x = math_ops.cast(x, dtypes.float32)
    940     grad_y = math_ops.cast(grad_y, dtypes.float32)
    941   if is_training:
    942     if data_format == b"NHWC":
    943       keepdims = False
    944       reduce_axis = [0, 1, 2]
    945     else:
    946       keepdims = True
    947       reduce_axis = [0, 2, 3]
    948       shape = [1, array_ops.size(scale), 1, 1]
    949       scale = array_ops.reshape(scale, shape)
    950     mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims)
    951     mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims)
    952     var_x = math_ops.reduce_mean(
    953         math_ops.squared_difference(x, array_ops.stop_gradient(mean_x)),
    954         reduce_axis,
    955         keepdims=keepdims)
    956     grad_y_offset = grad_y - mean_grad_y
    957     x_offset = x - mean_x
    958     mean = math_ops.reduce_mean(
    959         grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
    960     grad_x = scale * math_ops.rsqrt(var_x + epsilon) * (
    961         grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset)
    962     grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum(
    963         grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
    964     if data_format == b"NCHW":
    965       grad_scale = array_ops.squeeze(grad_scale)
    966     grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
    967     return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
    968   else:
    969     if data_format == b"NHWC":
    970       reduce_axis = [0, 1, 2]
    971     else:
    972       reduce_axis = [0, 2, 3]
    973       shape = [1, array_ops.size(pop_mean), 1, 1]
    974       pop_mean = array_ops.reshape(pop_mean, shape)
    975       pop_var = array_ops.reshape(pop_var, shape)
    976       scale = array_ops.reshape(scale, shape)
    977 
    978     grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
    979     var_rsqrt = math_ops.rsqrt(pop_var + epsilon)
    980     grad_scale = math_ops.reduce_sum(
    981         grad_y * (x - pop_mean) * var_rsqrt, axis=reduce_axis)
    982     grad_x = grad_y * scale * var_rsqrt
    983     return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
    984 
    985 
    986 @ops.RegisterGradient("FusedBatchNormGrad")
    987 def _FusedBatchNormGradGrad(op, *grad):
    988   """Returns the gradients for the 3 inputs of FusedBatchNormGrad.
    989 
    990   Args:
    991     op: The FusedBatchNormGradOp for which we need to compute gradients.
    992     *grad: An argument list for tensors of gradients wrt the outputs with
    993       grad[0] as grad_grad_x, grad[1] as grad_grad_scale, grad[2] as
    994       grad_grad_offset.
    995 
    996   Returns:
    997     A tuple (grad_grad_y, grad_x, grad_scale, None, None), where grad_grad_y
    998     is the gradient for grad_y, grad_x the gradient for x, grad_scale the
    999     gradient for scale.
   1000   """
   1001   data_format = op.get_attr("data_format")
   1002   epsilon = op.get_attr("epsilon")
   1003   is_training = op.get_attr("is_training")
   1004   grad_y = op.inputs[0]
   1005   x = op.inputs[1]
   1006   scale = op.inputs[2]
   1007   pop_mean = op.inputs[3]
   1008   pop_var = op.inputs[4]
   1009   grad_grad_x = grad[0]
   1010   grad_grad_scale = grad[1]
   1011   grad_grad_offset = grad[2]
   1012   with backprop.GradientTape() as tape:
   1013     tape.watch(grad_y)
   1014     tape.watch(x)
   1015     tape.watch(scale)
   1016     grad_x, grad_scale, grad_offset = _BatchNormGrad(
   1017         grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training)
   1018     grad_initial = [grad_grad_x, grad_grad_scale, grad_grad_offset]
   1019   grad_grad_y, grad_x, grad_scale = tape.gradient(
   1020       [grad_x, grad_scale, grad_offset], [grad_y, x, scale], grad_initial)
   1021   return grad_grad_y, grad_x, grad_scale, None, None
   1022 
   1023 
   1024 @ops.RegisterGradient("FusedBatchNormGradV2")
   1025 def _FusedBatchNormGradGradV2(op, *grad):
   1026   return _FusedBatchNormGradGrad(op, *grad)
   1027 
   1028 
   1029 @ops.RegisterGradient("L2Loss")
   1030 def _L2LossGrad(op, grad):
   1031   """Return the gradients for L2Loss.
   1032 
   1033   Args:
   1034     op: The L2LossOp for which we need to generate gradients.
   1035     grad: Tensor containing a single number.
   1036 
   1037   Returns:
   1038     The gradient, which is (x * grad).
   1039   """
   1040   return op.inputs[0] * grad
   1041 
   1042 
   1043 @ops.RegisterGradient("TopK")
   1044 @ops.RegisterGradient("TopKV2")
   1045 def _TopKGrad(op, grad, _):
   1046   """Return the gradients for TopK.
   1047 
   1048   Args:
   1049     op: The TopKOp for which we need to generate gradients.
   1050     grad: Tensor. The gradients passed to the TopKOp.
   1051 
   1052   Returns:
   1053     A list of two tensors, the first being the gradient w.r.t to the input and
   1054     TopK, and the second being the gradient w.r.t. to the indices (all zero).
   1055   """
   1056   in_shape = array_ops.shape(op.inputs[0])
   1057   ind_shape = array_ops.shape(op.outputs[1])
   1058 
   1059   # int32 is not supported on GPU hence up-casting
   1060   ind_lastdim = array_ops.gather(
   1061       math_ops.cast(ind_shape, dtypes.int64),
   1062       array_ops.size(ind_shape) - 1)
   1063   # Flatten indices to 2D.
   1064   ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim]))
   1065 
   1066   in_lastdim = array_ops.gather(
   1067       math_ops.cast(in_shape, dtypes.int64),
   1068       array_ops.size(in_shape) - 1)
   1069   outerdim = array_ops.shape(ind_2d)[0]
   1070   # Compute linear indices (flattened to 1D).
   1071   ind = array_ops.reshape(
   1072       ind_2d + math_ops.cast(
   1073           array_ops.expand_dims(
   1074               math_ops.range(0,
   1075                              math_ops.cast(outerdim, dtypes.int64) * in_lastdim,
   1076                              in_lastdim), -1), dtypes.int32), [-1])
   1077 
   1078   # Substitute grad to appropriate locations and fill the rest with zeros,
   1079   # finally reshaping it to the original input shape.
   1080   return [
   1081       array_ops.reshape(
   1082           array_ops.scatter_nd(
   1083               array_ops.expand_dims(ind, -1), array_ops.reshape(grad, [-1]),
   1084               [math_ops.reduce_prod(in_shape)]), in_shape),
   1085       array_ops.zeros([], dtype=dtypes.int32)
   1086   ]
   1087 
   1088 
   1089 @ops.RegisterGradient("NthElement")
   1090 def _NthElementGrad(op, grad):
   1091   """Return the gradients for NthElement.
   1092 
   1093   Args:
   1094     op: The NthElementOp for which we need to generate gradients.
   1095     grad: Tensor. The gradients passed to the NthElementOp
   1096 
   1097   Returns:
   1098     A list of two tensors, the first being the gradient w.r.t. the input,
   1099     the second being the gradient w.r.t. the N (None).
   1100   """
   1101   input = op.inputs[0]  # pylint: disable=redefined-builtin
   1102   output = op.outputs[0]
   1103 
   1104   # Compute the number of elements which equal to output in each reduction
   1105   # dimension. If there are multiple elements then the gradient will be
   1106   # divided between them.
   1107   indicators = math_ops.cast(
   1108       math_ops.equal(array_ops.expand_dims(output, -1), input), grad.dtype)
   1109 
   1110   grad = array_ops.expand_dims(grad, -1)
   1111   num_selected = array_ops.expand_dims(math_ops.reduce_sum(indicators, -1), -1)
   1112 
   1113   return [math_ops.div(indicators, num_selected) * grad, None]
   1114