1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <memory> 17 #include <vector> 18 19 #include "tensorflow/compiler/xla/array2d.h" 20 #include "tensorflow/compiler/xla/array3d.h" 21 #include "tensorflow/compiler/xla/client/computation_builder.h" 22 #include "tensorflow/compiler/xla/client/local_client.h" 23 #include "tensorflow/compiler/xla/primitive_util.h" 24 #include "tensorflow/compiler/xla/reference_util.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 27 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 28 #include "tensorflow/compiler/xla/tests/test_macros.h" 29 #include "tensorflow/compiler/xla/tests/test_utils.h" 30 #include "tensorflow/core/platform/test.h" 31 #include "tensorflow/core/platform/types.h" 32 #include "tensorflow/core/util/command_line_flags.h" 33 34 namespace xla { 35 namespace { 36 37 // TODO(b/34468543): use GUnit typed tests when we can do all tests on all 38 // backends. 39 class DotOperationTest : public ClientLibraryTestBase { 40 public: 41 ErrorSpec error_spec_{0.0001, 1e-5}; 42 43 protected: 44 template <typename Element> 45 void TestOneElementVectorDot(); 46 template <typename Element> 47 void TestVectorDot(); 48 template <typename Element> 49 void TestSquareMatrixDot(bool lhs_row_major = false, 50 bool rhs_row_major = false); 51 template <typename Element> 52 void TestNonsquareMatrixDot(bool lhs_row_major = false, 53 bool rhs_row_major = false); 54 }; 55 56 XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) { 57 ComputationBuilder builder(client_, TestName()); 58 auto lhs = builder.ConstantR1<float>({}); 59 auto rhs = builder.ConstantR1<float>({}); 60 auto result = builder.Dot(lhs, rhs); 61 62 ComputeAndCompareR0<float>(&builder, 0.0, {}, error_spec_); 63 } 64 65 XLA_TEST_F(DotOperationTest, TrivialMatrixVectorDotF32) { 66 ComputationBuilder builder(client_, TestName()); 67 auto lhs = builder.ConstantR2<float>({{3.0, 4.0}}); 68 auto rhs = builder.ConstantR1<float>({3.0, 4.0}); 69 auto result = builder.Dot(lhs, rhs); 70 71 ComputeAndCompareR1<float>(&builder, {25.0}, {}, error_spec_); 72 } 73 74 template <typename Element> 75 void DotOperationTest::TestOneElementVectorDot() { 76 ComputationBuilder builder(client_, TestName()); 77 auto lhs = builder.ConstantR1<Element>({2.0}); 78 auto rhs = builder.ConstantR1<Element>({3.0}); 79 auto result = builder.Dot(lhs, rhs); 80 81 ComputeAndCompareR0<Element>(&builder, 6.0, {}, error_spec_); 82 } 83 84 XLA_TEST_F(DotOperationTest, OneElementVectorDotF32) { 85 TestOneElementVectorDot<float>(); 86 } 87 88 XLA_TEST_F(DotOperationTest, OneElementVectorDotF64) { 89 TestOneElementVectorDot<double>(); 90 } 91 92 template <typename Element> 93 void DotOperationTest::TestVectorDot() { 94 ComputationBuilder builder(client_, TestName()); 95 auto lhs = builder.ConstantR1<Element>({1.0, 2.5, 42.0}); 96 auto rhs = builder.ConstantR1<Element>({11.0, -1.0, 0.5}); 97 auto result = builder.Dot(lhs, rhs); 98 99 ComputeAndCompareR0<Element>(&builder, 29.5, {}, error_spec_); 100 } 101 102 XLA_TEST_F(DotOperationTest, VectorDotF32) { TestVectorDot<float>(); } 103 104 XLA_TEST_F(DotOperationTest, VectorDotF64) { TestVectorDot<double>(); } 105 106 namespace { 107 108 std::vector<int64> MinorToMajorForIsRowMajor(bool row_major) { 109 return {row_major ? 1 : 0, row_major ? 0 : 1}; 110 } 111 112 } // namespace 113 114 XLA_TEST_F(DotOperationTest, Dot_0x2_2x0) { 115 ComputationBuilder builder(client_, TestName()); 116 auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2)); 117 auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0)); 118 auto result = builder.Dot(lhs, rhs); 119 120 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {}, error_spec_); 121 } 122 123 XLA_TEST_F(DotOperationTest, Dot_0x2_2x3) { 124 ComputationBuilder builder(client_, TestName()); 125 auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2)); 126 auto rhs = builder.ConstantR2<float>({{7.0, 8.0, 9.0}, {42.0, 77.0, 101.0}}); 127 auto result = builder.Dot(lhs, rhs); 128 129 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 3), {}, error_spec_); 130 } 131 132 XLA_TEST_F(DotOperationTest, Dot_3x2_2x0) { 133 ComputationBuilder builder(client_, TestName()); 134 auto lhs = 135 builder.ConstantR2<float>({{7.0, 8.0}, {9.0, 42.0}, {77.0, 101.0}}); 136 auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0)); 137 auto result = builder.Dot(lhs, rhs); 138 139 ComputeAndCompareR2<float>(&builder, Array2D<float>(3, 0), {}, error_spec_); 140 } 141 142 XLA_TEST_F(DotOperationTest, Dot_2x0_0x2) { 143 ComputationBuilder builder(client_, TestName()); 144 auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0)); 145 auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2)); 146 auto result = builder.Dot(lhs, rhs); 147 148 ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 2, 0.0f), {}, 149 error_spec_); 150 } 151 152 XLA_TEST_F(DotOperationTest, FusedDot) { 153 ComputationBuilder builder(client_, TestName()); 154 auto param0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 4}), "arg0"); 155 auto param1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4, 1}), "arg1"); 156 auto exp0 = builder.Exp(param0); 157 auto result = builder.Dot(exp0, param1); 158 159 auto lhs_handle = client_ 160 ->TransferToServer(*Literal::CreateR2<float>( 161 {{1.0, 2.0, 3.0, 4.0}, {-1.0, -2.0, -3.0, -4.0}})) 162 .ConsumeValueOrDie(); 163 auto rhs_handle = client_ 164 ->TransferToServer(*Literal::CreateR2<float>( 165 {{1.0}, {2.0}, {3.0}, {4.0}})) 166 .ConsumeValueOrDie(); 167 168 ComputeAndCompareR2<float>( 169 &builder, Array2D<float>({{296.14560492846033}, {0.8611737683031964}}), 170 {lhs_handle.get(), rhs_handle.get()}, error_spec_); 171 } 172 173 template <typename Element> 174 void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major, 175 bool rhs_row_major) { 176 auto lhs_handle = 177 client_ 178 ->TransferToServer(*Literal::CreateR2WithLayout<Element>( 179 {{1.0, 2.0}, {3.0, -4.0}}, 180 LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) 181 .ConsumeValueOrDie(); 182 auto rhs_handle = 183 client_ 184 ->TransferToServer(*Literal::CreateR2WithLayout<Element>( 185 {{1.0, 6.0}, {7.0, -4.0}}, 186 LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) 187 .ConsumeValueOrDie(); 188 189 ComputationBuilder builder(client_, TestName()); 190 auto prim_type = primitive_util::NativeToPrimitiveType<Element>(); 191 auto result = builder.Dot( 192 builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), 193 builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs")); 194 195 Array2D<Element> expected({{15.0, -2.0}, {-25.0, 34.0}}); 196 ComputeAndCompareR2<Element>( 197 &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); 198 } 199 200 struct DotTestParam { 201 int m; 202 int k; 203 int n; 204 bool dot_lhs_row_major; 205 bool dot_rhs_row_major; 206 bool has_addend; 207 bool addend_row_major; 208 }; 209 210 string PrintDotTestParam( 211 const ::testing::TestParamInfo<DotTestParam>& test_param) { 212 const DotTestParam& param = test_param.param; 213 if (param.has_addend) { 214 return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, 215 "_MajorToMinor", 216 param.dot_lhs_row_major ? "T" : "F", 217 param.dot_rhs_row_major ? "T" : "F", 218 param.addend_row_major ? "T" : "F"); 219 } else { 220 return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, 221 "_MajorToMinor", 222 param.dot_lhs_row_major ? "T" : "F", 223 param.dot_rhs_row_major ? "T" : "F"); 224 } 225 } 226 227 class ParametricDotTest : public DotOperationTest, 228 public ::testing::WithParamInterface<DotTestParam> {}; 229 230 XLA_TEST_P(ParametricDotTest, TestF32) { 231 DotTestParam param = GetParam(); 232 233 std::unique_ptr<Array2D<float>> dot_lhs_data = 234 MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); 235 std::unique_ptr<Literal> dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout( 236 *dot_lhs_data, LayoutUtil::MakeLayout( 237 MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); 238 std::unique_ptr<GlobalData> dot_lhs_handle = 239 client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); 240 241 std::unique_ptr<Array2D<float>> dot_rhs_data = 242 MakeLinspaceArray2D(0.0, 1.0, param.k, param.n); 243 std::unique_ptr<Literal> dot_rhs_lit = Literal::CreateR2FromArray2DWithLayout( 244 *dot_rhs_data, LayoutUtil::MakeLayout( 245 MinorToMajorForIsRowMajor(param.dot_rhs_row_major))); 246 std::unique_ptr<GlobalData> dot_rhs_handle = 247 client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); 248 249 std::unique_ptr<Array2D<float>> addend_data; 250 std::unique_ptr<Literal> addend_lit; 251 std::unique_ptr<GlobalData> addend_handle; 252 253 if (param.has_addend) { 254 addend_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.n); 255 addend_lit = Literal::CreateR2FromArray2DWithLayout( 256 *addend_data, LayoutUtil::MakeLayout( 257 MinorToMajorForIsRowMajor(param.addend_row_major))); 258 addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); 259 } 260 261 ComputationBuilder builder(client_, TestName()); 262 auto prim_type = primitive_util::NativeToPrimitiveType<float>(); 263 auto result = builder.Dot( 264 builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {param.m, param.k}), 265 "dot_lhs"), 266 builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {param.k, param.n}), 267 "dot_rhs")); 268 269 if (param.has_addend) { 270 result = builder.Add( 271 result, 272 builder.Parameter( 273 2, ShapeUtil::MakeShape(prim_type, {param.m, param.n}), "addend")); 274 } 275 276 std::unique_ptr<Array2D<float>> expected; 277 if (param.has_addend) { 278 expected = ReferenceUtil::ApplyElementwise2D( 279 std::plus<float>(), 280 *ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data), 281 *addend_data); 282 } else { 283 expected = ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data); 284 } 285 286 std::vector<GlobalData*> args = {dot_lhs_handle.get(), dot_rhs_handle.get()}; 287 if (param.has_addend) { 288 args.push_back(addend_handle.get()); 289 } 290 291 ComputeAndCompareR2<float>(&builder, *expected, args, ErrorSpec(0.3, 3e-3)); 292 } 293 294 std::vector<DotTestParam> CreateDotTestParameters() { 295 std::vector<DotTestParam> params; 296 297 auto add_matrix_matrix_dot_test = [&](int m, int k, int n) { 298 for (bool lhs_row_major : {true, false}) { 299 for (bool rhs_row_major : {true, false}) { 300 params.push_back({/*m=*/m, /*k=*/k, /*n=*/n, 301 /*dot_lhs_row_major=*/lhs_row_major, 302 /*dot_rhs_row_major=*/rhs_row_major, 303 /*has_addend=*/false, /*addend_row_major=*/true}); 304 } 305 } 306 }; 307 308 auto add_matrix_vector_dot_test = [&](int k, int n) { 309 for (bool has_addend : {false, true}) { 310 params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, 311 /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true, 312 /*has_addend=*/has_addend, /*addend_row_major=*/true}); 313 if (n != 1) { 314 params.push_back( 315 {/*m=*/n, /*k=*/k, /*n=*/1, 316 /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true, 317 /*has_addend=*/has_addend, /*addend_row_major=*/true}); 318 } 319 } 320 }; 321 322 add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7); 323 add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520); 324 add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520); 325 326 add_matrix_vector_dot_test(/*k=*/8, /*n=*/8); 327 add_matrix_vector_dot_test(/*k=*/130, /*n=*/8); 328 add_matrix_vector_dot_test(/*k=*/8, /*n=*/130); 329 add_matrix_vector_dot_test(/*k=*/290, /*n=*/130); 330 add_matrix_vector_dot_test(/*k=*/1, /*n=*/1); 331 add_matrix_vector_dot_test(/*k=*/1, /*n=*/16); 332 add_matrix_vector_dot_test(/*k=*/3, /*n=*/16); 333 add_matrix_vector_dot_test(/*k=*/3, /*n=*/3); 334 add_matrix_vector_dot_test(/*k=*/29, /*n=*/29); 335 add_matrix_vector_dot_test(/*k=*/8, /*n=*/2); 336 add_matrix_vector_dot_test(/*k=*/2, /*n=*/8); 337 add_matrix_vector_dot_test(/*k=*/259, /*n=*/258); 338 339 return params; 340 } 341 342 INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest, 343 ::testing::ValuesIn(CreateDotTestParameters()), 344 PrintDotTestParam); 345 346 XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { 347 TestSquareMatrixDot<float>(false, false); 348 } 349 350 XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) { 351 TestSquareMatrixDot<float>(false, true); 352 } 353 354 XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) { 355 TestSquareMatrixDot<float>(true, false); 356 } 357 358 XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) { 359 TestSquareMatrixDot<float>(true, true); 360 } 361 362 XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFF) { 363 TestSquareMatrixDot<complex64>(false, false); 364 } 365 366 XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFT) { 367 TestSquareMatrixDot<complex64>(false, true); 368 } 369 370 XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTF) { 371 TestSquareMatrixDot<complex64>(true, false); 372 } 373 374 XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTT) { 375 TestSquareMatrixDot<complex64>(true, true); 376 } 377 378 XLA_TEST_F(DotOperationTest, SquareMatrixDotF64) { 379 TestSquareMatrixDot<double>(); 380 } 381 382 template <typename Element> 383 void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major, 384 bool rhs_row_major) { 385 auto lhs_handle = 386 client_ 387 ->TransferToServer(*Literal::CreateR2WithLayout<Element>( 388 {{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}}, 389 LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) 390 .ConsumeValueOrDie(); 391 auto rhs_handle = 392 client_ 393 ->TransferToServer(*Literal::CreateR2WithLayout<Element>( 394 {{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}}, 395 LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) 396 .ConsumeValueOrDie(); 397 398 ComputationBuilder builder(client_, TestName()); 399 auto prim_type = primitive_util::NativeToPrimitiveType<Element>(); 400 auto result = builder.Dot( 401 builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), 402 builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs")); 403 404 Array2D<Element> expected({{26.0, 0.0}, {-12.0, 10.0}}); 405 406 ComputeAndCompareR2<Element>( 407 &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); 408 } 409 410 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFF) { 411 TestNonsquareMatrixDot<float>(false, false); 412 } 413 414 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFT) { 415 TestNonsquareMatrixDot<float>(false, true); 416 } 417 418 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) { 419 TestNonsquareMatrixDot<float>(true, false); 420 } 421 422 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) { 423 TestNonsquareMatrixDot<float>(true, true); 424 } 425 426 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) { 427 TestNonsquareMatrixDot<double>(); 428 } 429 430 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFF) { 431 TestNonsquareMatrixDot<complex64>(false, false); 432 } 433 434 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFT) { 435 TestNonsquareMatrixDot<complex64>(false, true); 436 } 437 438 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTF) { 439 TestNonsquareMatrixDot<complex64>(true, false); 440 } 441 442 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTT) { 443 TestNonsquareMatrixDot<complex64>(true, true); 444 } 445 446 XLA_TEST_F(DotOperationTest, MatrixVectorC64) { 447 auto lhs_handle = 448 client_ 449 ->TransferToServer(*Literal::CreateR2WithLayout<complex64>( 450 {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0}))) 451 .ConsumeValueOrDie(); 452 auto rhs_handle = 453 client_ 454 ->TransferToServer(*Literal::CreateR2WithLayout<complex64>( 455 {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, 456 LayoutUtil::MakeLayout({1, 0}))) 457 .ConsumeValueOrDie(); 458 459 ComputationBuilder builder(client_, TestName()); 460 auto prim_type = primitive_util::NativeToPrimitiveType<complex64>(); 461 auto result = builder.Dot( 462 builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"), 463 builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs")); 464 465 Array2D<complex64> expected({{30.0, -2.0}}); 466 467 ComputeAndCompareR2<complex64>( 468 &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); 469 } 470 471 XLA_TEST_F(DotOperationTest, ConcurrentMatMul) { 472 ComputationBuilder builder(client_, TestName()); 473 auto matrix1 = builder.ConstantR2<float>({{1.0, 2.0}, {3.0, 4.0}}); 474 auto matrix2 = builder.ConstantR2<float>({{5.0, 6.0}, {7.0, 8.0}}); 475 auto matrix12 = builder.Dot(matrix1, matrix2); 476 auto matrix21 = builder.Dot(matrix2, matrix1); 477 builder.Add(matrix12, matrix21); 478 479 Array2D<float> expected({{42.0, 56.0}, {74.0, 96.0}}); 480 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); 481 } 482 483 // Regression test for b/32055648. The root of the graph is a kFusion of 4 484 // bitcasts. Although bitcasts don't map to thunks, the root should still be 485 // sync-dependent on bitcasts' operands. 486 XLA_TEST_F(DotOperationTest, BatchMatMul) { 487 ComputationBuilder builder(client_, TestName()); 488 auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "x"); 489 auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "y"); 490 491 auto x_flat = builder.Reshape(x, {0, 1, 2, 3}, {4, 2, 2}); 492 auto y_flat = builder.Reshape(y, {0, 1, 2, 3}, {4, 2, 2}); 493 494 // Slice batches into individual matrices and multiply them. 495 std::vector<xla::ComputationDataHandle> out_slices; 496 for (int i = 0; i < 4; ++i) { 497 // Slice off individual matrices and reshape to 2D tensors. 498 auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); 499 x_slice = builder.Reshape(x_slice, {0, 1, 2}, {2, 2}); 500 auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); 501 y_slice = builder.Reshape(y_slice, {0, 1, 2}, {2, 2}); 502 503 auto out = builder.Dot(x_slice, y_slice); 504 out = builder.Reshape(out, {0, 1}, {1, 2, 2}); 505 out_slices.push_back(out); 506 } 507 auto out_flat = builder.ConcatInDim(out_slices, 0); 508 builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); 509 510 auto x_data = client_ 511 ->TransferToServer(*Literal::CreateR4<float>( 512 {{{{1000, 100}, {10, 1}}, {{2000, 200}, {20, 2}}}, 513 {{{3000, 300}, {30, 3}}, {{4000, 400}, {40, 4}}}})) 514 .ConsumeValueOrDie(); 515 auto y_data = client_ 516 ->TransferToServer(*Literal::CreateR4<float>( 517 {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}, 518 {{{11, 22}, {33, 44}}, {{55, 66}, {77, 88}}}})) 519 .ConsumeValueOrDie(); 520 521 ComputeAndCompareR4<float>( 522 &builder, 523 /*expected=*/ 524 {{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}}, 525 {{{42900, 79200}, {429, 792}}, {{250800, 299200}, {2508, 2992}}}}, 526 {x_data.get(), y_data.get()}, error_spec_); 527 } 528 529 XLA_TEST_F(DotOperationTest, GeneralMatMul) { 530 ComputationBuilder builder(client_, TestName()); 531 auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2}), "x"); 532 auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2}), "y"); 533 534 DotDimensionNumbers dnums; 535 dnums.add_lhs_contracting_dimensions(2); 536 dnums.add_rhs_contracting_dimensions(1); 537 dnums.add_lhs_batch_dimensions(0); 538 dnums.add_rhs_batch_dimensions(0); 539 540 auto out = builder.DotGeneral(x, y, dnums); 541 542 auto x_data = client_ 543 ->TransferToServer(*Literal::CreateR3<float>( 544 {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}})) 545 .ConsumeValueOrDie(); 546 547 auto y_data = client_ 548 ->TransferToServer(*Literal::CreateR3<float>( 549 {{{1.0, 0.0}, {0.0, 1.0}}, {{1.0, 0.0}, {0.0, 1.0}}})) 550 .ConsumeValueOrDie(); 551 552 ComputeAndCompareR3<float>( 553 &builder, 554 /*expected=*/ 555 {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}}, 556 {x_data.get(), y_data.get()}, error_spec_); 557 } 558 559 TEST_F(DotOperationTest, TransposeFolding) { 560 for (bool transpose_lhs : {false, true}) { 561 for (bool transpose_rhs : {false, true}) { 562 for (bool row_major : {false, true}) { 563 std::unique_ptr<Array2D<float>> lhs( 564 new Array2D<float>({{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}})); 565 std::unique_ptr<Array2D<float>> rhs( 566 new Array2D<float>({{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}})); 567 568 if (transpose_lhs) { 569 lhs = ReferenceUtil::TransposeArray2D(*lhs); 570 } 571 if (transpose_rhs) { 572 rhs = ReferenceUtil::TransposeArray2D(*rhs); 573 } 574 auto lhs_handle = 575 client_ 576 ->TransferToServer( 577 *Literal::CreateR2FromArray2DWithLayout<float>( 578 *lhs, LayoutUtil::MakeLayout( 579 MinorToMajorForIsRowMajor(row_major)))) 580 .ConsumeValueOrDie(); 581 auto rhs_handle = 582 client_ 583 ->TransferToServer( 584 *Literal::CreateR2FromArray2DWithLayout<float>( 585 *rhs, LayoutUtil::MakeLayout( 586 MinorToMajorForIsRowMajor(row_major)))) 587 .ConsumeValueOrDie(); 588 589 ComputationBuilder builder(client_, TestName()); 590 auto prim_type = primitive_util::NativeToPrimitiveType<float>(); 591 auto lhs_arg = builder.Parameter( 592 0, ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}), 593 "lhs"); 594 auto rhs_arg = builder.Parameter( 595 1, ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}), 596 "rhs"); 597 if (transpose_lhs) { 598 lhs_arg = builder.Transpose(lhs_arg, {1, 0}); 599 } 600 if (transpose_rhs) { 601 rhs_arg = builder.Transpose(rhs_arg, {1, 0}); 602 } 603 auto result = builder.Dot(lhs_arg, rhs_arg); 604 605 Array2D<float> expected({{26.0, 0.0}, {-12.0, 10.0}}); 606 VLOG(1) << "TestTransposeFolding " << transpose_lhs << " " 607 << transpose_rhs << " " << row_major; 608 ComputeAndCompareR2<float>(&builder, expected, 609 {lhs_handle.get(), rhs_handle.get()}, 610 error_spec_); 611 } 612 } 613 } 614 } 615 616 TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstLHS) { 617 auto prim_type = primitive_util::NativeToPrimitiveType<float>(); 618 619 std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( 620 {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); 621 622 ComputationBuilder builder(client_, TestName()); 623 auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); 624 auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), 625 "rhs_arg_0"); 626 auto rhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), 627 "rhs_arg_1"); 628 auto rhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {1, 2}), 629 "rhs_arg_2"); 630 auto result = builder.Dot( 631 lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); 632 633 std::unique_ptr<Array2D<float>> arg_0_value_array( 634 new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}})); 635 std::unique_ptr<Array2D<float>> arg_1_value_array( 636 new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})); 637 std::unique_ptr<Array2D<float>> arg_2_value_array( 638 new Array2D<float>({{1.0, 2.0}})); 639 640 TF_ASSERT_OK_AND_ASSIGN( 641 auto arg_0_value, 642 client_->TransferToServer( 643 *Literal::CreateR2FromArray2D<float>(*arg_0_value_array))); 644 TF_ASSERT_OK_AND_ASSIGN( 645 auto arg_1_value, 646 client_->TransferToServer( 647 *Literal::CreateR2FromArray2D<float>(*arg_1_value_array))); 648 TF_ASSERT_OK_AND_ASSIGN( 649 auto arg_2_value, 650 client_->TransferToServer( 651 *Literal::CreateR2FromArray2D<float>(*arg_2_value_array))); 652 653 Array2D<float> expected({{53.0, 74.0}, {45.0, 66.0}}); 654 ComputeAndCompareR2<float>( 655 &builder, expected, 656 {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); 657 } 658 659 TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstRHS) { 660 auto prim_type = primitive_util::NativeToPrimitiveType<float>(); 661 662 std::unique_ptr<Array2D<float>> constant_rhs_array( 663 new Array2D<float>({{1.0, 2.0}, 664 {3.0, 4.0}, 665 {5.0, 6.0}, 666 {6.0, 5.0}, 667 {4.0, 3.0}, 668 {2.0, 1.0}})); 669 670 ComputationBuilder builder(client_, TestName()); 671 auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); 672 auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), 673 "lhs_arg_0"); 674 auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 3}), 675 "lhs_arg_1"); 676 auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {2, 1}), 677 "lhs_arg_2"); 678 auto result = builder.Dot( 679 builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant); 680 681 std::unique_ptr<Array2D<float>> arg_0_value_array( 682 new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}})); 683 std::unique_ptr<Array2D<float>> arg_1_value_array( 684 new Array2D<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); 685 std::unique_ptr<Array2D<float>> arg_2_value_array( 686 new Array2D<float>({{1.0}, {2.0}})); 687 688 TF_ASSERT_OK_AND_ASSIGN( 689 auto arg_0_value, 690 client_->TransferToServer( 691 *Literal::CreateR2FromArray2D<float>(*arg_0_value_array))); 692 TF_ASSERT_OK_AND_ASSIGN( 693 auto arg_1_value, 694 client_->TransferToServer( 695 *Literal::CreateR2FromArray2D<float>(*arg_1_value_array))); 696 TF_ASSERT_OK_AND_ASSIGN( 697 auto arg_2_value, 698 client_->TransferToServer( 699 *Literal::CreateR2FromArray2D<float>(*arg_2_value_array))); 700 701 Array2D<float> expected({{38.0, 36.0}, {93.0, 91.0}}); 702 ComputeAndCompareR2<float>( 703 &builder, expected, 704 {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); 705 } 706 } // namespace 707 } // namespace xla 708