Home | History | Annotate | Download | only in lib
      1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
     17 
     18 #include "tensorflow/compiler/xla/array2d.h"
     19 #include "tensorflow/compiler/xla/array3d.h"
     20 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
     21 #include "tensorflow/compiler/xla/client/lib/constants.h"
     22 #include "tensorflow/compiler/xla/client/lib/matrix.h"
     23 #include "tensorflow/compiler/xla/client/xla_builder.h"
     24 #include "tensorflow/compiler/xla/literal.h"
     25 #include "tensorflow/compiler/xla/statusor.h"
     26 #include "tensorflow/compiler/xla/test.h"
     27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     28 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     29 #include "tensorflow/compiler/xla/tests/test_macros.h"
     30 #include "tensorflow/compiler/xla/xla_data.pb.h"
     31 #include "tensorflow/core/lib/core/status_test_util.h"
     32 
     33 namespace xla {
     34 
     35 class SelfAdjointEigTest : public ClientLibraryTestBase {
     36  protected:
     37   void SetUp() override {
     38     ClientLibraryTestBase::SetUp();
     39     batch_3d_4x4_ = Array3D<float>{
     40         {
     41             {4, 6, 8, 10},
     42             {6, 45, 54, 63},
     43             {8, 54, 146, 166},
     44             {10, 63, 166, 310},
     45         },
     46         {
     47             {16, 24, 8, 12},
     48             {24, 61, 82, 48},
     49             {8, 82, 100, 6},
     50             {12, 48, 6, 62},
     51         },
     52     };
     53     matrix2d_8x8_ = Array2D<float>{
     54         {14., 123., 49., 112., 115., 173., 182., 125.},
     55         {123., 14., 60., 118., 150., 130., 91., 72.},
     56         {49., 60., 138., 111., 106., 101., 115., 142.},
     57         {112., 118., 111., 142., 91., 130., 25., 61.},
     58         {115., 150., 106., 91., 116., 121., 128., 85.},
     59         {173., 130., 101., 130., 121., 70., 151., 132.},
     60         {182., 91., 115., 25., 128., 151., 66., 92.},
     61         {125., 72., 142., 61., 85., 132., 92., 156.},
     62     };
     63     low_rank_4x4_ = Array2D<float>{
     64         // x = [[1, 2, 3, 4], [1, -1, 1, -1]]
     65         // matmul(x.T, x)
     66         {2, 1, 4, 3},
     67         {1, 5, 5, 9},
     68         {4, 5, 10, 11},
     69         {3, 9, 11, 17},
     70     };
     71   }
     72   void TearDown() override { ClientLibraryTestBase::TearDown(); }
     73 
     74   Array3D<float> GetUnitMatrix3D(const Array3D<float>& matrix) {
     75     Array3D<float> result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0);
     76     for (int i = 0; i < matrix.n1(); ++i) {
     77       for (int j = 0; j < matrix.n2(); ++j) {
     78         result({i, j, j}) = 1.0;
     79       }
     80     }
     81     return result;
     82   }
     83 
     84   Array3D<float> ExtractTriangularMatrix(const Array3D<float>& matrix,
     85                                          bool lower) {
     86     Array3D<float> result(matrix);
     87     for (int i = 0; i < result.n1(); ++i) {
     88       for (int j = 0; j < result.n2(); ++j) {
     89         if (lower) {
     90           for (int k = j + 1; k < result.n3(); ++k) {
     91             result({i, j, k}) = 0.0;
     92           }
     93         } else {
     94           for (int k = 0; k < j; ++k) {
     95             result({i, j, k}) = 0.0;
     96           }
     97         }
     98       }
     99     }
    100     return result;
    101   }
    102 
    103   XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) {
    104     Shape shape = builder->GetShape(result.v).ValueOrDie();
    105     std::vector<int64> out_dims = shape.dimensions();
    106     std::vector<int64> broadcast_dims(shape.rank() - 1);
    107     std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
    108 
    109     broadcast_dims[shape.rank() - 2] = shape.rank() - 1;
    110     auto vw = Mul(result.v, BroadcastInDim(result.w, out_dims, broadcast_dims));
    111     return BatchDot(vw, TransposeInMinorDims(result.v),
    112                     PrecisionConfig::HIGHEST);
    113   }
    114 
    115   XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) {
    116     Shape shape = builder->GetShape(m1).ValueOrDie();
    117     int64 size = 1;
    118     for (auto d : shape.dimensions()) {
    119       size *= d;
    120     }
    121     return ReduceAll(Abs(m1 - m2), ConstantR0WithType(builder, F32, 0),
    122                      CreateScalarAddComputation(F32, builder)) /
    123            ConstantR0WithType(builder, F32, size);
    124   }
    125 
    126   Array2D<float> GenerateRandomSymmetricMatrix(int size) {
    127     Array2D<float> result{size, size, 0.0};
    128     // TODO(b/128001705): This seed should not be needed but makes the test
    129     // avoid inputs which trigger numerical instability.
    130     result.FillRandom(10 /* stddev */, 2 /* mean */, 12346 /* seed */);
    131     for (int i = 0; i < size; ++i) {
    132       for (int j = 0; j < i; ++j) {
    133         result({j, i}) = result({i, j});
    134       }
    135     }
    136     return result;
    137   }
    138 
    139   Array3D<float> batch_3d_4x4_;
    140   Array2D<float> matrix2d_8x8_;
    141   Array2D<float> low_rank_4x4_;
    142   Array2D<int> wrong_type_4x4_;
    143 };
    144 
    145 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) {
    146   XlaBuilder builder(TestName());
    147 
    148   XlaOp a;
    149   auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
    150   auto result = SelfAdjointEig(a);
    151   ComputeMatmulVWVt(result, &builder);
    152 
    153   ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
    154                              ErrorSpec(1e-3, 1e-3));
    155 }
    156 
    157 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) {
    158   XlaBuilder builder(TestName());
    159 
    160   XlaOp a;
    161   auto a_data = CreateR3Parameter<float>(
    162       ExtractTriangularMatrix(batch_3d_4x4_, true), 0, "a", &builder, &a);
    163   auto result = SelfAdjointEig(a);
    164   ComputeMatmulVWVt(result, &builder);
    165 
    166   ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
    167                              ErrorSpec(1e-3, 1e-3));
    168 }
    169 
    170 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Upper_2x4x4) {
    171   XlaBuilder builder(TestName());
    172 
    173   XlaOp a;
    174   auto a_data = CreateR3Parameter<float>(
    175       ExtractTriangularMatrix(batch_3d_4x4_, false), 0, "a", &builder, &a);
    176   auto result = SelfAdjointEig(a, false);
    177   ComputeMatmulVWVt(result, &builder);
    178 
    179   ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
    180                              ErrorSpec(1e-3, 1e-3));
    181 }
    182 
    183 XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_2x4x4) {
    184   XlaBuilder builder(TestName());
    185 
    186   XlaOp a;
    187   auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
    188   auto result = SelfAdjointEig(a);
    189   BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST);
    190 
    191   ComputeAndCompareR3<float>(&builder, GetUnitMatrix3D(batch_3d_4x4_),
    192                              {a_data.get()}, ErrorSpec(1e-3, 1e-3));
    193 }
    194 
    195 XLA_TEST_F(SelfAdjointEigTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) {
    196   XlaBuilder builder(TestName());
    197 
    198   XlaOp a;
    199   auto a_data = CreateR2Parameter<float>(low_rank_4x4_, 0, "a", &builder, &a);
    200   auto result = SelfAdjointEig(a);
    201   ComputeMatmulVWVt(result, &builder);
    202 
    203   ComputeAndCompareR2<float>(&builder, low_rank_4x4_, {a_data.get()},
    204                              ErrorSpec(1e-3, 1e-3));
    205 }
    206 
    207 XLA_TEST_F(SelfAdjointEigTest, Test_Eigen_8x8) {
    208   XlaBuilder builder(TestName());
    209 
    210   // This is computed by numpy.linalg.eigh with float32.
    211   std::vector<float> expected{-182.69205, -116.86245, -105.74489, -9.545369,
    212                               37.81711,   104.732285, 120.29153,  868.00385};
    213 
    214   XlaOp a;
    215   auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
    216   auto result = SelfAdjointEig(a);
    217   Add(result.w, ZerosLike(result.w));
    218 
    219   ComputeAndCompareR1<float>(&builder, expected, {a_data.get()},
    220                              ErrorSpec(1e-3, 1e-3));
    221 }
    222 
    223 XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_8x8) {
    224   XlaBuilder builder(TestName());
    225 
    226   float expected_vals = 1e-3;
    227 
    228   XlaOp a;
    229   auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
    230   auto result = SelfAdjointEig(a);
    231   // np.sum(norm(eye(n) - matmul(conj(T(v)), v)) / n**2
    232   GetAverageAbsoluteError(IdentityMatrix(&builder, F32, 8, 8),
    233                           BatchDot(TransposeInMinorDims(result.v), result.v),
    234                           &builder);
    235 
    236   ComputeAndCompareR0<float>(&builder, expected_vals, {a_data.get()},
    237                              ErrorSpec(1e-3, 1e-3));
    238 }
    239 
    240 XLA_TEST_F(SelfAdjointEigTest, Wrong_Type_Int) {
    241   XlaBuilder builder(TestName());
    242 
    243   XlaOp a;
    244   auto a_data = CreateR2Parameter<int>(wrong_type_4x4_, 0, "a", &builder, &a);
    245   auto result = SelfAdjointEig(a);
    246   EXPECT_FALSE(result.v.valid());
    247   EXPECT_FALSE(result.w.valid());
    248 }
    249 
    250 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_8x8) {
    251   XlaBuilder builder(TestName());
    252   int size = 8;
    253   Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
    254   XlaOp a;
    255   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
    256   auto result = SelfAdjointEig(a);
    257   GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
    258 
    259   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
    260                              ErrorSpec(1e-3, 1e-3));
    261 }
    262 
    263 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_16x16) {
    264   XlaBuilder builder(TestName());
    265   int size = 16;
    266   Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
    267   XlaOp a;
    268   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
    269   auto result = SelfAdjointEig(a);
    270   GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
    271 
    272   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
    273                              ErrorSpec(1e-3, 1e-3));
    274 }
    275 
    276 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_32x32) {
    277   XlaBuilder builder(TestName());
    278   int size = 32;
    279   Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
    280   XlaOp a;
    281   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
    282   auto result = SelfAdjointEig(a);
    283   GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
    284 
    285   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
    286                              ErrorSpec(1e-3, 1e-3));
    287 }
    288 
    289 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_256x256) {
    290   XlaBuilder builder(TestName());
    291   int size = 256;
    292   Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
    293   XlaOp a;
    294   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
    295   auto result = SelfAdjointEig(a);
    296   GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
    297 
    298   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
    299                              ErrorSpec(1e-3, 1e-3));
    300 }
    301 
    302 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_512x512) {
    303   XlaBuilder builder(TestName());
    304   int size = 512;
    305   Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
    306   XlaOp a;
    307   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
    308   auto result = SelfAdjointEig(a);
    309   GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
    310 
    311   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
    312                              ErrorSpec(1e-3, 1e-3));
    313 }
    314 
    315 }  // namespace xla
    316