1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define _USE_MATH_DEFINES 17 #include <cmath> 18 19 #include "tensorflow/cc/ops/array_ops_internal.h" 20 #include "tensorflow/cc/ops/math_ops_internal.h" 21 #include "tensorflow/cc/ops/standard_ops.h" 22 23 #include "tensorflow/cc/framework/grad_op_registry.h" 24 #include "tensorflow/cc/framework/gradients.h" 25 26 namespace tensorflow { 27 namespace ops { 28 namespace { 29 30 // Logical operations have no gradients. 31 REGISTER_NO_GRADIENT_OP("Less"); 32 REGISTER_NO_GRADIENT_OP("LessEqual"); 33 REGISTER_NO_GRADIENT_OP("Greater"); 34 REGISTER_NO_GRADIENT_OP("GreaterEqual"); 35 REGISTER_NO_GRADIENT_OP("Equal"); 36 REGISTER_NO_GRADIENT_OP("ApproximateEqual"); 37 REGISTER_NO_GRADIENT_OP("NotEqual"); 38 REGISTER_NO_GRADIENT_OP("LogicalAnd"); 39 REGISTER_NO_GRADIENT_OP("LogicalOr"); 40 REGISTER_NO_GRADIENT_OP("LogicalNot"); 41 42 // Conjugate helper function returns the conjugate of an Output if it 43 // is complex valued. 44 Output ConjugateHelper(const Scope& scope, const Output& out) { 45 DataType dtype = out.type(); 46 if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { 47 return Conj(scope, out); 48 } else { 49 return out; 50 } 51 } 52 53 // TODO(andydavis) Add control dependencies to gradient functions (as needed). 54 55 Status AbsGrad(const Scope& scope, const Operation& op, 56 const std::vector<Output>& grad_inputs, 57 std::vector<Output>* grad_outputs) { 58 // dx = dy * sign(x) 59 grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0)))); 60 return scope.status(); 61 } 62 REGISTER_GRADIENT_OP("Abs", AbsGrad); 63 64 Status NegGrad(const Scope& scope, const Operation& op, 65 const std::vector<Output>& grad_inputs, 66 std::vector<Output>* grad_outputs) { 67 // dx = -dy; 68 grad_outputs->push_back(Neg(scope, grad_inputs[0])); 69 return scope.status(); 70 } 71 REGISTER_GRADIENT_OP("Neg", NegGrad); 72 73 Status InvGrad(const Scope& scope, const Operation& op, 74 const std::vector<Output>& grad_inputs, 75 std::vector<Output>* grad_outputs) { 76 // Use the built-in operator. 77 grad_outputs->push_back( 78 internal::ReciprocalGrad(scope, op.output(0), grad_inputs[0])); 79 return scope.status(); 80 } 81 REGISTER_GRADIENT_OP("Inv", InvGrad); 82 REGISTER_GRADIENT_OP("Reciprocal", InvGrad); 83 84 Status SquareGrad(const Scope& scope, const Operation& op, 85 const std::vector<Output>& grad_inputs, 86 std::vector<Output>* grad_outputs) { 87 // dy/dx = (2 * x) 88 auto two = Cast(scope, Const(scope, 2), op.input(0).type()); 89 auto dydx = Mul(scope, two, op.input(0)); 90 // grad(x) = grad(y) * conj(dy/dx) 91 grad_outputs->push_back( 92 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 93 return scope.status(); 94 } 95 REGISTER_GRADIENT_OP("Square", SquareGrad); 96 97 Status SqrtGrad(const Scope& scope, const Operation& op, 98 const std::vector<Output>& grad_inputs, 99 std::vector<Output>* grad_outputs) { 100 // Use the built-in operator. 101 grad_outputs->push_back( 102 internal::SqrtGrad(scope, op.output(0), grad_inputs[0])); 103 return scope.status(); 104 } 105 REGISTER_GRADIENT_OP("Sqrt", SqrtGrad); 106 107 Status RsqrtGrad(const Scope& scope, const Operation& op, 108 const std::vector<Output>& grad_inputs, 109 std::vector<Output>* grad_outputs) { 110 // Use the built-in operator. 111 grad_outputs->push_back( 112 internal::RsqrtGrad(scope, op.output(0), grad_inputs[0])); 113 return scope.status(); 114 } 115 REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad); 116 117 Status ExpGrad(const Scope& scope, const Operation& op, 118 const std::vector<Output>& grad_inputs, 119 std::vector<Output>* grad_outputs) { 120 // dy/dx = exp(x) = y 121 // grad(x) = grad(y) * conj(dy/dx) 122 // = grad(y) * conj(y) 123 grad_outputs->push_back( 124 Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0)))); 125 return scope.status(); 126 } 127 REGISTER_GRADIENT_OP("Exp", ExpGrad); 128 129 Status Expm1Grad(const Scope& scope, const Operation& op, 130 const std::vector<Output>& grad_inputs, 131 std::vector<Output>* grad_outputs) { 132 // y = expm1(x) 133 // dy/dx = exp(x) 134 auto dydx = Exp(scope, op.input(0)); 135 // grad(x) = grad(y) * conj(dy/dx) 136 grad_outputs->push_back( 137 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 138 return scope.status(); 139 } 140 REGISTER_GRADIENT_OP("Expm1", Expm1Grad); 141 142 Status LogGrad(const Scope& scope, const Operation& op, 143 const std::vector<Output>& grad_inputs, 144 std::vector<Output>* grad_outputs) { 145 // y = log(x) 146 // dy/dx = 1 / x 147 auto dydx = Reciprocal(scope, op.input(0)); 148 // grad(x) = grad(y) * conj(dy/dx) 149 grad_outputs->push_back( 150 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 151 return scope.status(); 152 } 153 REGISTER_GRADIENT_OP("Log", LogGrad); 154 155 Status Log1pGrad(const Scope& scope, const Operation& op, 156 const std::vector<Output>& grad_inputs, 157 std::vector<Output>* grad_outputs) { 158 // y = log1p(x) 159 // dy/dx = 1 / (1 + x) 160 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 161 auto dydx = Reciprocal(scope, Add(scope, one, op.input(0))); 162 // grad(x) = grad(y) * conj(dy/dx) 163 grad_outputs->push_back( 164 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 165 return scope.status(); 166 } 167 REGISTER_GRADIENT_OP("Log1p", Log1pGrad); 168 169 Status SinhGrad(const Scope& scope, const Operation& op, 170 const std::vector<Output>& grad_inputs, 171 std::vector<Output>* grad_outputs) { 172 // y = sinh(x) 173 // dy/dx = cosh(x) 174 auto dydx = Cosh(scope, op.input(0)); 175 // grad(x) = grad(y) * conj(dy/dx) 176 grad_outputs->push_back( 177 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 178 return scope.status(); 179 } 180 REGISTER_GRADIENT_OP("Sinh", SinhGrad); 181 182 Status CoshGrad(const Scope& scope, const Operation& op, 183 const std::vector<Output>& grad_inputs, 184 std::vector<Output>* grad_outputs) { 185 // y = cosh(x) 186 // dy/dx = sinh(x) 187 auto dydx = Sinh(scope, op.input(0)); 188 // grad(x) = grad(y) * conj(dy/dx) 189 grad_outputs->push_back( 190 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 191 return scope.status(); 192 } 193 REGISTER_GRADIENT_OP("Cosh", CoshGrad); 194 195 Status TanhGrad(const Scope& scope, const Operation& op, 196 const std::vector<Output>& grad_inputs, 197 std::vector<Output>* grad_outputs) { 198 // Use the built-in operator. 199 // Note that the built-in operator does not return the conjugate of 200 // the gradient. 201 auto grad = grad_inputs[0]; 202 // Optimization to avoid calculating conj(y) until the gradient is 203 // evaluated. 204 Scope grad_scope = scope.WithControlDependencies(grad); 205 auto y = ConjugateHelper(grad_scope, op.output(0)); 206 grad_outputs->push_back(internal::TanhGrad(grad_scope, y, grad)); 207 return grad_scope.status(); 208 } 209 REGISTER_GRADIENT_OP("Tanh", TanhGrad); 210 211 Status AsinhGrad(const Scope& scope, const Operation& op, 212 const std::vector<Output>& grad_inputs, 213 std::vector<Output>* grad_outputs) { 214 // y = asinh(x) 215 // dy/dx = 1 / cosh(y) 216 auto dydx = Reciprocal(scope, Cosh(scope, op.output(0))); 217 // grad(x) = grad(y) * conj(dy/dx) 218 grad_outputs->push_back( 219 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 220 return scope.status(); 221 } 222 REGISTER_GRADIENT_OP("Asinh", AsinhGrad); 223 224 Status AcoshGrad(const Scope& scope, const Operation& op, 225 const std::vector<Output>& grad_inputs, 226 std::vector<Output>* grad_outputs) { 227 // y = acosh(x) 228 // dy/dx = 1 / sinh(y) 229 auto dydx = Reciprocal(scope, Sinh(scope, op.output(0))); 230 // grad(x) = grad(y) * conj(dy/dx) 231 grad_outputs->push_back( 232 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 233 return scope.status(); 234 } 235 REGISTER_GRADIENT_OP("Acosh", AcoshGrad); 236 237 Status AtanhGrad(const Scope& scope, const Operation& op, 238 const std::vector<Output>& grad_inputs, 239 std::vector<Output>* grad_outputs) { 240 // y = atanh(x) 241 // dy/dx = 1 / (1 - x^2) 242 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 243 auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0)))); 244 // grad(x) = grad(y) * conj(dy/dx) 245 grad_outputs->push_back( 246 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 247 return scope.status(); 248 } 249 REGISTER_GRADIENT_OP("Atanh", AtanhGrad); 250 251 Status SigmoidGrad(const Scope& scope, const Operation& op, 252 const std::vector<Output>& grad_inputs, 253 std::vector<Output>* grad_outputs) { 254 // Use the built-in operator. 255 // Note that the built-in operator does not return the conjugate of 256 // the gradient. 257 auto grad = grad_inputs[0]; 258 // Optimization to avoid calculating conj(y) until the gradient is 259 // evaluated. 260 Scope grad_scope = scope.WithControlDependencies(grad); 261 auto y = ConjugateHelper(grad_scope, op.output(0)); 262 grad_outputs->push_back(internal::SigmoidGrad(grad_scope, y, grad)); 263 return grad_scope.status(); 264 } 265 REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad); 266 267 Status SignGrad(const Scope& scope, const Operation& op, 268 const std::vector<Output>& grad_inputs, 269 std::vector<Output>* grad_outputs) { 270 auto shape = Shape(scope, op.input(0)); 271 auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type()); 272 auto dx = Fill(scope, shape, zero); 273 grad_outputs->push_back(dx); 274 return scope.status(); 275 } 276 REGISTER_GRADIENT_OP("Sign", SignGrad); 277 278 Status SinGrad(const Scope& scope, const Operation& op, 279 const std::vector<Output>& grad_inputs, 280 std::vector<Output>* grad_outputs) { 281 // y = sin(x) 282 // dy/dx = cos(x) 283 auto dydx = Cos(scope, op.input(0)); 284 // grad(x) = grad(y) * conj(dy/dx) 285 grad_outputs->push_back( 286 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 287 return scope.status(); 288 } 289 REGISTER_GRADIENT_OP("Sin", SinGrad); 290 291 Status CosGrad(const Scope& scope, const Operation& op, 292 const std::vector<Output>& grad_inputs, 293 std::vector<Output>* grad_outputs) { 294 // y = cos(x) 295 // dy/dx = -sin(x) 296 auto dydx = Neg(scope, Sin(scope, op.input(0))); 297 // grad(x) = grad(y) * conj(dy/dx) 298 grad_outputs->push_back( 299 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 300 return scope.status(); 301 } 302 REGISTER_GRADIENT_OP("Cos", CosGrad); 303 304 Status AsinGrad(const Scope& scope, const Operation& op, 305 const std::vector<Output>& grad_inputs, 306 std::vector<Output>* grad_outputs) { 307 // y = asin(x) 308 // dy/dx = 1 / sqrt(1 - x^2) 309 auto x2 = Square(scope, op.input(0)); 310 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 311 auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))); 312 // grad(x) = grad(y) * conj(dy/dx) 313 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); 314 grad_outputs->push_back(dx); 315 return scope.status(); 316 } 317 REGISTER_GRADIENT_OP("Asin", AsinGrad); 318 319 Status AcosGrad(const Scope& scope, const Operation& op, 320 const std::vector<Output>& grad_inputs, 321 std::vector<Output>* grad_outputs) { 322 // y = acos(x) 323 // dy/dx = - 1 / (1 - x * x)^1/2 324 // dx = dy * (- 1 / (1 - x * x)^1/2) 325 auto x2 = Square(scope, op.input(0)); 326 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 327 auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)))); 328 auto dx = Mul(scope, grad_inputs[0], dydx); 329 grad_outputs->push_back(dx); 330 return scope.status(); 331 } 332 REGISTER_GRADIENT_OP("Acos", AcosGrad); 333 334 Status TanGrad(const Scope& scope, const Operation& op, 335 const std::vector<Output>& grad_inputs, 336 std::vector<Output>* grad_outputs) { 337 // y = tan(x) 338 // dy/dx = sec(x)^2 = 1 / cos(x)^2 339 auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0)))); 340 // grad(x) = grad(y) * conj(dy/dx) 341 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); 342 grad_outputs->push_back(dx); 343 return scope.status(); 344 } 345 REGISTER_GRADIENT_OP("Tan", TanGrad); 346 347 Status AtanGrad(const Scope& scope, const Operation& op, 348 const std::vector<Output>& grad_inputs, 349 std::vector<Output>* grad_outputs) { 350 // y = arctan(x) 351 // dy/dx = 1 / (1 + x^2) 352 // dx = dy * (1 / (1 + x^2) 353 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 354 auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0)))); 355 auto dx = Mul(scope, grad_inputs[0], dydx); 356 grad_outputs->push_back(dx); 357 return scope.status(); 358 } 359 REGISTER_GRADIENT_OP("Atan", AtanGrad); 360 361 // BinaryGradCommon handles the setup for binary ops that broadcast 362 // their inputs. 363 Status BinaryGradCommon(const Scope& scope, const Operation& op, 364 std::vector<Output>* grad_outputs, const Output& gx_1, 365 const Output& gx_2) { 366 auto sx_1 = Shape(scope, op.input(0)); 367 auto sx_2 = Shape(scope, op.input(1)); 368 auto rx = internal::BroadcastGradientArgs(scope, sx_1, sx_2); 369 auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1); 370 auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2); 371 grad_outputs->push_back(dx_1); 372 grad_outputs->push_back(dx_2); 373 return scope.status(); 374 } 375 376 Status AddGrad(const Scope& scope, const Operation& op, 377 const std::vector<Output>& grad_inputs, 378 std::vector<Output>* grad_outputs) { 379 // y = x_1 + x_2 380 // dy/dx_1 = dy/dx_2 = 1 381 auto gx_1 = Identity(scope, grad_inputs[0]); 382 auto gx_2 = Identity(scope, grad_inputs[0]); 383 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 384 } 385 REGISTER_GRADIENT_OP("Add", AddGrad); 386 387 Status SubGrad(const Scope& scope, const Operation& op, 388 const std::vector<Output>& grad_inputs, 389 std::vector<Output>* grad_outputs) { 390 // y = x_1 - x_2 391 // dy/dx_1 = 1 392 // dy/dx_2 = -1 393 auto gx_1 = Identity(scope, grad_inputs[0]); 394 auto gx_2 = Neg(scope, grad_inputs[0]); 395 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 396 } 397 REGISTER_GRADIENT_OP("Sub", SubGrad); 398 399 Status MulGrad(const Scope& scope, const Operation& op, 400 const std::vector<Output>& grad_inputs, 401 std::vector<Output>* grad_outputs) { 402 auto x_1 = ConjugateHelper(scope, op.input(0)); 403 auto x_2 = ConjugateHelper(scope, op.input(1)); 404 // y = x_1 * x_2 405 // dy/dx_1 = x_2 406 // dy/dx_2 = x_1 407 auto gx_1 = Mul(scope, grad_inputs[0], x_2); 408 auto gx_2 = Mul(scope, grad_inputs[0], x_1); 409 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 410 } 411 REGISTER_GRADIENT_OP("Mul", MulGrad); 412 413 Status DivGrad(const Scope& scope, const Operation& op, 414 const std::vector<Output>& grad_inputs, 415 std::vector<Output>* grad_outputs) { 416 auto x_1 = ConjugateHelper(scope, op.input(0)); 417 auto x_2 = ConjugateHelper(scope, op.input(1)); 418 // y = x_1 / x_2 419 // dy/dx_1 = 1/x_2 420 // dy/dx_2 = -x_1/x_2^2 421 auto gx_1 = Div(scope, grad_inputs[0], x_2); 422 auto gx_2 = Mul(scope, grad_inputs[0], 423 Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2)); 424 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 425 } 426 REGISTER_GRADIENT_OP("Div", DivGrad); 427 428 Status RealDivGrad(const Scope& scope, const Operation& op, 429 const std::vector<Output>& grad_inputs, 430 std::vector<Output>* grad_outputs) { 431 auto x_1 = ConjugateHelper(scope, op.input(0)); 432 auto x_2 = ConjugateHelper(scope, op.input(1)); 433 // y = x_1 / x_2 434 // dy/dx_1 = 1/x_2 435 // dy/dx_2 = -x_1/x_2^2 436 auto gx_1 = RealDiv(scope, grad_inputs[0], x_2); 437 auto gx_2 = Mul(scope, grad_inputs[0], 438 RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2)); 439 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 440 } 441 REGISTER_GRADIENT_OP("RealDiv", RealDivGrad); 442 443 Status SquaredDifferenceGrad(const Scope& scope, const Operation& op, 444 const std::vector<Output>& grad_inputs, 445 std::vector<Output>* grad_outputs) { 446 auto x_1 = ConjugateHelper(scope, op.input(0)); 447 auto x_2 = ConjugateHelper(scope, op.input(1)); 448 // y = (x_1 - x_2)^2 449 // dy/dx_1 = 2 * (x_1 - x_2) 450 // dy/dx_2 = -2 * (x_1 - x_2) 451 auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type()); 452 auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2))); 453 auto gx_2 = Neg(scope, gx_1); 454 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 455 } 456 REGISTER_GRADIENT_OP("SquaredDifference", SquaredDifferenceGrad); 457 458 Status AddNGrad(const Scope& scope, const Operation& op, 459 const std::vector<Output>& grad_inputs, 460 std::vector<Output>* grad_outputs) { 461 // AddN doesn't support broadcasting, so all the inputs must be the 462 // same shape. 463 // Note: 464 // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k 465 // hence dx_k = dy for all x_k 466 // So the gradient for AddN just transfers the incoming gradient to 467 // all outgoing gradients. 468 auto incoming = Identity(scope, grad_inputs[0]); 469 for (int32 i = 0; i < op.num_inputs(); ++i) { 470 grad_outputs->push_back(incoming); 471 } 472 return scope.status(); 473 } 474 REGISTER_GRADIENT_OP("AddN", AddNGrad); 475 476 Status PowGrad(const Scope& scope, const Operation& op, 477 const std::vector<Output>& grad_inputs, 478 std::vector<Output>* grad_outputs) { 479 auto x = ConjugateHelper(scope, op.input(0)); 480 auto y = ConjugateHelper(scope, op.input(1)); 481 auto z = ConjugateHelper(scope, op.output(0)); 482 auto grad = grad_inputs[0]; 483 // grad * y * pow(x, y - 1) 484 auto one = Cast(scope, Const(scope, 1.0), y.type()); 485 auto gx_1 = Mul(scope, 486 Mul(scope, grad, y), 487 Pow(scope, x, Sub(scope, y, one))); 488 // Avoid false singularity at x = 0 489 DataType x_dtype = x.type(); 490 auto zero = Cast(scope, Const(scope, 0.0), x_dtype); 491 if (x_dtype == DT_COMPLEX64 || x_dtype == DT_COMPLEX128) { 492 // real(x) < 0 is fine for the complex case 493 auto log_x = Where3(scope, 494 NotEqual(scope, x, zero), 495 Log(scope, x), 496 ZerosLike(scope, x)); 497 auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x); 498 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1); 499 } else { 500 // There's no sensible real value to return if x < 0, so return 0 501 auto log_x = Where3(scope, 502 Greater(scope, x, zero), 503 Log(scope, x), 504 ZerosLike(scope, x)); 505 auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x); 506 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1); 507 } 508 } 509 REGISTER_GRADIENT_OP("Pow", PowGrad); 510 511 // MaximumMinimumGradCommon adds shared ops to calculate gradients for 512 // the binary Maximum and Minimum ops. 513 Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op, 514 const std::vector<Output>& grad_inputs, 515 std::vector<Output>* grad_outputs, 516 const Output& comparator) { 517 // comparator is a boolean tensor, with 518 // y = x_1 at points where comparator is true, and x_2 otherwise 519 // Therefore 520 // dy/dx_1 = 1 where comparator is true, and 0 otherwise. 521 // dy/dx_2 = 0 where comparator is true, and 1 otherwise. 522 auto grad = grad_inputs[0]; 523 auto zeros = ZerosLike(scope, grad); 524 auto gx_1 = Where3(scope, comparator, grad, zeros); 525 auto gx_2 = Where3(scope, comparator, zeros, grad); 526 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 527 } 528 529 Status MaximumGrad(const Scope& scope, const Operation& op, 530 const std::vector<Output>& grad_inputs, 531 std::vector<Output>* grad_outputs) { 532 auto comparator = GreaterEqual(scope, op.input(0), op.input(1)); 533 return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs, 534 comparator); 535 } 536 REGISTER_GRADIENT_OP("Maximum", MaximumGrad); 537 538 Status MinimumGrad(const Scope& scope, const Operation& op, 539 const std::vector<Output>& grad_inputs, 540 std::vector<Output>* grad_outputs) { 541 auto comparator = LessEqual(scope, op.input(0), op.input(1)); 542 return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs, 543 comparator); 544 } 545 REGISTER_GRADIENT_OP("Minimum", MinimumGrad); 546 547 Status RealGrad(const Scope& scope, const Operation& op, 548 const std::vector<Output>& grad_inputs, 549 std::vector<Output>* grad_outputs) { 550 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 551 auto dx = Complex(scope, grad_inputs[0], zero); 552 grad_outputs->push_back(dx); 553 return scope.status(); 554 } 555 REGISTER_GRADIENT_OP("Real", RealGrad); 556 557 Status ImagGrad(const Scope& scope, const Operation& op, 558 const std::vector<Output>& grad_inputs, 559 std::vector<Output>* grad_outputs) { 560 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 561 auto dx = Complex(scope, zero, grad_inputs[0]); 562 grad_outputs->push_back(dx); 563 return scope.status(); 564 } 565 REGISTER_GRADIENT_OP("Imag", ImagGrad); 566 567 Status ComplexGrad(const Scope& scope, const Operation& op, 568 const std::vector<Output>& grad_inputs, 569 std::vector<Output>* grad_outputs) { 570 auto gx_1 = Real(scope, grad_inputs[0]); 571 auto gx_2 = Imag(scope, grad_inputs[0]); 572 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 573 } 574 REGISTER_GRADIENT_OP("Complex", ComplexGrad); 575 576 Status AngleGrad(const Scope& scope, const Operation& op, 577 const std::vector<Output>& grad_inputs, 578 std::vector<Output>* grad_outputs) { 579 // y = Angle(x) 580 // dx = -dy / (Im(x) + iRe(x)) = -dy * z 581 auto re = Real(scope, op.input(0)); 582 auto im = Imag(scope, op.input(0)); 583 auto z_inv = Reciprocal(scope, Complex(scope, im, re)); 584 auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type()); 585 auto grad = Complex(scope, grad_inputs[0], zero); 586 auto dx = Neg(scope, Mul(scope, grad, z_inv)); 587 grad_outputs->push_back(dx); 588 return scope.status(); 589 } 590 REGISTER_GRADIENT_OP("Angle", AngleGrad); 591 592 Status ConjGrad(const Scope& scope, const Operation& op, 593 const std::vector<Output>& grad_inputs, 594 std::vector<Output>* grad_outputs) { 595 grad_outputs->push_back(Conj(scope, grad_inputs[0])); 596 return scope.status(); 597 } 598 REGISTER_GRADIENT_OP("Conj", ConjGrad); 599 600 // Integer division x / y, assuming x and y >=0, but treats x/0 = x 601 Output SafeDivHelper(const Scope& scope, const Output& x, const Output& y) { 602 return Div(scope, x, Maximum(scope, y, Const(scope, 1))); 603 } 604 605 // Helper function for reduction ops. 606 // 607 // input_shape: 1-D Tensor, the shape of the Tensor being reduced. 608 // axes: 1-D Tensor, the reduction axes. 609 // Note that the reduction indices are in the range 610 // -rank(input_shape), rank(input_shape) 611 // returns a 1-D Tensor, the output shape as if keep_dims were set to True. 612 Output ReducedShapeHelper(const Scope& scope, const Output& input_shape, 613 const Output& reduction_axes) { 614 auto zero = Const(scope, 0); 615 auto one = Const(scope, 1); 616 617 // Running example in comments 618 // input_shape = [2, 3, 5, 7] 619 // axes = [1, 2] 620 // The result (a shape after a reduction with keep_dims=True) 621 // [2, 1, 1, 7] 622 // 623 // We can treat each entry in axes as an index into input_shape that 624 // should be replaced by 1. 625 // We use DynamicStitch to do this. 626 627 // input_rank = 4 628 auto input_rank = Size(scope, input_shape); 629 630 // Normalize any negative indices in the reduction_axes to positive 631 // values. 632 auto axes = Mod(scope, Add(scope, reduction_axes, input_rank), input_rank); 633 634 // This [0..input_rank) range of integers is used in DynamicStitch to 635 // first copy input_shape to the result. 636 // input_rank_range = [0, 1, 2, 3] 637 auto input_rank_range = Range(scope, zero, input_rank, one); 638 639 // A 1-filled tensor with the same shape as axes. DynamicStitch will 640 // merge these 1s (using axes for indices) to the correct 641 // position in the result. 642 // axes_ones = [1, 1] 643 auto axes_ones = OnesLike(scope, axes); 644 645 // using DynamicStitch: 646 // indices = { input_rank_range, axes } 647 // = { [0, 1, 2, 3], [1, 2] } 648 // data = { input_shape, axes_ones } 649 // = { [2, 3, 5, 7], [1, 1] } 650 // The input_rank_range entry in indices first replicates the 651 // input_shape to the result. 652 // The axes entry in indices then moves a 1 to each of its entries, 653 // resulting in 654 // [2, 1, 1, 7] 655 std::vector<Output> indices = {input_rank_range, axes}; 656 std::vector<Output> data = {input_shape, axes_ones}; 657 return DynamicStitch(scope, indices, data); 658 } 659 660 // SumGradHelper returns the gradient for the Sum operator, and is used 661 // by SumGrad and MeanGrad. 662 Output SumGradHelper(const Scope& scope, const Operation& op, 663 const std::vector<Output>& grad_inputs) { 664 // The partial derivative for any input along a "reduced" dimension 665 // is just 1, so we only need replicate the output gradient on such a 666 // dimension to its "expanded" shape. 667 // Running example: 668 // input is 669 // [[a, b, c], 670 // [d, e, f]] 671 // reduction_indices = [1] 672 // Sum = [a + b + c, d + e + f] 673 // if the gradient is [g1, g2] 674 // We want the propagated gradient to be 675 // [[g1, g1, g1], 676 // [g2, g2, g2]] 677 678 // input_shape = [2, 3] 679 auto input_shape = Shape(scope, op.input(0)); 680 681 // output_shape_kept_dims = [2, 1] 682 auto output_shape_kept_dims = 683 ReducedShapeHelper(scope, input_shape, op.input(1)); 684 685 // This step "flips" any 1s with values from the input_shape, and 686 // replaces remaining entries with 1. This creates a shape that 687 // shows how much each dimension in the incoming gradient should be 688 // replicated. 689 // tile_scaling = [1, 3] 690 auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims); 691 692 // grad = [[g1], [g2]] 693 auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims); 694 695 // tile(grad, tile_scaling) = [[g1, g1, g1], [g2, g2, g2]] 696 return Tile(scope, grad, tile_scaling); 697 } 698 699 Status SumGrad(const Scope& scope, const Operation& op, 700 const std::vector<Output>& grad_inputs, 701 std::vector<Output>* grad_outputs) { 702 grad_outputs->push_back(SumGradHelper(scope, op, grad_inputs)); 703 704 // Stop propagation along reduction_indices 705 grad_outputs->push_back(NoGradient()); 706 return scope.status(); 707 } 708 REGISTER_GRADIENT_OP("Sum", SumGrad); 709 710 Status MeanGrad(const Scope& scope, const Operation& op, 711 const std::vector<Output>& grad_inputs, 712 std::vector<Output>* grad_outputs) { 713 // The Mean gradient is just like the Sum gradient, except that 714 // all gradients are also divided by the size of reduced groups. 715 auto sum_grad = SumGradHelper(scope, op, grad_inputs); 716 717 // The product of all entries in a tensor's shape is the total 718 // number of entries in the tensor. This step calculates 719 // n_input_entries/n_output_entries 720 // = group_size 721 auto input_shape = Shape(scope, op.input(0)); 722 auto output_shape = Shape(scope, op.output(0)); 723 auto zero = Const(scope, 0); 724 auto group_size = SafeDivHelper(scope, Prod(scope, input_shape, zero), 725 Prod(scope, output_shape, zero)); 726 727 // propagate sum_grad/group_size 728 grad_outputs->push_back( 729 Div(scope, sum_grad, Cast(scope, group_size, sum_grad.type()))); 730 731 // Stop propagation along reduction_indices 732 grad_outputs->push_back(NoGradient()); 733 return scope.status(); 734 } 735 REGISTER_GRADIENT_OP("Mean", MeanGrad); 736 737 Status ErfGrad(const Scope& scope, const Operation& op, 738 const std::vector<Output>& grad_inputs, 739 std::vector<Output>* grad_outputs) { 740 auto grad = grad_inputs[0]; 741 auto two_over_root_pi = Cast(scope, Const(scope, 2 / std::sqrt(M_PI)), 742 grad.type()); 743 Scope grad_scope = scope.WithControlDependencies(grad); 744 auto x = ConjugateHelper(grad_scope, op.input(0)); 745 // grad * 2/sqrt(pi) * exp(-x**2) 746 auto dx = Mul(grad_scope, 747 Mul(grad_scope, grad, two_over_root_pi), 748 Exp(grad_scope, Neg(grad_scope, Square(grad_scope, x)))); 749 grad_outputs->push_back(dx); 750 return grad_scope.status(); 751 } 752 REGISTER_GRADIENT_OP("Erf", ErfGrad); 753 754 Status LgammaGrad(const Scope& scope, const Operation& op, 755 const std::vector<Output>& grad_inputs, 756 std::vector<Output>* grad_outputs) { 757 auto grad = grad_inputs[0]; 758 Scope grad_scope = scope.WithControlDependencies(grad); 759 auto x = ConjugateHelper(grad_scope, op.input(0)); 760 auto dx = Mul(grad_scope, grad, Digamma(grad_scope, x)); 761 grad_outputs->push_back(dx); 762 return grad_scope.status(); 763 } 764 REGISTER_GRADIENT_OP("Lgamma", LgammaGrad); 765 766 Status MinOrMaxGrad(const Scope& scope, const Operation& op, 767 const std::vector<Output>& grad_inputs, 768 std::vector<Output>* grad_outputs) { 769 // The partial derivative for any input along a "reduced" dimension 770 // is 1 when it is the min (or max) and 0 everywhere else. So the 771 // gradient calculation is identical for both operators. 772 // 773 // There's a special case for propagating gradients when there are 774 // multiple minima (or maxima) - we choose to divide the gradient 775 // equally among all matching inputs. 776 // 777 // Please note this comment 778 // https://github.com/tensorflow/tensorflow/issues/4886#issuecomment-256836063 779 // for details. 780 781 // Running example: 782 // input: [[5, 5, 5], 783 // [1, 2, -3]] 784 // reduction_indices: [1] 785 auto input = op.input(0); 786 auto reduction_indices = op.input(1); 787 788 // [2, 3] 789 auto input_shape = Shape(scope, input); 790 791 // [2, 1] 792 auto output_shape_kept_dims = 793 ReducedShapeHelper(scope, input_shape, reduction_indices); 794 795 // for op=min (say) 796 // output = [5, -3] 797 // y = [[5], 798 // [-3]] 799 auto y = Reshape(scope, op.output(0), output_shape_kept_dims); 800 801 // reshape([g1, g2], [2, 1]) = [[g1], 802 // [g2]] 803 auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims); 804 805 // indicators = equal(y, input) 806 // = equal([[5], [[5, 5, 5], 807 // [-3]], [1, 2, -3]]) 808 // = [[1, 1, 1], 809 // [0, 0, 1]] 810 auto indicators = Cast(scope, Equal(scope, y, input), grad_inputs[0].type()); 811 812 // [[3], 813 // [1]] 814 auto num_selected = Reshape(scope, Sum(scope, indicators, reduction_indices), 815 output_shape_kept_dims); 816 817 // [[1/3, 1/3, 1/3], 818 // [0, 0, 1]] 819 auto scale = Div(scope, indicators, num_selected); 820 821 // [[g1/3, g1/3, g1/3], 822 // [0, 0, g2]] 823 grad_outputs->push_back(Mul(scope, scale, grad)); 824 825 // Stop propagation along reduction_indices 826 grad_outputs->push_back(NoGradient()); 827 return scope.status(); 828 } 829 REGISTER_GRADIENT_OP("Min", MinOrMaxGrad); 830 REGISTER_GRADIENT_OP("Max", MinOrMaxGrad); 831 832 Status ProdGrad(const Scope& scope, const Operation& op, 833 const std::vector<Output>& grad_inputs, 834 std::vector<Output>* grad_outputs) { 835 auto zero = Const(scope, 0); 836 auto one = Const(scope, 1); 837 838 // The gradient can be expressed by dividing the product by each entry of 839 // the input tensor. If our input is 840 // [ 841 // [3, 4], 842 // [5, 6], 843 // [7, 8] 844 // ] 845 // and we do a Prod operation on the axis 1, we will obtain [[105, 192]]. 846 // The gradient will have the same shape as the input 847 // [ 848 // [105/3, 192/4], 849 // dz * [105/5, 192/6], 850 // [105/7, 192/6] 851 // ] 852 // If the input contains a zero, the division is impossible but 853 // if we take the calculation that gave the first gradient 854 // (3 * 5 * 6)/3 is equal to 5 * 6 855 // the trick will be to cumprod the elements on the axis without 856 // the element at the current position (3 in the example above). 857 // We will take as example: 858 // [ 859 // [ 860 // [3.0, 4.0], 861 // [5.0, 6.0], 862 // [7.0, 8.0] 863 // ], 864 // [ 865 // [3.0, 5.0], 866 // [0.0, 6.0], 867 // [5.0, 6.0] 868 // ] 869 // ] 870 871 // [2, 3, 2] 872 auto input_shape = Shape(scope, op.input(0)); 873 874 // The Reshape with -1 flattens the reduction indices. 875 // [1] 876 auto reduction_indices = Reshape(scope, op.input(1), {-1}); 877 878 // [2, 1, 2] 879 auto output_shape_kept_dims = 880 ReducedShapeHelper(scope, input_shape, reduction_indices); 881 882 // [1, 3, 1] 883 auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims); 884 885 // [[[105, 192]], [[0, 180]]] 886 auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims); 887 888 // [[[105, 192], [105, 192], [105, 192]], [[0, 180], [0, 180], [0, 180]]] 889 auto grad_tiled = Tile(scope, grad, tile_scaling); 890 891 Scope cpu_scope = scope.WithDevice("/cpu:0"); 892 893 // [3] 894 auto rank = Rank(cpu_scope, op.input(0)); 895 896 897 // Normalize any negative indices in the reduction_axes to positive values. 898 auto reduction_indices_pos = Mod(cpu_scope, Add(cpu_scope, reduction_indices, rank), rank); 899 900 // [1] 901 auto reduced = Cast(cpu_scope, reduction_indices_pos, DataType::DT_INT32); 902 903 // [0, 1, 2] 904 auto idx = Range(cpu_scope, zero, rank, one); 905 906 // [0, 2] 907 auto other = SetDiff1D(cpu_scope, idx, reduced).out; 908 909 // [1, 0, 2] 910 auto perm = 911 Concat(cpu_scope, std::initializer_list<Input>{reduced, other}, 0); 912 913 // 3 => [3] 914 auto reduced_num = Prod(cpu_scope, Gather(scope, input_shape, reduced), 0); 915 916 // 2 * 2 => [2] 917 auto other_num = Prod(cpu_scope, Gather(scope, input_shape, other), 0); 918 919 // [ 920 // [ 921 // [ 3., 4.], 922 // [ 3., 5.] 923 // ], 924 // [ 925 // [ 5., 6.], 926 // [ 0., 6.] 927 // ], 928 // [ 929 // [ 7., 8.], 930 // [ 5., 6.] 931 // ] 932 // ] 933 auto permuted = Transpose(scope, op.input(0), perm); 934 935 // [3, 2, 2] 936 auto permuted_shape = Shape(scope, permuted); 937 938 // [ 939 // [ 3., 4., 3., 5.], 940 // [ 5., 6., 0., 6.], 941 // [ 7., 8., 5., 6.] 942 // ] 943 auto reshaped = Reshape( 944 scope, permuted, 945 Stack(scope, std::initializer_list<Input>{reduced_num, other_num})); 946 947 // [ 948 // [ 1., 1., 1., 1.], 949 // [ 3., 4., 3., 5.], 950 // [ 15., 24., 0., 30.] 951 // ] 952 auto left = Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true)); 953 954 // [ 955 // [ 35., 48., 0., 36.], 956 // [ 7., 8., 5., 6.], 957 // [ 1., 1., 1., 1.] 958 // ] 959 auto right = 960 Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true).Reverse(true)); 961 962 // left * right = 963 // [ 964 // [ 35., 48., 0., 36.], 965 // [ 21., 32., 15., 30.], 966 // [ 15., 24., 0., 30.] 967 // ] 968 // y = 969 // [ 970 // [ 971 // [ 35., 48.], 972 // [ 0., 36.] 973 // ], 974 // [ 975 // [ 21., 32.], 976 // [ 15., 30.] 977 // ], 978 // [ 979 // [ 15., 24.], 980 // [ 0., 30.] 981 // ] 982 // ] 983 auto y = Reshape(scope, Mul(scope, left, right), permuted_shape); 984 985 // out = 986 // [ 987 // [ 988 // [ 35., 48.], 989 // [ 21., 32.], 990 // [ 15., 24.] 991 // ], 992 // [ 993 // [ 0., 36.], 994 // [ 15., 30.], 995 // [ 0., 30.] 996 // ] 997 // ] 998 auto out = 999 Mul(scope, grad_tiled, Transpose(scope, y, InvertPermutation(scope, perm))); 1000 1001 grad_outputs->push_back(Reshape(scope, out, input_shape)); 1002 1003 // stop propagation along reduction_indices 1004 grad_outputs->push_back(NoGradient()); 1005 return scope.status(); 1006 } 1007 REGISTER_GRADIENT_OP("Prod", ProdGrad); 1008 1009 // MatMulGrad helper function used to compute two MatMul operations 1010 // based on input matrix transposition combinations. 1011 Status MatMulGradHelper(const Scope& scope, const bool is_batch, 1012 const Output& x0, const bool adj_x0, const Output& x1, 1013 const bool adj_x1, const Output& y0, const bool adj_y0, 1014 const Output& y1, const bool adj_y1, 1015 std::vector<Output>* grad_outputs) { 1016 if (is_batch == false) { 1017 auto dx = 1018 MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1)); 1019 grad_outputs->push_back(dx); 1020 auto dy = 1021 MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1)); 1022 grad_outputs->push_back(dy); 1023 } else { 1024 auto dx = 1025 BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1)); 1026 grad_outputs->push_back(dx); 1027 auto dy = 1028 BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1)); 1029 grad_outputs->push_back(dy); 1030 } 1031 return scope.status(); 1032 } 1033 1034 // MatMulGrad common used to read and check node attr state, and determine 1035 // proper MatMul products for gradients based on input matrix transposition 1036 // combinations. 1037 Status MatMulGradCommon(const Scope& scope, const Operation& op, 1038 const bool is_batch, 1039 const std::vector<Output>& grad_inputs, 1040 const string& attr_adj_x, const string& attr_adj_y, 1041 std::vector<Output>* grad_outputs) { 1042 auto a = op.input(0); 1043 auto b = op.input(1); 1044 // Use conjugate of the inputs for MatMul 1045 if (is_batch == false) { 1046 a = ConjugateHelper(scope, a); 1047 b = ConjugateHelper(scope, b); 1048 } 1049 auto product = op.output(0); 1050 1051 bool ta; 1052 bool tb; 1053 TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_x, &ta)); 1054 TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_y, &tb)); 1055 1056 if (!ta && !tb) { 1057 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, true, a, 1058 true, grad_inputs[0], false, grad_outputs); 1059 } else if (!ta && tb) { 1060 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, false, 1061 grad_inputs[0], true, a, false, grad_outputs); 1062 } else if (ta && !tb) { 1063 return MatMulGradHelper(scope, is_batch, b, false, grad_inputs[0], true, a, 1064 false, grad_inputs[0], false, grad_outputs); 1065 } 1066 return MatMulGradHelper(scope, is_batch, b, true, grad_inputs[0], true, 1067 grad_inputs[0], true, a, true, grad_outputs); 1068 } 1069 1070 Status MatMulGrad(const Scope& scope, const Operation& op, 1071 const std::vector<Output>& grad_inputs, 1072 std::vector<Output>* grad_outputs) { 1073 return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a", 1074 "transpose_b", grad_outputs); 1075 } 1076 REGISTER_GRADIENT_OP("MatMul", MatMulGrad); 1077 1078 Status BatchMatMulGrad(const Scope& scope, const Operation& op, 1079 const std::vector<Output>& grad_inputs, 1080 std::vector<Output>* grad_outputs) { 1081 return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y", 1082 grad_outputs); 1083 } 1084 REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad); 1085 1086 } // anonymous namespace 1087 } // namespace ops 1088 } // namespace tensorflow 1089