1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <memory> 17 #include <numeric> 18 #include <vector> 19 20 #include "tensorflow/compiler/xla/array2d.h" 21 #include "tensorflow/compiler/xla/client/lib/math.h" 22 #include "tensorflow/compiler/xla/client/lib/matrix.h" 23 #include "tensorflow/compiler/xla/client/xla_builder.h" 24 #include "tensorflow/compiler/xla/literal.h" 25 #include "tensorflow/compiler/xla/statusor.h" 26 #include "tensorflow/compiler/xla/test.h" 27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 28 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 29 #include "tensorflow/compiler/xla/tests/test_macros.h" 30 #include "tensorflow/compiler/xla/types.h" 31 #include "tensorflow/core/lib/core/status_test_util.h" 32 33 namespace xla { 34 namespace { 35 36 using TriangularSolveTest = ClientLibraryTestBase; 37 using TriangularSolveLeftLookingTest = ClientLibraryTestBase; 38 39 static constexpr float kNan = std::numeric_limits<float>::quiet_NaN(); 40 41 Array2D<float> AValsLower() { 42 return {{2, kNan, kNan, kNan}, 43 {3, 6, kNan, kNan}, 44 {4, 7, 9, kNan}, 45 {5, 8, 10, 11}}; 46 } 47 48 Array2D<float> AValsUpper() { 49 return {{2, 3, 4, 5}, 50 {kNan, 6, 7, 8}, 51 {kNan, kNan, 9, 10}, 52 {kNan, kNan, kNan, 11}}; 53 } 54 55 Array2D<float> AValsLowerUnitDiagonal() { 56 return {{kNan, kNan, kNan, kNan}, 57 {3, kNan, kNan, kNan}, 58 {4, 7, kNan, kNan}, 59 {5, 8, 10, kNan}}; 60 } 61 62 Array2D<float> AValsUpperUnitDiagonal() { 63 return {{kNan, 3, 4, 5}, 64 {kNan, kNan, 7, 8}, 65 {kNan, kNan, kNan, 10}, 66 {kNan, kNan, kNan, kNan}}; 67 } 68 69 Array2D<float> BValsRight() { 70 return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; 71 } 72 73 Array2D<float> BValsLeft() { 74 return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; 75 } 76 77 static constexpr complex64 kNanC64 = complex64(kNan, kNan); 78 79 Array2D<complex64> AValsLowerComplex() { 80 return {{2, kNanC64, kNanC64, kNanC64}, 81 {complex64(3, 1), 6, kNanC64, kNanC64}, 82 {4, complex64(7, 2), 9, kNanC64}, 83 {5, 8, complex64(10, 3), 11}}; 84 } 85 86 Array2D<complex64> AValsUpperComplex() { 87 return {{2, 3, complex64(4, 3), 5}, 88 {kNanC64, 6, complex64(7, 2), 8}, 89 {kNanC64, kNanC64, complex64(9, 1), 10}, 90 {kNanC64, kNanC64, kNanC64, 11}}; 91 } 92 93 Array2D<complex64> BValsRightComplex() { 94 return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; 95 } 96 97 Array2D<complex64> BValsLeftComplex() { 98 return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; 99 } 100 101 XLA_TEST_F(TriangularSolveTest, EmptyArrays) { 102 XlaBuilder builder(TestName()); 103 104 XlaOp a, b; 105 auto a_data = 106 CreateR2Parameter<float>(Array2D<float>(0, 0), 0, "a", &builder, &a); 107 auto b_data = 108 CreateR2Parameter<float>(Array2D<float>(0, 10), 1, "b", &builder, &b); 109 TriangularSolve(a, b, 110 /*left_side=*/true, /*lower=*/true, 111 /*unit_diagonal=*/false, 112 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); 113 114 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 10), 115 {a_data.get(), b_data.get()}); 116 } 117 118 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { 119 XlaBuilder builder(TestName()); 120 121 XlaOp a, b; 122 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a); 123 auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b); 124 TriangularSolve(a, b, 125 /*left_side=*/false, /*lower=*/true, 126 /*unit_diagonal=*/false, 127 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); 128 129 Array2D<float> expected({ 130 {0.5, 0.08333334, 0.04629629, 0.03367003}, 131 {2.5, -0.25, -0.1388889, -0.1010101}, 132 {4.5, -0.58333331, -0.32407406, -0.23569024}, 133 }); 134 135 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 136 ErrorSpec(1e-2, 1e-2)); 137 } 138 139 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { 140 XlaBuilder builder(TestName()); 141 142 XlaOp a, b; 143 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a); 144 auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b); 145 TriangularSolve(a, b, 146 /*left_side=*/false, /*lower=*/true, 147 /*unit_diagonal=*/false, 148 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); 149 150 Array2D<float> expected({ 151 {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, 152 {0.64393939, 0.06565657, -0.03030303, 0.72727273}, 153 {1.4520202, 0.2003367, 0.01010101, 1.09090909}, 154 }); 155 156 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 157 ErrorSpec(1e-2, 1e-2)); 158 } 159 160 XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { 161 XlaBuilder builder(TestName()); 162 163 XlaOp a, b; 164 auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a); 165 auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b); 166 TriangularSolve(a, b, 167 /*left_side=*/false, /*lower=*/false, 168 /*unit_diagonal=*/false, 169 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); 170 171 Array2D<float> expected({ 172 {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, 173 {0.64393939, 0.06565657, -0.03030303, 0.72727273}, 174 {1.4520202, 0.2003367, 0.01010101, 1.09090909}, 175 }); 176 177 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 178 ErrorSpec(1e-2, 1e-2)); 179 } 180 181 XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { 182 XlaBuilder builder(TestName()); 183 184 XlaOp a, b; 185 auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a); 186 auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b); 187 TriangularSolve(a, b, 188 /*left_side=*/false, /*lower=*/false, 189 /*unit_diagonal=*/false, 190 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); 191 192 Array2D<float> expected({ 193 {0.5, 0.08333334, 0.04629629, 0.03367003}, 194 {2.5, -0.25, -0.1388889, -0.1010101}, 195 {4.5, -0.58333331, -0.32407406, -0.23569024}, 196 }); 197 198 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 199 ErrorSpec(1e-2, 1e-2)); 200 } 201 202 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { 203 XlaBuilder builder(TestName()); 204 205 XlaOp a, b; 206 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a); 207 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); 208 TriangularSolve(a, b, 209 /*left_side=*/true, /*lower=*/true, 210 /*unit_diagonal=*/false, 211 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); 212 213 Array2D<float> expected({ 214 {-0.89646465, -0.69444444, -0.49242424}, 215 {-0.27441077, -0.24074074, -0.20707071}, 216 {-0.23232323, -0.22222222, -0.21212121}, 217 {0.90909091, 1., 1.09090909}, 218 }); 219 220 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 221 ErrorSpec(1e-2, 1e-2)); 222 } 223 224 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { 225 XlaBuilder builder(TestName()); 226 227 XlaOp a, b; 228 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a); 229 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); 230 TriangularSolve(a, b, 231 /*left_side=*/true, /*lower=*/true, 232 /*unit_diagonal=*/false, 233 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); 234 235 Array2D<float> expected({ 236 {0.5, 1.0, 1.5}, 237 {0.41666667, 0.33333333, 0.25}, 238 {0.23148148, 0.18518519, 0.13888889}, 239 {0.16835017, 0.13468013, 0.1010101}, 240 }); 241 242 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 243 ErrorSpec(1e-2, 1e-2)); 244 } 245 246 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNoTransposeUnitDiagonal) { 247 XlaBuilder builder(TestName()); 248 249 XlaOp a, b; 250 auto a_data = 251 CreateR2Parameter<float>(AValsLowerUnitDiagonal(), 0, "a", &builder, &a); 252 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); 253 TriangularSolve(a, b, 254 /*left_side=*/true, /*lower=*/true, 255 /*unit_diagonal=*/true, 256 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); 257 258 Array2D<float> expected( 259 {{1., 2., 3.}, {1., -1., -3.}, {-4., 7., 18.}, {37., -61., -159.}}); 260 261 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 262 ErrorSpec(1e-2, 1e-2)); 263 } 264 265 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { 266 XlaBuilder builder(TestName()); 267 268 XlaOp a, b; 269 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a); 270 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); 271 TriangularSolve(a, b, 272 /*left_side=*/true, /*lower=*/true, 273 /*unit_diagonal=*/false, 274 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); 275 276 Array2D<float> expected({ 277 {0.5, 1.0, 1.5}, 278 {0.41666667, 0.33333333, 0.25}, 279 {0.23148148, 0.18518519, 0.13888889}, 280 {0.16835017, 0.13468013, 0.1010101}, 281 }); 282 283 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 284 ErrorSpec(1e-2, 1e-2)); 285 } 286 287 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { 288 XlaBuilder builder(TestName()); 289 290 XlaOp a, b; 291 auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a); 292 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); 293 TriangularSolve(a, b, 294 /*left_side=*/true, /*lower=*/false, 295 /*unit_diagonal=*/false, 296 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); 297 298 Array2D<float> expected({ 299 {0.5, 1.0, 1.5}, 300 {0.41666667, 0.33333333, 0.25}, 301 {0.23148148, 0.18518519, 0.13888889}, 302 {0.16835017, 0.13468013, 0.1010101}, 303 }); 304 305 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 306 ErrorSpec(1e-2, 1e-2)); 307 } 308 309 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { 310 XlaBuilder builder(TestName()); 311 312 XlaOp a, b; 313 auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a); 314 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); 315 TriangularSolve(a, b, 316 /*left_side=*/true, /*lower=*/false, 317 /*unit_diagonal=*/false, 318 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); 319 320 Array2D<float> expected({ 321 {-0.89646465, -0.69444444, -0.49242424}, 322 {-0.27441077, -0.24074074, -0.20707071}, 323 {-0.23232323, -0.22222222, -0.21212121}, 324 {0.90909091, 1., 1.09090909}, 325 }); 326 327 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 328 ErrorSpec(1e-2, 1e-2)); 329 } 330 331 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotransposeUnitDiagonal) { 332 XlaBuilder builder(TestName()); 333 334 XlaOp a, b; 335 auto a_data = 336 CreateR2Parameter<float>(AValsUpperUnitDiagonal(), 0, "a", &builder, &a); 337 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); 338 TriangularSolve(a, b, 339 /*left_side=*/true, /*lower=*/false, 340 /*unit_diagonal=*/true, 341 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); 342 343 Array2D<float> expected({{-1402., -1538., -1674.}, 344 {575., 631., 687.}, 345 {-93., -102., -111.}, 346 {10., 11., 12.}}); 347 348 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 349 ErrorSpec(1e-2, 1e-2)); 350 } 351 352 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { 353 XlaBuilder builder(TestName()); 354 355 XlaOp a, b; 356 auto a_data = 357 CreateR2Parameter<complex64>(AValsLowerComplex(), 0, "a", &builder, &a); 358 auto b_data = 359 CreateR2Parameter<complex64>(BValsRightComplex(), 1, "b", &builder, &b); 360 TriangularSolve(a, b, 361 /*left_side=*/false, /*lower=*/true, 362 /*unit_diagonal=*/false, 363 /*transpose_a=*/TriangularSolveOptions::ADJOINT); 364 365 Array2D<complex64> expected({ 366 {0.5, complex64(0.08333333, 0.08333333), 367 complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)}, 368 {2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963), 369 complex64(0.08670034, -0.02104377)}, 370 {4.5, complex64(-0.58333333, 0.75), complex64(-0.49074074, -0.71296296), 371 complex64(0.11026936, -0.03114478)}, 372 }); 373 374 ComputeAndCompareR2<complex64>( 375 &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2)); 376 } 377 378 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { 379 XlaBuilder builder(TestName()); 380 381 XlaOp a, b; 382 auto a_data = 383 CreateR2Parameter<complex64>(AValsUpperComplex(), 0, "a", &builder, &a); 384 auto b_data = 385 CreateR2Parameter<complex64>(BValsLeftComplex(), 1, "b", &builder, &b); 386 TriangularSolve(a, b, 387 /*left_side=*/true, /*lower=*/false, 388 /*unit_diagonal=*/false, 389 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); 390 391 Array2D<complex64> expected({ 392 {0.5, 1., 1.5}, 393 {0.41666667, 0.33333333, 0.25}, 394 {complex64(0.20020325, -2.81504065e-01), 395 complex64(0.13821138, -4.22764228e-01), 396 complex64(0.07621951, -5.64024390e-01)}, 397 {complex64(0.19678492, 2.55912786e-01), 398 complex64(0.17738359, 3.84331116e-01), 399 complex64(0.15798226, 5.12749446e-01)}, 400 }); 401 402 ComputeAndCompareR2<complex64>( 403 &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2)); 404 } 405 406 XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) { 407 XlaBuilder builder(TestName()); 408 409 Array3D<float> bvals(7, 5, 5); 410 bvals.FillIota(1.); 411 412 // Set avals to the upper triangle of bvals. 413 Array3D<float> avals = bvals; 414 avals.Each([](absl::Span<const int64> indices, float* value) { 415 if (indices[1] > indices[2]) { 416 *value = 0; 417 } 418 }); 419 420 XlaOp a, b; 421 auto a_data = CreateR3Parameter<float>(avals, 0, "a", &builder, &a); 422 auto b_data = CreateR3Parameter<float>(bvals, 1, "b", &builder, &b); 423 BatchDot( 424 ConstantR3FromArray3D(&builder, avals), 425 TriangularSolve(a, b, 426 /*left_side=*/true, /*lower=*/false, 427 /*unit_diagonal=*/false, 428 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE)); 429 430 ComputeAndCompareR3<float>(&builder, bvals, {a_data.get(), b_data.get()}, 431 ErrorSpec(1e-2, 1e-2)); 432 } 433 434 struct TriangularSolveTestSpec { 435 int m, n; // A is mxm, B is mxn 436 bool left_side; 437 bool lower; 438 TriangularSolveOptions::Transpose transpose_a; 439 }; 440 441 class TriangularSolveParametricTest 442 : public ClientLibraryTestBase, 443 public ::testing::WithParamInterface<TriangularSolveTestSpec> {}; 444 445 XLA_TEST_P(TriangularSolveParametricTest, Random) { 446 TriangularSolveTestSpec spec = GetParam(); 447 448 XlaBuilder builder(TestName()); 449 450 Array2D<float> avals(spec.m, spec.m); 451 avals.FillRandom(1.0); 452 for (int i = 0; i < spec.m; ++i) { 453 avals(i, i) += 10; 454 } 455 456 std::pair<int, int> bdims = spec.left_side ? std::make_pair(spec.m, spec.n) 457 : std::make_pair(spec.n, spec.m); 458 Array2D<float> bvals(bdims.first, bdims.second); 459 bvals.FillRandom(1.0); 460 461 XlaOp a, b; 462 auto a_data = CreateR2Parameter<float>(avals, 0, "a", &builder, &a); 463 auto b_data = CreateR2Parameter<float>(bvals, 1, "b", &builder, &b); 464 auto x = TriangularSolve(a, b, spec.left_side, spec.lower, 465 /*unit_diagonal=*/false, spec.transpose_a); 466 auto a_tri = Triangle(a, spec.lower); 467 a_tri = MaybeTransposeInMinorDims( 468 a_tri, spec.transpose_a != TriangularSolveOptions::NO_TRANSPOSE); 469 if (spec.left_side) { 470 BatchDot(a_tri, x); 471 } else { 472 BatchDot(x, a_tri); 473 } 474 475 ComputeAndCompareR2<float>(&builder, bvals, {a_data.get(), b_data.get()}, 476 ErrorSpec(1e-2, 1e-2)); 477 } 478 479 std::vector<TriangularSolveTestSpec> TriangularSolveTests() { 480 std::vector<TriangularSolveTestSpec> specs; 481 for (int m : {5, 10}) { 482 for (int n : {5, 10}) { 483 for (bool left_side : {false, true}) { 484 for (bool lower : {false, true}) { 485 for (TriangularSolveOptions::Transpose transpose_a : 486 {TriangularSolveOptions::NO_TRANSPOSE, 487 TriangularSolveOptions::TRANSPOSE}) { 488 specs.push_back({m, n, left_side, lower, transpose_a}); 489 } 490 } 491 } 492 } 493 } 494 return specs; 495 } 496 497 INSTANTIATE_TEST_SUITE_P(TriangularSolveParametricTestInstantiation, 498 TriangularSolveParametricTest, 499 ::testing::ValuesIn(TriangularSolveTests())); 500 501 } // namespace 502 } // namespace xla 503