Home | History | Annotate | Download | only in lib
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/lib/triangular_solve.h"
     17 
     18 #include <memory>
     19 #include <numeric>
     20 #include <vector>
     21 
     22 #include "tensorflow/compiler/xla/array2d.h"
     23 #include "tensorflow/compiler/xla/client/computation_builder.h"
     24 #include "tensorflow/compiler/xla/literal_util.h"
     25 #include "tensorflow/compiler/xla/statusor.h"
     26 #include "tensorflow/compiler/xla/test.h"
     27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     28 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     29 #include "tensorflow/compiler/xla/tests/test_macros.h"
     30 #include "tensorflow/compiler/xla/types.h"
     31 #include "tensorflow/core/lib/core/status_test_util.h"
     32 
     33 namespace tensorflow {
     34 namespace {
     35 
     36 using TriangularSolveTest = xla::ClientLibraryTestBase;
     37 using TriangularSolveLeftLookingTest = xla::ClientLibraryTestBase;
     38 using complex64 = xla::complex64;
     39 
     40 xla::Array2D<float> AValsLower() {
     41   return {{2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, {5, 8, 10, 11}};
     42 }
     43 
     44 xla::Array2D<float> AValsUpper() {
     45   return {{2, 3, 4, 5}, {0, 6, 7, 8}, {0, 0, 9, 10}, {0, 0, 0, 11}};
     46 }
     47 
     48 xla::Array2D<float> BValsRight() {
     49   return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
     50 }
     51 
     52 xla::Array2D<float> BValsLeft() {
     53   return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
     54 }
     55 
     56 xla::Array2D<complex64> AValsLowerComplex() {
     57   return {{2, 0, 0, 0},
     58           {complex64(3, 1), 6, 0, 0},
     59           {4, complex64(7, 2), 9, 0},
     60           {5, 8, complex64(10, 3), 11}};
     61 }
     62 
     63 xla::Array2D<complex64> AValsUpperComplex() {
     64   return {{2, 3, complex64(4, 3), 5},
     65           {0, 6, complex64(7, 2), 8},
     66           {0, 0, complex64(9, 1), 10},
     67           {0, 0, 0, 11}};
     68 }
     69 
     70 xla::Array2D<complex64> BValsRightComplex() {
     71   return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
     72 }
     73 
     74 xla::Array2D<complex64> BValsLeftComplex() {
     75   return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
     76 }
     77 
     78 xla::Array2D<float> AValsFull() {
     79   return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}};
     80 }
     81 
     82 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) {
     83   xla::ComputationBuilder builder(client_, TestName());
     84 
     85   xla::ComputationDataHandle a, b;
     86   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
     87   auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
     88   auto result = TriangularSolve(&builder, a, b,
     89                                 /*left_side=*/false, /*lower=*/true,
     90                                 /*transpose_a=*/true, /*conjugate_a=*/false,
     91                                 /*block_size=*/2);
     92   TF_ASSERT_OK(result.status());
     93 
     94   xla::Array2D<float> expected({
     95       {0.5, 0.08333334, 0.04629629, 0.03367003},
     96       {2.5, -0.25, -0.1388889, -0.1010101},
     97       {4.5, -0.58333331, -0.32407406, -0.23569024},
     98   });
     99 
    100   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    101                              xla::ErrorSpec(1e-2, 1e-2));
    102 }
    103 
    104 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) {
    105   xla::ComputationBuilder builder(client_, TestName());
    106 
    107   xla::ComputationDataHandle a, b;
    108   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
    109   auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
    110   auto result = TriangularSolve(&builder, a, b,
    111                                 /*left_side=*/false, /*lower=*/true,
    112                                 /*transpose_a=*/false, /*conjugate_a=*/false,
    113                                 /*block_size=*/2);
    114   TF_ASSERT_OK(result.status());
    115 
    116   xla::Array2D<float> expected({
    117       {-0.16414141, -0.06902357, -0.07070707, 0.36363636},
    118       {0.64393939, 0.06565657, -0.03030303, 0.72727273},
    119       {1.4520202, 0.2003367, 0.01010101, 1.09090909},
    120   });
    121 
    122   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    123                              xla::ErrorSpec(1e-2, 1e-2));
    124 }
    125 
    126 XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) {
    127   xla::ComputationBuilder builder(client_, TestName());
    128 
    129   xla::ComputationDataHandle a, b;
    130   auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
    131   auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
    132   auto result = TriangularSolve(&builder, a, b,
    133                                 /*left_side=*/false, /*lower=*/false,
    134                                 /*transpose_a=*/true, /*conjugate_a=*/false,
    135                                 /*block_size=*/2);
    136   TF_ASSERT_OK(result.status());
    137 
    138   xla::Array2D<float> expected({
    139       {-0.16414141, -0.06902357, -0.07070707, 0.36363636},
    140       {0.64393939, 0.06565657, -0.03030303, 0.72727273},
    141       {1.4520202, 0.2003367, 0.01010101, 1.09090909},
    142   });
    143 
    144   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    145                              xla::ErrorSpec(1e-2, 1e-2));
    146 }
    147 
    148 XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) {
    149   xla::ComputationBuilder builder(client_, TestName());
    150 
    151   xla::ComputationDataHandle a, b;
    152   auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
    153   auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
    154   auto result = TriangularSolve(&builder, a, b,
    155                                 /*left_side=*/false, /*lower=*/false,
    156                                 /*transpose_a=*/false, /*conjugate_a=*/false,
    157                                 /*block_size=*/2);
    158   TF_ASSERT_OK(result.status());
    159 
    160   xla::Array2D<float> expected({
    161       {0.5, 0.08333334, 0.04629629, 0.03367003},
    162       {2.5, -0.25, -0.1388889, -0.1010101},
    163       {4.5, -0.58333331, -0.32407406, -0.23569024},
    164   });
    165 
    166   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    167                              xla::ErrorSpec(1e-2, 1e-2));
    168 }
    169 
    170 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) {
    171   xla::ComputationBuilder builder(client_, TestName());
    172 
    173   xla::ComputationDataHandle a, b;
    174   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
    175   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
    176   auto result = TriangularSolve(&builder, a, b,
    177                                 /*left_side=*/true, /*lower=*/true,
    178                                 /*transpose_a=*/true, /*conjugate_a=*/false,
    179                                 /*block_size=*/2);
    180   TF_ASSERT_OK(result.status());
    181 
    182   xla::Array2D<float> expected({
    183       {-0.89646465, -0.69444444, -0.49242424},
    184       {-0.27441077, -0.24074074, -0.20707071},
    185       {-0.23232323, -0.22222222, -0.21212121},
    186       {0.90909091, 1., 1.09090909},
    187   });
    188 
    189   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    190                              xla::ErrorSpec(1e-2, 1e-2));
    191 }
    192 
    193 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) {
    194   xla::ComputationBuilder builder(client_, TestName());
    195 
    196   xla::ComputationDataHandle a, b;
    197   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
    198   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
    199   auto result = TriangularSolve(&builder, a, b,
    200                                 /*left_side=*/true, /*lower=*/true,
    201                                 /*transpose_a=*/false, /*conjugate_a=*/false,
    202                                 /*block_size=*/2);
    203   TF_ASSERT_OK(result.status());
    204 
    205   xla::Array2D<float> expected({
    206       {0.5, 1.0, 1.5},
    207       {0.41666667, 0.33333333, 0.25},
    208       {0.23148148, 0.18518519, 0.13888889},
    209       {0.16835017, 0.13468013, 0.1010101},
    210   });
    211 
    212   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    213                              xla::ErrorSpec(1e-2, 1e-2));
    214 }
    215 
    216 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) {
    217   xla::ComputationBuilder builder(client_, TestName());
    218 
    219   xla::ComputationDataHandle a, b;
    220   auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
    221   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
    222   auto result = TriangularSolve(&builder, a, b,
    223                                 /*left_side=*/true, /*lower=*/false,
    224                                 /*transpose_a=*/true, /*conjugate_a=*/false,
    225                                 /*block_size=*/2);
    226   TF_ASSERT_OK(result.status());
    227 
    228   xla::Array2D<float> expected({
    229       {0.5, 1.0, 1.5},
    230       {0.41666667, 0.33333333, 0.25},
    231       {0.23148148, 0.18518519, 0.13888889},
    232       {0.16835017, 0.13468013, 0.1010101},
    233   });
    234 
    235   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    236                              xla::ErrorSpec(1e-2, 1e-2));
    237 }
    238 
    239 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) {
    240   xla::ComputationBuilder builder(client_, TestName());
    241 
    242   xla::ComputationDataHandle a, b;
    243   auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
    244   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
    245   auto result = TriangularSolve(&builder, a, b,
    246                                 /*left_side=*/true, /*lower=*/false,
    247                                 /*transpose_a=*/false, /*conjugate_a=*/false,
    248                                 /*block_size=*/2);
    249   TF_ASSERT_OK(result.status());
    250 
    251   xla::Array2D<float> expected({
    252       {-0.89646465, -0.69444444, -0.49242424},
    253       {-0.27441077, -0.24074074, -0.20707071},
    254       {-0.23232323, -0.22222222, -0.21212121},
    255       {0.90909091, 1., 1.09090909},
    256   });
    257 
    258   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    259                              xla::ErrorSpec(1e-2, 1e-2));
    260 }
    261 
    262 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) {
    263   xla::ComputationBuilder builder(client_, TestName());
    264 
    265   xla::ComputationDataHandle a, b;
    266   auto a_data =
    267       CreateR2Parameter<complex64>(AValsLowerComplex(), 0, "a", &builder, &a);
    268   auto b_data =
    269       CreateR2Parameter<complex64>(BValsRightComplex(), 1, "b", &builder, &b);
    270   auto result = TriangularSolve(&builder, a, b,
    271                                 /*left_side=*/false, /*lower=*/true,
    272                                 /*transpose_a=*/true, /*conjugate_a=*/true,
    273                                 /*block_size=*/2);
    274   TF_ASSERT_OK(result.status());
    275 
    276   xla::Array2D<complex64> expected({
    277       {0.5, complex64(0.08333333, 0.08333333),
    278        complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)},
    279       {2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963),
    280        complex64(0.08670034, -0.02104377)},
    281       {4.5, complex64(-0.58333333, 0.75), complex64(-0.49074074, -0.71296296),
    282        complex64(0.11026936, -0.03114478)},
    283   });
    284 
    285   ComputeAndCompareR2<complex64>(&builder, expected,
    286                                  {a_data.get(), b_data.get()},
    287                                  xla::ErrorSpec(1e-2, 1e-2));
    288 }
    289 
    290 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
    291   xla::ComputationBuilder builder(client_, TestName());
    292 
    293   xla::ComputationDataHandle a, b;
    294   auto a_data =
    295       CreateR2Parameter<complex64>(AValsUpperComplex(), 0, "a", &builder, &a);
    296   auto b_data =
    297       CreateR2Parameter<complex64>(BValsLeftComplex(), 1, "b", &builder, &b);
    298   auto result = TriangularSolve(&builder, a, b,
    299                                 /*left_side=*/true, /*lower=*/false,
    300                                 /*transpose_a=*/true, /*conjugate_a=*/false,
    301                                 /*block_size=*/2);
    302   TF_ASSERT_OK(result.status());
    303 
    304   xla::Array2D<complex64> expected({
    305       {0.5, 1., 1.5},
    306       {0.41666667, 0.33333333, 0.25},
    307       {complex64(0.20020325, -2.81504065e-01),
    308        complex64(0.13821138, -4.22764228e-01),
    309        complex64(0.07621951, -5.64024390e-01)},
    310       {complex64(0.19678492, 2.55912786e-01),
    311        complex64(0.17738359, 3.84331116e-01),
    312        complex64(0.15798226, 5.12749446e-01)},
    313   });
    314 
    315   ComputeAndCompareR2<complex64>(&builder, expected,
    316                                  {a_data.get(), b_data.get()},
    317                                  xla::ErrorSpec(1e-2, 1e-2));
    318 }
    319 
    320 XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) {
    321   xla::ComputationBuilder builder(client_, TestName());
    322 
    323   xla::ComputationDataHandle a, b;
    324   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
    325   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
    326   auto result = TriangularSolveLeftLooking(&builder, a, b,
    327                                            /*transpose_a=*/false,
    328                                            /*conjugate_a=*/false);
    329   TF_ASSERT_OK(result.status());
    330 
    331   xla::Array2D<float> expected({
    332       {0.5, 1.0, 1.5},
    333       {0.41666667, 0.33333333, 0.25},
    334       {0.23148148, 0.18518519, 0.13888889},
    335       {0.16835017, 0.13468013, 0.1010101},
    336   });
    337 
    338   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    339                              xla::ErrorSpec(1e-2, 1e-2));
    340 }
    341 
    342 XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) {
    343   xla::ComputationBuilder builder(client_, TestName());
    344 
    345   xla::ComputationDataHandle a, b;
    346   auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
    347   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
    348   auto result = TriangularSolveLeftLooking(&builder, a, b,
    349                                            /*transpose_a=*/false,
    350                                            /*conjugate_a=*/false);
    351   TF_ASSERT_OK(result.status());
    352 
    353   xla::Array2D<float> expected({
    354       {0.5, 1.0, 1.5},
    355       {0.41666667, 0.33333333, 0.25},
    356       {0.23148148, 0.18518519, 0.13888889},
    357       {0.16835017, 0.13468013, 0.1010101},
    358   });
    359 
    360   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
    361                              xla::ErrorSpec(1e-2, 1e-2));
    362 }
    363 
    364 }  // namespace
    365 }  // namespace tensorflow
    366