1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/cc/framework/grad_op_registry.h" 17 #include "tensorflow/cc/framework/gradient_checker.h" 18 #include "tensorflow/cc/framework/testutil.h" 19 #include "tensorflow/cc/gradients/grad_testutil.h" 20 #include "tensorflow/cc/ops/standard_ops.h" 21 #include "tensorflow/core/framework/tensor_testutil.h" 22 #include "tensorflow/core/lib/core/status_test_util.h" 23 #include "tensorflow/core/lib/random/random.h" 24 25 namespace tensorflow { 26 namespace { 27 28 using ops::Abs; 29 using ops::Add; 30 using ops::AddN; 31 using ops::BatchMatMul; 32 using ops::Const; 33 using ops::Div; 34 using ops::Greater; 35 using ops::MatMul; 36 using ops::Max; 37 using ops::Maximum; 38 using ops::Mean; 39 using ops::Min; 40 using ops::Minimum; 41 using ops::Mul; 42 using ops::Placeholder; 43 using ops::Pow; 44 using ops::Prod; 45 using ops::RealDiv; 46 using ops::SquaredDifference; 47 using ops::Sub; 48 using ops::Sum; 49 using ops::Where3; 50 51 // TODO(andydavis) Test gradient function against numeric gradients output. 52 // TODO(andydavis) As more gradients are added move common test functions 53 // to a testutil library. 54 55 class CWiseUnaryGradTest : public ::testing::Test { 56 protected: 57 CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 58 59 enum UnaryOpType { 60 ABS, 61 NEG, 62 INV, 63 SQUARE, 64 SQRT, 65 RSQRT, 66 EXP, 67 EXPM1, 68 LOG, 69 LOG1P, 70 SINH, 71 COSH, 72 TANH, 73 ASINH, 74 ACOSH, 75 ATANH, 76 SIGMOID, 77 SIGN, 78 SIN, 79 COS, 80 ASIN, 81 ACOS, 82 TAN, 83 ATAN, 84 REAL, 85 IMAG, 86 CONJ, 87 COMPLEX, 88 ANGLE, 89 LGAMMA, 90 ERF 91 }; 92 93 template <typename X_T, typename Y_T> 94 void TestCWiseGrad(UnaryOpType op_type, const std::function<X_T(int)>& x_fn) { 95 TF_ASSERT_OK(scope_.status()); 96 DataType x_type = DataTypeToEnum<X_T>::v(); 97 TensorShape shape({2, 3, 2}); 98 auto x = Placeholder(scope_, x_type, Placeholder::Shape(shape)); 99 Tensor x_data(x_type, shape); 100 auto x_data_flat = x_data.flat<X_T>(); 101 for (int i = 0; i < x_data_flat.size(); ++i) { 102 x_data_flat(i) = x_fn(i); 103 } 104 105 Output y; 106 switch (op_type) { 107 using namespace ops; // NOLINT(build/namespaces) 108 case ABS: 109 y = Abs(scope_, x); 110 break; 111 case NEG: 112 y = Neg(scope_, x); 113 break; 114 case INV: 115 y = Reciprocal(scope_, x); 116 break; 117 case SQUARE: 118 y = Square(scope_, x); 119 break; 120 case SQRT: 121 y = Sqrt(scope_, x); 122 break; 123 case RSQRT: 124 y = Rsqrt(scope_, x); 125 break; 126 case EXP: 127 y = Exp(scope_, x); 128 break; 129 case EXPM1: 130 y = Expm1(scope_, x); 131 break; 132 case LOG: 133 y = Log(scope_, x); 134 break; 135 case LOG1P: 136 y = Log1p(scope_, x); 137 break; 138 case SINH: 139 y = Sinh(scope_, x); 140 break; 141 case COSH: 142 y = Cosh(scope_, x); 143 break; 144 case TANH: 145 y = Tanh(scope_, x); 146 break; 147 case ASINH: 148 y = Asinh(scope_, x); 149 break; 150 case ACOSH: 151 y = Acosh(scope_, x); 152 break; 153 case ATANH: 154 y = Atanh(scope_, x); 155 break; 156 case SIGMOID: 157 y = Sigmoid(scope_, x); 158 break; 159 case SIGN: 160 y = Sign(scope_, x); 161 break; 162 case SIN: 163 y = Sin(scope_, x); 164 break; 165 case COS: 166 y = Cos(scope_, x); 167 break; 168 case ASIN: 169 y = Asin(scope_, x); 170 break; 171 case ACOS: 172 y = Acos(scope_, x); 173 break; 174 case TAN: 175 y = Tan(scope_, x); 176 break; 177 case ATAN: 178 y = Atan(scope_, x); 179 break; 180 case REAL: 181 y = Real(scope_, x); 182 break; 183 case IMAG: 184 y = Imag(scope_, x); 185 break; 186 case CONJ: 187 y = Conj(scope_, x); 188 break; 189 case COMPLEX: 190 y = Complex(scope_, x, x); 191 break; 192 case ANGLE: 193 y = Angle(scope_, x); 194 break; 195 case LGAMMA: 196 y = Lgamma(scope_, x); 197 break; 198 case ERF: 199 y = Erf(scope_, x); 200 break; 201 } 202 203 float max_error; 204 TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, float>(scope_, x, x_data, y, 205 shape, &max_error))); 206 EXPECT_LT(max_error, 1e-3f); 207 } 208 209 float RV(const std::vector<float>& v) { 210 return v[random::New64() % v.size()]; 211 } 212 213 complex64 CRV(const std::vector<complex64>& v) { 214 return v[random::New64() % v.size()]; 215 } 216 217 complex64 conjugate(const complex64& val) { 218 return complex64(val.real(), -val.imag()); 219 } 220 221 Scope scope_; 222 }; 223 224 TEST_F(CWiseUnaryGradTest, Abs) { 225 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 226 TestCWiseGrad<float, float>(ABS, x_fn); 227 } 228 229 TEST_F(CWiseUnaryGradTest, Neg) { 230 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 231 TestCWiseGrad<float, float>(NEG, x_fn); 232 } 233 234 TEST_F(CWiseUnaryGradTest, Reciprocal) { 235 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; 236 TestCWiseGrad<float, float>(INV, x_fn); 237 } 238 239 TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) { 240 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 241 TestCWiseGrad<complex64, complex64>(INV, x_fn); 242 } 243 244 TEST_F(CWiseUnaryGradTest, Square) { 245 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 246 TestCWiseGrad<float, float>(SQUARE, x_fn); 247 } 248 249 TEST_F(CWiseUnaryGradTest, Square_Complex) { 250 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 251 TestCWiseGrad<complex64, complex64>(SQUARE, x_fn); 252 } 253 254 TEST_F(CWiseUnaryGradTest, Sqrt) { 255 auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4, 5, 6, 7}); }; 256 TestCWiseGrad<float, float>(SQRT, x_fn); 257 } 258 259 TEST_F(CWiseUnaryGradTest, Sqrt_Complex) { 260 auto x_fn = [this](const int i) { 261 return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}}); 262 }; 263 TestCWiseGrad<complex64, complex64>(SQRT, x_fn); 264 } 265 266 TEST_F(CWiseUnaryGradTest, Rsqrt) { 267 auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); }; 268 TestCWiseGrad<float, float>(RSQRT, x_fn); 269 } 270 271 TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) { 272 auto x_fn = [this](const int i) { 273 return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}}); 274 }; 275 TestCWiseGrad<complex64, complex64>(RSQRT, x_fn); 276 } 277 278 TEST_F(CWiseUnaryGradTest, Exp) { 279 auto x_fn = [this](const int i) { 280 return RV({0, -1, 1, -1.5f, 1.5f, -2, 2}); 281 }; 282 TestCWiseGrad<float, float>(EXP, x_fn); 283 } 284 285 TEST_F(CWiseUnaryGradTest, Exp_Complex) { 286 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 287 TestCWiseGrad<complex64, complex64>(EXP, x_fn); 288 } 289 290 TEST_F(CWiseUnaryGradTest, Expm1) { 291 auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -1.5, 1.5}); }; 292 TestCWiseGrad<float, float>(EXPM1, x_fn); 293 } 294 295 TEST_F(CWiseUnaryGradTest, Expm1_Complex) { 296 auto x_fn = [this](const int i) { 297 return CRV({{-1, 0}, {1, 0}, {1.5, -1.5}}); 298 }; 299 TestCWiseGrad<complex64, complex64>(EXPM1, x_fn); 300 } 301 302 TEST_F(CWiseUnaryGradTest, Log) { 303 auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4}); }; 304 TestCWiseGrad<float, float>(LOG, x_fn); 305 } 306 307 TEST_F(CWiseUnaryGradTest, Log_Complex) { 308 auto x_fn = [this](const int i) { 309 return CRV({{-1, 0.5f}, {1, 0.5f}, {2, -1}}); 310 }; 311 TestCWiseGrad<complex64, complex64>(LOG, x_fn); 312 } 313 314 TEST_F(CWiseUnaryGradTest, Log1p) { 315 auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); }; 316 TestCWiseGrad<float, float>(LOG1P, x_fn); 317 } 318 319 TEST_F(CWiseUnaryGradTest, Log1p_Complex) { 320 auto x_fn = [this](const int i) { 321 return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}}); 322 }; 323 TestCWiseGrad<complex64, complex64>(LOG1P, x_fn); 324 } 325 326 TEST_F(CWiseUnaryGradTest, Sinh) { 327 auto x_fn = [this](const int i) { return RV({0.5, -0.5, 1, -1, 1.5, -1.5}); }; 328 TestCWiseGrad<float, float>(SINH, x_fn); 329 } 330 331 TEST_F(CWiseUnaryGradTest, Sinh_Complex) { 332 auto x_fn = [this](const int i) { 333 return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}}); 334 }; 335 TestCWiseGrad<complex64, complex64>(SINH, x_fn); 336 } 337 338 TEST_F(CWiseUnaryGradTest, Cosh) { 339 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 340 TestCWiseGrad<float, float>(COSH, x_fn); 341 } 342 343 TEST_F(CWiseUnaryGradTest, Cosh_Complex) { 344 auto x_fn = [this](const int i) { 345 return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}}); 346 }; 347 TestCWiseGrad<complex64, complex64>(COSH, x_fn); 348 } 349 350 TEST_F(CWiseUnaryGradTest, Tanh) { 351 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 352 TestCWiseGrad<float, float>(TANH, x_fn); 353 } 354 355 TEST_F(CWiseUnaryGradTest, Tanh_Complex) { 356 auto x_fn = [this](const int i) { 357 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 358 }; 359 TestCWiseGrad<complex64, complex64>(TANH, x_fn); 360 } 361 362 TEST_F(CWiseUnaryGradTest, Asinh) { 363 auto x_fn = [this](const int i) { return RV({0.5, 1, -1, -1.5, 1.5}); }; 364 TestCWiseGrad<float, float>(ASINH, x_fn); 365 } 366 367 TEST_F(CWiseUnaryGradTest, Asinh_Complex) { 368 auto x_fn = [this](const int i) { 369 return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}}); 370 }; 371 TestCWiseGrad<complex64, complex64>(ASINH, x_fn); 372 } 373 374 TEST_F(CWiseUnaryGradTest, Acosh) { 375 auto x_fn = [this](const int i) { return RV({1.5, 2, 2.5}); }; 376 TestCWiseGrad<float, float>(ACOSH, x_fn); 377 } 378 379 TEST_F(CWiseUnaryGradTest, Acosh_Complex) { 380 auto x_fn = [this](const int i) { 381 return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}}); 382 }; 383 TestCWiseGrad<complex64, complex64>(ACOSH, x_fn); 384 } 385 386 TEST_F(CWiseUnaryGradTest, Atanh) { 387 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); }; 388 TestCWiseGrad<float, float>(ATANH, x_fn); 389 } 390 391 TEST_F(CWiseUnaryGradTest, Atanh_Complex) { 392 auto x_fn = [this](const int i) { 393 return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}}); 394 }; 395 TestCWiseGrad<complex64, complex64>(ATANH, x_fn); 396 } 397 398 TEST_F(CWiseUnaryGradTest, Sigmoid) { 399 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 400 TestCWiseGrad<float, float>(SIGMOID, x_fn); 401 } 402 403 TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) { 404 auto x_fn = [this](const int i) { 405 return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}}); 406 }; 407 TestCWiseGrad<complex64, complex64>(SIGMOID, x_fn); 408 } 409 410 TEST_F(CWiseUnaryGradTest, Sign) { 411 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3}); }; 412 TestCWiseGrad<float, float>(SIGN, x_fn); 413 } 414 415 TEST_F(CWiseUnaryGradTest, Sin) { 416 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 417 TestCWiseGrad<float, float>(SIN, x_fn); 418 } 419 420 TEST_F(CWiseUnaryGradTest, Sin_Complex) { 421 auto x_fn = [this](const int i) { 422 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}}); 423 }; 424 TestCWiseGrad<complex64, complex64>(SIN, x_fn); 425 } 426 427 TEST_F(CWiseUnaryGradTest, Cos) { 428 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 429 TestCWiseGrad<float, float>(COS, x_fn); 430 } 431 432 TEST_F(CWiseUnaryGradTest, Cos_Complex) { 433 auto x_fn = [this](const int i) { 434 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}}); 435 }; 436 TestCWiseGrad<complex64, complex64>(COS, x_fn); 437 } 438 439 TEST_F(CWiseUnaryGradTest, Asin) { 440 auto x_fn = [this](const int i) { return RV({0, 0.25, -0.25, -0.5, 0.5}); }; 441 TestCWiseGrad<float, float>(ASIN, x_fn); 442 } 443 444 TEST_F(CWiseUnaryGradTest, Asin_Complex) { 445 auto x_fn = [this](const int i) { 446 return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}}); 447 }; 448 // TODO(kbsriram) 449 // Enable test when the asin kernel supports complex numbers 450 if (false) { 451 TestCWiseGrad<complex64, complex64>(ASIN, x_fn); 452 } 453 } 454 455 TEST_F(CWiseUnaryGradTest, Acos) { 456 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.75, 0.75}); }; 457 TestCWiseGrad<float, float>(ACOS, x_fn); 458 } 459 460 TEST_F(CWiseUnaryGradTest, Acos_Complex) { 461 auto x_fn = [this](const int i) { 462 return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}}); 463 }; 464 // TODO(kbsriram) 465 // Add test when the acos kernel supports complex numbers 466 if (false) { 467 TestCWiseGrad<complex64, complex64>(ACOS, x_fn); 468 } 469 } 470 471 TEST_F(CWiseUnaryGradTest, Tan) { 472 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 473 TestCWiseGrad<float, float>(TAN, x_fn); 474 } 475 476 TEST_F(CWiseUnaryGradTest, Tan_Complex) { 477 auto x_fn = [this](const int i) { 478 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 479 }; 480 // TODO(kbsriram) 481 // Enable when tan kernel supports complex inputs 482 if (false) { 483 TestCWiseGrad<complex64, complex64>(TAN, x_fn); 484 } 485 } 486 487 TEST_F(CWiseUnaryGradTest, Atan) { 488 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 489 TestCWiseGrad<float, float>(ATAN, x_fn); 490 } 491 492 TEST_F(CWiseUnaryGradTest, Atan_Complex) { 493 auto x_fn = [this](const int i) { 494 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 495 }; 496 // TODO(kbsriram) 497 // Add test when the atan kernel supports complex numbers 498 if (false) { 499 TestCWiseGrad<complex64, complex64>(ATAN, x_fn); 500 } 501 } 502 503 TEST_F(CWiseUnaryGradTest, Real) { 504 auto x_fn = [this](const int i) { 505 return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}}); 506 }; 507 TestCWiseGrad<complex64, float>(REAL, x_fn); 508 } 509 510 TEST_F(CWiseUnaryGradTest, Imag) { 511 auto x_fn = [this](const int i) { 512 return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}}); 513 }; 514 TestCWiseGrad<complex64, float>(IMAG, x_fn); 515 } 516 517 TEST_F(CWiseUnaryGradTest, Conj) { 518 auto x_fn = [this](const int i) { 519 return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}}); 520 }; 521 TestCWiseGrad<complex64, complex64>(CONJ, x_fn); 522 } 523 524 TEST_F(CWiseUnaryGradTest, Complex) { 525 auto x_fn = [this](const int i) { return RV({1, -1, 2, -2, 3, -3}); }; 526 TestCWiseGrad<float, complex64>(COMPLEX, x_fn); 527 } 528 529 TEST_F(CWiseUnaryGradTest, Angle) { 530 auto x_fn = [this](const int i) { 531 return CRV({{1.5, 1.5}, {1.5, -1.5}, {-1.5, 1.5}, {-1.5, -1.5}}); 532 }; 533 TestCWiseGrad<complex64, float>(ANGLE, x_fn); 534 } 535 536 TEST_F(CWiseUnaryGradTest, Lgamma) { 537 auto x_fn = [this](const int i) { 538 return RV({-3.5, -2.5, -1.5, 1.0, 2.0, 3.5}); 539 }; 540 TestCWiseGrad<float, float>(LGAMMA, x_fn); 541 } 542 543 TEST_F(CWiseUnaryGradTest, Lgamma_Complex) { 544 auto x_fn = [this](const int i) { 545 return CRV({{-3.5, 0.5}, {-1.5, -0.5}, {1.5, -1.0}, {3.5, 1.0}}); 546 }; 547 // TODO(kbsriram) 548 // Add test when the lgamma kernel supports complex numbers 549 if (false) { 550 TestCWiseGrad<complex64, complex64>(LGAMMA, x_fn); 551 } 552 } 553 554 TEST_F(CWiseUnaryGradTest, Erf) { 555 auto x_fn = [this](const int i) { 556 return RV({-1.2, -1.0, -0.5, 0.3, 0.5, 1.3}); 557 }; 558 TestCWiseGrad<float, float>(ERF, x_fn); 559 } 560 561 TEST_F(CWiseUnaryGradTest, Erf_Complex) { 562 auto x_fn = [this](const int i) { 563 return CRV({{-1.2, 0.5}, {-0.5, -0.5}, {0.5, 0.5}, {1.2, -0.5}}); 564 }; 565 // TODO(kbsriram) 566 // Add test when the erf kernel supports complex numbers 567 if (false) { 568 TestCWiseGrad<complex64, complex64>(ERF, x_fn); 569 } 570 } 571 572 class MathGradTest : public ::testing::Test { 573 protected: 574 MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 575 576 template <typename T> 577 void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) { 578 TF_ASSERT_OK(root_.status()); 579 // Generate random (but compatible) shapes for matrix multiplication. 580 std::vector<TensorShape> shapes; 581 RandMatMulShapes(is_batch, t_x, t_y, &shapes); 582 TensorShape x_shape = shapes[0]; 583 TensorShape y_shape = shapes[1]; 584 TensorShape z_shape = shapes[2]; 585 auto x = 586 Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(x_shape)); 587 auto y = 588 Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(y_shape)); 589 Output z; 590 if (is_batch) { 591 z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); 592 } else { 593 z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); 594 } 595 596 float max_error; 597 TF_ASSERT_OK((ComputeGradientError<T, T, float>( 598 root_, {x, y}, {x_shape, y_shape}, {z}, {z_shape}, &max_error))); 599 EXPECT_LT(max_error, 1e-3); 600 } 601 602 void RandMatMulShapes(const bool is_batch, const bool tx, const bool ty, 603 std::vector<TensorShape>* shapes) { 604 // Choose a random batch size in [1, 4] 605 const int b = 1 + (random::New64() % 4); 606 // z = MatMul(x, y) 607 const int m = Rand(); 608 const int k = Rand(); 609 const int n = Rand(); 610 611 TensorShape x_shape; 612 if (is_batch) { 613 // x.shape = [b, m, k] 614 x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k}); 615 } else { 616 // x.shape = [m, k] 617 x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k}); 618 } 619 shapes->push_back(x_shape); 620 621 TensorShape y_shape; 622 if (is_batch) { 623 // y.shape = [b, k, n] 624 y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n}); 625 } else { 626 // y.shape = [k, n] 627 y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n}); 628 } 629 shapes->push_back(y_shape); 630 631 TensorShape z_shape; 632 if (is_batch) { 633 // z.shape = [b, m, n] 634 z_shape = TensorShape({b, m, n}); 635 } else { 636 // z.shape = [m, n] 637 z_shape = TensorShape({m, n}); 638 } 639 shapes->push_back(z_shape); 640 } 641 642 int Rand() { return 1 + (random::New64() % 10); } 643 644 Scope root_; 645 }; 646 647 TEST_F(MathGradTest, MatMulGrad_NoTranspose) { 648 TestMatMulGrad<float>(false, false, false); 649 } 650 651 TEST_F(MathGradTest, MatMulComplexGrad_NoTranspose) { 652 TestMatMulGrad<complex64>(false, false, false); 653 } 654 655 TEST_F(MathGradTest, MatMulGrad_TransposeX) { 656 TestMatMulGrad<float>(false, true, false); 657 } 658 659 TEST_F(MathGradTest, MatMulComplexGrad_TransposeX) { 660 TestMatMulGrad<complex64>(false, true, false); 661 } 662 663 TEST_F(MathGradTest, MatMulGrad_TransposeY) { 664 TestMatMulGrad<float>(false, false, true); 665 } 666 667 TEST_F(MathGradTest, MatMulComplexGrad_TransposeY) { 668 TestMatMulGrad<complex64>(false, false, true); 669 } 670 671 TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) { 672 TestMatMulGrad<float>(false, true, true); 673 } 674 675 TEST_F(MathGradTest, MatMulComplexGrad_TransposeX_TransposeY) { 676 TestMatMulGrad<complex64>(false, true, true); 677 } 678 679 TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) { 680 TestMatMulGrad<float>(true, false, false); 681 } 682 683 TEST_F(MathGradTest, BatchMatMulComplexGrad_NoTranspose) { 684 TestMatMulGrad<complex64>(true, false, false); 685 } 686 687 TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) { 688 TestMatMulGrad<float>(true, true, false); 689 } 690 691 TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX) { 692 TestMatMulGrad<complex64>(true, true, false); 693 } 694 695 TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) { 696 TestMatMulGrad<float>(true, false, true); 697 } 698 699 TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeY) { 700 TestMatMulGrad<complex64>(true, false, true); 701 } 702 703 TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) { 704 TestMatMulGrad<float>(true, true, true); 705 } 706 707 TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX_TransposeY) { 708 TestMatMulGrad<complex64>(true, true, true); 709 } 710 711 class NaryGradTest : public ::testing::Test { 712 protected: 713 NaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 714 715 void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes, 716 const OutputList& ys, const std::vector<TensorShape>& y_shapes) { 717 TF_ASSERT_OK(scope_.status()); 718 float max_error; 719 TF_ASSERT_OK((ComputeGradientError<float, float, float>( 720 scope_, xs, x_shapes, ys, y_shapes, &max_error))); 721 EXPECT_LT(max_error, 1e-3); 722 } 723 724 void RunTest(const Output& x, const Tensor& x_init_value, const Output& y, 725 const TensorShape& y_shape) { 726 TF_ASSERT_OK(scope_.status()); 727 float max_error; 728 TF_ASSERT_OK((ComputeGradientError<float, float, float>( 729 scope_, x, x_init_value, y, y_shape, &max_error))); 730 EXPECT_LT(max_error, 1e-3); 731 } 732 733 Scope scope_; 734 }; 735 736 TEST_F(NaryGradTest, Sum) { 737 TensorShape x_shape({2, 3, 5, 7}); 738 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 739 auto y = Sum(scope_, x, {1, -1}); 740 // y's shape is the result of reducing x along axes 1 and -1 (= 3) 741 TensorShape y_shape({2, 5}); 742 RunTest({x}, {x_shape}, {y}, {y_shape}); 743 } 744 745 TEST_F(NaryGradTest, Mean) { 746 TensorShape x_shape({2, 3, 5, 7}); 747 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 748 auto y = Mean(scope_, x, {1, -1}); 749 // y's shape is the result of reducing x along axes 1 and -1 (= 3) 750 TensorShape y_shape({2, 5}); 751 RunTest({x}, {x_shape}, {y}, {y_shape}); 752 } 753 754 TEST_F(NaryGradTest, Min) { 755 TensorShape x_shape({2, 3}); 756 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 757 auto y = Min(scope_, x, {-1}); 758 // y's shape is the result of reducing x along axes -1 (= 1) 759 TensorShape y_shape({2}); 760 Tensor x_init_value = 761 test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape); 762 RunTest(x, x_init_value, y, y_shape); 763 } 764 765 TEST_F(NaryGradTest, Max) { 766 TensorShape x_shape({2, 3}); 767 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 768 auto y = Max(scope_, x, {-1}); 769 // y's shape is the result of reducing x along axes -1 (= 1) 770 TensorShape y_shape({2}); 771 Tensor x_init_value = 772 test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape); 773 RunTest(x, x_init_value, y, y_shape); 774 } 775 776 TEST_F(NaryGradTest, MinMulti) { 777 // Test gradient when there are multiple minima. 778 // Note that we cannot directly use a test Tensor with multiple 779 // minima, as the numeric estimator will calculate incorrect 780 // gradients when perturbing each entry in the Tensor (which then 781 // changes how many minima exist.) 782 // Instead, we use a single input that broadcast-multiplies a larger 783 // tensor with equal values, and apply reduce_min to the multiplied 784 // result. 785 TensorShape x_shape({1}); 786 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 787 auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x); 788 auto y = Min(scope_, all_same, {0}); 789 // y is a [3] shaped tensor reduced along dimension 0, so it is [1] shaped 790 TensorShape y_shape({1}); 791 RunTest({x}, {x_shape}, {y}, {y_shape}); 792 } 793 794 TEST_F(NaryGradTest, MaxMulti) { 795 TensorShape x_shape({1}); 796 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 797 auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x); 798 auto y = Max(scope_, all_same, {0}); 799 TensorShape y_shape({1}); 800 RunTest({x}, {x_shape}, {y}, {y_shape}); 801 } 802 803 TEST_F(NaryGradTest, AddN) { 804 TensorShape shape({3, 2, 5}); 805 std::vector<Output> xs; 806 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape))); 807 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape))); 808 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape))); 809 auto y = AddN(scope_, xs); 810 RunTest(xs, {shape, shape, shape}, {y}, {shape}); 811 } 812 813 TEST_F(NaryGradTest, Add) { 814 TensorShape x1_shape({3, 2, 5}); 815 TensorShape x2_shape({2, 5}); 816 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 817 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 818 auto y = Add(scope_, x1, x2); 819 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 820 } 821 822 TEST_F(NaryGradTest, Sub) { 823 TensorShape x1_shape({3, 2, 5}); 824 TensorShape x2_shape({2, 5}); 825 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 826 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 827 auto y = Sub(scope_, x1, x2); 828 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 829 } 830 831 TEST_F(NaryGradTest, Mul) { 832 TensorShape x1_shape({3, 2, 5}); 833 TensorShape x2_shape({2, 5}); 834 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 835 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 836 auto y = Mul(scope_, x1, x2); 837 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 838 } 839 840 TEST_F(NaryGradTest, Div) { 841 TensorShape x_shape({3, 2, 5}); 842 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 843 // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large 844 // division errors in the numeric estimator used by the gradient checker. 845 auto y = Div(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x))); 846 RunTest({x}, {x_shape}, {y}, {x_shape}); 847 } 848 849 TEST_F(NaryGradTest, RealDiv) { 850 TensorShape x_shape({3, 2, 5}); 851 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 852 // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large 853 // division errors in the numeric estimator used by the gradient checker. 854 auto y = 855 RealDiv(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x))); 856 RunTest({x}, {x_shape}, {y}, {x_shape}); 857 } 858 859 TEST_F(NaryGradTest, SquaredDifference) { 860 TensorShape x1_shape({3, 2, 5}); 861 TensorShape x2_shape({2, 5}); 862 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 863 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 864 auto y = SquaredDifference(scope_, x1, x2); 865 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 866 } 867 868 TEST_F(NaryGradTest, Pow) { 869 TensorShape shape({3}); 870 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); 871 // fix exponent to avoid overflow 872 auto y = Pow(scope_, x, Const(scope_, {1.f, 2.f, 3.f})); 873 RunTest({x}, {shape}, {y}, {shape}); 874 } 875 876 TEST_F(NaryGradTest, Maximum) { 877 TensorShape shape({3, 2}); 878 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); 879 auto y = Maximum(scope_, x, Const(scope_, 1.0f)); 880 // Select values away from 1.0f to avoid instability when computing 881 // finite differences. 882 Tensor x_init_value = 883 test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2}); 884 RunTest(x, x_init_value, y, shape); 885 } 886 887 TEST_F(NaryGradTest, Minimum) { 888 TensorShape shape({3, 2}); 889 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); 890 auto y = Minimum(scope_, x, Const(scope_, 1.0f)); 891 // Select values away from 1.0f to avoid instability when computing 892 // finite differences. 893 Tensor x_init_value = 894 test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2}); 895 RunTest(x, x_init_value, y, shape); 896 } 897 898 TEST_F(NaryGradTest, Prod) { 899 TensorShape x_shape({2, 3, 2}); 900 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 901 auto y = Prod(scope_, x, {1}); 902 // y's shape is the result of reducing x along axes 1 903 TensorShape y_shape({2, 1, 2}); 904 RunTest({x}, {x_shape}, {y}, {y_shape}); 905 } 906 907 } // namespace 908 } // namespace tensorflow 909