1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Gradients for operators defined in math_ops.py.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 22 from tensorflow.python.eager import context 23 from tensorflow.python.framework import constant_op 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.framework import ops 26 from tensorflow.python.framework import tensor_util 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import gen_array_ops 29 from tensorflow.python.ops import gen_math_ops 30 from tensorflow.python.ops import math_ops 31 32 33 def _safe_shape_div(x, y): 34 """Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`.""" 35 return x // math_ops.maximum(y, 1) 36 37 38 @ops.RegisterGradient("Sum") 39 def _SumGrad(op, grad): 40 """Gradient for Sum.""" 41 # Fast path for when reducing to a scalar and ndims is known: adds only 42 # Reshape and Tile ops (and possibly a Shape). 43 input_0_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access 44 if input_0_shape is not None: 45 axes = tensor_util.constant_value(op.inputs[1]) 46 if axes is not None: 47 rank = len(input_0_shape) 48 if np.array_equal(axes, np.arange(rank)): # Reduce all dims. 49 grad = array_ops.reshape(grad, [1] * rank) 50 # If shape is not fully defined (but rank is), we use Shape. 51 if None not in input_0_shape: 52 input_shape = input_0_shape 53 else: 54 input_shape = array_ops.shape(op.inputs[0]) 55 return [array_ops.tile(grad, input_shape), None] 56 57 input_shape = array_ops.shape(op.inputs[0]) 58 # TODO(apassos) remove this once device placement for eager ops makes more 59 # sense. 60 with ops.colocate_with(input_shape): 61 output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) 62 tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) 63 grad = array_ops.reshape(grad, output_shape_kept_dims) 64 return [array_ops.tile(grad, tile_scaling), None] 65 66 67 def _MinOrMaxGrad(op, grad): 68 """Gradient for Min or Max. Amazingly it's precisely the same code.""" 69 input_shape = array_ops.shape(op.inputs[0]) 70 output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) 71 y = op.outputs[0] 72 y = array_ops.reshape(y, output_shape_kept_dims) 73 grad = array_ops.reshape(grad, output_shape_kept_dims) 74 75 # Compute the number of selected (maximum or minimum) elements in each 76 # reduction dimension. If there are multiple minimum or maximum elements 77 # then the gradient will be divided between them. 78 indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype) 79 num_selected = array_ops.reshape( 80 math_ops.reduce_sum(indicators, op.inputs[1]), output_shape_kept_dims) 81 82 return [math_ops.div(indicators, num_selected) * grad, None] 83 84 85 @ops.RegisterGradient("Max") 86 def _MaxGrad(op, grad): 87 """Gradient for Max.""" 88 return _MinOrMaxGrad(op, grad) 89 90 91 @ops.RegisterGradient("Min") 92 def _MinGrad(op, grad): 93 return _MinOrMaxGrad(op, grad) 94 95 96 @ops.RegisterGradient("Mean") 97 def _MeanGrad(op, grad): 98 """Gradient for Mean.""" 99 sum_grad = _SumGrad(op, grad)[0] 100 input_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access 101 output_shape = op.outputs[0]._shape_tuple() # pylint: disable=protected-access 102 if (input_shape is not None and output_shape is not None and 103 None not in input_shape and None not in output_shape): 104 input_size = np.prod(input_shape) 105 output_size = np.prod(output_shape) 106 factor = input_size // max(output_size, 1) 107 factor = constant_op.constant(factor, dtype=sum_grad.dtype) 108 else: 109 input_shape = array_ops.shape(op.inputs[0]) 110 output_shape = array_ops.shape(op.outputs[0]) 111 factor = _safe_shape_div( 112 math_ops.reduce_prod(input_shape), math_ops.reduce_prod(output_shape)) 113 return math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), None 114 115 116 @ops.RegisterGradient("Prod") 117 def _ProdGrad(op, grad): 118 """Gradient for Prod.""" 119 # The gradient can be expressed by dividing the product by each entry of the 120 # input tensor, but this approach can't deal with zeros in the input. 121 # Here, we avoid this problem by composing the output as a product of two 122 # cumprod operations. 123 124 input_shape = array_ops.shape(op.inputs[0]) 125 # Reshape reduction indices for the case where the parameter is a scalar 126 reduction_indices = array_ops.reshape(op.inputs[1], [-1]) 127 128 # Expand grad to full input shape 129 output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) 130 tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) 131 grad = array_ops.reshape(grad, output_shape_kept_dims) 132 grad = array_ops.tile(grad, tile_scaling) 133 134 # Pack all reduced dimensions into a single one, so we can perform the 135 # cumprod ops. If the reduction dims list is empty, it defaults to float32, 136 # so we need to cast here. We put all the shape-related ops on CPU to avoid 137 # copying back and forth, and since listdiff is CPU only. 138 with ops.device("/cpu:0"): 139 rank = array_ops.rank(op.inputs[0]) 140 reduction_indices = (reduction_indices + rank) % rank 141 reduced = math_ops.cast(reduction_indices, dtypes.int32) 142 idx = math_ops.range(0, rank) 143 other, _ = array_ops.setdiff1d(idx, reduced) 144 perm = array_ops.concat([reduced, other], 0) 145 reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced)) 146 other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other)) 147 permuted = array_ops.transpose(op.inputs[0], perm) 148 permuted_shape = array_ops.shape(permuted) 149 reshaped = array_ops.reshape(permuted, (reduced_num, other_num)) 150 151 # Calculate product, leaving out the current entry 152 left = math_ops.cumprod(reshaped, axis=0, exclusive=True) 153 right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True) 154 y = array_ops.reshape(left * right, permuted_shape) 155 156 # Invert the transpose and reshape operations. 157 # Make sure to set the statically known shape information through a reshape. 158 out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm)) 159 return array_ops.reshape(out, input_shape), None 160 161 162 @ops.RegisterGradient("SegmentSum") 163 def _SegmentSumGrad(op, grad): 164 """Gradient for SegmentSum.""" 165 return array_ops.gather(grad, op.inputs[1]), None 166 167 168 @ops.RegisterGradient("SegmentMean") 169 def _SegmentMeanGrad(op, grad): 170 """Gradient for SegmentMean.""" 171 input_rank = array_ops.rank(op.inputs[0]) 172 ones_shape = array_ops.concat([ 173 array_ops.shape(op.inputs[1]), 174 array_ops.fill(array_ops.expand_dims(input_rank - 1, 0), 1) 175 ], 0) 176 ones = array_ops.fill(ones_shape, constant_op.constant(1, dtype=grad.dtype)) 177 scaled_grad = math_ops.div(grad, math_ops.segment_sum(ones, op.inputs[1])) 178 return array_ops.gather(scaled_grad, op.inputs[1]), None 179 180 181 @ops.RegisterGradient("SparseSegmentSum") 182 def _SparseSegmentSumGrad(op, grad): 183 """Gradient for SparseSegmentSum.""" 184 input_rows = array_ops.shape(op.inputs[0])[0] 185 return (math_ops.unsorted_segment_sum( 186 array_ops.gather(grad, op.inputs[2]), op.inputs[1], input_rows), None, 187 None) 188 189 190 @ops.RegisterGradient("SparseSegmentSumWithNumSegments") 191 def _SparseSegmentSumWithNumSegmentsGrad(op, grad): 192 """Gradient for SparseSegmentSumWithNumSegments.""" 193 input_rows = array_ops.shape(op.inputs[0])[0] 194 return (math_ops.unsorted_segment_sum( 195 array_ops.gather(grad, op.inputs[2]), op.inputs[1], input_rows), None, 196 None, None) 197 198 199 @ops.RegisterGradient("SparseSegmentMean") 200 def _SparseSegmentMeanGrad(op, grad): 201 """Gradient for SparseSegmentMean.""" 202 dim0 = array_ops.shape(op.inputs[0])[0] 203 return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2], 204 dim0), None, None) 205 206 207 @ops.RegisterGradient("SparseSegmentMeanWithNumSegments") 208 def _SparseSegmentMeanWithNumSegmentsGrad(op, grad): 209 """Gradient for SparseSegmentMeanWithNumSegments.""" 210 dim0 = array_ops.shape(op.inputs[0])[0] 211 return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2], 212 dim0), None, None, None) 213 214 215 @ops.RegisterGradient("SparseSegmentSqrtN") 216 def _SparseSegmentSqrtNGrad(op, grad): 217 """Gradient for SparseSegmentSqrtN.""" 218 dim0 = array_ops.shape(op.inputs[0])[0] 219 return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2], 220 dim0), None, None) 221 222 223 @ops.RegisterGradient("SparseSegmentSqrtNWithNumSegments") 224 def _SparseSegmentSqrtNWithNumSegmentsGrad(op, grad): 225 """Gradient for SparseSegmentSqrtNWithNumSegments.""" 226 dim0 = array_ops.shape(op.inputs[0])[0] 227 return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2], 228 dim0), None, None, None) 229 230 231 def _SegmentMinOrMaxGrad(op, grad): 232 """ Gradient for SegmentMin and SegmentMax. """ 233 zeros = array_ops.zeros_like(op.inputs[0], dtype=op.inputs[0].dtype) 234 # Get the number of selected (minimum or maximum) elements in each segment. 235 gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1]) 236 is_selected = math_ops.equal(op.inputs[0], gathered_outputs) 237 num_selected = math_ops.segment_sum(math_ops.cast(is_selected, grad.dtype), 238 op.inputs[1]) 239 # Compute the gradient for each segment. The gradient for the ith segment is 240 # divided evenly among the selected elements in that segment. 241 weighted_grads = math_ops.div(grad, num_selected) 242 gathered_grads = array_ops.gather(weighted_grads, op.inputs[1]) 243 return array_ops.where(is_selected, gathered_grads, zeros), None 244 245 246 @ops.RegisterGradient("SegmentMin") 247 def _SegmentMinGrad(op, grad): 248 """Gradient for SegmentMin.""" 249 return _SegmentMinOrMaxGrad(op, grad) 250 251 252 @ops.RegisterGradient("SegmentMax") 253 def _SegmentMaxGrad(op, grad): 254 """Gradient for SegmentMax.""" 255 return _SegmentMinOrMaxGrad(op, grad) 256 257 258 def _GatherDropNegatives(params, ids, zero_clipped_indices=None, 259 is_positive=None): 260 """ Helper function for unsorted segment ops. Gathers params for 261 positive segment ids and gathers 0 for inputs with negative segment id. 262 Also returns the clipped indices and a boolean mask with the same shape 263 as ids where a positive id is masked as true. With this, the latter two 264 can be passed as arguments to this function to reuse them. 265 """ 266 if zero_clipped_indices is None: 267 zero_clipped_indices = math_ops.maximum(ids, array_ops.zeros_like(ids)) 268 gathered = array_ops.gather(params, zero_clipped_indices) 269 if is_positive is None: 270 is_positive = math_ops.greater_equal(ids, 0) 271 # tf.where(condition, x, y) requires condition to have the same shape as x 272 # and y. 273 # todo(philjd): remove this if tf.where supports broadcasting (#9284) 274 for _ in range(gathered.shape.ndims - is_positive.shape.ndims): 275 is_positive = array_ops.expand_dims(is_positive, -1) 276 is_positive = (is_positive & 277 array_ops.ones_like(gathered, dtype=dtypes.bool)) 278 # replace gathered params of negative indices with 0 279 zero_slice = array_ops.zeros_like(gathered) 280 return (array_ops.where(is_positive, gathered, zero_slice), 281 zero_clipped_indices, is_positive) 282 283 284 def _UnsortedSegmentMinOrMaxGrad(op, grad): 285 """ Gradient for UnsortedSegmentMin and UnsortedSegmentMax. """ 286 # Get the number of selected (minimum or maximum) elements in each segment. 287 gathered_outputs, zero_clipped_indices, is_positive = \ 288 _GatherDropNegatives(op.outputs[0], op.inputs[1]) 289 is_selected = math_ops.equal(op.inputs[0], gathered_outputs) 290 is_selected = math_ops.logical_and(is_selected, is_positive) 291 num_selected = math_ops.unsorted_segment_sum( 292 math_ops.cast(is_selected, grad.dtype), op.inputs[1], op.inputs[2]) 293 # Compute the gradient for each segment. The gradient for the ith segment is 294 # divided evenly among the selected elements in that segment. 295 weighted_grads = math_ops.div(grad, num_selected) 296 gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None, 297 zero_clipped_indices, 298 is_positive) 299 zeros = array_ops.zeros_like(gathered_grads) 300 return array_ops.where(is_selected, gathered_grads, zeros), None, None 301 302 303 @ops.RegisterGradient("UnsortedSegmentSum") 304 def _UnsortedSegmentSumGrad(op, grad): 305 """Gradient for UnsortedSegmentSum.""" 306 return _GatherDropNegatives(grad, op.inputs[1])[0], None, None 307 308 309 @ops.RegisterGradient("UnsortedSegmentMax") 310 def _UnsortedSegmentMaxGrad(op, grad): 311 """ Gradient for UnsortedSegmentMax. """ 312 return _UnsortedSegmentMinOrMaxGrad(op, grad) 313 314 315 @ops.RegisterGradient("UnsortedSegmentMin") 316 def _UnsortedSegmentMinGrad(op, grad): 317 """ Gradient for UnsortedSegmentMin. """ 318 return _UnsortedSegmentMinOrMaxGrad(op, grad) 319 320 321 @ops.RegisterGradient("UnsortedSegmentProd") 322 def _UnsortedSegmentProdGrad(op, grad): 323 """ Gradient for UnsortedSegmentProd. 324 The gradient can be expressed for each segment by dividing the segment's 325 product by each element of the segment input tensor, but this approach can't 326 deal with zeros in the input. 327 Unlike reduce_prod we can't use cumsum here as individual segments may have 328 a different number of elements. Therefore we consider three cases: 329 1) A segment input contains no zeros and we can safely divide by the input 330 tensor. 331 2) A segment contains exactly one zero. Then the gradient of each input of 332 the segment is zero except for the 0-input, there the gradient is 333 the product of the remaining segment entries. 334 3) A segment contains at least two zeros. The gradient is zero for all 335 segment inputs. 336 """ 337 # Note that unsorted_segment_sum will filter out the negative indices, 338 # so we don't need to do a logical_and with is_positive here 339 is_zero = math_ops.equal(op.inputs[0], 0) 340 num_zeros = gen_math_ops.unsorted_segment_sum( 341 math_ops.cast(is_zero, dtype=dtypes.int32), op.inputs[1], op.inputs[2]) 342 # handle case 3 and set the gradient to 0 for segments with more than one 343 # 0 as input 344 grad = array_ops.where(math_ops.greater(num_zeros, 1), 345 array_ops.zeros_like(grad), grad) 346 # replace all zeros with ones and compute the unsorted_segment_prod 347 non_zero_data = array_ops.where(is_zero, array_ops.ones_like(op.inputs[0]), 348 op.inputs[0]) 349 non_zero_prod = gen_math_ops.unsorted_segment_prod( 350 non_zero_data, op.inputs[1], op.inputs[2]) 351 # clip the indices for gather to be positive 352 zero_clipped_indices = math_ops.maximum(op.inputs[1], 353 array_ops.zeros_like(op.inputs[1])) 354 gathered_prod = array_ops.gather(op.outputs[0], zero_clipped_indices) 355 gathered_non_zero_prod = array_ops.gather(non_zero_prod, 356 zero_clipped_indices) 357 prod_divided_by_el = gathered_prod / op.inputs[0] # May contain nan/inf. 358 # Now fetch the individual results for segments containing 0 and those that 359 # don't. is_zero will also fetch results for entries with negative index 360 # but the following gather_drop_negatives sets the corresponding entry in 361 # grad to 0 for these 362 partial_derivative = array_ops.where(is_zero, gathered_non_zero_prod, 363 prod_divided_by_el) 364 gathered_grad = _GatherDropNegatives(grad, op.inputs[1], 365 zero_clipped_indices)[0] 366 return gathered_grad * partial_derivative, None, None 367 368 369 @ops.RegisterGradient("Abs") 370 def _AbsGrad(op, grad): 371 x = op.inputs[0] 372 return grad * math_ops.sign(x) 373 374 375 @ops.RegisterGradient("Neg") 376 def _NegGrad(_, grad): 377 """Returns -grad.""" 378 return -grad 379 380 381 @ops.RegisterGradient("Inv") 382 def _InvGrad(op, grad): 383 """Returns -grad * (1 / x^2).""" 384 y = op.outputs[0] # y = 1 / x 385 # pylint: disable=protected-access 386 return gen_math_ops._reciprocal_grad(y, grad) 387 388 389 @ops.RegisterGradient("Reciprocal") 390 def _ReciprocalGrad(op, grad): 391 """Returns -grad * (1 / x^2).""" 392 y = op.outputs[0] # y = 1 / x 393 # pylint: disable=protected-access 394 return gen_math_ops._reciprocal_grad(y, grad) 395 396 397 @ops.RegisterGradient("InvGrad") 398 def _InvGradGrad(op, grad): 399 b = op.inputs[1] 400 # op.output[0]: y = -b * conj(a)^2 401 with ops.control_dependencies([grad]): 402 ca = math_ops.conj(op.inputs[0]) 403 cg = math_ops.conj(grad) 404 # pylint: disable=protected-access 405 return cg * -2.0 * b * ca, gen_math_ops._reciprocal_grad(ca, grad) 406 407 408 @ops.RegisterGradient("ReciprocalGrad") 409 def _ReciprocalGradGrad(op, grad): 410 b = op.inputs[1] 411 # op.output[0]: y = -b * conj(a)^2 412 with ops.control_dependencies([grad]): 413 ca = math_ops.conj(op.inputs[0]) 414 cg = math_ops.conj(grad) 415 # pylint: disable=protected-access 416 return cg * -2.0 * b * ca, gen_math_ops._reciprocal_grad(ca, grad) 417 418 419 @ops.RegisterGradient("Square") 420 def _SquareGrad(op, grad): 421 x = op.inputs[0] 422 # Added control dependencies to prevent 2*x from being computed too early. 423 with ops.control_dependencies([grad]): 424 x = math_ops.conj(x) 425 return math_ops.multiply(grad, math_ops.multiply(x, 2.0)) 426 427 428 @ops.RegisterGradient("Sqrt") 429 def _SqrtGrad(op, grad): 430 y = op.outputs[0] # y = x^(1/2) 431 # pylint: disable=protected-access 432 return gen_math_ops._sqrt_grad(y, grad) 433 # pylint: enable=protected-access 434 435 436 @ops.RegisterGradient("SqrtGrad") 437 def _SqrtGradGrad(op, grad): 438 a = op.inputs[0] 439 y = op.outputs[0] # y = 0.5 * b / conj(a) 440 with ops.control_dependencies([grad]): 441 ga = grad / a 442 return -math_ops.conj(ga) * y, 0.5 * ga 443 444 445 @ops.RegisterGradient("Rsqrt") 446 def _RsqrtGrad(op, grad): 447 """Returns -0.5 * grad * conj(y)^3.""" 448 y = op.outputs[0] # y = x^(-1/2) 449 # pylint: disable=protected-access 450 return gen_math_ops._rsqrt_grad(y, grad) 451 # pylint: enable=protected-access 452 453 454 @ops.RegisterGradient("RsqrtGrad") 455 def _RsqrtGradGrad(op, grad): 456 """Returns backprop gradient for f(a,b) = -0.5 * b * conj(a)^3.""" 457 a = op.inputs[0] # a = x^{-1/2} 458 b = op.inputs[1] # backprop gradient for a 459 with ops.control_dependencies([grad]): 460 ca = math_ops.conj(a) 461 cg = math_ops.conj(grad) 462 grad_a = -1.5 * cg * b * math_ops.square(ca) 463 # pylint: disable=protected-access 464 grad_b = gen_math_ops._rsqrt_grad(ca, grad) 465 return grad_a, grad_b 466 467 468 @ops.RegisterGradient("Exp") 469 def _ExpGrad(op, grad): 470 """Returns grad * exp(x).""" 471 y = op.outputs[0] # y = e^x 472 with ops.control_dependencies([grad]): 473 y = math_ops.conj(y) 474 return grad * y 475 476 477 @ops.RegisterGradient("Expm1") 478 def _Expm1Grad(op, grad): 479 """Returns grad * exp(x).""" 480 x = op.inputs[0] 481 with ops.control_dependencies([grad]): 482 x = math_ops.conj(x) 483 y = math_ops.exp(x) 484 return grad * y 485 486 487 @ops.RegisterGradient("Log") 488 def _LogGrad(op, grad): 489 """Returns grad * (1/x).""" 490 x = op.inputs[0] 491 with ops.control_dependencies([grad]): 492 x = math_ops.conj(x) 493 return grad * math_ops.reciprocal(x) 494 495 496 @ops.RegisterGradient("Log1p") 497 def _Log1pGrad(op, grad): 498 """Returns grad * (1/(1 + x)).""" 499 x = op.inputs[0] 500 with ops.control_dependencies([grad]): 501 x = math_ops.conj(x) 502 return grad * math_ops.reciprocal(1 + x) 503 504 505 @ops.RegisterGradient("Sinh") 506 def _SinhGrad(op, grad): 507 """Returns grad * cosh(x).""" 508 x = op.inputs[0] 509 with ops.control_dependencies([grad]): 510 x = math_ops.conj(x) 511 return grad * math_ops.cosh(x) 512 513 514 @ops.RegisterGradient("Cosh") 515 def _CoshGrad(op, grad): 516 """Returns grad * sinh(x).""" 517 x = op.inputs[0] 518 with ops.control_dependencies([grad]): 519 x = math_ops.conj(x) 520 return grad * math_ops.sinh(x) 521 522 523 @ops.RegisterGradient("Tanh") 524 def _TanhGrad(op, grad): 525 """Returns grad * (1 - tanh(x) * tanh(x)).""" 526 y = op.outputs[0] # y = tanh(x) 527 with ops.control_dependencies([grad]): 528 y = math_ops.conj(y) 529 # pylint: disable=protected-access 530 return gen_math_ops._tanh_grad(y, grad) 531 532 533 @ops.RegisterGradient("Asinh") 534 def _AsinhGrad(op, grad): 535 """Returns grad * 1/cosh(y).""" 536 y = op.outputs[0] 537 with ops.control_dependencies([grad]): 538 y = math_ops.conj(y) 539 return grad / math_ops.cosh(y) 540 541 542 @ops.RegisterGradient("Acosh") 543 def _AcoshGrad(op, grad): 544 """Returns grad * 1/sinh(y).""" 545 y = op.outputs[0] 546 with ops.control_dependencies([grad]): 547 y = math_ops.conj(y) 548 return grad / math_ops.sinh(y) 549 550 551 @ops.RegisterGradient("Atanh") 552 def _AtanhGrad(op, grad): 553 """Returns grad * 1/ (1 - x^2).""" 554 x = op.inputs[0] 555 with ops.control_dependencies([grad]): 556 x = math_ops.conj(x) 557 x2 = math_ops.square(x) 558 one = constant_op.constant(1, dtype=grad.dtype) 559 inv = math_ops.reciprocal(math_ops.subtract(one, x2)) 560 return grad * inv 561 562 563 @ops.RegisterGradient("TanhGrad") 564 def _TanhGradGrad(op, grad): 565 with ops.control_dependencies([grad]): 566 a = math_ops.conj(op.inputs[0]) 567 b = math_ops.conj(op.inputs[1]) 568 # pylint: disable=protected-access 569 return grad * -2.0 * b * a, gen_math_ops._tanh_grad(a, grad) 570 571 572 @ops.RegisterGradient("Erf") 573 def _ErfGrad(op, grad): 574 """Returns grad * 2/sqrt(pi) * exp(-x**2).""" 575 x = op.inputs[0] 576 two_over_root_pi = constant_op.constant(2 / np.sqrt(np.pi), dtype=grad.dtype) 577 with ops.control_dependencies([grad]): 578 x = math_ops.conj(x) 579 return grad * two_over_root_pi * math_ops.exp(-math_ops.square(x)) 580 581 582 @ops.RegisterGradient("Erfc") 583 def _ErfcGrad(op, grad): 584 """Returns -grad * 2/sqrt(pi) * exp(-x**2).""" 585 x = op.inputs[0] 586 minus_two_over_root_pi = constant_op.constant( 587 -2 / np.sqrt(np.pi), dtype=grad.dtype) 588 with ops.control_dependencies([grad]): 589 x = math_ops.conj(x) 590 return grad * minus_two_over_root_pi * math_ops.exp(-math_ops.square(x)) 591 592 593 @ops.RegisterGradient("Lgamma") 594 def _LgammaGrad(op, grad): 595 """Returns grad * digamma(x).""" 596 x = op.inputs[0] 597 with ops.control_dependencies([grad]): 598 x = math_ops.conj(x) 599 return grad * math_ops.digamma(x) 600 601 602 @ops.RegisterGradient("Digamma") 603 def _DigammaGrad(op, grad): 604 """Compute gradient of the digamma function with respect to its argument.""" 605 x = op.inputs[0] 606 with ops.control_dependencies([grad]): 607 x = math_ops.conj(x) 608 return grad * math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x) 609 610 611 @ops.RegisterGradient("Igamma") 612 def _IgammaGrad(op, grad): 613 """Returns gradient of igamma(a, x) with respect to x.""" 614 # TODO(ebrevdo): Perhaps add the derivative w.r.t. a 615 a = op.inputs[0] 616 x = op.inputs[1] 617 sa = array_ops.shape(a) 618 sx = array_ops.shape(x) 619 # pylint: disable=protected-access 620 unused_ra, rx = gen_array_ops._broadcast_gradient_args(sa, sx) 621 # pylint: enable=protected-access 622 623 # Perform operations in log space before summing, because Gamma(a) 624 # and Gamma'(a) can grow large. 625 partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) - math_ops.lgamma(a)) 626 # TODO(b/36815900): Mark None return values as NotImplemented 627 return (None, array_ops.reshape( 628 math_ops.reduce_sum(partial_x * grad, rx), sx)) 629 630 631 @ops.RegisterGradient("Igammac") 632 def _IgammacGrad(op, grad): 633 """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. x.""" 634 _, igamma_grad_x = _IgammaGrad(op, grad) 635 return None, -igamma_grad_x 636 637 638 @ops.RegisterGradient("Betainc") 639 def _BetaincGrad(op, grad): 640 """Returns gradient of betainc(a, b, x) with respect to x.""" 641 # TODO(ebrevdo): Perhaps add the derivative w.r.t. a, b 642 a, b, x = op.inputs 643 644 # two cases: x is a scalar and a/b are same-shaped tensors, or vice 645 # versa; so its sufficient to check against shape(a). 646 sa = array_ops.shape(a) 647 sx = array_ops.shape(x) 648 # pylint: disable=protected-access 649 _, rx = gen_array_ops._broadcast_gradient_args(sa, sx) 650 # pylint: enable=protected-access 651 652 # Perform operations in log space before summing, because terms 653 # can grow large. 654 log_beta = ( 655 gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) - 656 gen_math_ops.lgamma(a + b)) 657 partial_x = math_ops.exp((b - 1) * math_ops.log(1 - x) + 658 (a - 1) * math_ops.log(x) - log_beta) 659 660 # TODO(b/36815900): Mark None return values as NotImplemented 661 return ( 662 None, # da 663 None, # db 664 array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) 665 666 667 @ops.RegisterGradient("Zeta") 668 def _ZetaGrad(op, grad): 669 """Returns gradient of zeta(x, q) with respect to x and q.""" 670 # TODO(tillahoffmann): Add derivative with respect to x 671 x = op.inputs[0] 672 q = op.inputs[1] 673 # Broadcast gradients 674 sx = array_ops.shape(x) 675 sq = array_ops.shape(q) 676 # pylint: disable=protected-access 677 unused_rx, rq = gen_array_ops._broadcast_gradient_args(sx, sq) 678 # pylint: enable=protected-access 679 # Evaluate gradient 680 with ops.control_dependencies([grad]): 681 x = math_ops.conj(x) 682 q = math_ops.conj(q) 683 partial_q = -x * math_ops.zeta(x + 1, q) 684 # TODO(b/36815900): Mark None return values as NotImplemented 685 return (None, 686 array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq)) 687 688 689 @ops.RegisterGradient("Polygamma") 690 def _PolygammaGrad(op, grad): 691 """Returns gradient of psi(n, x) with respect to n and x.""" 692 # TODO(tillahoffmann): Add derivative with respect to n 693 n = op.inputs[0] 694 x = op.inputs[1] 695 # Broadcast gradients 696 sn = array_ops.shape(n) 697 sx = array_ops.shape(x) 698 # pylint: disable=protected-access 699 unused_rn, rx = gen_array_ops._broadcast_gradient_args(sn, sx) 700 # pylint: enable=protected-access 701 # Evaluate gradient 702 with ops.control_dependencies([grad]): 703 n = math_ops.conj(n) 704 x = math_ops.conj(x) 705 partial_x = math_ops.polygamma(n + 1, x) 706 # TODO(b/36815900): Mark None return values as NotImplemented 707 return (None, 708 array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) 709 710 711 @ops.RegisterGradient("Sigmoid") 712 def _SigmoidGrad(op, grad): 713 """Returns grad * sigmoid(x) * (1 - sigmoid(x)).""" 714 y = op.outputs[0] # y = sigmoid(x) 715 with ops.control_dependencies([grad]): 716 y = math_ops.conj(y) 717 # pylint: disable=protected-access 718 return gen_math_ops._sigmoid_grad(y, grad) 719 720 721 @ops.RegisterGradient("SigmoidGrad") 722 def _SigmoidGradGrad(op, grad): 723 with ops.control_dependencies([grad]): 724 a = math_ops.conj(op.inputs[0]) 725 b = math_ops.conj(op.inputs[1]) 726 gb = grad * b 727 # pylint: disable=protected-access 728 return gb - 2.0 * gb * a, gen_math_ops._sigmoid_grad(a, grad) 729 730 731 @ops.RegisterGradient("Sign") 732 def _SignGrad(op, _): 733 """Returns 0.""" 734 x = op.inputs[0] 735 return array_ops.zeros(array_ops.shape(x), dtype=x.dtype) 736 737 738 @ops.RegisterGradient("Sin") 739 def _SinGrad(op, grad): 740 """Returns grad * cos(x).""" 741 x = op.inputs[0] 742 with ops.control_dependencies([grad]): 743 x = math_ops.conj(x) 744 return grad * math_ops.cos(x) 745 746 747 @ops.RegisterGradient("Cos") 748 def _CosGrad(op, grad): 749 """Returns grad * -sin(x).""" 750 x = op.inputs[0] 751 with ops.control_dependencies([grad]): 752 x = math_ops.conj(x) 753 return -grad * math_ops.sin(x) 754 755 756 @ops.RegisterGradient("Tan") 757 def _TanGrad(op, grad): 758 """Returns grad * 1/sec^2(x).""" 759 x = op.inputs[0] 760 with ops.control_dependencies([grad]): 761 x = math_ops.conj(x) 762 secx = math_ops.reciprocal(math_ops.cos(x)) 763 secx2 = math_ops.square(secx) 764 return grad * secx2 765 766 767 @ops.RegisterGradient("Asin") 768 def _AsinGrad(op, grad): 769 """Returns grad * 1/sqrt(1-x^2).""" 770 x = op.inputs[0] 771 with ops.control_dependencies([grad]): 772 x = math_ops.conj(x) 773 x2 = math_ops.square(x) 774 one = constant_op.constant(1, dtype=grad.dtype) 775 den = math_ops.sqrt(math_ops.subtract(one, x2)) 776 inv = math_ops.reciprocal(den) 777 return grad * inv 778 779 780 @ops.RegisterGradient("Acos") 781 def _AcosGrad(op, grad): 782 """Returns grad * -1/sqrt(1-x^2).""" 783 x = op.inputs[0] 784 with ops.control_dependencies([grad]): 785 x = math_ops.conj(x) 786 x2 = math_ops.square(x) 787 one = constant_op.constant(1, dtype=grad.dtype) 788 den = math_ops.sqrt(math_ops.subtract(one, x2)) 789 inv = math_ops.reciprocal(den) 790 return -grad * inv 791 792 793 @ops.RegisterGradient("Atan") 794 def _AtanGrad(op, grad): 795 """Returns grad * 1/ (1 + x^2).""" 796 x = op.inputs[0] 797 with ops.control_dependencies([grad]): 798 x = math_ops.conj(x) 799 x2 = math_ops.square(x) 800 one = constant_op.constant(1, dtype=grad.dtype) 801 inv = math_ops.reciprocal(math_ops.add(one, x2)) 802 return grad * inv 803 804 805 @ops.RegisterGradient("Atan2") 806 def _Atan2Grad(op, grad): 807 """Returns grad * x / (x^2 + y^2), grad * -y / (x^2 + y^2).""" 808 y = op.inputs[0] 809 x = op.inputs[1] 810 with ops.control_dependencies([grad]): 811 grad_inv = grad / (math_ops.square(x) + math_ops.square(y)) 812 return x * grad_inv, -y * grad_inv 813 814 815 @ops.RegisterGradient("AddN") 816 def _AddNGrad(op, grad): 817 """Copies the gradient to all inputs.""" 818 # Not broadcasting. 819 return [grad] * len(op.inputs) 820 821 822 def _ShapesFullySpecifiedAndEqual(x, y, grad): 823 # pylint: disable=protected-access 824 x_shape = x._shape_tuple() 825 y_shape = y._shape_tuple() 826 grad_shape = grad._shape_tuple() 827 # pylint: enable=protected-access 828 return (x_shape == y_shape and x_shape == grad_shape and 829 x_shape is not None and None not in x_shape) 830 831 832 @ops.RegisterGradient("Add") 833 def _AddGrad(op, grad): 834 """Gradient for Add.""" 835 x = op.inputs[0] 836 y = op.inputs[1] 837 if (isinstance(grad, ops.Tensor) and 838 _ShapesFullySpecifiedAndEqual(x, y, grad)): 839 return grad, grad 840 sx = array_ops.shape(x) 841 sy = array_ops.shape(y) 842 # pylint: disable=protected-access 843 rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) 844 # pylint: enable=protected-access 845 return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx), 846 array_ops.reshape(math_ops.reduce_sum(grad, ry), sy)) 847 848 849 @ops.RegisterGradient("Sub") 850 def _SubGrad(op, grad): 851 """Gradient for Sub.""" 852 x = op.inputs[0] 853 y = op.inputs[1] 854 if (isinstance(grad, ops.Tensor) and 855 _ShapesFullySpecifiedAndEqual(x, y, grad)): 856 return grad, -grad 857 sx = array_ops.shape(x) 858 sy = array_ops.shape(y) 859 # pylint: disable=protected-access 860 rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) 861 # pylint: enable=protected-access 862 return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx), 863 array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy)) 864 865 866 @ops.RegisterGradient("Mul") 867 def _MulGrad(op, grad): 868 """The gradient of scalar multiplication.""" 869 x = op.inputs[0] 870 y = op.inputs[1] 871 # pylint: disable=protected-access 872 if (isinstance(grad, ops.Tensor) and 873 _ShapesFullySpecifiedAndEqual(x, y, grad) and 874 grad.dtype in (dtypes.int32, dtypes.float32)): 875 return gen_math_ops._mul(grad, y), gen_math_ops._mul(grad, x) 876 assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype) 877 sx = array_ops.shape(x) 878 sy = array_ops.shape(y) 879 rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) 880 x = math_ops.conj(x) 881 y = math_ops.conj(y) 882 return (array_ops.reshape( 883 math_ops.reduce_sum(gen_math_ops._mul(grad, y), rx), sx), 884 array_ops.reshape( 885 math_ops.reduce_sum(gen_math_ops._mul(x, grad), ry), sy)) 886 # pylint: enable=protected-access 887 888 889 @ops.RegisterGradient("Div") 890 def _DivGrad(op, grad): 891 """The gradient for the Div operator.""" 892 x = op.inputs[0] 893 y = op.inputs[1] 894 sx = array_ops.shape(x) 895 sy = array_ops.shape(y) 896 # pylint: disable=protected-access 897 rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) 898 # pylint: enable=protected-access 899 x = math_ops.conj(x) 900 y = math_ops.conj(y) 901 return (array_ops.reshape(math_ops.reduce_sum(math_ops.div(grad, y), rx), sx), 902 array_ops.reshape( 903 math_ops.reduce_sum(grad * math_ops.div(math_ops.div(-x, y), y), 904 ry), sy)) 905 906 907 @ops.RegisterGradient("FloorDiv") 908 def _FloorDivGrad(_, unused_grad): 909 """The gradient for the FloorDiv operator.""" 910 return None, None 911 912 913 @ops.RegisterGradient("FloorMod") 914 def _FloorModGrad(op, grad): 915 """Returns grad * (1, -floor(x/y)).""" 916 x = math_ops.conj(op.inputs[0]) 917 y = math_ops.conj(op.inputs[1]) 918 919 sx = array_ops.shape(x) 920 sy = array_ops.shape(y) 921 # pylint: disable=protected-access 922 rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) 923 # pylint: enable=protected-access 924 floor_xy = math_ops.floor_div(x, y) 925 gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx) 926 gy = array_ops.reshape( 927 math_ops.reduce_sum(grad * math_ops.negative(floor_xy), ry), sy) 928 return gx, gy 929 930 931 @ops.RegisterGradient("TruncateDiv") 932 def _TruncateDivGrad(_, unused_grad): 933 return None, None 934 935 936 @ops.RegisterGradient("RealDiv") 937 def _RealDivGrad(op, grad): 938 """RealDiv op gradient.""" 939 x = op.inputs[0] 940 y = op.inputs[1] 941 sx = array_ops.shape(x) 942 sy = array_ops.shape(y) 943 # pylint: disable=protected-access 944 rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) 945 # pylint: enable=protected-access 946 x = math_ops.conj(x) 947 y = math_ops.conj(y) 948 return (array_ops.reshape( 949 math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx), 950 array_ops.reshape( 951 math_ops.reduce_sum( 952 grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy)) 953 954 955 @ops.RegisterGradient("Pow") 956 def _PowGrad(op, grad): 957 """Returns grad * (y*x^(y-1), z*log(x)).""" 958 x = op.inputs[0] 959 y = op.inputs[1] 960 z = op.outputs[0] 961 sx = array_ops.shape(x) 962 sy = array_ops.shape(y) 963 rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) 964 x = math_ops.conj(x) 965 y = math_ops.conj(y) 966 z = math_ops.conj(z) 967 gx = array_ops.reshape( 968 math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx) 969 # Avoid false singularity at x = 0 970 if x.dtype.is_complex: 971 # real(x) < 0 is fine for the complex case 972 log_x = array_ops.where( 973 math_ops.not_equal(x, 0), math_ops.log(x), array_ops.zeros_like(x)) 974 else: 975 # There's no sensible real value to return if x < 0, so return 0 976 log_x = array_ops.where(x > 0, math_ops.log(x), array_ops.zeros_like(x)) 977 gy = array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy) 978 return gx, gy 979 980 981 def _MaximumMinimumGrad(op, grad, selector_op): 982 """Factor out the code for the gradient of Maximum or Minimum.""" 983 x = op.inputs[0] 984 y = op.inputs[1] 985 gdtype = grad.dtype 986 sx = array_ops.shape(x) 987 sy = array_ops.shape(y) 988 gradshape = array_ops.shape(grad) 989 zeros = array_ops.zeros(gradshape, gdtype) 990 xmask = selector_op(x, y) 991 rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) 992 xgrad = array_ops.where(xmask, grad, zeros) 993 ygrad = array_ops.where(xmask, zeros, grad) 994 gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx) 995 gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy) 996 return (gx, gy) 997 998 999 @ops.RegisterGradient("Maximum") 1000 def _MaximumGrad(op, grad): 1001 """Returns grad*(x > y, x <= y) with type of grad.""" 1002 return _MaximumMinimumGrad(op, grad, math_ops.greater_equal) 1003 1004 1005 @ops.RegisterGradient("Minimum") 1006 def _MinimumGrad(op, grad): 1007 """Returns grad*(x < y, x >= y) with type of grad.""" 1008 return _MaximumMinimumGrad(op, grad, math_ops.less_equal) 1009 1010 1011 @ops.RegisterGradient("SquaredDifference") 1012 def _SquaredDifferenceGrad(op, grad): 1013 """Returns the gradient for (x-y)^2.""" 1014 x = op.inputs[0] 1015 y = op.inputs[1] 1016 sx = array_ops.shape(x) 1017 sy = array_ops.shape(y) 1018 # pylint: disable=protected-access 1019 rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) 1020 # pylint: enable=protected-access 1021 with ops.control_dependencies([grad]): 1022 # The parens ensure that if grad is IndexedSlices, it'll get multiplied by 1023 # Tensor (not a number like 2.0) which causes it to convert to Tensor. 1024 x_grad = math_ops.scalar_mul(2.0, grad) * (x - y) 1025 return (array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx), 1026 -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy)) 1027 1028 1029 # Logical operations have no gradients. 1030 ops.NotDifferentiable("Less") 1031 ops.NotDifferentiable("LessEqual") 1032 ops.NotDifferentiable("Greater") 1033 ops.NotDifferentiable("GreaterEqual") 1034 ops.NotDifferentiable("Equal") 1035 ops.NotDifferentiable("ApproximateEqual") 1036 ops.NotDifferentiable("NotEqual") 1037 ops.NotDifferentiable("LogicalAnd") 1038 ops.NotDifferentiable("LogicalOr") 1039 ops.NotDifferentiable("LogicalNot") 1040 1041 1042 @ops.RegisterGradient("Select") 1043 def _SelectGrad(op, grad): 1044 c = op.inputs[0] 1045 x = op.inputs[1] 1046 zeros = array_ops.zeros_like(x) 1047 return (None, array_ops.where(c, grad, zeros), array_ops.where( 1048 c, zeros, grad)) 1049 1050 1051 @ops.RegisterGradient("MatMul") 1052 def _MatMulGrad(op, grad): 1053 """Gradient for MatMul.""" 1054 1055 t_a = op.get_attr("transpose_a") 1056 t_b = op.get_attr("transpose_b") 1057 a = math_ops.conj(op.inputs[0]) 1058 b = math_ops.conj(op.inputs[1]) 1059 # pylint: disable=protected-access 1060 if not t_a and not t_b: 1061 grad_a = gen_math_ops._mat_mul(grad, b, transpose_b=True) 1062 grad_b = gen_math_ops._mat_mul(a, grad, transpose_a=True) 1063 elif not t_a and t_b: 1064 grad_a = gen_math_ops._mat_mul(grad, b) 1065 grad_b = gen_math_ops._mat_mul(grad, a, transpose_a=True) 1066 elif t_a and not t_b: 1067 grad_a = gen_math_ops._mat_mul(b, grad, transpose_b=True) 1068 grad_b = gen_math_ops._mat_mul(a, grad) 1069 elif t_a and t_b: 1070 grad_a = gen_math_ops._mat_mul(b, grad, transpose_a=True, transpose_b=True) 1071 grad_b = gen_math_ops._mat_mul(grad, a, transpose_a=True, transpose_b=True) 1072 # pylint: enable=protected-access 1073 return grad_a, grad_b 1074 1075 1076 @ops.RegisterGradient("SparseMatMul") 1077 def _SparseMatMulGrad(op, grad): 1078 """Gradient for SparseMatMul.""" 1079 1080 t_a = op.get_attr("transpose_a") 1081 t_b = op.get_attr("transpose_b") 1082 is_sparse = { 1083 op.inputs[0]: op.get_attr("a_is_sparse"), 1084 op.inputs[1]: op.get_attr("b_is_sparse"), 1085 # Use heuristic to figure out if grad might be sparse 1086 grad: context.in_graph_mode() and (grad.op.type == "ReluGrad") 1087 } 1088 1089 def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False): 1090 """Helper function to create SparseMatMul op.""" 1091 1092 assert t1 in is_sparse and t2 in is_sparse 1093 t1_sparse = is_sparse[t1] 1094 t2_sparse = is_sparse[t2] 1095 if transpose_b: 1096 t2 = array_ops.transpose(t2) 1097 transpose_b = False 1098 prod = math_ops.matmul( 1099 t1, 1100 t2, 1101 transpose_a=transpose_a, 1102 transpose_b=transpose_b, 1103 a_is_sparse=t1_sparse, 1104 b_is_sparse=t2_sparse) 1105 if prod.dtype != out_dtype: 1106 prod = math_ops.cast(prod, out_dtype) 1107 return prod 1108 1109 dtype_a = op.inputs[0].dtype 1110 dtype_b = op.inputs[1].dtype 1111 if not t_a and not t_b: 1112 return (_SparseMatMul(grad, op.inputs[1], dtype_a, transpose_b=True), 1113 _SparseMatMul(op.inputs[0], grad, dtype_b, transpose_a=True)) 1114 elif not t_a and t_b: 1115 return (_SparseMatMul(grad, op.inputs[1], dtype_a), 1116 _SparseMatMul(grad, op.inputs[0], dtype_b, transpose_a=True)) 1117 elif t_a and not t_b: 1118 return (_SparseMatMul(op.inputs[1], grad, dtype_a, transpose_b=True), 1119 _SparseMatMul(op.inputs[0], grad, dtype_b)) 1120 elif t_a and t_b: 1121 return (_SparseMatMul( 1122 op.inputs[1], grad, dtype_a, transpose_a=True, transpose_b=True), 1123 _SparseMatMul( 1124 grad, op.inputs[0], dtype_b, transpose_a=True, 1125 transpose_b=True)) 1126 1127 1128 @ops.RegisterGradient("Floor") 1129 def _FloorGrad(_, unused_grad): 1130 return [None] 1131 1132 1133 @ops.RegisterGradient("Ceil") 1134 def _CeilGrad(_, unused_grad): 1135 return [None] 1136 1137 1138 @ops.RegisterGradient("Round") 1139 def _RoundGrad(_, unused_grad): 1140 return [None] 1141 1142 1143 @ops.RegisterGradient("Rint") 1144 def _RintGrad(_, unused_grad): 1145 # the gradient of Rint is zero 1146 return [None] 1147 1148 1149 @ops.RegisterGradient("BatchMatMul") 1150 def _BatchMatMul(op, grad): 1151 """Returns the gradient of x and y given the gradient of x * y.""" 1152 x = op.inputs[0] 1153 y = op.inputs[1] 1154 adj_x = op.get_attr("adj_x") 1155 adj_y = op.get_attr("adj_y") 1156 1157 if not adj_x: 1158 if not adj_y: 1159 grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True) 1160 grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False) 1161 else: 1162 grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False) 1163 grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False) 1164 else: 1165 if not adj_y: 1166 grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True) 1167 grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False) 1168 else: 1169 grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True) 1170 grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True) 1171 1172 return grad_x, grad_y 1173 1174 1175 ops.NotDifferentiable("Range") 1176 ops.NotDifferentiable("LinSpace") 1177 1178 1179 @ops.RegisterGradient("Complex") 1180 def _ComplexGrad(op, grad): 1181 """Returns the real and imaginary components of 'grad', respectively.""" 1182 x = op.inputs[0] 1183 y = op.inputs[1] 1184 sx = array_ops.shape(x) 1185 sy = array_ops.shape(y) 1186 rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) 1187 return (array_ops.reshape(math_ops.reduce_sum(math_ops.real(grad), rx), sx), 1188 array_ops.reshape(math_ops.reduce_sum(math_ops.imag(grad), ry), sy)) 1189 1190 1191 @ops.RegisterGradient("Real") 1192 def _RealGrad(_, grad): 1193 """Returns 'grad' as the real part and set the imaginary part 0.""" 1194 zero = constant_op.constant(0, dtype=grad.dtype) 1195 return math_ops.complex(grad, zero) 1196 1197 1198 @ops.RegisterGradient("Imag") 1199 def _ImagGrad(_, grad): 1200 """Returns 'grad' as the imaginary part and set the real part 0.""" 1201 zero = constant_op.constant(0, dtype=grad.dtype) 1202 return math_ops.complex(zero, grad) 1203 1204 1205 @ops.RegisterGradient("Angle") 1206 def _AngleGrad(op, grad): 1207 """Returns -grad / (Im(x) + iRe(x))""" 1208 x = op.inputs[0] 1209 with ops.control_dependencies([grad]): 1210 re = math_ops.real(x) 1211 im = math_ops.imag(x) 1212 z = math_ops.reciprocal(math_ops.complex(im, re)) 1213 zero = constant_op.constant(0, dtype=grad.dtype) 1214 complex_grad = math_ops.complex(grad, zero) 1215 return -complex_grad * z 1216 1217 1218 @ops.RegisterGradient("Conj") 1219 def _ConjGrad(_, grad): 1220 """Returns the complex conjugate of grad.""" 1221 return math_ops.conj(grad) 1222 1223 1224 @ops.RegisterGradient("ComplexAbs") 1225 def _ComplexAbsGrad(op, grad): 1226 """Returns the gradient of ComplexAbs.""" 1227 # TODO(b/27786104): The cast to complex could be removed once arithmetic 1228 # supports mixtures of complex64 and real values. 1229 return (math_ops.complex(grad, array_ops.zeros_like(grad)) * math_ops.sign( 1230 op.inputs[0])) 1231 1232 1233 @ops.RegisterGradient("Cast") 1234 def _CastGrad(op, grad): 1235 t = [ 1236 dtypes.float16, dtypes.float32, dtypes.float64, dtypes.bfloat16, 1237 dtypes.complex64, dtypes.complex128 1238 ] 1239 src_type = op.inputs[0].dtype.base_dtype 1240 dst_type = grad.dtype.base_dtype 1241 if src_type in t and dst_type in t: 1242 return math_ops.cast(grad, src_type) 1243 else: 1244 return None 1245 1246 1247 @ops.RegisterGradient("Cross") 1248 def _CrossGrad(op, grad): 1249 u = op.inputs[0] 1250 v = op.inputs[1] 1251 return (math_ops.cross(v, grad), math_ops.cross(grad, u)) 1252 1253 1254 @ops.RegisterGradient("Cumsum") 1255 def _CumsumGrad(op, grad): 1256 axis = op.inputs[1] 1257 exclusive = op.get_attr("exclusive") 1258 reverse = op.get_attr("reverse") 1259 return [ 1260 math_ops.cumsum(grad, axis, exclusive=exclusive, reverse=not reverse), 1261 None 1262 ] 1263 1264 1265 @ops.RegisterGradient("Cumprod") 1266 def _CumprodGrad(op, grad): 1267 x = op.inputs[0] 1268 axis = op.inputs[1] 1269 exclusive = op.get_attr("exclusive") 1270 reverse = op.get_attr("reverse") 1271 1272 # TODO This fails when x contains 0 and should be fixed 1273 prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse) 1274 out = math_ops.cumsum( 1275 prod * grad, axis, exclusive=exclusive, reverse=not reverse) 1276 return [out / x, None] 1277