1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" 17 18 #include "tensorflow/compiler/xla/array2d.h" 19 #include "tensorflow/compiler/xla/array3d.h" 20 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 21 #include "tensorflow/compiler/xla/client/lib/constants.h" 22 #include "tensorflow/compiler/xla/client/lib/matrix.h" 23 #include "tensorflow/compiler/xla/client/xla_builder.h" 24 #include "tensorflow/compiler/xla/literal.h" 25 #include "tensorflow/compiler/xla/statusor.h" 26 #include "tensorflow/compiler/xla/test.h" 27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 28 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 29 #include "tensorflow/compiler/xla/tests/test_macros.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 #include "tensorflow/core/lib/core/status_test_util.h" 32 33 namespace xla { 34 35 class SelfAdjointEigTest : public ClientLibraryTestBase { 36 protected: 37 void SetUp() override { 38 ClientLibraryTestBase::SetUp(); 39 batch_3d_4x4_ = Array3D<float>{ 40 { 41 {4, 6, 8, 10}, 42 {6, 45, 54, 63}, 43 {8, 54, 146, 166}, 44 {10, 63, 166, 310}, 45 }, 46 { 47 {16, 24, 8, 12}, 48 {24, 61, 82, 48}, 49 {8, 82, 100, 6}, 50 {12, 48, 6, 62}, 51 }, 52 }; 53 matrix2d_8x8_ = Array2D<float>{ 54 {14., 123., 49., 112., 115., 173., 182., 125.}, 55 {123., 14., 60., 118., 150., 130., 91., 72.}, 56 {49., 60., 138., 111., 106., 101., 115., 142.}, 57 {112., 118., 111., 142., 91., 130., 25., 61.}, 58 {115., 150., 106., 91., 116., 121., 128., 85.}, 59 {173., 130., 101., 130., 121., 70., 151., 132.}, 60 {182., 91., 115., 25., 128., 151., 66., 92.}, 61 {125., 72., 142., 61., 85., 132., 92., 156.}, 62 }; 63 low_rank_4x4_ = Array2D<float>{ 64 // x = [[1, 2, 3, 4], [1, -1, 1, -1]] 65 // matmul(x.T, x) 66 {2, 1, 4, 3}, 67 {1, 5, 5, 9}, 68 {4, 5, 10, 11}, 69 {3, 9, 11, 17}, 70 }; 71 } 72 void TearDown() override { ClientLibraryTestBase::TearDown(); } 73 74 Array3D<float> GetUnitMatrix3D(const Array3D<float>& matrix) { 75 Array3D<float> result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0); 76 for (int i = 0; i < matrix.n1(); ++i) { 77 for (int j = 0; j < matrix.n2(); ++j) { 78 result({i, j, j}) = 1.0; 79 } 80 } 81 return result; 82 } 83 84 Array3D<float> ExtractTriangularMatrix(const Array3D<float>& matrix, 85 bool lower) { 86 Array3D<float> result(matrix); 87 for (int i = 0; i < result.n1(); ++i) { 88 for (int j = 0; j < result.n2(); ++j) { 89 if (lower) { 90 for (int k = j + 1; k < result.n3(); ++k) { 91 result({i, j, k}) = 0.0; 92 } 93 } else { 94 for (int k = 0; k < j; ++k) { 95 result({i, j, k}) = 0.0; 96 } 97 } 98 } 99 } 100 return result; 101 } 102 103 XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) { 104 Shape shape = builder->GetShape(result.v).ValueOrDie(); 105 std::vector<int64> out_dims = shape.dimensions(); 106 std::vector<int64> broadcast_dims(shape.rank() - 1); 107 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); 108 109 broadcast_dims[shape.rank() - 2] = shape.rank() - 1; 110 auto vw = Mul(result.v, BroadcastInDim(result.w, out_dims, broadcast_dims)); 111 return BatchDot(vw, TransposeInMinorDims(result.v), 112 PrecisionConfig::HIGHEST); 113 } 114 115 XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) { 116 Shape shape = builder->GetShape(m1).ValueOrDie(); 117 int64 size = 1; 118 for (auto d : shape.dimensions()) { 119 size *= d; 120 } 121 return ReduceAll(Abs(m1 - m2), ConstantR0WithType(builder, F32, 0), 122 CreateScalarAddComputation(F32, builder)) / 123 ConstantR0WithType(builder, F32, size); 124 } 125 126 Array2D<float> GenerateRandomSymmetricMatrix(int size) { 127 Array2D<float> result{size, size, 0.0}; 128 // TODO(b/128001705): This seed should not be needed but makes the test 129 // avoid inputs which trigger numerical instability. 130 result.FillRandom(10 /* stddev */, 2 /* mean */, 12346 /* seed */); 131 for (int i = 0; i < size; ++i) { 132 for (int j = 0; j < i; ++j) { 133 result({j, i}) = result({i, j}); 134 } 135 } 136 return result; 137 } 138 139 Array3D<float> batch_3d_4x4_; 140 Array2D<float> matrix2d_8x8_; 141 Array2D<float> low_rank_4x4_; 142 Array2D<int> wrong_type_4x4_; 143 }; 144 145 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) { 146 XlaBuilder builder(TestName()); 147 148 XlaOp a; 149 auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a); 150 auto result = SelfAdjointEig(a); 151 ComputeMatmulVWVt(result, &builder); 152 153 ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()}, 154 ErrorSpec(1e-3, 1e-3)); 155 } 156 157 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) { 158 XlaBuilder builder(TestName()); 159 160 XlaOp a; 161 auto a_data = CreateR3Parameter<float>( 162 ExtractTriangularMatrix(batch_3d_4x4_, true), 0, "a", &builder, &a); 163 auto result = SelfAdjointEig(a); 164 ComputeMatmulVWVt(result, &builder); 165 166 ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()}, 167 ErrorSpec(1e-3, 1e-3)); 168 } 169 170 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Upper_2x4x4) { 171 XlaBuilder builder(TestName()); 172 173 XlaOp a; 174 auto a_data = CreateR3Parameter<float>( 175 ExtractTriangularMatrix(batch_3d_4x4_, false), 0, "a", &builder, &a); 176 auto result = SelfAdjointEig(a, false); 177 ComputeMatmulVWVt(result, &builder); 178 179 ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()}, 180 ErrorSpec(1e-3, 1e-3)); 181 } 182 183 XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_2x4x4) { 184 XlaBuilder builder(TestName()); 185 186 XlaOp a; 187 auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a); 188 auto result = SelfAdjointEig(a); 189 BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST); 190 191 ComputeAndCompareR3<float>(&builder, GetUnitMatrix3D(batch_3d_4x4_), 192 {a_data.get()}, ErrorSpec(1e-3, 1e-3)); 193 } 194 195 XLA_TEST_F(SelfAdjointEigTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) { 196 XlaBuilder builder(TestName()); 197 198 XlaOp a; 199 auto a_data = CreateR2Parameter<float>(low_rank_4x4_, 0, "a", &builder, &a); 200 auto result = SelfAdjointEig(a); 201 ComputeMatmulVWVt(result, &builder); 202 203 ComputeAndCompareR2<float>(&builder, low_rank_4x4_, {a_data.get()}, 204 ErrorSpec(1e-3, 1e-3)); 205 } 206 207 XLA_TEST_F(SelfAdjointEigTest, Test_Eigen_8x8) { 208 XlaBuilder builder(TestName()); 209 210 // This is computed by numpy.linalg.eigh with float32. 211 std::vector<float> expected{-182.69205, -116.86245, -105.74489, -9.545369, 212 37.81711, 104.732285, 120.29153, 868.00385}; 213 214 XlaOp a; 215 auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a); 216 auto result = SelfAdjointEig(a); 217 Add(result.w, ZerosLike(result.w)); 218 219 ComputeAndCompareR1<float>(&builder, expected, {a_data.get()}, 220 ErrorSpec(1e-3, 1e-3)); 221 } 222 223 XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_8x8) { 224 XlaBuilder builder(TestName()); 225 226 float expected_vals = 1e-3; 227 228 XlaOp a; 229 auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a); 230 auto result = SelfAdjointEig(a); 231 // np.sum(norm(eye(n) - matmul(conj(T(v)), v)) / n**2 232 GetAverageAbsoluteError(IdentityMatrix(&builder, F32, 8, 8), 233 BatchDot(TransposeInMinorDims(result.v), result.v), 234 &builder); 235 236 ComputeAndCompareR0<float>(&builder, expected_vals, {a_data.get()}, 237 ErrorSpec(1e-3, 1e-3)); 238 } 239 240 XLA_TEST_F(SelfAdjointEigTest, Wrong_Type_Int) { 241 XlaBuilder builder(TestName()); 242 243 XlaOp a; 244 auto a_data = CreateR2Parameter<int>(wrong_type_4x4_, 0, "a", &builder, &a); 245 auto result = SelfAdjointEig(a); 246 EXPECT_FALSE(result.v.valid()); 247 EXPECT_FALSE(result.w.valid()); 248 } 249 250 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_8x8) { 251 XlaBuilder builder(TestName()); 252 int size = 8; 253 Array2D<float> a_val = GenerateRandomSymmetricMatrix(size); 254 XlaOp a; 255 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a); 256 auto result = SelfAdjointEig(a); 257 GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); 258 259 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()}, 260 ErrorSpec(1e-3, 1e-3)); 261 } 262 263 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_16x16) { 264 XlaBuilder builder(TestName()); 265 int size = 16; 266 Array2D<float> a_val = GenerateRandomSymmetricMatrix(size); 267 XlaOp a; 268 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a); 269 auto result = SelfAdjointEig(a); 270 GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); 271 272 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()}, 273 ErrorSpec(1e-3, 1e-3)); 274 } 275 276 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_32x32) { 277 XlaBuilder builder(TestName()); 278 int size = 32; 279 Array2D<float> a_val = GenerateRandomSymmetricMatrix(size); 280 XlaOp a; 281 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a); 282 auto result = SelfAdjointEig(a); 283 GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); 284 285 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()}, 286 ErrorSpec(1e-3, 1e-3)); 287 } 288 289 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_256x256) { 290 XlaBuilder builder(TestName()); 291 int size = 256; 292 Array2D<float> a_val = GenerateRandomSymmetricMatrix(size); 293 XlaOp a; 294 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a); 295 auto result = SelfAdjointEig(a); 296 GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); 297 298 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()}, 299 ErrorSpec(1e-3, 1e-3)); 300 } 301 302 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_512x512) { 303 XlaBuilder builder(TestName()); 304 int size = 512; 305 Array2D<float> a_val = GenerateRandomSymmetricMatrix(size); 306 XlaOp a; 307 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a); 308 auto result = SelfAdjointEig(a); 309 GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); 310 311 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()}, 312 ErrorSpec(1e-3, 1e-3)); 313 } 314 315 } // namespace xla 316