Home | History | Annotate | Download | only in gradients
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define _USE_MATH_DEFINES
     17 #include <cmath>
     18 
     19 #include "tensorflow/cc/ops/array_ops_internal.h"
     20 #include "tensorflow/cc/ops/math_ops_internal.h"
     21 #include "tensorflow/cc/ops/standard_ops.h"
     22 
     23 #include "tensorflow/cc/framework/grad_op_registry.h"
     24 #include "tensorflow/cc/framework/gradients.h"
     25 
     26 namespace tensorflow {
     27 namespace ops {
     28 namespace {
     29 
     30 // Logical operations have no gradients.
     31 REGISTER_NO_GRADIENT_OP("Less");
     32 REGISTER_NO_GRADIENT_OP("LessEqual");
     33 REGISTER_NO_GRADIENT_OP("Greater");
     34 REGISTER_NO_GRADIENT_OP("GreaterEqual");
     35 REGISTER_NO_GRADIENT_OP("Equal");
     36 REGISTER_NO_GRADIENT_OP("ApproximateEqual");
     37 REGISTER_NO_GRADIENT_OP("NotEqual");
     38 REGISTER_NO_GRADIENT_OP("LogicalAnd");
     39 REGISTER_NO_GRADIENT_OP("LogicalOr");
     40 REGISTER_NO_GRADIENT_OP("LogicalNot");
     41 
     42 // Conjugate helper function returns the conjugate of an Output if it
     43 // is complex valued.
     44 Output ConjugateHelper(const Scope& scope, const Output& out) {
     45   DataType dtype = out.type();
     46   if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
     47     return Conj(scope, out);
     48   } else {
     49     return out;
     50   }
     51 }
     52 
     53 // TODO(andydavis) Add control dependencies to gradient functions (as needed).
     54 
     55 Status AbsGrad(const Scope& scope, const Operation& op,
     56                const std::vector<Output>& grad_inputs,
     57                std::vector<Output>* grad_outputs) {
     58   // dx = dy * sign(x)
     59   grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0))));
     60   return scope.status();
     61 }
     62 REGISTER_GRADIENT_OP("Abs", AbsGrad);
     63 
     64 Status NegGrad(const Scope& scope, const Operation& op,
     65                const std::vector<Output>& grad_inputs,
     66                std::vector<Output>* grad_outputs) {
     67   // dx = -dy;
     68   grad_outputs->push_back(Neg(scope, grad_inputs[0]));
     69   return scope.status();
     70 }
     71 REGISTER_GRADIENT_OP("Neg", NegGrad);
     72 
     73 Status InvGrad(const Scope& scope, const Operation& op,
     74                const std::vector<Output>& grad_inputs,
     75                std::vector<Output>* grad_outputs) {
     76   // Use the built-in operator.
     77   grad_outputs->push_back(
     78       internal::ReciprocalGrad(scope, op.output(0), grad_inputs[0]));
     79   return scope.status();
     80 }
     81 REGISTER_GRADIENT_OP("Inv", InvGrad);
     82 REGISTER_GRADIENT_OP("Reciprocal", InvGrad);
     83 
     84 Status SquareGrad(const Scope& scope, const Operation& op,
     85                   const std::vector<Output>& grad_inputs,
     86                   std::vector<Output>* grad_outputs) {
     87   // dy/dx = (2 * x)
     88   auto two = Cast(scope, Const(scope, 2), op.input(0).type());
     89   auto dydx = Mul(scope, two, op.input(0));
     90   // grad(x) = grad(y) * conj(dy/dx)
     91   grad_outputs->push_back(
     92       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
     93   return scope.status();
     94 }
     95 REGISTER_GRADIENT_OP("Square", SquareGrad);
     96 
     97 Status SqrtGrad(const Scope& scope, const Operation& op,
     98                 const std::vector<Output>& grad_inputs,
     99                 std::vector<Output>* grad_outputs) {
    100   // Use the built-in operator.
    101   grad_outputs->push_back(
    102       internal::SqrtGrad(scope, op.output(0), grad_inputs[0]));
    103   return scope.status();
    104 }
    105 REGISTER_GRADIENT_OP("Sqrt", SqrtGrad);
    106 
    107 Status RsqrtGrad(const Scope& scope, const Operation& op,
    108                  const std::vector<Output>& grad_inputs,
    109                  std::vector<Output>* grad_outputs) {
    110   // Use the built-in operator.
    111   grad_outputs->push_back(
    112       internal::RsqrtGrad(scope, op.output(0), grad_inputs[0]));
    113   return scope.status();
    114 }
    115 REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad);
    116 
    117 Status ExpGrad(const Scope& scope, const Operation& op,
    118                const std::vector<Output>& grad_inputs,
    119                std::vector<Output>* grad_outputs) {
    120   // dy/dx = exp(x) = y
    121   // grad(x) = grad(y) * conj(dy/dx)
    122   //         = grad(y) * conj(y)
    123   grad_outputs->push_back(
    124       Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0))));
    125   return scope.status();
    126 }
    127 REGISTER_GRADIENT_OP("Exp", ExpGrad);
    128 
    129 Status Expm1Grad(const Scope& scope, const Operation& op,
    130                  const std::vector<Output>& grad_inputs,
    131                  std::vector<Output>* grad_outputs) {
    132   // y = expm1(x)
    133   // dy/dx = exp(x)
    134   auto dydx = Exp(scope, op.input(0));
    135   // grad(x) = grad(y) * conj(dy/dx)
    136   grad_outputs->push_back(
    137       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
    138   return scope.status();
    139 }
    140 REGISTER_GRADIENT_OP("Expm1", Expm1Grad);
    141 
    142 Status LogGrad(const Scope& scope, const Operation& op,
    143                const std::vector<Output>& grad_inputs,
    144                std::vector<Output>* grad_outputs) {
    145   // y = log(x)
    146   // dy/dx = 1 / x
    147   auto dydx = Reciprocal(scope, op.input(0));
    148   // grad(x) = grad(y) * conj(dy/dx)
    149   grad_outputs->push_back(
    150       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
    151   return scope.status();
    152 }
    153 REGISTER_GRADIENT_OP("Log", LogGrad);
    154 
    155 Status Log1pGrad(const Scope& scope, const Operation& op,
    156                  const std::vector<Output>& grad_inputs,
    157                  std::vector<Output>* grad_outputs) {
    158   // y = log1p(x)
    159   // dy/dx = 1 / (1 + x)
    160   auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
    161   auto dydx = Reciprocal(scope, Add(scope, one, op.input(0)));
    162   // grad(x) = grad(y) * conj(dy/dx)
    163   grad_outputs->push_back(
    164       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
    165   return scope.status();
    166 }
    167 REGISTER_GRADIENT_OP("Log1p", Log1pGrad);
    168 
    169 Status SinhGrad(const Scope& scope, const Operation& op,
    170                 const std::vector<Output>& grad_inputs,
    171                 std::vector<Output>* grad_outputs) {
    172   // y = sinh(x)
    173   // dy/dx = cosh(x)
    174   auto dydx = Cosh(scope, op.input(0));
    175   // grad(x) = grad(y) * conj(dy/dx)
    176   grad_outputs->push_back(
    177       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
    178   return scope.status();
    179 }
    180 REGISTER_GRADIENT_OP("Sinh", SinhGrad);
    181 
    182 Status CoshGrad(const Scope& scope, const Operation& op,
    183                 const std::vector<Output>& grad_inputs,
    184                 std::vector<Output>* grad_outputs) {
    185   // y = cosh(x)
    186   // dy/dx = sinh(x)
    187   auto dydx = Sinh(scope, op.input(0));
    188   // grad(x) = grad(y) * conj(dy/dx)
    189   grad_outputs->push_back(
    190       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
    191   return scope.status();
    192 }
    193 REGISTER_GRADIENT_OP("Cosh", CoshGrad);
    194 
    195 Status TanhGrad(const Scope& scope, const Operation& op,
    196                 const std::vector<Output>& grad_inputs,
    197                 std::vector<Output>* grad_outputs) {
    198   // Use the built-in operator.
    199   // Note that the built-in operator does not return the conjugate of
    200   // the gradient.
    201   auto grad = grad_inputs[0];
    202   // Optimization to avoid calculating conj(y) until the gradient is
    203   // evaluated.
    204   Scope grad_scope = scope.WithControlDependencies(grad);
    205   auto y = ConjugateHelper(grad_scope, op.output(0));
    206   grad_outputs->push_back(internal::TanhGrad(grad_scope, y, grad));
    207   return grad_scope.status();
    208 }
    209 REGISTER_GRADIENT_OP("Tanh", TanhGrad);
    210 
    211 Status AsinhGrad(const Scope& scope, const Operation& op,
    212                  const std::vector<Output>& grad_inputs,
    213                  std::vector<Output>* grad_outputs) {
    214   // y = asinh(x)
    215   // dy/dx = 1 / cosh(y)
    216   auto dydx = Reciprocal(scope, Cosh(scope, op.output(0)));
    217   // grad(x) = grad(y) * conj(dy/dx)
    218   grad_outputs->push_back(
    219       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
    220   return scope.status();
    221 }
    222 REGISTER_GRADIENT_OP("Asinh", AsinhGrad);
    223 
    224 Status AcoshGrad(const Scope& scope, const Operation& op,
    225                  const std::vector<Output>& grad_inputs,
    226                  std::vector<Output>* grad_outputs) {
    227   // y = acosh(x)
    228   // dy/dx = 1 / sinh(y)
    229   auto dydx = Reciprocal(scope, Sinh(scope, op.output(0)));
    230   // grad(x) = grad(y) * conj(dy/dx)
    231   grad_outputs->push_back(
    232       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
    233   return scope.status();
    234 }
    235 REGISTER_GRADIENT_OP("Acosh", AcoshGrad);
    236 
    237 Status AtanhGrad(const Scope& scope, const Operation& op,
    238                  const std::vector<Output>& grad_inputs,
    239                  std::vector<Output>* grad_outputs) {
    240   // y = atanh(x)
    241   // dy/dx = 1 / (1 - x^2)
    242   auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
    243   auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0))));
    244   // grad(x) = grad(y) * conj(dy/dx)
    245   grad_outputs->push_back(
    246       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
    247   return scope.status();
    248 }
    249 REGISTER_GRADIENT_OP("Atanh", AtanhGrad);
    250 
    251 Status SigmoidGrad(const Scope& scope, const Operation& op,
    252                    const std::vector<Output>& grad_inputs,
    253                    std::vector<Output>* grad_outputs) {
    254   // Use the built-in operator.
    255   // Note that the built-in operator does not return the conjugate of
    256   // the gradient.
    257   auto grad = grad_inputs[0];
    258   // Optimization to avoid calculating conj(y) until the gradient is
    259   // evaluated.
    260   Scope grad_scope = scope.WithControlDependencies(grad);
    261   auto y = ConjugateHelper(grad_scope, op.output(0));
    262   grad_outputs->push_back(internal::SigmoidGrad(grad_scope, y, grad));
    263   return grad_scope.status();
    264 }
    265 REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad);
    266 
    267 Status SignGrad(const Scope& scope, const Operation& op,
    268                 const std::vector<Output>& grad_inputs,
    269                 std::vector<Output>* grad_outputs) {
    270   auto shape = Shape(scope, op.input(0));
    271   auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type());
    272   auto dx = Fill(scope, shape, zero);
    273   grad_outputs->push_back(dx);
    274   return scope.status();
    275 }
    276 REGISTER_GRADIENT_OP("Sign", SignGrad);
    277 
    278 Status SinGrad(const Scope& scope, const Operation& op,
    279                const std::vector<Output>& grad_inputs,
    280                std::vector<Output>* grad_outputs) {
    281   // y = sin(x)
    282   // dy/dx = cos(x)
    283   auto dydx = Cos(scope, op.input(0));
    284   // grad(x) = grad(y) * conj(dy/dx)
    285   grad_outputs->push_back(
    286       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
    287   return scope.status();
    288 }
    289 REGISTER_GRADIENT_OP("Sin", SinGrad);
    290 
    291 Status CosGrad(const Scope& scope, const Operation& op,
    292                const std::vector<Output>& grad_inputs,
    293                std::vector<Output>* grad_outputs) {
    294   // y = cos(x)
    295   // dy/dx = -sin(x)
    296   auto dydx = Neg(scope, Sin(scope, op.input(0)));
    297   // grad(x) = grad(y) * conj(dy/dx)
    298   grad_outputs->push_back(
    299       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
    300   return scope.status();
    301 }
    302 REGISTER_GRADIENT_OP("Cos", CosGrad);
    303 
    304 Status AsinGrad(const Scope& scope, const Operation& op,
    305                 const std::vector<Output>& grad_inputs,
    306                 std::vector<Output>* grad_outputs) {
    307   // y = asin(x)
    308   // dy/dx = 1 / sqrt(1 - x^2)
    309   auto x2 = Square(scope, op.input(0));
    310   auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
    311   auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)));
    312   // grad(x) = grad(y) * conj(dy/dx)
    313   auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
    314   grad_outputs->push_back(dx);
    315   return scope.status();
    316 }
    317 REGISTER_GRADIENT_OP("Asin", AsinGrad);
    318 
    319 Status AcosGrad(const Scope& scope, const Operation& op,
    320                 const std::vector<Output>& grad_inputs,
    321                 std::vector<Output>* grad_outputs) {
    322   // y = acos(x)
    323   // dy/dx = - 1 / (1 - x * x)^1/2
    324   // dx = dy * (- 1 / (1 - x * x)^1/2)
    325   auto x2 = Square(scope, op.input(0));
    326   auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
    327   auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))));
    328   auto dx = Mul(scope, grad_inputs[0], dydx);
    329   grad_outputs->push_back(dx);
    330   return scope.status();
    331 }
    332 REGISTER_GRADIENT_OP("Acos", AcosGrad);
    333 
    334 Status TanGrad(const Scope& scope, const Operation& op,
    335                const std::vector<Output>& grad_inputs,
    336                std::vector<Output>* grad_outputs) {
    337   // y = tan(x)
    338   // dy/dx = sec(x)^2 = 1 / cos(x)^2
    339   auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0))));
    340   // grad(x) = grad(y) * conj(dy/dx)
    341   auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
    342   grad_outputs->push_back(dx);
    343   return scope.status();
    344 }
    345 REGISTER_GRADIENT_OP("Tan", TanGrad);
    346 
    347 Status AtanGrad(const Scope& scope, const Operation& op,
    348                 const std::vector<Output>& grad_inputs,
    349                 std::vector<Output>* grad_outputs) {
    350   // y = arctan(x)
    351   // dy/dx = 1 / (1 + x^2)
    352   // dx = dy * (1 / (1 + x^2)
    353   auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
    354   auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0))));
    355   auto dx = Mul(scope, grad_inputs[0], dydx);
    356   grad_outputs->push_back(dx);
    357   return scope.status();
    358 }
    359 REGISTER_GRADIENT_OP("Atan", AtanGrad);
    360 
    361 // BinaryGradCommon handles the setup for binary ops that broadcast
    362 // their inputs.
    363 Status BinaryGradCommon(const Scope& scope, const Operation& op,
    364                         std::vector<Output>* grad_outputs, const Output& gx_1,
    365                         const Output& gx_2) {
    366   auto sx_1 = Shape(scope, op.input(0));
    367   auto sx_2 = Shape(scope, op.input(1));
    368   auto rx = internal::BroadcastGradientArgs(scope, sx_1, sx_2);
    369   auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1);
    370   auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2);
    371   grad_outputs->push_back(dx_1);
    372   grad_outputs->push_back(dx_2);
    373   return scope.status();
    374 }
    375 
    376 Status AddGrad(const Scope& scope, const Operation& op,
    377                const std::vector<Output>& grad_inputs,
    378                std::vector<Output>* grad_outputs) {
    379   // y = x_1 + x_2
    380   // dy/dx_1 = dy/dx_2 = 1
    381   auto gx_1 = Identity(scope, grad_inputs[0]);
    382   auto gx_2 = Identity(scope, grad_inputs[0]);
    383   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
    384 }
    385 REGISTER_GRADIENT_OP("Add", AddGrad);
    386 
    387 Status SubGrad(const Scope& scope, const Operation& op,
    388                const std::vector<Output>& grad_inputs,
    389                std::vector<Output>* grad_outputs) {
    390   // y = x_1 - x_2
    391   // dy/dx_1 = 1
    392   // dy/dx_2 = -1
    393   auto gx_1 = Identity(scope, grad_inputs[0]);
    394   auto gx_2 = Neg(scope, grad_inputs[0]);
    395   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
    396 }
    397 REGISTER_GRADIENT_OP("Sub", SubGrad);
    398 
    399 Status MulGrad(const Scope& scope, const Operation& op,
    400                const std::vector<Output>& grad_inputs,
    401                std::vector<Output>* grad_outputs) {
    402   auto x_1 = ConjugateHelper(scope, op.input(0));
    403   auto x_2 = ConjugateHelper(scope, op.input(1));
    404   // y = x_1 * x_2
    405   // dy/dx_1 = x_2
    406   // dy/dx_2 = x_1
    407   auto gx_1 = Mul(scope, grad_inputs[0], x_2);
    408   auto gx_2 = Mul(scope, grad_inputs[0], x_1);
    409   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
    410 }
    411 REGISTER_GRADIENT_OP("Mul", MulGrad);
    412 
    413 Status DivGrad(const Scope& scope, const Operation& op,
    414                const std::vector<Output>& grad_inputs,
    415                std::vector<Output>* grad_outputs) {
    416   auto x_1 = ConjugateHelper(scope, op.input(0));
    417   auto x_2 = ConjugateHelper(scope, op.input(1));
    418   // y = x_1 / x_2
    419   // dy/dx_1 = 1/x_2
    420   // dy/dx_2 = -x_1/x_2^2
    421   auto gx_1 = Div(scope, grad_inputs[0], x_2);
    422   auto gx_2 = Mul(scope, grad_inputs[0],
    423                   Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2));
    424   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
    425 }
    426 REGISTER_GRADIENT_OP("Div", DivGrad);
    427 
    428 Status RealDivGrad(const Scope& scope, const Operation& op,
    429                    const std::vector<Output>& grad_inputs,
    430                    std::vector<Output>* grad_outputs) {
    431   auto x_1 = ConjugateHelper(scope, op.input(0));
    432   auto x_2 = ConjugateHelper(scope, op.input(1));
    433   // y = x_1 / x_2
    434   // dy/dx_1 = 1/x_2
    435   // dy/dx_2 = -x_1/x_2^2
    436   auto gx_1 = RealDiv(scope, grad_inputs[0], x_2);
    437   auto gx_2 = Mul(scope, grad_inputs[0],
    438                   RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2));
    439   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
    440 }
    441 REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
    442 
    443 Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
    444                              const std::vector<Output>& grad_inputs,
    445                              std::vector<Output>* grad_outputs) {
    446   auto x_1 = ConjugateHelper(scope, op.input(0));
    447   auto x_2 = ConjugateHelper(scope, op.input(1));
    448   // y = (x_1 - x_2)^2
    449   // dy/dx_1 = 2 * (x_1 - x_2)
    450   // dy/dx_2 = -2 * (x_1 - x_2)
    451   auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type());
    452   auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2)));
    453   auto gx_2 = Neg(scope, gx_1);
    454   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
    455 }
    456 REGISTER_GRADIENT_OP("SquaredDifference", SquaredDifferenceGrad);
    457 
    458 Status AddNGrad(const Scope& scope, const Operation& op,
    459                 const std::vector<Output>& grad_inputs,
    460                 std::vector<Output>* grad_outputs) {
    461   // AddN doesn't support broadcasting, so all the inputs must be the
    462   // same shape.
    463   // Note:
    464   // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k
    465   // hence dx_k = dy for all x_k
    466   // So the gradient for AddN just transfers the incoming gradient to
    467   // all outgoing gradients.
    468   auto incoming = Identity(scope, grad_inputs[0]);
    469   for (int32 i = 0; i < op.num_inputs(); ++i) {
    470     grad_outputs->push_back(incoming);
    471   }
    472   return scope.status();
    473 }
    474 REGISTER_GRADIENT_OP("AddN", AddNGrad);
    475 
    476 Status PowGrad(const Scope& scope, const Operation& op,
    477                const std::vector<Output>& grad_inputs,
    478                std::vector<Output>* grad_outputs) {
    479   auto x = ConjugateHelper(scope, op.input(0));
    480   auto y = ConjugateHelper(scope, op.input(1));
    481   auto z = ConjugateHelper(scope, op.output(0));
    482   auto grad = grad_inputs[0];
    483   // grad * y * pow(x, y - 1)
    484   auto one = Cast(scope, Const(scope, 1.0), y.type());
    485   auto gx_1 = Mul(scope,
    486                   Mul(scope, grad, y),
    487                   Pow(scope, x, Sub(scope, y, one)));
    488   // Avoid false singularity at x = 0
    489   DataType x_dtype = x.type();
    490   auto zero = Cast(scope, Const(scope, 0.0), x_dtype);
    491   if (x_dtype == DT_COMPLEX64 || x_dtype == DT_COMPLEX128) {
    492     // real(x) < 0 is fine for the complex case
    493     auto log_x = Where3(scope,
    494                         NotEqual(scope, x, zero),
    495                         Log(scope, x),
    496                         ZerosLike(scope, x));
    497     auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x);
    498     return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1);
    499   } else {
    500     // There's no sensible real value to return if x < 0, so return 0
    501     auto log_x = Where3(scope,
    502                         Greater(scope, x, zero),
    503                         Log(scope, x),
    504                         ZerosLike(scope, x));
    505     auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x);
    506     return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1);
    507   }
    508 }
    509 REGISTER_GRADIENT_OP("Pow", PowGrad);
    510 
    511 // MaximumMinimumGradCommon adds shared ops to calculate gradients for
    512 // the binary Maximum and Minimum ops.
    513 Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op,
    514                                 const std::vector<Output>& grad_inputs,
    515                                 std::vector<Output>* grad_outputs,
    516                                 const Output& comparator) {
    517   // comparator is a boolean tensor, with
    518   // y = x_1 at points where comparator is true, and x_2 otherwise
    519   // Therefore
    520   // dy/dx_1 = 1 where comparator is true, and 0 otherwise.
    521   // dy/dx_2 = 0 where comparator is true, and 1 otherwise.
    522   auto grad = grad_inputs[0];
    523   auto zeros = ZerosLike(scope, grad);
    524   auto gx_1 = Where3(scope, comparator, grad, zeros);
    525   auto gx_2 = Where3(scope, comparator, zeros, grad);
    526   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
    527 }
    528 
    529 Status MaximumGrad(const Scope& scope, const Operation& op,
    530                    const std::vector<Output>& grad_inputs,
    531                    std::vector<Output>* grad_outputs) {
    532   auto comparator = GreaterEqual(scope, op.input(0), op.input(1));
    533   return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
    534                                   comparator);
    535 }
    536 REGISTER_GRADIENT_OP("Maximum", MaximumGrad);
    537 
    538 Status MinimumGrad(const Scope& scope, const Operation& op,
    539                    const std::vector<Output>& grad_inputs,
    540                    std::vector<Output>* grad_outputs) {
    541   auto comparator = LessEqual(scope, op.input(0), op.input(1));
    542   return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
    543                                   comparator);
    544 }
    545 REGISTER_GRADIENT_OP("Minimum", MinimumGrad);
    546 
    547 Status RealGrad(const Scope& scope, const Operation& op,
    548                 const std::vector<Output>& grad_inputs,
    549                 std::vector<Output>* grad_outputs) {
    550   auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
    551   auto dx = Complex(scope, grad_inputs[0], zero);
    552   grad_outputs->push_back(dx);
    553   return scope.status();
    554 }
    555 REGISTER_GRADIENT_OP("Real", RealGrad);
    556 
    557 Status ImagGrad(const Scope& scope, const Operation& op,
    558                 const std::vector<Output>& grad_inputs,
    559                 std::vector<Output>* grad_outputs) {
    560   auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
    561   auto dx = Complex(scope, zero, grad_inputs[0]);
    562   grad_outputs->push_back(dx);
    563   return scope.status();
    564 }
    565 REGISTER_GRADIENT_OP("Imag", ImagGrad);
    566 
    567 Status ComplexGrad(const Scope& scope, const Operation& op,
    568                    const std::vector<Output>& grad_inputs,
    569                    std::vector<Output>* grad_outputs) {
    570   auto gx_1 = Real(scope, grad_inputs[0]);
    571   auto gx_2 = Imag(scope, grad_inputs[0]);
    572   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
    573 }
    574 REGISTER_GRADIENT_OP("Complex", ComplexGrad);
    575 
    576 Status AngleGrad(const Scope& scope, const Operation& op,
    577                  const std::vector<Output>& grad_inputs,
    578                  std::vector<Output>* grad_outputs) {
    579   // y = Angle(x)
    580   // dx = -dy / (Im(x) + iRe(x)) = -dy * z
    581   auto re = Real(scope, op.input(0));
    582   auto im = Imag(scope, op.input(0));
    583   auto z_inv = Reciprocal(scope, Complex(scope, im, re));
    584   auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type());
    585   auto grad = Complex(scope, grad_inputs[0], zero);
    586   auto dx = Neg(scope, Mul(scope, grad, z_inv));
    587   grad_outputs->push_back(dx);
    588   return scope.status();
    589 }
    590 REGISTER_GRADIENT_OP("Angle", AngleGrad);
    591 
    592 Status ConjGrad(const Scope& scope, const Operation& op,
    593                 const std::vector<Output>& grad_inputs,
    594                 std::vector<Output>* grad_outputs) {
    595   grad_outputs->push_back(Conj(scope, grad_inputs[0]));
    596   return scope.status();
    597 }
    598 REGISTER_GRADIENT_OP("Conj", ConjGrad);
    599 
    600 // Integer division x / y, assuming x and y >=0, but treats x/0 = x
    601 Output SafeDivHelper(const Scope& scope, const Output& x, const Output& y) {
    602   return Div(scope, x, Maximum(scope, y, Const(scope, 1)));
    603 }
    604 
    605 // Helper function for reduction ops.
    606 //
    607 // input_shape: 1-D Tensor, the shape of the Tensor being reduced.
    608 // axes: 1-D Tensor, the reduction axes.
    609 //   Note that the reduction indices are in the range
    610 //   -rank(input_shape), rank(input_shape)
    611 // returns a 1-D Tensor, the output shape as if keep_dims were set to True.
    612 Output ReducedShapeHelper(const Scope& scope, const Output& input_shape,
    613                           const Output& reduction_axes) {
    614   auto zero = Const(scope, 0);
    615   auto one = Const(scope, 1);
    616 
    617   // Running example in comments
    618   // input_shape = [2, 3, 5, 7]
    619   // axes = [1, 2]
    620   // The result (a shape after a reduction with keep_dims=True)
    621   // [2, 1, 1, 7]
    622   //
    623   // We can treat each entry in axes as an index into input_shape that
    624   // should be replaced by 1.
    625   // We use DynamicStitch to do this.
    626 
    627   // input_rank = 4
    628   auto input_rank = Size(scope, input_shape);
    629 
    630   // Normalize any negative indices in the reduction_axes to positive
    631   // values.
    632   auto axes = Mod(scope, Add(scope, reduction_axes, input_rank), input_rank);
    633 
    634   // This [0..input_rank) range of integers is used in DynamicStitch to
    635   // first copy input_shape to the result.
    636   // input_rank_range = [0, 1, 2, 3]
    637   auto input_rank_range = Range(scope, zero, input_rank, one);
    638 
    639   // A 1-filled tensor with the same shape as axes. DynamicStitch will
    640   // merge these 1s (using axes for indices) to the correct
    641   // position in the result.
    642   // axes_ones = [1, 1]
    643   auto axes_ones = OnesLike(scope, axes);
    644 
    645   // using DynamicStitch:
    646   // indices = { input_rank_range, axes }
    647   //         = { [0, 1, 2, 3], [1, 2] }
    648   // data = { input_shape, axes_ones }
    649   //      = { [2, 3, 5, 7], [1, 1] }
    650   // The input_rank_range entry in indices first replicates the
    651   // input_shape to the result.
    652   // The axes entry in indices then moves a 1 to each of its entries,
    653   // resulting in
    654   // [2, 1, 1, 7]
    655   std::vector<Output> indices = {input_rank_range, axes};
    656   std::vector<Output> data = {input_shape, axes_ones};
    657   return DynamicStitch(scope, indices, data);
    658 }
    659 
    660 // SumGradHelper returns the gradient for the Sum operator, and is used
    661 // by SumGrad and MeanGrad.
    662 Output SumGradHelper(const Scope& scope, const Operation& op,
    663                      const std::vector<Output>& grad_inputs) {
    664   // The partial derivative for any input along a "reduced" dimension
    665   // is just 1, so we only need replicate the output gradient on such a
    666   // dimension to its "expanded" shape.
    667   // Running example:
    668   // input is
    669   // [[a, b, c],
    670   //  [d, e, f]]
    671   // reduction_indices = [1]
    672   // Sum = [a + b + c, d + e + f]
    673   // if the gradient is [g1, g2]
    674   // We want the propagated gradient to be
    675   // [[g1, g1, g1],
    676   //  [g2, g2, g2]]
    677 
    678   // input_shape = [2, 3]
    679   auto input_shape = Shape(scope, op.input(0));
    680 
    681   // output_shape_kept_dims = [2, 1]
    682   auto output_shape_kept_dims =
    683       ReducedShapeHelper(scope, input_shape, op.input(1));
    684 
    685   // This step "flips" any 1s with values from the input_shape, and
    686   // replaces remaining entries with 1. This creates a shape that
    687   // shows how much each dimension in the incoming gradient should be
    688   // replicated.
    689   // tile_scaling = [1, 3]
    690   auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims);
    691 
    692   // grad = [[g1], [g2]]
    693   auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
    694 
    695   // tile(grad, tile_scaling) = [[g1, g1, g1], [g2, g2, g2]]
    696   return Tile(scope, grad, tile_scaling);
    697 }
    698 
    699 Status SumGrad(const Scope& scope, const Operation& op,
    700                const std::vector<Output>& grad_inputs,
    701                std::vector<Output>* grad_outputs) {
    702   grad_outputs->push_back(SumGradHelper(scope, op, grad_inputs));
    703 
    704   // Stop propagation along reduction_indices
    705   grad_outputs->push_back(NoGradient());
    706   return scope.status();
    707 }
    708 REGISTER_GRADIENT_OP("Sum", SumGrad);
    709 
    710 Status MeanGrad(const Scope& scope, const Operation& op,
    711                 const std::vector<Output>& grad_inputs,
    712                 std::vector<Output>* grad_outputs) {
    713   // The Mean gradient is just like the Sum gradient, except that
    714   // all gradients are also divided by the size of reduced groups.
    715   auto sum_grad = SumGradHelper(scope, op, grad_inputs);
    716 
    717   // The product of all entries in a tensor's shape is the total
    718   // number of entries in the tensor. This step calculates
    719   // n_input_entries/n_output_entries
    720   // = group_size
    721   auto input_shape = Shape(scope, op.input(0));
    722   auto output_shape = Shape(scope, op.output(0));
    723   auto zero = Const(scope, 0);
    724   auto group_size = SafeDivHelper(scope, Prod(scope, input_shape, zero),
    725                                   Prod(scope, output_shape, zero));
    726 
    727   // propagate sum_grad/group_size
    728   grad_outputs->push_back(
    729       Div(scope, sum_grad, Cast(scope, group_size, sum_grad.type())));
    730 
    731   // Stop propagation along reduction_indices
    732   grad_outputs->push_back(NoGradient());
    733   return scope.status();
    734 }
    735 REGISTER_GRADIENT_OP("Mean", MeanGrad);
    736 
    737 Status ErfGrad(const Scope& scope, const Operation& op,
    738                const std::vector<Output>& grad_inputs,
    739                std::vector<Output>* grad_outputs) {
    740   auto grad = grad_inputs[0];
    741   auto two_over_root_pi = Cast(scope, Const(scope, 2 / std::sqrt(M_PI)),
    742                                grad.type());
    743   Scope grad_scope = scope.WithControlDependencies(grad);
    744   auto x = ConjugateHelper(grad_scope, op.input(0));
    745   // grad * 2/sqrt(pi) * exp(-x**2)
    746   auto dx = Mul(grad_scope,
    747                 Mul(grad_scope, grad, two_over_root_pi),
    748                 Exp(grad_scope, Neg(grad_scope, Square(grad_scope, x))));
    749   grad_outputs->push_back(dx);
    750   return grad_scope.status();
    751 }
    752 REGISTER_GRADIENT_OP("Erf", ErfGrad);
    753 
    754 Status LgammaGrad(const Scope& scope, const Operation& op,
    755                   const std::vector<Output>& grad_inputs,
    756                   std::vector<Output>* grad_outputs) {
    757   auto grad = grad_inputs[0];
    758   Scope grad_scope = scope.WithControlDependencies(grad);
    759   auto x = ConjugateHelper(grad_scope, op.input(0));
    760   auto dx = Mul(grad_scope, grad, Digamma(grad_scope, x));
    761   grad_outputs->push_back(dx);
    762   return grad_scope.status();
    763 }
    764 REGISTER_GRADIENT_OP("Lgamma", LgammaGrad);
    765 
    766 Status MinOrMaxGrad(const Scope& scope, const Operation& op,
    767                     const std::vector<Output>& grad_inputs,
    768                     std::vector<Output>* grad_outputs) {
    769   // The partial derivative for any input along a "reduced" dimension
    770   // is 1 when it is the min (or max) and 0 everywhere else. So the
    771   // gradient calculation is identical for both operators.
    772   //
    773   // There's a special case for propagating gradients when there are
    774   // multiple minima (or maxima) - we choose to divide the gradient
    775   // equally among all matching inputs.
    776   //
    777   // Please note this comment
    778   // https://github.com/tensorflow/tensorflow/issues/4886#issuecomment-256836063
    779   // for details.
    780 
    781   // Running example:
    782   // input: [[5, 5, 5],
    783   //         [1, 2, -3]]
    784   // reduction_indices: [1]
    785   auto input = op.input(0);
    786   auto reduction_indices = op.input(1);
    787 
    788   // [2, 3]
    789   auto input_shape = Shape(scope, input);
    790 
    791   // [2, 1]
    792   auto output_shape_kept_dims =
    793       ReducedShapeHelper(scope, input_shape, reduction_indices);
    794 
    795   // for op=min (say)
    796   // output = [5, -3]
    797   // y = [[5],
    798   //      [-3]]
    799   auto y = Reshape(scope, op.output(0), output_shape_kept_dims);
    800 
    801   // reshape([g1, g2], [2, 1]) = [[g1],
    802   //                              [g2]]
    803   auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
    804 
    805   // indicators = equal(y, input)
    806   //  = equal([[5],   [[5, 5, 5],
    807   //           [-3]],  [1, 2, -3]])
    808   //  = [[1, 1, 1],
    809   //     [0, 0, 1]]
    810   auto indicators = Cast(scope, Equal(scope, y, input), grad_inputs[0].type());
    811 
    812   // [[3],
    813   //  [1]]
    814   auto num_selected = Reshape(scope, Sum(scope, indicators, reduction_indices),
    815                               output_shape_kept_dims);
    816 
    817   // [[1/3, 1/3, 1/3],
    818   //  [0, 0, 1]]
    819   auto scale = Div(scope, indicators, num_selected);
    820 
    821   // [[g1/3, g1/3, g1/3],
    822   //  [0, 0, g2]]
    823   grad_outputs->push_back(Mul(scope, scale, grad));
    824 
    825   // Stop propagation along reduction_indices
    826   grad_outputs->push_back(NoGradient());
    827   return scope.status();
    828 }
    829 REGISTER_GRADIENT_OP("Min", MinOrMaxGrad);
    830 REGISTER_GRADIENT_OP("Max", MinOrMaxGrad);
    831 
    832 Status ProdGrad(const Scope& scope, const Operation& op,
    833                 const std::vector<Output>& grad_inputs,
    834                 std::vector<Output>* grad_outputs) {
    835   auto zero = Const(scope, 0);
    836   auto one = Const(scope, 1);
    837 
    838   // The gradient can be expressed by dividing the product by each entry of
    839   // the input tensor. If our input is
    840   // [
    841   //  [3, 4],
    842   //  [5, 6],
    843   //  [7, 8]
    844   // ]
    845   // and we do a Prod operation on the axis 1, we will obtain [[105, 192]].
    846   // The gradient will have the same shape as the input
    847   //     [
    848   //       [105/3, 192/4],
    849   // dz *  [105/5, 192/6],
    850   //       [105/7, 192/6]
    851   //     ]
    852   // If the input contains a zero, the division is impossible but
    853   // if we take the calculation that gave the first gradient
    854   // (3 * 5 * 6)/3 is equal to 5 * 6
    855   // the trick will be to cumprod the elements on the axis without
    856   // the element at the current position (3 in the example above).
    857   // We will take as example:
    858   // [
    859   //   [
    860   //     [3.0, 4.0],
    861   //     [5.0, 6.0],
    862   //     [7.0, 8.0]
    863   //   ],
    864   //   [
    865   //     [3.0, 5.0],
    866   //     [0.0, 6.0],
    867   //     [5.0, 6.0]
    868   //   ]
    869   // ]
    870 
    871   // [2, 3, 2]
    872   auto input_shape = Shape(scope, op.input(0));
    873 
    874   // The Reshape with -1 flattens the reduction indices.
    875   // [1]
    876   auto reduction_indices = Reshape(scope, op.input(1), {-1});
    877 
    878   // [2, 1, 2]
    879   auto output_shape_kept_dims =
    880       ReducedShapeHelper(scope, input_shape, reduction_indices);
    881 
    882   // [1, 3, 1]
    883   auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims);
    884 
    885   // [[[105, 192]], [[0, 180]]]
    886   auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
    887 
    888   // [[[105, 192], [105, 192], [105, 192]], [[0, 180], [0, 180], [0, 180]]]
    889   auto grad_tiled = Tile(scope, grad, tile_scaling);
    890 
    891   Scope cpu_scope = scope.WithDevice("/cpu:0");
    892 
    893   // [3]
    894   auto rank = Rank(cpu_scope, op.input(0));
    895 
    896 
    897   // Normalize any negative indices in the reduction_axes to positive values.
    898   auto reduction_indices_pos = Mod(cpu_scope, Add(cpu_scope, reduction_indices, rank), rank);
    899 
    900   // [1]
    901   auto reduced = Cast(cpu_scope, reduction_indices_pos, DataType::DT_INT32);
    902 
    903   // [0, 1, 2]
    904   auto idx = Range(cpu_scope, zero, rank, one);
    905 
    906   // [0, 2]
    907   auto other = SetDiff1D(cpu_scope, idx, reduced).out;
    908 
    909   // [1, 0, 2]
    910   auto perm =
    911       Concat(cpu_scope, std::initializer_list<Input>{reduced, other}, 0);
    912 
    913   // 3 => [3]
    914   auto reduced_num = Prod(cpu_scope, Gather(scope, input_shape, reduced), 0);
    915 
    916   // 2 * 2 => [2]
    917   auto other_num = Prod(cpu_scope, Gather(scope, input_shape, other), 0);
    918 
    919   // [
    920   //    [
    921   //       [ 3.,  4.],
    922   //       [ 3.,  5.]
    923   //   ],
    924   //   [
    925   //       [ 5.,  6.],
    926   //       [ 0.,  6.]
    927   //   ],
    928   //   [
    929   //       [ 7.,  8.],
    930   //       [ 5.,  6.]
    931   //   ]
    932   // ]
    933   auto permuted = Transpose(scope, op.input(0), perm);
    934 
    935   // [3, 2, 2]
    936   auto permuted_shape = Shape(scope, permuted);
    937 
    938   // [
    939   //   [ 3.,  4.,  3.,  5.],
    940   //   [ 5.,  6.,  0.,  6.],
    941   //   [ 7.,  8.,  5.,  6.]
    942   // ]
    943   auto reshaped = Reshape(
    944       scope, permuted,
    945       Stack(scope, std::initializer_list<Input>{reduced_num, other_num}));
    946 
    947   // [
    948   //   [ 1.,  1.,  1.,  1.],
    949   //   [ 3.,  4.,  3.,  5.],
    950   //   [ 15.,  24.,  0.,  30.]
    951   // ]
    952   auto left = Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true));
    953 
    954   // [
    955   //   [ 35.,  48.,  0.,  36.],
    956   //   [  7.,   8.,   5.,   6.],
    957   //   [  1.,   1.,   1.,   1.]
    958   // ]
    959   auto right =
    960       Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true).Reverse(true));
    961 
    962   // left * right =
    963   // [
    964   //   [ 35.,  48.,  0.,  36.],
    965   //   [ 21.,  32.,  15.,  30.],
    966   //   [ 15.,  24.,  0.,  30.]
    967   // ]
    968   // y =
    969   // [
    970   //   [
    971   //     [ 35.,  48.],
    972   //     [ 0.,  36.]
    973   //   ],
    974   //   [
    975   //     [ 21.,  32.],
    976   //     [ 15.,  30.]
    977   //   ],
    978   //   [
    979   //     [ 15.,  24.],
    980   //     [ 0.,  30.]
    981   //   ]
    982   // ]
    983   auto y = Reshape(scope, Mul(scope, left, right), permuted_shape);
    984 
    985   // out =
    986   // [
    987   //   [
    988   //     [ 35.,  48.],
    989   //     [ 21.,  32.],
    990   //     [ 15.,  24.]
    991   //   ],
    992   //   [
    993   //     [ 0.,   36.],
    994   //     [ 15.,  30.],
    995   //     [ 0.,  30.]
    996   //   ]
    997   // ]
    998   auto out =
    999       Mul(scope, grad_tiled, Transpose(scope, y, InvertPermutation(scope, perm)));
   1000 
   1001   grad_outputs->push_back(Reshape(scope, out, input_shape));
   1002 
   1003   // stop propagation along reduction_indices
   1004   grad_outputs->push_back(NoGradient());
   1005   return scope.status();
   1006 }
   1007 REGISTER_GRADIENT_OP("Prod", ProdGrad);
   1008 
   1009 // MatMulGrad helper function used to compute two MatMul operations
   1010 // based on input matrix transposition combinations.
   1011 Status MatMulGradHelper(const Scope& scope, const bool is_batch,
   1012                         const Output& x0, const bool adj_x0, const Output& x1,
   1013                         const bool adj_x1, const Output& y0, const bool adj_y0,
   1014                         const Output& y1, const bool adj_y1,
   1015                         std::vector<Output>* grad_outputs) {
   1016   if (is_batch == false) {
   1017     auto dx =
   1018         MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
   1019     grad_outputs->push_back(dx);
   1020     auto dy =
   1021         MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
   1022     grad_outputs->push_back(dy);
   1023   } else {
   1024     auto dx =
   1025         BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1));
   1026     grad_outputs->push_back(dx);
   1027     auto dy =
   1028         BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1));
   1029     grad_outputs->push_back(dy);
   1030   }
   1031   return scope.status();
   1032 }
   1033 
   1034 // MatMulGrad common used to read and check node attr state, and determine
   1035 // proper MatMul products for gradients based on input matrix transposition
   1036 // combinations.
   1037 Status MatMulGradCommon(const Scope& scope, const Operation& op,
   1038                         const bool is_batch,
   1039                         const std::vector<Output>& grad_inputs,
   1040                         const string& attr_adj_x, const string& attr_adj_y,
   1041                         std::vector<Output>* grad_outputs) {
   1042   auto a = op.input(0);
   1043   auto b = op.input(1);
   1044   // Use conjugate of the inputs for MatMul
   1045   if (is_batch == false) {
   1046     a = ConjugateHelper(scope, a);
   1047     b = ConjugateHelper(scope, b);
   1048   }
   1049   auto product = op.output(0);
   1050 
   1051   bool ta;
   1052   bool tb;
   1053   TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_x, &ta));
   1054   TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_y, &tb));
   1055 
   1056   if (!ta && !tb) {
   1057     return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, true, a,
   1058                             true, grad_inputs[0], false, grad_outputs);
   1059   } else if (!ta && tb) {
   1060     return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, false,
   1061                             grad_inputs[0], true, a, false, grad_outputs);
   1062   } else if (ta && !tb) {
   1063     return MatMulGradHelper(scope, is_batch, b, false, grad_inputs[0], true, a,
   1064                             false, grad_inputs[0], false, grad_outputs);
   1065   }
   1066   return MatMulGradHelper(scope, is_batch, b, true, grad_inputs[0], true,
   1067                           grad_inputs[0], true, a, true, grad_outputs);
   1068 }
   1069 
   1070 Status MatMulGrad(const Scope& scope, const Operation& op,
   1071                   const std::vector<Output>& grad_inputs,
   1072                   std::vector<Output>* grad_outputs) {
   1073   return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a",
   1074                           "transpose_b", grad_outputs);
   1075 }
   1076 REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
   1077 
   1078 Status BatchMatMulGrad(const Scope& scope, const Operation& op,
   1079                        const std::vector<Output>& grad_inputs,
   1080                        std::vector<Output>* grad_outputs) {
   1081   return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y",
   1082                           grad_outputs);
   1083 }
   1084 REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad);
   1085 
   1086 }  // anonymous namespace
   1087 }  // namespace ops
   1088 }  // namespace tensorflow
   1089