Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <memory>
     17 #include <vector>
     18 
     19 #include "tensorflow/compiler/xla/array2d.h"
     20 #include "tensorflow/compiler/xla/array3d.h"
     21 #include "tensorflow/compiler/xla/client/computation_builder.h"
     22 #include "tensorflow/compiler/xla/client/local_client.h"
     23 #include "tensorflow/compiler/xla/primitive_util.h"
     24 #include "tensorflow/compiler/xla/reference_util.h"
     25 #include "tensorflow/compiler/xla/shape_util.h"
     26 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     27 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     28 #include "tensorflow/compiler/xla/tests/test_macros.h"
     29 #include "tensorflow/compiler/xla/tests/test_utils.h"
     30 #include "tensorflow/core/platform/test.h"
     31 #include "tensorflow/core/platform/types.h"
     32 #include "tensorflow/core/util/command_line_flags.h"
     33 
     34 namespace xla {
     35 namespace {
     36 
     37 // TODO(b/34468543): use GUnit typed tests when we can do all tests on all
     38 // backends.
     39 class DotOperationTest : public ClientLibraryTestBase {
     40  public:
     41   ErrorSpec error_spec_{0.0001, 1e-5};
     42 
     43  protected:
     44   template <typename Element>
     45   void TestOneElementVectorDot();
     46   template <typename Element>
     47   void TestVectorDot();
     48   template <typename Element>
     49   void TestSquareMatrixDot(bool lhs_row_major = false,
     50                            bool rhs_row_major = false);
     51   template <typename Element>
     52   void TestNonsquareMatrixDot(bool lhs_row_major = false,
     53                               bool rhs_row_major = false);
     54 };
     55 
     56 XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) {
     57   ComputationBuilder builder(client_, TestName());
     58   auto lhs = builder.ConstantR1<float>({});
     59   auto rhs = builder.ConstantR1<float>({});
     60   auto result = builder.Dot(lhs, rhs);
     61 
     62   ComputeAndCompareR0<float>(&builder, 0.0, {}, error_spec_);
     63 }
     64 
     65 XLA_TEST_F(DotOperationTest, TrivialMatrixVectorDotF32) {
     66   ComputationBuilder builder(client_, TestName());
     67   auto lhs = builder.ConstantR2<float>({{3.0, 4.0}});
     68   auto rhs = builder.ConstantR1<float>({3.0, 4.0});
     69   auto result = builder.Dot(lhs, rhs);
     70 
     71   ComputeAndCompareR1<float>(&builder, {25.0}, {}, error_spec_);
     72 }
     73 
     74 template <typename Element>
     75 void DotOperationTest::TestOneElementVectorDot() {
     76   ComputationBuilder builder(client_, TestName());
     77   auto lhs = builder.ConstantR1<Element>({2.0});
     78   auto rhs = builder.ConstantR1<Element>({3.0});
     79   auto result = builder.Dot(lhs, rhs);
     80 
     81   ComputeAndCompareR0<Element>(&builder, 6.0, {}, error_spec_);
     82 }
     83 
     84 XLA_TEST_F(DotOperationTest, OneElementVectorDotF32) {
     85   TestOneElementVectorDot<float>();
     86 }
     87 
     88 XLA_TEST_F(DotOperationTest, OneElementVectorDotF64) {
     89   TestOneElementVectorDot<double>();
     90 }
     91 
     92 template <typename Element>
     93 void DotOperationTest::TestVectorDot() {
     94   ComputationBuilder builder(client_, TestName());
     95   auto lhs = builder.ConstantR1<Element>({1.0, 2.5, 42.0});
     96   auto rhs = builder.ConstantR1<Element>({11.0, -1.0, 0.5});
     97   auto result = builder.Dot(lhs, rhs);
     98 
     99   ComputeAndCompareR0<Element>(&builder, 29.5, {}, error_spec_);
    100 }
    101 
    102 XLA_TEST_F(DotOperationTest, VectorDotF32) { TestVectorDot<float>(); }
    103 
    104 XLA_TEST_F(DotOperationTest, VectorDotF64) { TestVectorDot<double>(); }
    105 
    106 namespace {
    107 
    108 std::vector<int64> MinorToMajorForIsRowMajor(bool row_major) {
    109   return {row_major ? 1 : 0, row_major ? 0 : 1};
    110 }
    111 
    112 }  // namespace
    113 
    114 XLA_TEST_F(DotOperationTest, Dot_0x2_2x0) {
    115   ComputationBuilder builder(client_, TestName());
    116   auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
    117   auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0));
    118   auto result = builder.Dot(lhs, rhs);
    119 
    120   ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {}, error_spec_);
    121 }
    122 
    123 XLA_TEST_F(DotOperationTest, Dot_0x2_2x3) {
    124   ComputationBuilder builder(client_, TestName());
    125   auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
    126   auto rhs = builder.ConstantR2<float>({{7.0, 8.0, 9.0}, {42.0, 77.0, 101.0}});
    127   auto result = builder.Dot(lhs, rhs);
    128 
    129   ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 3), {}, error_spec_);
    130 }
    131 
    132 XLA_TEST_F(DotOperationTest, Dot_3x2_2x0) {
    133   ComputationBuilder builder(client_, TestName());
    134   auto lhs =
    135       builder.ConstantR2<float>({{7.0, 8.0}, {9.0, 42.0}, {77.0, 101.0}});
    136   auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0));
    137   auto result = builder.Dot(lhs, rhs);
    138 
    139   ComputeAndCompareR2<float>(&builder, Array2D<float>(3, 0), {}, error_spec_);
    140 }
    141 
    142 XLA_TEST_F(DotOperationTest, Dot_2x0_0x2) {
    143   ComputationBuilder builder(client_, TestName());
    144   auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0));
    145   auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
    146   auto result = builder.Dot(lhs, rhs);
    147 
    148   ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 2, 0.0f), {},
    149                              error_spec_);
    150 }
    151 
    152 XLA_TEST_F(DotOperationTest, FusedDot) {
    153   ComputationBuilder builder(client_, TestName());
    154   auto param0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 4}), "arg0");
    155   auto param1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4, 1}), "arg1");
    156   auto exp0 = builder.Exp(param0);
    157   auto result = builder.Dot(exp0, param1);
    158 
    159   auto lhs_handle = client_
    160                         ->TransferToServer(*Literal::CreateR2<float>(
    161                             {{1.0, 2.0, 3.0, 4.0}, {-1.0, -2.0, -3.0, -4.0}}))
    162                         .ConsumeValueOrDie();
    163   auto rhs_handle = client_
    164                         ->TransferToServer(*Literal::CreateR2<float>(
    165                             {{1.0}, {2.0}, {3.0}, {4.0}}))
    166                         .ConsumeValueOrDie();
    167 
    168   ComputeAndCompareR2<float>(
    169       &builder, Array2D<float>({{296.14560492846033}, {0.8611737683031964}}),
    170       {lhs_handle.get(), rhs_handle.get()}, error_spec_);
    171 }
    172 
    173 template <typename Element>
    174 void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major,
    175                                            bool rhs_row_major) {
    176   auto lhs_handle =
    177       client_
    178           ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
    179               {{1.0, 2.0}, {3.0, -4.0}},
    180               LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
    181           .ConsumeValueOrDie();
    182   auto rhs_handle =
    183       client_
    184           ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
    185               {{1.0, 6.0}, {7.0, -4.0}},
    186               LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
    187           .ConsumeValueOrDie();
    188 
    189   ComputationBuilder builder(client_, TestName());
    190   auto prim_type = primitive_util::NativeToPrimitiveType<Element>();
    191   auto result = builder.Dot(
    192       builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"),
    193       builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs"));
    194 
    195   Array2D<Element> expected({{15.0, -2.0}, {-25.0, 34.0}});
    196   ComputeAndCompareR2<Element>(
    197       &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
    198 }
    199 
    200 struct DotTestParam {
    201   int m;
    202   int k;
    203   int n;
    204   bool dot_lhs_row_major;
    205   bool dot_rhs_row_major;
    206   bool has_addend;
    207   bool addend_row_major;
    208 };
    209 
    210 string PrintDotTestParam(
    211     const ::testing::TestParamInfo<DotTestParam>& test_param) {
    212   const DotTestParam& param = test_param.param;
    213   if (param.has_addend) {
    214     return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n,
    215                                        "_MajorToMinor",
    216                                        param.dot_lhs_row_major ? "T" : "F",
    217                                        param.dot_rhs_row_major ? "T" : "F",
    218                                        param.addend_row_major ? "T" : "F");
    219   } else {
    220     return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n,
    221                                        "_MajorToMinor",
    222                                        param.dot_lhs_row_major ? "T" : "F",
    223                                        param.dot_rhs_row_major ? "T" : "F");
    224   }
    225 }
    226 
    227 class ParametricDotTest : public DotOperationTest,
    228                           public ::testing::WithParamInterface<DotTestParam> {};
    229 
    230 XLA_TEST_P(ParametricDotTest, TestF32) {
    231   DotTestParam param = GetParam();
    232 
    233   std::unique_ptr<Array2D<float>> dot_lhs_data =
    234       MakeLinspaceArray2D(0.0, 1.0, param.m, param.k);
    235   std::unique_ptr<Literal> dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout(
    236       *dot_lhs_data, LayoutUtil::MakeLayout(
    237                          MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
    238   std::unique_ptr<GlobalData> dot_lhs_handle =
    239       client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie();
    240 
    241   std::unique_ptr<Array2D<float>> dot_rhs_data =
    242       MakeLinspaceArray2D(0.0, 1.0, param.k, param.n);
    243   std::unique_ptr<Literal> dot_rhs_lit = Literal::CreateR2FromArray2DWithLayout(
    244       *dot_rhs_data, LayoutUtil::MakeLayout(
    245                          MinorToMajorForIsRowMajor(param.dot_rhs_row_major)));
    246   std::unique_ptr<GlobalData> dot_rhs_handle =
    247       client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie();
    248 
    249   std::unique_ptr<Array2D<float>> addend_data;
    250   std::unique_ptr<Literal> addend_lit;
    251   std::unique_ptr<GlobalData> addend_handle;
    252 
    253   if (param.has_addend) {
    254     addend_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.n);
    255     addend_lit = Literal::CreateR2FromArray2DWithLayout(
    256         *addend_data, LayoutUtil::MakeLayout(
    257                           MinorToMajorForIsRowMajor(param.addend_row_major)));
    258     addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie();
    259   }
    260 
    261   ComputationBuilder builder(client_, TestName());
    262   auto prim_type = primitive_util::NativeToPrimitiveType<float>();
    263   auto result = builder.Dot(
    264       builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {param.m, param.k}),
    265                         "dot_lhs"),
    266       builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {param.k, param.n}),
    267                         "dot_rhs"));
    268 
    269   if (param.has_addend) {
    270     result = builder.Add(
    271         result,
    272         builder.Parameter(
    273             2, ShapeUtil::MakeShape(prim_type, {param.m, param.n}), "addend"));
    274   }
    275 
    276   std::unique_ptr<Array2D<float>> expected;
    277   if (param.has_addend) {
    278     expected = ReferenceUtil::ApplyElementwise2D(
    279         std::plus<float>(),
    280         *ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data),
    281         *addend_data);
    282   } else {
    283     expected = ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data);
    284   }
    285 
    286   std::vector<GlobalData*> args = {dot_lhs_handle.get(), dot_rhs_handle.get()};
    287   if (param.has_addend) {
    288     args.push_back(addend_handle.get());
    289   }
    290 
    291   ComputeAndCompareR2<float>(&builder, *expected, args, ErrorSpec(0.3, 3e-3));
    292 }
    293 
    294 std::vector<DotTestParam> CreateDotTestParameters() {
    295   std::vector<DotTestParam> params;
    296 
    297   auto add_matrix_matrix_dot_test = [&](int m, int k, int n) {
    298     for (bool lhs_row_major : {true, false}) {
    299       for (bool rhs_row_major : {true, false}) {
    300         params.push_back({/*m=*/m, /*k=*/k, /*n=*/n,
    301                           /*dot_lhs_row_major=*/lhs_row_major,
    302                           /*dot_rhs_row_major=*/rhs_row_major,
    303                           /*has_addend=*/false, /*addend_row_major=*/true});
    304       }
    305     }
    306   };
    307 
    308   auto add_matrix_vector_dot_test = [&](int k, int n) {
    309     for (bool has_addend : {false, true}) {
    310       params.push_back({/*m=*/1, /*k=*/k, /*n=*/n,
    311                         /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true,
    312                         /*has_addend=*/has_addend, /*addend_row_major=*/true});
    313       if (n != 1) {
    314         params.push_back(
    315             {/*m=*/n, /*k=*/k, /*n=*/1,
    316              /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true,
    317              /*has_addend=*/has_addend, /*addend_row_major=*/true});
    318       }
    319     }
    320   };
    321 
    322   add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7);
    323   add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520);
    324   add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520);
    325 
    326   add_matrix_vector_dot_test(/*k=*/8, /*n=*/8);
    327   add_matrix_vector_dot_test(/*k=*/130, /*n=*/8);
    328   add_matrix_vector_dot_test(/*k=*/8, /*n=*/130);
    329   add_matrix_vector_dot_test(/*k=*/290, /*n=*/130);
    330   add_matrix_vector_dot_test(/*k=*/1, /*n=*/1);
    331   add_matrix_vector_dot_test(/*k=*/1, /*n=*/16);
    332   add_matrix_vector_dot_test(/*k=*/3, /*n=*/16);
    333   add_matrix_vector_dot_test(/*k=*/3, /*n=*/3);
    334   add_matrix_vector_dot_test(/*k=*/29, /*n=*/29);
    335   add_matrix_vector_dot_test(/*k=*/8, /*n=*/2);
    336   add_matrix_vector_dot_test(/*k=*/2, /*n=*/8);
    337   add_matrix_vector_dot_test(/*k=*/259, /*n=*/258);
    338 
    339   return params;
    340 }
    341 
    342 INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest,
    343                         ::testing::ValuesIn(CreateDotTestParameters()),
    344                         PrintDotTestParam);
    345 
    346 XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) {
    347   TestSquareMatrixDot<float>(false, false);
    348 }
    349 
    350 XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) {
    351   TestSquareMatrixDot<float>(false, true);
    352 }
    353 
    354 XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) {
    355   TestSquareMatrixDot<float>(true, false);
    356 }
    357 
    358 XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) {
    359   TestSquareMatrixDot<float>(true, true);
    360 }
    361 
    362 XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFF) {
    363   TestSquareMatrixDot<complex64>(false, false);
    364 }
    365 
    366 XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFT) {
    367   TestSquareMatrixDot<complex64>(false, true);
    368 }
    369 
    370 XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTF) {
    371   TestSquareMatrixDot<complex64>(true, false);
    372 }
    373 
    374 XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTT) {
    375   TestSquareMatrixDot<complex64>(true, true);
    376 }
    377 
    378 XLA_TEST_F(DotOperationTest, SquareMatrixDotF64) {
    379   TestSquareMatrixDot<double>();
    380 }
    381 
    382 template <typename Element>
    383 void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major,
    384                                               bool rhs_row_major) {
    385   auto lhs_handle =
    386       client_
    387           ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
    388               {{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}},
    389               LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
    390           .ConsumeValueOrDie();
    391   auto rhs_handle =
    392       client_
    393           ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
    394               {{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}},
    395               LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
    396           .ConsumeValueOrDie();
    397 
    398   ComputationBuilder builder(client_, TestName());
    399   auto prim_type = primitive_util::NativeToPrimitiveType<Element>();
    400   auto result = builder.Dot(
    401       builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"),
    402       builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs"));
    403 
    404   Array2D<Element> expected({{26.0, 0.0}, {-12.0, 10.0}});
    405 
    406   ComputeAndCompareR2<Element>(
    407       &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
    408 }
    409 
    410 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFF) {
    411   TestNonsquareMatrixDot<float>(false, false);
    412 }
    413 
    414 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFT) {
    415   TestNonsquareMatrixDot<float>(false, true);
    416 }
    417 
    418 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) {
    419   TestNonsquareMatrixDot<float>(true, false);
    420 }
    421 
    422 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) {
    423   TestNonsquareMatrixDot<float>(true, true);
    424 }
    425 
    426 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) {
    427   TestNonsquareMatrixDot<double>();
    428 }
    429 
    430 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFF) {
    431   TestNonsquareMatrixDot<complex64>(false, false);
    432 }
    433 
    434 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFT) {
    435   TestNonsquareMatrixDot<complex64>(false, true);
    436 }
    437 
    438 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTF) {
    439   TestNonsquareMatrixDot<complex64>(true, false);
    440 }
    441 
    442 XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTT) {
    443   TestNonsquareMatrixDot<complex64>(true, true);
    444 }
    445 
    446 XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
    447   auto lhs_handle =
    448       client_
    449           ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
    450               {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
    451           .ConsumeValueOrDie();
    452   auto rhs_handle =
    453       client_
    454           ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
    455               {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
    456               LayoutUtil::MakeLayout({1, 0})))
    457           .ConsumeValueOrDie();
    458 
    459   ComputationBuilder builder(client_, TestName());
    460   auto prim_type = primitive_util::NativeToPrimitiveType<complex64>();
    461   auto result = builder.Dot(
    462       builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"),
    463       builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs"));
    464 
    465   Array2D<complex64> expected({{30.0, -2.0}});
    466 
    467   ComputeAndCompareR2<complex64>(
    468       &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
    469 }
    470 
    471 XLA_TEST_F(DotOperationTest, ConcurrentMatMul) {
    472   ComputationBuilder builder(client_, TestName());
    473   auto matrix1 = builder.ConstantR2<float>({{1.0, 2.0}, {3.0, 4.0}});
    474   auto matrix2 = builder.ConstantR2<float>({{5.0, 6.0}, {7.0, 8.0}});
    475   auto matrix12 = builder.Dot(matrix1, matrix2);
    476   auto matrix21 = builder.Dot(matrix2, matrix1);
    477   builder.Add(matrix12, matrix21);
    478 
    479   Array2D<float> expected({{42.0, 56.0}, {74.0, 96.0}});
    480   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
    481 }
    482 
    483 // Regression test for b/32055648. The root of the graph is a kFusion of 4
    484 // bitcasts. Although bitcasts don't map to thunks, the root should still be
    485 // sync-dependent on bitcasts' operands.
    486 XLA_TEST_F(DotOperationTest, BatchMatMul) {
    487   ComputationBuilder builder(client_, TestName());
    488   auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "x");
    489   auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "y");
    490 
    491   auto x_flat = builder.Reshape(x, {0, 1, 2, 3}, {4, 2, 2});
    492   auto y_flat = builder.Reshape(y, {0, 1, 2, 3}, {4, 2, 2});
    493 
    494   // Slice batches into individual matrices and multiply them.
    495   std::vector<xla::ComputationDataHandle> out_slices;
    496   for (int i = 0; i < 4; ++i) {
    497     // Slice off individual matrices and reshape to 2D tensors.
    498     auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
    499     x_slice = builder.Reshape(x_slice, {0, 1, 2}, {2, 2});
    500     auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
    501     y_slice = builder.Reshape(y_slice, {0, 1, 2}, {2, 2});
    502 
    503     auto out = builder.Dot(x_slice, y_slice);
    504     out = builder.Reshape(out, {0, 1}, {1, 2, 2});
    505     out_slices.push_back(out);
    506   }
    507   auto out_flat = builder.ConcatInDim(out_slices, 0);
    508   builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
    509 
    510   auto x_data = client_
    511                     ->TransferToServer(*Literal::CreateR4<float>(
    512                         {{{{1000, 100}, {10, 1}}, {{2000, 200}, {20, 2}}},
    513                          {{{3000, 300}, {30, 3}}, {{4000, 400}, {40, 4}}}}))
    514                     .ConsumeValueOrDie();
    515   auto y_data = client_
    516                     ->TransferToServer(*Literal::CreateR4<float>(
    517                         {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}},
    518                          {{{11, 22}, {33, 44}}, {{55, 66}, {77, 88}}}}))
    519                     .ConsumeValueOrDie();
    520 
    521   ComputeAndCompareR4<float>(
    522       &builder,
    523       /*expected=*/
    524       {{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}},
    525        {{{42900, 79200}, {429, 792}}, {{250800, 299200}, {2508, 2992}}}},
    526       {x_data.get(), y_data.get()}, error_spec_);
    527 }
    528 
    529 XLA_TEST_F(DotOperationTest, GeneralMatMul) {
    530   ComputationBuilder builder(client_, TestName());
    531   auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2}), "x");
    532   auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2}), "y");
    533 
    534   DotDimensionNumbers dnums;
    535   dnums.add_lhs_contracting_dimensions(2);
    536   dnums.add_rhs_contracting_dimensions(1);
    537   dnums.add_lhs_batch_dimensions(0);
    538   dnums.add_rhs_batch_dimensions(0);
    539 
    540   auto out = builder.DotGeneral(x, y, dnums);
    541 
    542   auto x_data = client_
    543                     ->TransferToServer(*Literal::CreateR3<float>(
    544                         {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}}))
    545                     .ConsumeValueOrDie();
    546 
    547   auto y_data = client_
    548                     ->TransferToServer(*Literal::CreateR3<float>(
    549                         {{{1.0, 0.0}, {0.0, 1.0}}, {{1.0, 0.0}, {0.0, 1.0}}}))
    550                     .ConsumeValueOrDie();
    551 
    552   ComputeAndCompareR3<float>(
    553       &builder,
    554       /*expected=*/
    555       {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}},
    556       {x_data.get(), y_data.get()}, error_spec_);
    557 }
    558 
    559 TEST_F(DotOperationTest, TransposeFolding) {
    560   for (bool transpose_lhs : {false, true}) {
    561     for (bool transpose_rhs : {false, true}) {
    562       for (bool row_major : {false, true}) {
    563         std::unique_ptr<Array2D<float>> lhs(
    564             new Array2D<float>({{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}}));
    565         std::unique_ptr<Array2D<float>> rhs(
    566             new Array2D<float>({{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}}));
    567 
    568         if (transpose_lhs) {
    569           lhs = ReferenceUtil::TransposeArray2D(*lhs);
    570         }
    571         if (transpose_rhs) {
    572           rhs = ReferenceUtil::TransposeArray2D(*rhs);
    573         }
    574         auto lhs_handle =
    575             client_
    576                 ->TransferToServer(
    577                     *Literal::CreateR2FromArray2DWithLayout<float>(
    578                         *lhs, LayoutUtil::MakeLayout(
    579                                   MinorToMajorForIsRowMajor(row_major))))
    580                 .ConsumeValueOrDie();
    581         auto rhs_handle =
    582             client_
    583                 ->TransferToServer(
    584                     *Literal::CreateR2FromArray2DWithLayout<float>(
    585                         *rhs, LayoutUtil::MakeLayout(
    586                                   MinorToMajorForIsRowMajor(row_major))))
    587                 .ConsumeValueOrDie();
    588 
    589         ComputationBuilder builder(client_, TestName());
    590         auto prim_type = primitive_util::NativeToPrimitiveType<float>();
    591         auto lhs_arg = builder.Parameter(
    592             0, ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}),
    593             "lhs");
    594         auto rhs_arg = builder.Parameter(
    595             1, ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}),
    596             "rhs");
    597         if (transpose_lhs) {
    598           lhs_arg = builder.Transpose(lhs_arg, {1, 0});
    599         }
    600         if (transpose_rhs) {
    601           rhs_arg = builder.Transpose(rhs_arg, {1, 0});
    602         }
    603         auto result = builder.Dot(lhs_arg, rhs_arg);
    604 
    605         Array2D<float> expected({{26.0, 0.0}, {-12.0, 10.0}});
    606         VLOG(1) << "TestTransposeFolding " << transpose_lhs << " "
    607                 << transpose_rhs << " " << row_major;
    608         ComputeAndCompareR2<float>(&builder, expected,
    609                                    {lhs_handle.get(), rhs_handle.get()},
    610                                    error_spec_);
    611       }
    612     }
    613   }
    614 }
    615 
    616 TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstLHS) {
    617   auto prim_type = primitive_util::NativeToPrimitiveType<float>();
    618 
    619   std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
    620       {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
    621 
    622   ComputationBuilder builder(client_, TestName());
    623   auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
    624   auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}),
    625                                      "rhs_arg_0");
    626   auto rhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}),
    627                                      "rhs_arg_1");
    628   auto rhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {1, 2}),
    629                                      "rhs_arg_2");
    630   auto result = builder.Dot(
    631       lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0));
    632 
    633   std::unique_ptr<Array2D<float>> arg_0_value_array(
    634       new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}}));
    635   std::unique_ptr<Array2D<float>> arg_1_value_array(
    636       new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}));
    637   std::unique_ptr<Array2D<float>> arg_2_value_array(
    638       new Array2D<float>({{1.0, 2.0}}));
    639 
    640   TF_ASSERT_OK_AND_ASSIGN(
    641       auto arg_0_value,
    642       client_->TransferToServer(
    643           *Literal::CreateR2FromArray2D<float>(*arg_0_value_array)));
    644   TF_ASSERT_OK_AND_ASSIGN(
    645       auto arg_1_value,
    646       client_->TransferToServer(
    647           *Literal::CreateR2FromArray2D<float>(*arg_1_value_array)));
    648   TF_ASSERT_OK_AND_ASSIGN(
    649       auto arg_2_value,
    650       client_->TransferToServer(
    651           *Literal::CreateR2FromArray2D<float>(*arg_2_value_array)));
    652 
    653   Array2D<float> expected({{53.0, 74.0}, {45.0, 66.0}});
    654   ComputeAndCompareR2<float>(
    655       &builder, expected,
    656       {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_);
    657 }
    658 
    659 TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstRHS) {
    660   auto prim_type = primitive_util::NativeToPrimitiveType<float>();
    661 
    662   std::unique_ptr<Array2D<float>> constant_rhs_array(
    663       new Array2D<float>({{1.0, 2.0},
    664                           {3.0, 4.0},
    665                           {5.0, 6.0},
    666                           {6.0, 5.0},
    667                           {4.0, 3.0},
    668                           {2.0, 1.0}}));
    669 
    670   ComputationBuilder builder(client_, TestName());
    671   auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
    672   auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}),
    673                                      "lhs_arg_0");
    674   auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 3}),
    675                                      "lhs_arg_1");
    676   auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {2, 1}),
    677                                      "lhs_arg_2");
    678   auto result = builder.Dot(
    679       builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant);
    680 
    681   std::unique_ptr<Array2D<float>> arg_0_value_array(
    682       new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}}));
    683   std::unique_ptr<Array2D<float>> arg_1_value_array(
    684       new Array2D<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
    685   std::unique_ptr<Array2D<float>> arg_2_value_array(
    686       new Array2D<float>({{1.0}, {2.0}}));
    687 
    688   TF_ASSERT_OK_AND_ASSIGN(
    689       auto arg_0_value,
    690       client_->TransferToServer(
    691           *Literal::CreateR2FromArray2D<float>(*arg_0_value_array)));
    692   TF_ASSERT_OK_AND_ASSIGN(
    693       auto arg_1_value,
    694       client_->TransferToServer(
    695           *Literal::CreateR2FromArray2D<float>(*arg_1_value_array)));
    696   TF_ASSERT_OK_AND_ASSIGN(
    697       auto arg_2_value,
    698       client_->TransferToServer(
    699           *Literal::CreateR2FromArray2D<float>(*arg_2_value_array)));
    700 
    701   Array2D<float> expected({{38.0, 36.0}, {93.0, 91.0}});
    702   ComputeAndCompareR2<float>(
    703       &builder, expected,
    704       {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_);
    705 }
    706 }  // namespace
    707 }  // namespace xla
    708