Home | History | Annotate | Download | only in gradients
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/cc/framework/grad_op_registry.h"
     17 #include "tensorflow/cc/framework/gradient_checker.h"
     18 #include "tensorflow/cc/framework/testutil.h"
     19 #include "tensorflow/cc/gradients/grad_testutil.h"
     20 #include "tensorflow/cc/ops/standard_ops.h"
     21 #include "tensorflow/core/framework/tensor_testutil.h"
     22 #include "tensorflow/core/lib/core/status_test_util.h"
     23 #include "tensorflow/core/lib/random/random.h"
     24 
     25 namespace tensorflow {
     26 namespace {
     27 
     28 using ops::Abs;
     29 using ops::Add;
     30 using ops::AddN;
     31 using ops::BatchMatMul;
     32 using ops::Const;
     33 using ops::Div;
     34 using ops::Greater;
     35 using ops::MatMul;
     36 using ops::Max;
     37 using ops::Maximum;
     38 using ops::Mean;
     39 using ops::Min;
     40 using ops::Minimum;
     41 using ops::Mul;
     42 using ops::Placeholder;
     43 using ops::Pow;
     44 using ops::Prod;
     45 using ops::RealDiv;
     46 using ops::SquaredDifference;
     47 using ops::Sub;
     48 using ops::Sum;
     49 using ops::Where3;
     50 
     51 // TODO(andydavis) Test gradient function against numeric gradients output.
     52 // TODO(andydavis) As more gradients are added move common test functions
     53 // to a testutil library.
     54 
     55 class CWiseUnaryGradTest : public ::testing::Test {
     56  protected:
     57   CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
     58 
     59   enum UnaryOpType {
     60     ABS,
     61     NEG,
     62     INV,
     63     SQUARE,
     64     SQRT,
     65     RSQRT,
     66     EXP,
     67     EXPM1,
     68     LOG,
     69     LOG1P,
     70     SINH,
     71     COSH,
     72     TANH,
     73     ASINH,
     74     ACOSH,
     75     ATANH,
     76     SIGMOID,
     77     SIGN,
     78     SIN,
     79     COS,
     80     ASIN,
     81     ACOS,
     82     TAN,
     83     ATAN,
     84     REAL,
     85     IMAG,
     86     CONJ,
     87     COMPLEX,
     88     ANGLE,
     89     LGAMMA,
     90     ERF
     91   };
     92 
     93   template <typename X_T, typename Y_T>
     94   void TestCWiseGrad(UnaryOpType op_type, const std::function<X_T(int)>& x_fn) {
     95     TF_ASSERT_OK(scope_.status());
     96     DataType x_type = DataTypeToEnum<X_T>::v();
     97     TensorShape shape({2, 3, 2});
     98     auto x = Placeholder(scope_, x_type, Placeholder::Shape(shape));
     99     Tensor x_data(x_type, shape);
    100     auto x_data_flat = x_data.flat<X_T>();
    101     for (int i = 0; i < x_data_flat.size(); ++i) {
    102       x_data_flat(i) = x_fn(i);
    103     }
    104 
    105     Output y;
    106     switch (op_type) {
    107       using namespace ops;  // NOLINT(build/namespaces)
    108       case ABS:
    109         y = Abs(scope_, x);
    110         break;
    111       case NEG:
    112         y = Neg(scope_, x);
    113         break;
    114       case INV:
    115         y = Reciprocal(scope_, x);
    116         break;
    117       case SQUARE:
    118         y = Square(scope_, x);
    119         break;
    120       case SQRT:
    121         y = Sqrt(scope_, x);
    122         break;
    123       case RSQRT:
    124         y = Rsqrt(scope_, x);
    125         break;
    126       case EXP:
    127         y = Exp(scope_, x);
    128         break;
    129       case EXPM1:
    130         y = Expm1(scope_, x);
    131         break;
    132       case LOG:
    133         y = Log(scope_, x);
    134         break;
    135       case LOG1P:
    136         y = Log1p(scope_, x);
    137         break;
    138       case SINH:
    139         y = Sinh(scope_, x);
    140         break;
    141       case COSH:
    142         y = Cosh(scope_, x);
    143         break;
    144       case TANH:
    145         y = Tanh(scope_, x);
    146         break;
    147       case ASINH:
    148         y = Asinh(scope_, x);
    149         break;
    150       case ACOSH:
    151         y = Acosh(scope_, x);
    152         break;
    153       case ATANH:
    154         y = Atanh(scope_, x);
    155         break;
    156       case SIGMOID:
    157         y = Sigmoid(scope_, x);
    158         break;
    159       case SIGN:
    160         y = Sign(scope_, x);
    161         break;
    162       case SIN:
    163         y = Sin(scope_, x);
    164         break;
    165       case COS:
    166         y = Cos(scope_, x);
    167         break;
    168       case ASIN:
    169         y = Asin(scope_, x);
    170         break;
    171       case ACOS:
    172         y = Acos(scope_, x);
    173         break;
    174       case TAN:
    175         y = Tan(scope_, x);
    176         break;
    177       case ATAN:
    178         y = Atan(scope_, x);
    179         break;
    180       case REAL:
    181         y = Real(scope_, x);
    182         break;
    183       case IMAG:
    184         y = Imag(scope_, x);
    185         break;
    186       case CONJ:
    187         y = Conj(scope_, x);
    188         break;
    189       case COMPLEX:
    190         y = Complex(scope_, x, x);
    191         break;
    192       case ANGLE:
    193         y = Angle(scope_, x);
    194         break;
    195       case LGAMMA:
    196         y = Lgamma(scope_, x);
    197         break;
    198       case ERF:
    199         y = Erf(scope_, x);
    200         break;
    201     }
    202 
    203     float max_error;
    204     TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, float>(scope_, x, x_data, y,
    205                                                         shape, &max_error)));
    206     EXPECT_LT(max_error, 1e-3f);
    207   }
    208 
    209   float RV(const std::vector<float>& v) {
    210     return v[random::New64() % v.size()];
    211   }
    212 
    213   complex64 CRV(const std::vector<complex64>& v) {
    214     return v[random::New64() % v.size()];
    215   }
    216 
    217   complex64 conjugate(const complex64& val) {
    218     return complex64(val.real(), -val.imag());
    219   }
    220 
    221   Scope scope_;
    222 };
    223 
    224 TEST_F(CWiseUnaryGradTest, Abs) {
    225   auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
    226   TestCWiseGrad<float, float>(ABS, x_fn);
    227 }
    228 
    229 TEST_F(CWiseUnaryGradTest, Neg) {
    230   auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
    231   TestCWiseGrad<float, float>(NEG, x_fn);
    232 }
    233 
    234 TEST_F(CWiseUnaryGradTest, Reciprocal) {
    235   auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
    236   TestCWiseGrad<float, float>(INV, x_fn);
    237 }
    238 
    239 TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) {
    240   auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
    241   TestCWiseGrad<complex64, complex64>(INV, x_fn);
    242 }
    243 
    244 TEST_F(CWiseUnaryGradTest, Square) {
    245   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
    246   TestCWiseGrad<float, float>(SQUARE, x_fn);
    247 }
    248 
    249 TEST_F(CWiseUnaryGradTest, Square_Complex) {
    250   auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
    251   TestCWiseGrad<complex64, complex64>(SQUARE, x_fn);
    252 }
    253 
    254 TEST_F(CWiseUnaryGradTest, Sqrt) {
    255   auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4, 5, 6, 7}); };
    256   TestCWiseGrad<float, float>(SQRT, x_fn);
    257 }
    258 
    259 TEST_F(CWiseUnaryGradTest, Sqrt_Complex) {
    260   auto x_fn = [this](const int i) {
    261     return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}});
    262   };
    263   TestCWiseGrad<complex64, complex64>(SQRT, x_fn);
    264 }
    265 
    266 TEST_F(CWiseUnaryGradTest, Rsqrt) {
    267   auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); };
    268   TestCWiseGrad<float, float>(RSQRT, x_fn);
    269 }
    270 
    271 TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) {
    272   auto x_fn = [this](const int i) {
    273     return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}});
    274   };
    275   TestCWiseGrad<complex64, complex64>(RSQRT, x_fn);
    276 }
    277 
    278 TEST_F(CWiseUnaryGradTest, Exp) {
    279   auto x_fn = [this](const int i) {
    280     return RV({0, -1, 1, -1.5f, 1.5f, -2, 2});
    281   };
    282   TestCWiseGrad<float, float>(EXP, x_fn);
    283 }
    284 
    285 TEST_F(CWiseUnaryGradTest, Exp_Complex) {
    286   auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
    287   TestCWiseGrad<complex64, complex64>(EXP, x_fn);
    288 }
    289 
    290 TEST_F(CWiseUnaryGradTest, Expm1) {
    291   auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -1.5, 1.5}); };
    292   TestCWiseGrad<float, float>(EXPM1, x_fn);
    293 }
    294 
    295 TEST_F(CWiseUnaryGradTest, Expm1_Complex) {
    296   auto x_fn = [this](const int i) {
    297     return CRV({{-1, 0}, {1, 0}, {1.5, -1.5}});
    298   };
    299   TestCWiseGrad<complex64, complex64>(EXPM1, x_fn);
    300 }
    301 
    302 TEST_F(CWiseUnaryGradTest, Log) {
    303   auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4}); };
    304   TestCWiseGrad<float, float>(LOG, x_fn);
    305 }
    306 
    307 TEST_F(CWiseUnaryGradTest, Log_Complex) {
    308   auto x_fn = [this](const int i) {
    309     return CRV({{-1, 0.5f}, {1, 0.5f}, {2, -1}});
    310   };
    311   TestCWiseGrad<complex64, complex64>(LOG, x_fn);
    312 }
    313 
    314 TEST_F(CWiseUnaryGradTest, Log1p) {
    315   auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); };
    316   TestCWiseGrad<float, float>(LOG1P, x_fn);
    317 }
    318 
    319 TEST_F(CWiseUnaryGradTest, Log1p_Complex) {
    320   auto x_fn = [this](const int i) {
    321     return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}});
    322   };
    323   TestCWiseGrad<complex64, complex64>(LOG1P, x_fn);
    324 }
    325 
    326 TEST_F(CWiseUnaryGradTest, Sinh) {
    327   auto x_fn = [this](const int i) { return RV({0.5, -0.5, 1, -1, 1.5, -1.5}); };
    328   TestCWiseGrad<float, float>(SINH, x_fn);
    329 }
    330 
    331 TEST_F(CWiseUnaryGradTest, Sinh_Complex) {
    332   auto x_fn = [this](const int i) {
    333     return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}});
    334   };
    335   TestCWiseGrad<complex64, complex64>(SINH, x_fn);
    336 }
    337 
    338 TEST_F(CWiseUnaryGradTest, Cosh) {
    339   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
    340   TestCWiseGrad<float, float>(COSH, x_fn);
    341 }
    342 
    343 TEST_F(CWiseUnaryGradTest, Cosh_Complex) {
    344   auto x_fn = [this](const int i) {
    345     return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}});
    346   };
    347   TestCWiseGrad<complex64, complex64>(COSH, x_fn);
    348 }
    349 
    350 TEST_F(CWiseUnaryGradTest, Tanh) {
    351   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
    352   TestCWiseGrad<float, float>(TANH, x_fn);
    353 }
    354 
    355 TEST_F(CWiseUnaryGradTest, Tanh_Complex) {
    356   auto x_fn = [this](const int i) {
    357     return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
    358   };
    359   TestCWiseGrad<complex64, complex64>(TANH, x_fn);
    360 }
    361 
    362 TEST_F(CWiseUnaryGradTest, Asinh) {
    363   auto x_fn = [this](const int i) { return RV({0.5, 1, -1, -1.5, 1.5}); };
    364   TestCWiseGrad<float, float>(ASINH, x_fn);
    365 }
    366 
    367 TEST_F(CWiseUnaryGradTest, Asinh_Complex) {
    368   auto x_fn = [this](const int i) {
    369     return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}});
    370   };
    371   TestCWiseGrad<complex64, complex64>(ASINH, x_fn);
    372 }
    373 
    374 TEST_F(CWiseUnaryGradTest, Acosh) {
    375   auto x_fn = [this](const int i) { return RV({1.5, 2, 2.5}); };
    376   TestCWiseGrad<float, float>(ACOSH, x_fn);
    377 }
    378 
    379 TEST_F(CWiseUnaryGradTest, Acosh_Complex) {
    380   auto x_fn = [this](const int i) {
    381     return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}});
    382   };
    383   TestCWiseGrad<complex64, complex64>(ACOSH, x_fn);
    384 }
    385 
    386 TEST_F(CWiseUnaryGradTest, Atanh) {
    387   auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); };
    388   TestCWiseGrad<float, float>(ATANH, x_fn);
    389 }
    390 
    391 TEST_F(CWiseUnaryGradTest, Atanh_Complex) {
    392   auto x_fn = [this](const int i) {
    393     return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}});
    394   };
    395   TestCWiseGrad<complex64, complex64>(ATANH, x_fn);
    396 }
    397 
    398 TEST_F(CWiseUnaryGradTest, Sigmoid) {
    399   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
    400   TestCWiseGrad<float, float>(SIGMOID, x_fn);
    401 }
    402 
    403 TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) {
    404   auto x_fn = [this](const int i) {
    405     return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}});
    406   };
    407   TestCWiseGrad<complex64, complex64>(SIGMOID, x_fn);
    408 }
    409 
    410 TEST_F(CWiseUnaryGradTest, Sign) {
    411   auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3}); };
    412   TestCWiseGrad<float, float>(SIGN, x_fn);
    413 }
    414 
    415 TEST_F(CWiseUnaryGradTest, Sin) {
    416   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
    417   TestCWiseGrad<float, float>(SIN, x_fn);
    418 }
    419 
    420 TEST_F(CWiseUnaryGradTest, Sin_Complex) {
    421   auto x_fn = [this](const int i) {
    422     return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}});
    423   };
    424   TestCWiseGrad<complex64, complex64>(SIN, x_fn);
    425 }
    426 
    427 TEST_F(CWiseUnaryGradTest, Cos) {
    428   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
    429   TestCWiseGrad<float, float>(COS, x_fn);
    430 }
    431 
    432 TEST_F(CWiseUnaryGradTest, Cos_Complex) {
    433   auto x_fn = [this](const int i) {
    434     return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}});
    435   };
    436   TestCWiseGrad<complex64, complex64>(COS, x_fn);
    437 }
    438 
    439 TEST_F(CWiseUnaryGradTest, Asin) {
    440   auto x_fn = [this](const int i) { return RV({0, 0.25, -0.25, -0.5, 0.5}); };
    441   TestCWiseGrad<float, float>(ASIN, x_fn);
    442 }
    443 
    444 TEST_F(CWiseUnaryGradTest, Asin_Complex) {
    445   auto x_fn = [this](const int i) {
    446     return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}});
    447   };
    448   // TODO(kbsriram)
    449   // Enable test when the asin kernel supports complex numbers
    450   if (false) {
    451     TestCWiseGrad<complex64, complex64>(ASIN, x_fn);
    452   }
    453 }
    454 
    455 TEST_F(CWiseUnaryGradTest, Acos) {
    456   auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.75, 0.75}); };
    457   TestCWiseGrad<float, float>(ACOS, x_fn);
    458 }
    459 
    460 TEST_F(CWiseUnaryGradTest, Acos_Complex) {
    461   auto x_fn = [this](const int i) {
    462     return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}});
    463   };
    464   // TODO(kbsriram)
    465   // Add test when the acos kernel supports complex numbers
    466   if (false) {
    467     TestCWiseGrad<complex64, complex64>(ACOS, x_fn);
    468   }
    469 }
    470 
    471 TEST_F(CWiseUnaryGradTest, Tan) {
    472   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
    473   TestCWiseGrad<float, float>(TAN, x_fn);
    474 }
    475 
    476 TEST_F(CWiseUnaryGradTest, Tan_Complex) {
    477   auto x_fn = [this](const int i) {
    478     return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
    479   };
    480   // TODO(kbsriram)
    481   // Enable when tan kernel supports complex inputs
    482   if (false) {
    483     TestCWiseGrad<complex64, complex64>(TAN, x_fn);
    484   }
    485 }
    486 
    487 TEST_F(CWiseUnaryGradTest, Atan) {
    488   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
    489   TestCWiseGrad<float, float>(ATAN, x_fn);
    490 }
    491 
    492 TEST_F(CWiseUnaryGradTest, Atan_Complex) {
    493   auto x_fn = [this](const int i) {
    494     return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
    495   };
    496   // TODO(kbsriram)
    497   // Add test when the atan kernel supports complex numbers
    498   if (false) {
    499     TestCWiseGrad<complex64, complex64>(ATAN, x_fn);
    500   }
    501 }
    502 
    503 TEST_F(CWiseUnaryGradTest, Real) {
    504   auto x_fn = [this](const int i) {
    505     return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
    506   };
    507   TestCWiseGrad<complex64, float>(REAL, x_fn);
    508 }
    509 
    510 TEST_F(CWiseUnaryGradTest, Imag) {
    511   auto x_fn = [this](const int i) {
    512     return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
    513   };
    514   TestCWiseGrad<complex64, float>(IMAG, x_fn);
    515 }
    516 
    517 TEST_F(CWiseUnaryGradTest, Conj) {
    518   auto x_fn = [this](const int i) {
    519     return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
    520   };
    521   TestCWiseGrad<complex64, complex64>(CONJ, x_fn);
    522 }
    523 
    524 TEST_F(CWiseUnaryGradTest, Complex) {
    525   auto x_fn = [this](const int i) { return RV({1, -1, 2, -2, 3, -3}); };
    526   TestCWiseGrad<float, complex64>(COMPLEX, x_fn);
    527 }
    528 
    529 TEST_F(CWiseUnaryGradTest, Angle) {
    530   auto x_fn = [this](const int i) {
    531     return CRV({{1.5, 1.5}, {1.5, -1.5}, {-1.5, 1.5}, {-1.5, -1.5}});
    532   };
    533   TestCWiseGrad<complex64, float>(ANGLE, x_fn);
    534 }
    535 
    536 TEST_F(CWiseUnaryGradTest, Lgamma) {
    537   auto x_fn = [this](const int i) {
    538     return RV({-3.5, -2.5, -1.5, 1.0, 2.0, 3.5});
    539   };
    540   TestCWiseGrad<float, float>(LGAMMA, x_fn);
    541 }
    542 
    543 TEST_F(CWiseUnaryGradTest, Lgamma_Complex) {
    544   auto x_fn = [this](const int i) {
    545     return CRV({{-3.5, 0.5}, {-1.5, -0.5}, {1.5, -1.0}, {3.5, 1.0}});
    546   };
    547   // TODO(kbsriram)
    548   // Add test when the lgamma kernel supports complex numbers
    549   if (false) {
    550     TestCWiseGrad<complex64, complex64>(LGAMMA, x_fn);
    551   }
    552 }
    553 
    554 TEST_F(CWiseUnaryGradTest, Erf) {
    555   auto x_fn = [this](const int i) {
    556     return RV({-1.2, -1.0, -0.5, 0.3, 0.5, 1.3});
    557   };
    558   TestCWiseGrad<float, float>(ERF, x_fn);
    559 }
    560 
    561 TEST_F(CWiseUnaryGradTest, Erf_Complex) {
    562   auto x_fn = [this](const int i) {
    563     return CRV({{-1.2, 0.5}, {-0.5, -0.5}, {0.5, 0.5}, {1.2, -0.5}});
    564   };
    565   // TODO(kbsriram)
    566   // Add test when the erf kernel supports complex numbers
    567   if (false) {
    568     TestCWiseGrad<complex64, complex64>(ERF, x_fn);
    569   }
    570 }
    571 
    572 class MathGradTest : public ::testing::Test {
    573  protected:
    574   MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
    575 
    576   template <typename T>
    577   void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) {
    578     TF_ASSERT_OK(root_.status());
    579     // Generate random (but compatible) shapes for matrix multiplication.
    580     std::vector<TensorShape> shapes;
    581     RandMatMulShapes(is_batch, t_x, t_y, &shapes);
    582     TensorShape x_shape = shapes[0];
    583     TensorShape y_shape = shapes[1];
    584     TensorShape z_shape = shapes[2];
    585     auto x =
    586         Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(x_shape));
    587     auto y =
    588         Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(y_shape));
    589     Output z;
    590     if (is_batch) {
    591       z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
    592     } else {
    593       z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
    594     }
    595 
    596     float max_error;
    597     TF_ASSERT_OK((ComputeGradientError<T, T, float>(
    598         root_, {x, y}, {x_shape, y_shape}, {z}, {z_shape}, &max_error)));
    599     EXPECT_LT(max_error, 1e-3);
    600   }
    601 
    602   void RandMatMulShapes(const bool is_batch, const bool tx, const bool ty,
    603                         std::vector<TensorShape>* shapes) {
    604     // Choose a random batch size in [1, 4]
    605     const int b = 1 + (random::New64() % 4);
    606     // z = MatMul(x, y)
    607     const int m = Rand();
    608     const int k = Rand();
    609     const int n = Rand();
    610 
    611     TensorShape x_shape;
    612     if (is_batch) {
    613       // x.shape = [b, m, k]
    614       x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k});
    615     } else {
    616       // x.shape = [m, k]
    617       x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
    618     }
    619     shapes->push_back(x_shape);
    620 
    621     TensorShape y_shape;
    622     if (is_batch) {
    623       // y.shape = [b, k, n]
    624       y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n});
    625     } else {
    626       // y.shape = [k, n]
    627       y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
    628     }
    629     shapes->push_back(y_shape);
    630 
    631     TensorShape z_shape;
    632     if (is_batch) {
    633       // z.shape = [b, m, n]
    634       z_shape = TensorShape({b, m, n});
    635     } else {
    636       // z.shape = [m, n]
    637       z_shape = TensorShape({m, n});
    638     }
    639     shapes->push_back(z_shape);
    640   }
    641 
    642   int Rand() { return 1 + (random::New64() % 10); }
    643 
    644   Scope root_;
    645 };
    646 
    647 TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
    648   TestMatMulGrad<float>(false, false, false);
    649 }
    650 
    651 TEST_F(MathGradTest, MatMulComplexGrad_NoTranspose) {
    652   TestMatMulGrad<complex64>(false, false, false);
    653 }
    654 
    655 TEST_F(MathGradTest, MatMulGrad_TransposeX) {
    656   TestMatMulGrad<float>(false, true, false);
    657 }
    658 
    659 TEST_F(MathGradTest, MatMulComplexGrad_TransposeX) {
    660   TestMatMulGrad<complex64>(false, true, false);
    661 }
    662 
    663 TEST_F(MathGradTest, MatMulGrad_TransposeY) {
    664   TestMatMulGrad<float>(false, false, true);
    665 }
    666 
    667 TEST_F(MathGradTest, MatMulComplexGrad_TransposeY) {
    668   TestMatMulGrad<complex64>(false, false, true);
    669 }
    670 
    671 TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
    672   TestMatMulGrad<float>(false, true, true);
    673 }
    674 
    675 TEST_F(MathGradTest, MatMulComplexGrad_TransposeX_TransposeY) {
    676   TestMatMulGrad<complex64>(false, true, true);
    677 }
    678 
    679 TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) {
    680   TestMatMulGrad<float>(true, false, false);
    681 }
    682 
    683 TEST_F(MathGradTest, BatchMatMulComplexGrad_NoTranspose) {
    684   TestMatMulGrad<complex64>(true, false, false);
    685 }
    686 
    687 TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) {
    688   TestMatMulGrad<float>(true, true, false);
    689 }
    690 
    691 TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX) {
    692   TestMatMulGrad<complex64>(true, true, false);
    693 }
    694 
    695 TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) {
    696   TestMatMulGrad<float>(true, false, true);
    697 }
    698 
    699 TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeY) {
    700   TestMatMulGrad<complex64>(true, false, true);
    701 }
    702 
    703 TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) {
    704   TestMatMulGrad<float>(true, true, true);
    705 }
    706 
    707 TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX_TransposeY) {
    708   TestMatMulGrad<complex64>(true, true, true);
    709 }
    710 
    711 class NaryGradTest : public ::testing::Test {
    712  protected:
    713   NaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
    714 
    715   void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes,
    716                const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
    717     TF_ASSERT_OK(scope_.status());
    718     float max_error;
    719     TF_ASSERT_OK((ComputeGradientError<float, float, float>(
    720         scope_, xs, x_shapes, ys, y_shapes, &max_error)));
    721     EXPECT_LT(max_error, 1e-3);
    722   }
    723 
    724   void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
    725                const TensorShape& y_shape) {
    726     TF_ASSERT_OK(scope_.status());
    727     float max_error;
    728     TF_ASSERT_OK((ComputeGradientError<float, float, float>(
    729         scope_, x, x_init_value, y, y_shape, &max_error)));
    730     EXPECT_LT(max_error, 1e-3);
    731   }
    732 
    733   Scope scope_;
    734 };
    735 
    736 TEST_F(NaryGradTest, Sum) {
    737   TensorShape x_shape({2, 3, 5, 7});
    738   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
    739   auto y = Sum(scope_, x, {1, -1});
    740   // y's shape is the result of reducing x along axes 1 and -1 (= 3)
    741   TensorShape y_shape({2, 5});
    742   RunTest({x}, {x_shape}, {y}, {y_shape});
    743 }
    744 
    745 TEST_F(NaryGradTest, Mean) {
    746   TensorShape x_shape({2, 3, 5, 7});
    747   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
    748   auto y = Mean(scope_, x, {1, -1});
    749   // y's shape is the result of reducing x along axes 1 and -1 (= 3)
    750   TensorShape y_shape({2, 5});
    751   RunTest({x}, {x_shape}, {y}, {y_shape});
    752 }
    753 
    754 TEST_F(NaryGradTest, Min) {
    755   TensorShape x_shape({2, 3});
    756   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
    757   auto y = Min(scope_, x, {-1});
    758   // y's shape is the result of reducing x along axes -1 (= 1)
    759   TensorShape y_shape({2});
    760   Tensor x_init_value =
    761       test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
    762   RunTest(x, x_init_value, y, y_shape);
    763 }
    764 
    765 TEST_F(NaryGradTest, Max) {
    766   TensorShape x_shape({2, 3});
    767   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
    768   auto y = Max(scope_, x, {-1});
    769   // y's shape is the result of reducing x along axes -1 (= 1)
    770   TensorShape y_shape({2});
    771   Tensor x_init_value =
    772       test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
    773   RunTest(x, x_init_value, y, y_shape);
    774 }
    775 
    776 TEST_F(NaryGradTest, MinMulti) {
    777   // Test gradient when there are multiple minima.
    778   // Note that we cannot directly use a test Tensor with multiple
    779   // minima, as the numeric estimator will calculate incorrect
    780   // gradients when perturbing each entry in the Tensor (which then
    781   // changes how many minima exist.)
    782   // Instead, we use a single input that broadcast-multiplies a larger
    783   // tensor with equal values, and apply reduce_min to the multiplied
    784   // result.
    785   TensorShape x_shape({1});
    786   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
    787   auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
    788   auto y = Min(scope_, all_same, {0});
    789   // y is a [3] shaped tensor reduced along dimension 0, so it is [1] shaped
    790   TensorShape y_shape({1});
    791   RunTest({x}, {x_shape}, {y}, {y_shape});
    792 }
    793 
    794 TEST_F(NaryGradTest, MaxMulti) {
    795   TensorShape x_shape({1});
    796   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
    797   auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
    798   auto y = Max(scope_, all_same, {0});
    799   TensorShape y_shape({1});
    800   RunTest({x}, {x_shape}, {y}, {y_shape});
    801 }
    802 
    803 TEST_F(NaryGradTest, AddN) {
    804   TensorShape shape({3, 2, 5});
    805   std::vector<Output> xs;
    806   xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
    807   xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
    808   xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
    809   auto y = AddN(scope_, xs);
    810   RunTest(xs, {shape, shape, shape}, {y}, {shape});
    811 }
    812 
    813 TEST_F(NaryGradTest, Add) {
    814   TensorShape x1_shape({3, 2, 5});
    815   TensorShape x2_shape({2, 5});
    816   auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
    817   auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
    818   auto y = Add(scope_, x1, x2);
    819   RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
    820 }
    821 
    822 TEST_F(NaryGradTest, Sub) {
    823   TensorShape x1_shape({3, 2, 5});
    824   TensorShape x2_shape({2, 5});
    825   auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
    826   auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
    827   auto y = Sub(scope_, x1, x2);
    828   RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
    829 }
    830 
    831 TEST_F(NaryGradTest, Mul) {
    832   TensorShape x1_shape({3, 2, 5});
    833   TensorShape x2_shape({2, 5});
    834   auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
    835   auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
    836   auto y = Mul(scope_, x1, x2);
    837   RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
    838 }
    839 
    840 TEST_F(NaryGradTest, Div) {
    841   TensorShape x_shape({3, 2, 5});
    842   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
    843   // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
    844   // division errors in the numeric estimator used by the gradient checker.
    845   auto y = Div(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
    846   RunTest({x}, {x_shape}, {y}, {x_shape});
    847 }
    848 
    849 TEST_F(NaryGradTest, RealDiv) {
    850   TensorShape x_shape({3, 2, 5});
    851   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
    852   // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
    853   // division errors in the numeric estimator used by the gradient checker.
    854   auto y =
    855       RealDiv(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
    856   RunTest({x}, {x_shape}, {y}, {x_shape});
    857 }
    858 
    859 TEST_F(NaryGradTest, SquaredDifference) {
    860   TensorShape x1_shape({3, 2, 5});
    861   TensorShape x2_shape({2, 5});
    862   auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
    863   auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
    864   auto y = SquaredDifference(scope_, x1, x2);
    865   RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
    866 }
    867 
    868 TEST_F(NaryGradTest, Pow) {
    869   TensorShape shape({3});
    870   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
    871   // fix exponent to avoid overflow
    872   auto y = Pow(scope_, x, Const(scope_, {1.f, 2.f, 3.f}));
    873   RunTest({x}, {shape}, {y}, {shape});
    874 }
    875 
    876 TEST_F(NaryGradTest, Maximum) {
    877   TensorShape shape({3, 2});
    878   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
    879   auto y = Maximum(scope_, x, Const(scope_, 1.0f));
    880   // Select values away from 1.0f to avoid instability when computing
    881   // finite differences.
    882   Tensor x_init_value =
    883       test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
    884   RunTest(x, x_init_value, y, shape);
    885 }
    886 
    887 TEST_F(NaryGradTest, Minimum) {
    888   TensorShape shape({3, 2});
    889   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
    890   auto y = Minimum(scope_, x, Const(scope_, 1.0f));
    891   // Select values away from 1.0f to avoid instability when computing
    892   // finite differences.
    893   Tensor x_init_value =
    894       test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
    895   RunTest(x, x_init_value, y, shape);
    896 }
    897 
    898 TEST_F(NaryGradTest, Prod) {
    899   TensorShape x_shape({2, 3, 2});
    900   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
    901   auto y = Prod(scope_, x, {1});
    902   // y's shape is the result of reducing x along axes 1
    903   TensorShape y_shape({2, 1, 2});
    904   RunTest({x}, {x_shape}, {y}, {y_shape});
    905 }
    906 
    907 }  // namespace
    908 }  // namespace tensorflow
    909