1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" 17 18 #include <memory> 19 #include <numeric> 20 #include <vector> 21 22 #include "tensorflow/compiler/xla/array2d.h" 23 #include "tensorflow/compiler/xla/client/computation_builder.h" 24 #include "tensorflow/compiler/xla/literal_util.h" 25 #include "tensorflow/compiler/xla/statusor.h" 26 #include "tensorflow/compiler/xla/test.h" 27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 28 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 29 #include "tensorflow/compiler/xla/tests/test_macros.h" 30 #include "tensorflow/compiler/xla/types.h" 31 #include "tensorflow/core/lib/core/status_test_util.h" 32 33 namespace tensorflow { 34 namespace { 35 36 using TriangularSolveTest = xla::ClientLibraryTestBase; 37 using TriangularSolveLeftLookingTest = xla::ClientLibraryTestBase; 38 using complex64 = xla::complex64; 39 40 xla::Array2D<float> AValsLower() { 41 return {{2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, {5, 8, 10, 11}}; 42 } 43 44 xla::Array2D<float> AValsUpper() { 45 return {{2, 3, 4, 5}, {0, 6, 7, 8}, {0, 0, 9, 10}, {0, 0, 0, 11}}; 46 } 47 48 xla::Array2D<float> BValsRight() { 49 return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; 50 } 51 52 xla::Array2D<float> BValsLeft() { 53 return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; 54 } 55 56 xla::Array2D<complex64> AValsLowerComplex() { 57 return {{2, 0, 0, 0}, 58 {complex64(3, 1), 6, 0, 0}, 59 {4, complex64(7, 2), 9, 0}, 60 {5, 8, complex64(10, 3), 11}}; 61 } 62 63 xla::Array2D<complex64> AValsUpperComplex() { 64 return {{2, 3, complex64(4, 3), 5}, 65 {0, 6, complex64(7, 2), 8}, 66 {0, 0, complex64(9, 1), 10}, 67 {0, 0, 0, 11}}; 68 } 69 70 xla::Array2D<complex64> BValsRightComplex() { 71 return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; 72 } 73 74 xla::Array2D<complex64> BValsLeftComplex() { 75 return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; 76 } 77 78 xla::Array2D<float> AValsFull() { 79 return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}}; 80 } 81 82 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { 83 xla::ComputationBuilder builder(client_, TestName()); 84 85 xla::ComputationDataHandle a, b; 86 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a); 87 auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b); 88 auto result = TriangularSolve(&builder, a, b, 89 /*left_side=*/false, /*lower=*/true, 90 /*transpose_a=*/true, /*conjugate_a=*/false, 91 /*block_size=*/2); 92 TF_ASSERT_OK(result.status()); 93 94 xla::Array2D<float> expected({ 95 {0.5, 0.08333334, 0.04629629, 0.03367003}, 96 {2.5, -0.25, -0.1388889, -0.1010101}, 97 {4.5, -0.58333331, -0.32407406, -0.23569024}, 98 }); 99 100 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 101 xla::ErrorSpec(1e-2, 1e-2)); 102 } 103 104 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { 105 xla::ComputationBuilder builder(client_, TestName()); 106 107 xla::ComputationDataHandle a, b; 108 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a); 109 auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b); 110 auto result = TriangularSolve(&builder, a, b, 111 /*left_side=*/false, /*lower=*/true, 112 /*transpose_a=*/false, /*conjugate_a=*/false, 113 /*block_size=*/2); 114 TF_ASSERT_OK(result.status()); 115 116 xla::Array2D<float> expected({ 117 {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, 118 {0.64393939, 0.06565657, -0.03030303, 0.72727273}, 119 {1.4520202, 0.2003367, 0.01010101, 1.09090909}, 120 }); 121 122 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 123 xla::ErrorSpec(1e-2, 1e-2)); 124 } 125 126 XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { 127 xla::ComputationBuilder builder(client_, TestName()); 128 129 xla::ComputationDataHandle a, b; 130 auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a); 131 auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b); 132 auto result = TriangularSolve(&builder, a, b, 133 /*left_side=*/false, /*lower=*/false, 134 /*transpose_a=*/true, /*conjugate_a=*/false, 135 /*block_size=*/2); 136 TF_ASSERT_OK(result.status()); 137 138 xla::Array2D<float> expected({ 139 {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, 140 {0.64393939, 0.06565657, -0.03030303, 0.72727273}, 141 {1.4520202, 0.2003367, 0.01010101, 1.09090909}, 142 }); 143 144 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 145 xla::ErrorSpec(1e-2, 1e-2)); 146 } 147 148 XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { 149 xla::ComputationBuilder builder(client_, TestName()); 150 151 xla::ComputationDataHandle a, b; 152 auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a); 153 auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b); 154 auto result = TriangularSolve(&builder, a, b, 155 /*left_side=*/false, /*lower=*/false, 156 /*transpose_a=*/false, /*conjugate_a=*/false, 157 /*block_size=*/2); 158 TF_ASSERT_OK(result.status()); 159 160 xla::Array2D<float> expected({ 161 {0.5, 0.08333334, 0.04629629, 0.03367003}, 162 {2.5, -0.25, -0.1388889, -0.1010101}, 163 {4.5, -0.58333331, -0.32407406, -0.23569024}, 164 }); 165 166 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 167 xla::ErrorSpec(1e-2, 1e-2)); 168 } 169 170 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { 171 xla::ComputationBuilder builder(client_, TestName()); 172 173 xla::ComputationDataHandle a, b; 174 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a); 175 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); 176 auto result = TriangularSolve(&builder, a, b, 177 /*left_side=*/true, /*lower=*/true, 178 /*transpose_a=*/true, /*conjugate_a=*/false, 179 /*block_size=*/2); 180 TF_ASSERT_OK(result.status()); 181 182 xla::Array2D<float> expected({ 183 {-0.89646465, -0.69444444, -0.49242424}, 184 {-0.27441077, -0.24074074, -0.20707071}, 185 {-0.23232323, -0.22222222, -0.21212121}, 186 {0.90909091, 1., 1.09090909}, 187 }); 188 189 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 190 xla::ErrorSpec(1e-2, 1e-2)); 191 } 192 193 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { 194 xla::ComputationBuilder builder(client_, TestName()); 195 196 xla::ComputationDataHandle a, b; 197 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a); 198 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); 199 auto result = TriangularSolve(&builder, a, b, 200 /*left_side=*/true, /*lower=*/true, 201 /*transpose_a=*/false, /*conjugate_a=*/false, 202 /*block_size=*/2); 203 TF_ASSERT_OK(result.status()); 204 205 xla::Array2D<float> expected({ 206 {0.5, 1.0, 1.5}, 207 {0.41666667, 0.33333333, 0.25}, 208 {0.23148148, 0.18518519, 0.13888889}, 209 {0.16835017, 0.13468013, 0.1010101}, 210 }); 211 212 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 213 xla::ErrorSpec(1e-2, 1e-2)); 214 } 215 216 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { 217 xla::ComputationBuilder builder(client_, TestName()); 218 219 xla::ComputationDataHandle a, b; 220 auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a); 221 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); 222 auto result = TriangularSolve(&builder, a, b, 223 /*left_side=*/true, /*lower=*/false, 224 /*transpose_a=*/true, /*conjugate_a=*/false, 225 /*block_size=*/2); 226 TF_ASSERT_OK(result.status()); 227 228 xla::Array2D<float> expected({ 229 {0.5, 1.0, 1.5}, 230 {0.41666667, 0.33333333, 0.25}, 231 {0.23148148, 0.18518519, 0.13888889}, 232 {0.16835017, 0.13468013, 0.1010101}, 233 }); 234 235 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 236 xla::ErrorSpec(1e-2, 1e-2)); 237 } 238 239 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { 240 xla::ComputationBuilder builder(client_, TestName()); 241 242 xla::ComputationDataHandle a, b; 243 auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a); 244 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); 245 auto result = TriangularSolve(&builder, a, b, 246 /*left_side=*/true, /*lower=*/false, 247 /*transpose_a=*/false, /*conjugate_a=*/false, 248 /*block_size=*/2); 249 TF_ASSERT_OK(result.status()); 250 251 xla::Array2D<float> expected({ 252 {-0.89646465, -0.69444444, -0.49242424}, 253 {-0.27441077, -0.24074074, -0.20707071}, 254 {-0.23232323, -0.22222222, -0.21212121}, 255 {0.90909091, 1., 1.09090909}, 256 }); 257 258 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 259 xla::ErrorSpec(1e-2, 1e-2)); 260 } 261 262 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { 263 xla::ComputationBuilder builder(client_, TestName()); 264 265 xla::ComputationDataHandle a, b; 266 auto a_data = 267 CreateR2Parameter<complex64>(AValsLowerComplex(), 0, "a", &builder, &a); 268 auto b_data = 269 CreateR2Parameter<complex64>(BValsRightComplex(), 1, "b", &builder, &b); 270 auto result = TriangularSolve(&builder, a, b, 271 /*left_side=*/false, /*lower=*/true, 272 /*transpose_a=*/true, /*conjugate_a=*/true, 273 /*block_size=*/2); 274 TF_ASSERT_OK(result.status()); 275 276 xla::Array2D<complex64> expected({ 277 {0.5, complex64(0.08333333, 0.08333333), 278 complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)}, 279 {2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963), 280 complex64(0.08670034, -0.02104377)}, 281 {4.5, complex64(-0.58333333, 0.75), complex64(-0.49074074, -0.71296296), 282 complex64(0.11026936, -0.03114478)}, 283 }); 284 285 ComputeAndCompareR2<complex64>(&builder, expected, 286 {a_data.get(), b_data.get()}, 287 xla::ErrorSpec(1e-2, 1e-2)); 288 } 289 290 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { 291 xla::ComputationBuilder builder(client_, TestName()); 292 293 xla::ComputationDataHandle a, b; 294 auto a_data = 295 CreateR2Parameter<complex64>(AValsUpperComplex(), 0, "a", &builder, &a); 296 auto b_data = 297 CreateR2Parameter<complex64>(BValsLeftComplex(), 1, "b", &builder, &b); 298 auto result = TriangularSolve(&builder, a, b, 299 /*left_side=*/true, /*lower=*/false, 300 /*transpose_a=*/true, /*conjugate_a=*/false, 301 /*block_size=*/2); 302 TF_ASSERT_OK(result.status()); 303 304 xla::Array2D<complex64> expected({ 305 {0.5, 1., 1.5}, 306 {0.41666667, 0.33333333, 0.25}, 307 {complex64(0.20020325, -2.81504065e-01), 308 complex64(0.13821138, -4.22764228e-01), 309 complex64(0.07621951, -5.64024390e-01)}, 310 {complex64(0.19678492, 2.55912786e-01), 311 complex64(0.17738359, 3.84331116e-01), 312 complex64(0.15798226, 5.12749446e-01)}, 313 }); 314 315 ComputeAndCompareR2<complex64>(&builder, expected, 316 {a_data.get(), b_data.get()}, 317 xla::ErrorSpec(1e-2, 1e-2)); 318 } 319 320 XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { 321 xla::ComputationBuilder builder(client_, TestName()); 322 323 xla::ComputationDataHandle a, b; 324 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a); 325 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); 326 auto result = TriangularSolveLeftLooking(&builder, a, b, 327 /*transpose_a=*/false, 328 /*conjugate_a=*/false); 329 TF_ASSERT_OK(result.status()); 330 331 xla::Array2D<float> expected({ 332 {0.5, 1.0, 1.5}, 333 {0.41666667, 0.33333333, 0.25}, 334 {0.23148148, 0.18518519, 0.13888889}, 335 {0.16835017, 0.13468013, 0.1010101}, 336 }); 337 338 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 339 xla::ErrorSpec(1e-2, 1e-2)); 340 } 341 342 XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) { 343 xla::ComputationBuilder builder(client_, TestName()); 344 345 xla::ComputationDataHandle a, b; 346 auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a); 347 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); 348 auto result = TriangularSolveLeftLooking(&builder, a, b, 349 /*transpose_a=*/false, 350 /*conjugate_a=*/false); 351 TF_ASSERT_OK(result.status()); 352 353 xla::Array2D<float> expected({ 354 {0.5, 1.0, 1.5}, 355 {0.41666667, 0.33333333, 0.25}, 356 {0.23148148, 0.18518519, 0.13888889}, 357 {0.16835017, 0.13468013, 0.1010101}, 358 }); 359 360 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, 361 xla::ErrorSpec(1e-2, 1e-2)); 362 } 363 364 } // namespace 365 } // namespace tensorflow 366