Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <memory>
     17 #include <numeric>
     18 #include <vector>
     19 
     20 #include "tensorflow/compiler/xla/array2d.h"
     21 #include "tensorflow/compiler/xla/client/lib/math.h"
     22 #include "tensorflow/compiler/xla/client/lib/matrix.h"
     23 #include "tensorflow/compiler/xla/client/xla_builder.h"
     24 #include "tensorflow/compiler/xla/literal.h"
     25 #include "tensorflow/compiler/xla/statusor.h"
     26 #include "tensorflow/compiler/xla/test.h"
     27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     28 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     29 #include "tensorflow/compiler/xla/tests/test_macros.h"
     30 #include "tensorflow/compiler/xla/types.h"
     31 #include "tensorflow/core/lib/core/status_test_util.h"
     32 
     33 namespace xla {
     34 namespace {
     35 
     36 using TriangularSolveTest = ClientLibraryTestBase;
     37 using TriangularSolveLeftLookingTest = ClientLibraryTestBase;
     38 
     39 static constexpr float kNan = std::numeric_limits<float>::quiet_NaN();
     40 
     41 Array2D<float> AValsLower() {
     42   return {{2, kNan, kNan, kNan},
     43           {3, 6, kNan, kNan},
     44           {4, 7, 9, kNan},
     45           {5, 8, 10, 11}};
     46 }
     47 
     48 Array2D<float> AValsUpper() {
     49   return {{2, 3, 4, 5},
     50           {kNan, 6, 7, 8},
     51           {kNan, kNan, 9, 10},
     52           {kNan, kNan, kNan, 11}};
     53 }
     54 
     55 Array2D<float> AValsLowerUnitDiagonal() {
     56   return {{kNan, kNan, kNan, kNan},
     57           {3, kNan, kNan, kNan},
     58           {4, 7, kNan, kNan},
     59           {5, 8, 10, kNan}};
     60 }
     61 
     62 Array2D<float> AValsUpperUnitDiagonal() {
     63   return {{kNan, 3, 4, 5},
     64           {kNan, kNan, 7, 8},
     65           {kNan, kNan, kNan, 10},
     66           {kNan, kNan, kNan, kNan}};
     67 }
     68 
     69 Array2D<float> BValsRight() {
     70   return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
     71 }
     72 
     73 Array2D<float> BValsLeft() {
     74   return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
     75 }
     76 
     77 static constexpr complex64 kNanC64 = complex64(kNan, kNan);
     78 
     79 Array2D<complex64> AValsLowerComplex() {
     80   return {{2, kNanC64, kNanC64, kNanC64},
     81           {complex64(3, 1), 6, kNanC64, kNanC64},
     82           {4, complex64(7, 2), 9, kNanC64},
     83           {5, 8, complex64(10, 3), 11}};
     84 }
     85 
     86 Array2D<complex64> AValsUpperComplex() {
     87   return {{2, 3, complex64(4, 3), 5},
     88           {kNanC64, 6, complex64(7, 2), 8},
     89           {kNanC64, kNanC64, complex64(9, 1), 10},
     90           {kNanC64, kNanC64, kNanC64, 11}};
     91 }
     92 
     93 Array2D<complex64> BValsRightComplex() {
     94   return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
     95 }
     96 
     97 Array2D<complex64> BValsLeftComplex() {
     98   return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
     99 }
    100 
    101 XLA_TEST_F(TriangularSolveTest, EmptyArrays) {
    102   XlaBuilder builder(TestName());
    103 
    104   XlaOp a, b;
    105   auto a_data =
    106       CreateR2Parameter<float>(Array2D<float>(0, 0), 0, "a", &builder, &a);
    107   auto b_data =
    108       CreateR2Parameter<float>(Array2D<float>(0, 10), 1, "b", &builder, &b);
    109   TriangularSolve(a, b,
    110                   /*left_side=*/true, /*lower=*/true,
    111                   /*unit_diagonal=*/false,
    112                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
    113 
    114   ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 10),
    115                              {a_data.get(), b_data.get()});
    116 }
    117 
    118 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) {
    119   XlaBuilder builder(TestName());
    120 
    121   XlaOp a, b;
    122   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
    123   auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
    124   TriangularSolve(a, b,
    125                   /*left_side=*/false, /*lower=*/true,
    126                   /*unit_diagonal=*/false,
    127                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
    128 
    129   Array2D<float> expected({
    130       {0.5, 0.08333334, 0.04629629, 0.03367003},
    131       {2.5, -0.25, -0.1388889, -0.1010101},
    132       {4.5, -0.58333331, -0.32407406, -0.23569024},
    133   });
    134 
    135   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    136                              ErrorSpec(1e-2, 1e-2));
    137 }
    138 
    139 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) {
    140   XlaBuilder builder(TestName());
    141 
    142   XlaOp a, b;
    143   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
    144   auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
    145   TriangularSolve(a, b,
    146                   /*left_side=*/false, /*lower=*/true,
    147                   /*unit_diagonal=*/false,
    148                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
    149 
    150   Array2D<float> expected({
    151       {-0.16414141, -0.06902357, -0.07070707, 0.36363636},
    152       {0.64393939, 0.06565657, -0.03030303, 0.72727273},
    153       {1.4520202, 0.2003367, 0.01010101, 1.09090909},
    154   });
    155 
    156   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    157                              ErrorSpec(1e-2, 1e-2));
    158 }
    159 
    160 XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) {
    161   XlaBuilder builder(TestName());
    162 
    163   XlaOp a, b;
    164   auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
    165   auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
    166   TriangularSolve(a, b,
    167                   /*left_side=*/false, /*lower=*/false,
    168                   /*unit_diagonal=*/false,
    169                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
    170 
    171   Array2D<float> expected({
    172       {-0.16414141, -0.06902357, -0.07070707, 0.36363636},
    173       {0.64393939, 0.06565657, -0.03030303, 0.72727273},
    174       {1.4520202, 0.2003367, 0.01010101, 1.09090909},
    175   });
    176 
    177   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    178                              ErrorSpec(1e-2, 1e-2));
    179 }
    180 
    181 XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) {
    182   XlaBuilder builder(TestName());
    183 
    184   XlaOp a, b;
    185   auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
    186   auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
    187   TriangularSolve(a, b,
    188                   /*left_side=*/false, /*lower=*/false,
    189                   /*unit_diagonal=*/false,
    190                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
    191 
    192   Array2D<float> expected({
    193       {0.5, 0.08333334, 0.04629629, 0.03367003},
    194       {2.5, -0.25, -0.1388889, -0.1010101},
    195       {4.5, -0.58333331, -0.32407406, -0.23569024},
    196   });
    197 
    198   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    199                              ErrorSpec(1e-2, 1e-2));
    200 }
    201 
    202 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) {
    203   XlaBuilder builder(TestName());
    204 
    205   XlaOp a, b;
    206   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
    207   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
    208   TriangularSolve(a, b,
    209                   /*left_side=*/true, /*lower=*/true,
    210                   /*unit_diagonal=*/false,
    211                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
    212 
    213   Array2D<float> expected({
    214       {-0.89646465, -0.69444444, -0.49242424},
    215       {-0.27441077, -0.24074074, -0.20707071},
    216       {-0.23232323, -0.22222222, -0.21212121},
    217       {0.90909091, 1., 1.09090909},
    218   });
    219 
    220   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    221                              ErrorSpec(1e-2, 1e-2));
    222 }
    223 
    224 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) {
    225   XlaBuilder builder(TestName());
    226 
    227   XlaOp a, b;
    228   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
    229   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
    230   TriangularSolve(a, b,
    231                   /*left_side=*/true, /*lower=*/true,
    232                   /*unit_diagonal=*/false,
    233                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
    234 
    235   Array2D<float> expected({
    236       {0.5, 1.0, 1.5},
    237       {0.41666667, 0.33333333, 0.25},
    238       {0.23148148, 0.18518519, 0.13888889},
    239       {0.16835017, 0.13468013, 0.1010101},
    240   });
    241 
    242   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    243                              ErrorSpec(1e-2, 1e-2));
    244 }
    245 
    246 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNoTransposeUnitDiagonal) {
    247   XlaBuilder builder(TestName());
    248 
    249   XlaOp a, b;
    250   auto a_data =
    251       CreateR2Parameter<float>(AValsLowerUnitDiagonal(), 0, "a", &builder, &a);
    252   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
    253   TriangularSolve(a, b,
    254                   /*left_side=*/true, /*lower=*/true,
    255                   /*unit_diagonal=*/true,
    256                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
    257 
    258   Array2D<float> expected(
    259       {{1., 2., 3.}, {1., -1., -3.}, {-4., 7., 18.}, {37., -61., -159.}});
    260 
    261   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    262                              ErrorSpec(1e-2, 1e-2));
    263 }
    264 
    265 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) {
    266   XlaBuilder builder(TestName());
    267 
    268   XlaOp a, b;
    269   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
    270   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
    271   TriangularSolve(a, b,
    272                   /*left_side=*/true, /*lower=*/true,
    273                   /*unit_diagonal=*/false,
    274                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
    275 
    276   Array2D<float> expected({
    277       {0.5, 1.0, 1.5},
    278       {0.41666667, 0.33333333, 0.25},
    279       {0.23148148, 0.18518519, 0.13888889},
    280       {0.16835017, 0.13468013, 0.1010101},
    281   });
    282 
    283   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    284                              ErrorSpec(1e-2, 1e-2));
    285 }
    286 
    287 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) {
    288   XlaBuilder builder(TestName());
    289 
    290   XlaOp a, b;
    291   auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
    292   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
    293   TriangularSolve(a, b,
    294                   /*left_side=*/true, /*lower=*/false,
    295                   /*unit_diagonal=*/false,
    296                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
    297 
    298   Array2D<float> expected({
    299       {0.5, 1.0, 1.5},
    300       {0.41666667, 0.33333333, 0.25},
    301       {0.23148148, 0.18518519, 0.13888889},
    302       {0.16835017, 0.13468013, 0.1010101},
    303   });
    304 
    305   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    306                              ErrorSpec(1e-2, 1e-2));
    307 }
    308 
    309 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) {
    310   XlaBuilder builder(TestName());
    311 
    312   XlaOp a, b;
    313   auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
    314   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
    315   TriangularSolve(a, b,
    316                   /*left_side=*/true, /*lower=*/false,
    317                   /*unit_diagonal=*/false,
    318                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
    319 
    320   Array2D<float> expected({
    321       {-0.89646465, -0.69444444, -0.49242424},
    322       {-0.27441077, -0.24074074, -0.20707071},
    323       {-0.23232323, -0.22222222, -0.21212121},
    324       {0.90909091, 1., 1.09090909},
    325   });
    326 
    327   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    328                              ErrorSpec(1e-2, 1e-2));
    329 }
    330 
    331 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotransposeUnitDiagonal) {
    332   XlaBuilder builder(TestName());
    333 
    334   XlaOp a, b;
    335   auto a_data =
    336       CreateR2Parameter<float>(AValsUpperUnitDiagonal(), 0, "a", &builder, &a);
    337   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
    338   TriangularSolve(a, b,
    339                   /*left_side=*/true, /*lower=*/false,
    340                   /*unit_diagonal=*/true,
    341                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
    342 
    343   Array2D<float> expected({{-1402., -1538., -1674.},
    344                            {575., 631., 687.},
    345                            {-93., -102., -111.},
    346                            {10., 11., 12.}});
    347 
    348   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    349                              ErrorSpec(1e-2, 1e-2));
    350 }
    351 
    352 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) {
    353   XlaBuilder builder(TestName());
    354 
    355   XlaOp a, b;
    356   auto a_data =
    357       CreateR2Parameter<complex64>(AValsLowerComplex(), 0, "a", &builder, &a);
    358   auto b_data =
    359       CreateR2Parameter<complex64>(BValsRightComplex(), 1, "b", &builder, &b);
    360   TriangularSolve(a, b,
    361                   /*left_side=*/false, /*lower=*/true,
    362                   /*unit_diagonal=*/false,
    363                   /*transpose_a=*/TriangularSolveOptions::ADJOINT);
    364 
    365   Array2D<complex64> expected({
    366       {0.5, complex64(0.08333333, 0.08333333),
    367        complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)},
    368       {2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963),
    369        complex64(0.08670034, -0.02104377)},
    370       {4.5, complex64(-0.58333333, 0.75), complex64(-0.49074074, -0.71296296),
    371        complex64(0.11026936, -0.03114478)},
    372   });
    373 
    374   ComputeAndCompareR2<complex64>(
    375       &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2));
    376 }
    377 
    378 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
    379   XlaBuilder builder(TestName());
    380 
    381   XlaOp a, b;
    382   auto a_data =
    383       CreateR2Parameter<complex64>(AValsUpperComplex(), 0, "a", &builder, &a);
    384   auto b_data =
    385       CreateR2Parameter<complex64>(BValsLeftComplex(), 1, "b", &builder, &b);
    386   TriangularSolve(a, b,
    387                   /*left_side=*/true, /*lower=*/false,
    388                   /*unit_diagonal=*/false,
    389                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
    390 
    391   Array2D<complex64> expected({
    392       {0.5, 1., 1.5},
    393       {0.41666667, 0.33333333, 0.25},
    394       {complex64(0.20020325, -2.81504065e-01),
    395        complex64(0.13821138, -4.22764228e-01),
    396        complex64(0.07621951, -5.64024390e-01)},
    397       {complex64(0.19678492, 2.55912786e-01),
    398        complex64(0.17738359, 3.84331116e-01),
    399        complex64(0.15798226, 5.12749446e-01)},
    400   });
    401 
    402   ComputeAndCompareR2<complex64>(
    403       &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2));
    404 }
    405 
    406 XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) {
    407   XlaBuilder builder(TestName());
    408 
    409   Array3D<float> bvals(7, 5, 5);
    410   bvals.FillIota(1.);
    411 
    412   // Set avals to the upper triangle of bvals.
    413   Array3D<float> avals = bvals;
    414   avals.Each([](absl::Span<const int64> indices, float* value) {
    415     if (indices[1] > indices[2]) {
    416       *value = 0;
    417     }
    418   });
    419 
    420   XlaOp a, b;
    421   auto a_data = CreateR3Parameter<float>(avals, 0, "a", &builder, &a);
    422   auto b_data = CreateR3Parameter<float>(bvals, 1, "b", &builder, &b);
    423   BatchDot(
    424       ConstantR3FromArray3D(&builder, avals),
    425       TriangularSolve(a, b,
    426                       /*left_side=*/true, /*lower=*/false,
    427                       /*unit_diagonal=*/false,
    428                       /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE));
    429 
    430   ComputeAndCompareR3<float>(&builder, bvals, {a_data.get(), b_data.get()},
    431                              ErrorSpec(1e-2, 1e-2));
    432 }
    433 
    434 struct TriangularSolveTestSpec {
    435   int m, n;  // A is mxm, B is mxn
    436   bool left_side;
    437   bool lower;
    438   TriangularSolveOptions::Transpose transpose_a;
    439 };
    440 
    441 class TriangularSolveParametricTest
    442     : public ClientLibraryTestBase,
    443       public ::testing::WithParamInterface<TriangularSolveTestSpec> {};
    444 
    445 XLA_TEST_P(TriangularSolveParametricTest, Random) {
    446   TriangularSolveTestSpec spec = GetParam();
    447 
    448   XlaBuilder builder(TestName());
    449 
    450   Array2D<float> avals(spec.m, spec.m);
    451   avals.FillRandom(1.0);
    452   for (int i = 0; i < spec.m; ++i) {
    453     avals(i, i) += 10;
    454   }
    455 
    456   std::pair<int, int> bdims = spec.left_side ? std::make_pair(spec.m, spec.n)
    457                                              : std::make_pair(spec.n, spec.m);
    458   Array2D<float> bvals(bdims.first, bdims.second);
    459   bvals.FillRandom(1.0);
    460 
    461   XlaOp a, b;
    462   auto a_data = CreateR2Parameter<float>(avals, 0, "a", &builder, &a);
    463   auto b_data = CreateR2Parameter<float>(bvals, 1, "b", &builder, &b);
    464   auto x = TriangularSolve(a, b, spec.left_side, spec.lower,
    465                            /*unit_diagonal=*/false, spec.transpose_a);
    466   auto a_tri = Triangle(a, spec.lower);
    467   a_tri = MaybeTransposeInMinorDims(
    468       a_tri, spec.transpose_a != TriangularSolveOptions::NO_TRANSPOSE);
    469   if (spec.left_side) {
    470     BatchDot(a_tri, x);
    471   } else {
    472     BatchDot(x, a_tri);
    473   }
    474 
    475   ComputeAndCompareR2<float>(&builder, bvals, {a_data.get(), b_data.get()},
    476                              ErrorSpec(1e-2, 1e-2));
    477 }
    478 
    479 std::vector<TriangularSolveTestSpec> TriangularSolveTests() {
    480   std::vector<TriangularSolveTestSpec> specs;
    481   for (int m : {5, 10}) {
    482     for (int n : {5, 10}) {
    483       for (bool left_side : {false, true}) {
    484         for (bool lower : {false, true}) {
    485           for (TriangularSolveOptions::Transpose transpose_a :
    486                {TriangularSolveOptions::NO_TRANSPOSE,
    487                 TriangularSolveOptions::TRANSPOSE}) {
    488             specs.push_back({m, n, left_side, lower, transpose_a});
    489           }
    490         }
    491       }
    492     }
    493   }
    494   return specs;
    495 }
    496 
    497 INSTANTIATE_TEST_SUITE_P(TriangularSolveParametricTestInstantiation,
    498                          TriangularSolveParametricTest,
    499                          ::testing::ValuesIn(TriangularSolveTests()));
    500 
    501 }  // namespace
    502 }  // namespace xla
    503