Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <memory>
     17 #include <numeric>
     18 #include <vector>
     19 
     20 #include "tensorflow/compiler/xla/array2d.h"
     21 #include "tensorflow/compiler/xla/array4d.h"
     22 #include "tensorflow/compiler/xla/client/computation_builder.h"
     23 #include "tensorflow/compiler/xla/client/local_client.h"
     24 #include "tensorflow/compiler/xla/literal_util.h"
     25 #include "tensorflow/compiler/xla/statusor.h"
     26 #include "tensorflow/compiler/xla/test.h"
     27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     28 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     29 #include "tensorflow/compiler/xla/tests/test_macros.h"
     30 
     31 namespace xla {
     32 namespace {
     33 
     34 class BroadcastSimpleTest : public ClientLibraryTestBase {
     35  public:
     36   ComputationDataHandle BuildBinOp(HloOpcode op,
     37                                    const ComputationDataHandle& lhs,
     38                                    const ComputationDataHandle& rhs,
     39                                    ComputationBuilder* builder) {
     40     switch (op) {
     41       case HloOpcode::kMinimum: {
     42         return builder->Min(lhs, rhs);
     43       }
     44       case HloOpcode::kMaximum: {
     45         return builder->Max(lhs, rhs);
     46       }
     47       case HloOpcode::kMultiply: {
     48         return builder->Mul(lhs, rhs);
     49       }
     50       default: {
     51         // Default to Add
     52         return builder->Add(lhs, rhs);
     53       }
     54     }
     55   }
     56 
     57   std::unique_ptr<GlobalData> MakeR3Data(
     58       tensorflow::gtl::ArraySlice<int64> bounds,
     59       tensorflow::gtl::ArraySlice<int64> minor_to_major, Shape* r3_shape,
     60       Array3D<float>* r3_array, float start, float end, int seed) {
     61     *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
     62     r3_array->FillRandom(start, end, seed);
     63     auto r3_data = Literal::CreateR3FromArray3D(*r3_array)->Relayout(
     64         LayoutUtil::MakeLayout(minor_to_major));
     65     std::unique_ptr<GlobalData> r3_global_data =
     66         client_->TransferToServer(*r3_data).ConsumeValueOrDie();
     67     return r3_global_data;
     68   }
     69 
     70   std::unique_ptr<GlobalData> MakeR2Data(
     71       tensorflow::gtl::ArraySlice<int64> bounds,
     72       tensorflow::gtl::ArraySlice<int64> minor_to_major, Shape* r2_shape,
     73       Array2D<float>* r2_array, float start, float end, int seed) {
     74     *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
     75     r2_array->FillRandom(start, end, seed);
     76     auto r2_data = Literal::CreateR2FromArray2D(*r2_array)->Relayout(
     77         LayoutUtil::MakeLayout(minor_to_major));
     78     std::unique_ptr<GlobalData> r2_global_data =
     79         client_->TransferToServer(*r2_data).ConsumeValueOrDie();
     80     return r2_global_data;
     81   }
     82 
     83   float ApplyOpToFloats(HloOpcode op, float lhs, float rhs) {
     84     switch (op) {
     85       case HloOpcode::kMinimum: {
     86         return std::min(lhs, rhs);
     87       }
     88       case HloOpcode::kMaximum: {
     89         return std::max(lhs, rhs);
     90       }
     91       case HloOpcode::kMultiply: {
     92         return lhs * rhs;
     93       }
     94       case HloOpcode::kAdd: {
     95         return lhs + rhs;
     96       }
     97       default: {
     98         // Default to Add
     99         LOG(FATAL);
    100       }
    101     }
    102   }
    103 };
    104 
    105 using ::testing::HasSubstr;
    106 
    107 XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) {
    108   ComputationBuilder b(client_, TestName());
    109   b.Broadcast(b.ConstantR0<float>(1.5), {});
    110   ComputeAndCompareR0<float>(&b, 1.5, {}, ErrorSpec(0.0001));
    111 }
    112 
    113 XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) {
    114   ComputationBuilder b(client_, TestName());
    115   b.Broadcast(b.ConstantR0<float>(2.25), {2, 3});
    116   Array2D<float> expected(2, 3, 2.25);
    117   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
    118 }
    119 
    120 XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) {
    121   ComputationBuilder b(client_, TestName());
    122   ComputationDataHandle src;
    123   std::unique_ptr<GlobalData> param_data =
    124       CreateR0Parameter<float>(2.25f, /*parameter_number=*/0, /*name=*/"src",
    125                                /*builder=*/&b, /*data_handle=*/&src);
    126 
    127   b.Broadcast(src, {2, 3});
    128   Array2D<float> expected(2, 3, 2.25);
    129   ComputeAndCompareR2<float>(&b, expected, {param_data.get()},
    130                              ErrorSpec(0.0001));
    131 }
    132 
    133 XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) {
    134   ComputationBuilder b(client_, TestName());
    135   b.Broadcast(b.ConstantR0<float>(2.25), {2, 0});
    136   Array2D<float> expected(2, 0);
    137   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
    138 }
    139 
    140 XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) {
    141   ComputationBuilder b(client_, TestName());
    142   b.Broadcast(b.ConstantR0<float>(2.25), {0, 2});
    143   Array2D<float> expected(0, 2);
    144   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
    145 }
    146 
    147 XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) {
    148   ComputationBuilder b(client_, TestName());
    149   b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {2});
    150 
    151   Array2D<float> expected(2, 3);
    152   expected(0, 0) = 1;
    153   expected(0, 1) = 2;
    154   expected(0, 2) = 3;
    155   expected(1, 0) = 1;
    156   expected(1, 1) = 2;
    157   expected(1, 2) = 3;
    158   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
    159 }
    160 
    161 // Tests implicit broadcasting of PREDs.
    162 XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) {
    163   ComputationBuilder b(client_, TestName());
    164 
    165   Array2D<bool> x_vals(2, 1);
    166   x_vals(0, 0) = true;
    167   x_vals(1, 0) = false;
    168   Array3D<bool> y_vals(2, 2, 1);
    169   y_vals(0, 0, 0) = false;
    170   y_vals(0, 1, 0) = false;
    171   y_vals(1, 0, 0) = true;
    172   y_vals(1, 1, 0) = true;
    173 
    174   ComputationDataHandle x, y;
    175   auto x_data = CreateR2Parameter<bool>(x_vals, 0, "x", &b, &x);
    176   auto y_data = CreateR3Parameter<bool>(y_vals, 1, "y", &b, &y);
    177   b.And(x, y, /*broadcast_dimensions=*/{1, 2});
    178 
    179   Array3D<bool> expected(2, 2, 1);
    180   expected(0, 0, 0) = false;
    181   expected(0, 1, 0) = false;
    182   expected(1, 0, 0) = true;
    183   expected(1, 1, 0) = false;
    184 
    185   ComputeAndCompareR3<bool>(&b, expected, {x_data.get(), y_data.get()});
    186 }
    187 
    188 XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) {
    189   ComputationBuilder b(client_, TestName());
    190   b.Broadcast(b.ConstantR1<float>({}), {2});
    191 
    192   Array2D<float> expected(2, 0);
    193   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
    194 }
    195 
    196 XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) {
    197   ComputationBuilder b(client_, TestName());
    198   b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {0});
    199 
    200   Array2D<float> expected(0, 3);
    201   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
    202 }
    203 
    204 XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
    205   // Verify that binary op and degenerate dimension broadcast work together in
    206   // the same operation.
    207   //
    208   // The lhs shape [1, 2] is first broadcast up to [2, 1, 2] using in-dimension
    209   // broadcasting (broadcast_dimensions {1, 2}), then is added to the rhs shape
    210   // [2, 3, 1]. Degenerate dimension broadcasting then broadcasts the size one
    211   // dimensions.
    212   ComputationBuilder b(client_, TestName());
    213 
    214   b.Add(b.ConstantR2<float>({{1.0, 5.0}}),
    215         b.ConstantLiteral(*Literal::CreateR3<float>(
    216             {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
    217         /*broadcast_dimensions=*/{1, 2});
    218 
    219   auto expected =
    220       Literal::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
    221                                 {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
    222 
    223   ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
    224 }
    225 
    226 struct R3ImplicitBroadcastSpec {
    227   std::array<int64, 3> output_bounds;
    228   std::array<int64, 3> minor2major_layout;
    229   std::array<int64, 3> input_bounds;
    230   HloOpcode op;
    231 } kR3ImplicitBroadcastTestCases[] = {
    232     {{{1, 1, 1}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd},
    233     {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 5}}, HloOpcode::kMaximum},
    234     {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 1}}, HloOpcode::kMinimum},
    235     {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 1}}, HloOpcode::kMultiply},
    236     {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd},
    237     {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 5}}, HloOpcode::kAdd},
    238     {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 4, 1}}, HloOpcode::kAdd},
    239     {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 5}}, HloOpcode::kAdd},
    240     {{{3, 199, 5}}, {{2, 1, 0}}, {{1, 199, 1}}, HloOpcode::kMinimum},
    241     {{{3, 4, 199}}, {{2, 1, 0}}, {{1, 1, 199}}, HloOpcode::kAdd},
    242 };
    243 
    244 class BroadcastR3ImplicitTest
    245     : public BroadcastSimpleTest,
    246       public ::testing::WithParamInterface<R3ImplicitBroadcastSpec> {};
    247 
    248 XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
    249   const R3ImplicitBroadcastSpec& spec = GetParam();
    250   ComputationBuilder builder(client_, TestName());
    251 
    252   Shape r3_shape, r3_implicit_shape;
    253   Array3D<float> r3_array(spec.output_bounds[0], spec.output_bounds[1],
    254                           spec.output_bounds[2]);
    255   Array3D<float> r3_implicit_array(spec.input_bounds[0], spec.input_bounds[1],
    256                                    spec.input_bounds[2]);
    257 
    258   std::unique_ptr<GlobalData> r3_global_data =
    259       MakeR3Data(spec.output_bounds, spec.minor2major_layout, &r3_shape,
    260                  &r3_array, 1.0, 2.5, 56789);
    261   std::unique_ptr<GlobalData> r3_implicit_global_data =
    262       MakeR3Data(spec.input_bounds, spec.minor2major_layout, &r3_implicit_shape,
    263                  &r3_implicit_array, 1.0, 0.2, 56789);
    264 
    265   auto r3_implicit_parameter = builder.Parameter(0, r3_implicit_shape, "input");
    266   auto r3_parameter = builder.Parameter(1, r3_shape, "input");
    267   ComputationDataHandle op =
    268       BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder);
    269 
    270   Array3D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1],
    271                                 spec.output_bounds[2]);
    272   auto Each = ([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
    273     float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0],
    274                                           indices[1] % spec.input_bounds[1],
    275                                           indices[2] % spec.input_bounds[2]);
    276     float r3 = r3_array(indices[0], indices[1], indices[2]);
    277     *value = ApplyOpToFloats(spec.op, r3_implicit, r3);
    278   });
    279 
    280   int n1 = expected_array.n1();
    281   int n2 = expected_array.n2();
    282   int n3 = expected_array.n3();
    283   for (int64 i = 0; i < n1; i++) {
    284     for (int64 j = 0; j < n2; j++) {
    285       for (int64 k = 0; k < n3; k++) {
    286         Each({i, j, k}, &expected_array(i, j, k));
    287       }
    288     }
    289   }
    290   auto expected = Literal::CreateR3FromArray3D(expected_array);
    291   ComputeAndCompareLiteral(
    292       &builder, *expected,
    293       {r3_implicit_global_data.get(), r3_global_data.get()},
    294       ErrorSpec(1e-7, 1e-7));
    295 }
    296 
    297 INSTANTIATE_TEST_CASE_P(BroadcastR3ImplicitTestInstances,
    298                         BroadcastR3ImplicitTest,
    299                         ::testing::ValuesIn(kR3ImplicitBroadcastTestCases));
    300 
    301 // r1 and r3's dim0 matches, and r1's dim1 and dim2 have size 1:
    302 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
    303   ComputationBuilder b(client_, TestName());
    304   ComputationDataHandle r1h;
    305   ComputationDataHandle r3h;
    306 
    307   Array3D<float> r1d = {{{1}}, {{2}}};
    308   Array3D<float> r3d = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}};
    309   auto r1 = CreateR3Parameter(r1d, 1, "r1", &b, &r1h);
    310   auto r3 = CreateR3Parameter(r3d, 0, "r3", &b, &r3h);
    311 
    312   b.Add(r3h, r1h);
    313 
    314   auto expected =
    315       Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
    316 
    317   ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()},
    318                            ErrorSpec(0.0001));
    319 }
    320 
    321 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
    322   ComputationBuilder b(client_, TestName());
    323   auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}}}));
    324   auto r3 = b.ConstantLiteral(
    325       *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
    326   b.Add(r3, r1);
    327 
    328   auto expected =
    329       Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
    330 
    331   ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
    332 }
    333 
    334 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
    335   ComputationBuilder b(client_, TestName());
    336   auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1}, {2}}}));
    337   auto r3 = b.ConstantLiteral(
    338       *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
    339   b.Add(r3, r1);
    340 
    341   auto expected =
    342       Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
    343 
    344   ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
    345 }
    346 
    347 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
    348   ComputationBuilder b(client_, TestName());
    349   auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}, {3, 4}}}));
    350   auto r3 = b.ConstantLiteral(
    351       *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
    352   b.Add(r3, r1);
    353 
    354   auto expected =
    355       Literal::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
    356 
    357   ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
    358 }
    359 
    360 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
    361   ComputationBuilder b(client_, TestName());
    362   auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
    363   auto r3 = b.ConstantLiteral(
    364       *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
    365   b.Add(r3, r1);
    366 
    367   auto expected =
    368       Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
    369 
    370   ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
    371 }
    372 
    373 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
    374   ComputationBuilder b(client_, TestName());
    375   auto r1 =
    376       b.ConstantLiteral(*Literal::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
    377   auto r3 = b.ConstantLiteral(
    378       *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
    379   b.Add(r3, r1);
    380 
    381   auto expected =
    382       Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
    383 
    384   ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
    385 }
    386 
    387 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
    388   ComputationBuilder b(client_, TestName());
    389   auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1}}}));
    390   auto r3 = b.ConstantLiteral(
    391       *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
    392   b.Add(r3, r1);
    393 
    394   auto expected =
    395       Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
    396 
    397   ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
    398 }
    399 
    400 struct R2ImplicitBroadcastSpec {
    401   std::array<int64, 2> output_bounds;
    402   std::array<int64, 2> minor2major_layout;
    403   std::array<int64, 2> input_bounds1;
    404   std::array<int64, 2> input_bounds2;
    405   HloOpcode op1;
    406   HloOpcode op2;
    407 } kR2ImplicitBroadcastTestCases[] = {
    408     {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd},
    409     {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{1, 3}}, HloOpcode::kAdd, HloOpcode::kAdd},
    410     {{{2, 3}},
    411      {{1, 0}},
    412      {{2, 1}},
    413      {{1, 1}},
    414      HloOpcode::kAdd,
    415      HloOpcode::kMinimum},
    416     {{{2, 3}},
    417      {{1, 0}},
    418      {{1, 3}},
    419      {{1, 1}},
    420      HloOpcode::kAdd,
    421      HloOpcode::kMinimum},
    422     {{{2, 3}},
    423      {{1, 0}},
    424      {{1, 1}},
    425      {{1, 1}},
    426      HloOpcode::kAdd,
    427      HloOpcode::kMinimum},
    428     {{{2, 3}}, {{0, 1}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd},
    429     {{{150, 150}},
    430      {{1, 0}},
    431      {{150, 1}},
    432      {{150, 1}},
    433      HloOpcode::kAdd,
    434      HloOpcode::kAdd},
    435     {{{150, 150}},
    436      {{1, 0}},
    437      {{150, 1}},
    438      {{1, 150}},
    439      HloOpcode::kAdd,
    440      HloOpcode::kAdd},
    441     {{{150, 150}},
    442      {{1, 0}},
    443      {{150, 1}},
    444      {{1, 1}},
    445      HloOpcode::kAdd,
    446      HloOpcode::kAdd},
    447     {{{50, 150}},
    448      {{1, 0}},
    449      {{50, 1}},
    450      {{50, 1}},
    451      HloOpcode::kAdd,
    452      HloOpcode::kAdd},
    453     {{{50, 150}},
    454      {{1, 0}},
    455      {{50, 1}},
    456      {{1, 150}},
    457      HloOpcode::kAdd,
    458      HloOpcode::kAdd},
    459     {{{50, 150}},
    460      {{1, 0}},
    461      {{50, 1}},
    462      {{1, 1}},
    463      HloOpcode::kAdd,
    464      HloOpcode::kAdd},
    465     {{{150, 50}},
    466      {{1, 0}},
    467      {{150, 1}},
    468      {{150, 1}},
    469      HloOpcode::kAdd,
    470      HloOpcode::kAdd},
    471     {{{150, 50}},
    472      {{1, 0}},
    473      {{150, 1}},
    474      {{1, 50}},
    475      HloOpcode::kAdd,
    476      HloOpcode::kAdd},
    477     {{{150, 50}},
    478      {{1, 0}},
    479      {{150, 1}},
    480      {{1, 1}},
    481      HloOpcode::kAdd,
    482      HloOpcode::kAdd}};
    483 
    484 class BroadcastR2ImplicitTest
    485     : public BroadcastSimpleTest,
    486       public ::testing::WithParamInterface<R2ImplicitBroadcastSpec> {};
    487 
    488 // Test r2 op1 r2_implicit_1 op2 r2_implicit_2
    489 // where R2 is a rank-2 operand, and r2_implicit_2 are two
    490 // rank-2 operands with degenerate dimensions:
    491 XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
    492   const R2ImplicitBroadcastSpec& spec = GetParam();
    493 
    494   ComputationBuilder builder(client_, TestName());
    495 
    496   // Operands with degenerate dimensions require implicit broadcasting:
    497   Shape r2_shape, r2_implicit_shape1, r2_implicit_shape2;
    498   Array2D<float> r2_array(spec.output_bounds[0], spec.output_bounds[1]);
    499   Array2D<float> r2_implicit_array1(spec.input_bounds1[0],
    500                                     spec.input_bounds1[1]);
    501   Array2D<float> r2_implicit_array2(spec.input_bounds2[0],
    502                                     spec.input_bounds2[1]);
    503 
    504   std::unique_ptr<GlobalData> r2_global_data =
    505       MakeR2Data(spec.output_bounds, spec.minor2major_layout, &r2_shape,
    506                  &r2_array, 1.0, 2.5, 56789);
    507   std::unique_ptr<GlobalData> r2_implicit_global_data1 =
    508       MakeR2Data(spec.input_bounds1, spec.minor2major_layout,
    509                  &r2_implicit_shape1, &r2_implicit_array1, 1.0, 0.2, 56789);
    510   std::unique_ptr<GlobalData> r2_implicit_global_data2 =
    511       MakeR2Data(spec.input_bounds2, spec.minor2major_layout,
    512                  &r2_implicit_shape2, &r2_implicit_array2, 0.8, 0.4, 56789);
    513 
    514   auto r2_implicit_parameter1 =
    515       builder.Parameter(0, r2_implicit_shape1, "input0");
    516   auto r2_parameter = builder.Parameter(1, r2_shape, "input1");
    517   auto r2_implicit_parameter2 =
    518       builder.Parameter(2, r2_implicit_shape2, "input2");
    519 
    520   ComputationDataHandle op1 =
    521       BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder);
    522   ComputationDataHandle op2 =
    523       BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder);
    524 
    525   Array2D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1]);
    526 
    527   expected_array.Each([&](int64 i, int64 j, float* v) {
    528     float v1 = r2_implicit_array1(i % spec.input_bounds1[0],
    529                                   j % spec.input_bounds1[1]);
    530     float v2 = r2_array(i, j);
    531     float v3 = r2_implicit_array2(i % spec.input_bounds2[0],
    532                                   j % spec.input_bounds2[1]);
    533     float tmp = ApplyOpToFloats(spec.op1, v1, v2);
    534     *v = ApplyOpToFloats(spec.op2, tmp, v3);
    535   });
    536 
    537   auto expected = Literal::CreateR2FromArray2D(expected_array);
    538   ComputeAndCompareLiteral(
    539       &builder, *expected,
    540       {r2_implicit_global_data1.get(), r2_global_data.get(),
    541        r2_implicit_global_data2.get()},
    542       ErrorSpec(1e-6, 1e-6));
    543 }
    544 
    545 INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
    546                         BroadcastR2ImplicitTest,
    547                         ::testing::ValuesIn(kR2ImplicitBroadcastTestCases));
    548 
    549 XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
    550   ComputationBuilder b(client_, TestName());
    551   auto r1 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}}));
    552   auto r2 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}, {3, 4}}));
    553   b.Add(r2, r1);
    554 
    555   auto expected = Literal::CreateR2<float>({{2, 4}, {4, 6}});
    556 
    557   ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
    558 }
    559 
    560 XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
    561   ComputationBuilder b(client_, TestName());
    562   auto r1 = b.ConstantLiteral(*Literal::CreateR2<float>({{1}, {2}}));
    563   auto r2 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}, {3, 4}}));
    564   b.Add(r2, r1);
    565 
    566   auto expected = Literal::CreateR2<float>({{2, 3}, {5, 6}});
    567 
    568   ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
    569 }
    570 
    571 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
    572   ComputationBuilder b(client_, TestName());
    573   auto r1 = b.ConstantR1<float>({10, 20});
    574   auto r3 = b.ConstantLiteral(
    575       *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
    576   b.Add(r3, r1, {0});
    577 
    578   auto expected =
    579       Literal::CreateR3<float>({{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
    580 
    581   ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
    582 }
    583 
    584 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
    585   ComputationBuilder b(client_, TestName());
    586   auto r1 = b.ConstantR1<float>({10, 20});
    587   auto r3 = b.ConstantLiteral(
    588       *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
    589   b.Add(r1, r3, {1});
    590 
    591   auto expected =
    592       Literal::CreateR3<float>({{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
    593 
    594   ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
    595 }
    596 
    597 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
    598   ComputationBuilder b(client_, TestName());
    599   auto r1 = b.ConstantR1<float>({10, 20});
    600   auto r3 = b.ConstantLiteral(
    601       *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
    602   b.Add(r1, r3, {2});
    603 
    604   auto expected =
    605       Literal::CreateR3<float>({{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
    606 
    607   ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
    608 }
    609 
    610 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
    611   ComputationBuilder b(client_, TestName());
    612   auto r1_0 = b.ConstantR1<float>({1000, 2000});
    613   auto r1_1 = b.ConstantR1<float>({100, 200});
    614   auto r1_2 = b.ConstantR1<float>({10, 20});
    615   auto r3 = b.ConstantLiteral(
    616       *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
    617   for (int i = 0; i < 3; ++i) {
    618     r3 = b.Add(r1_0, r3, {0});
    619     r3 = b.Add(r3, r1_1, {1});
    620     r3 = b.Add(r1_2, r3, {2});
    621   }
    622   r3 = b.Mul(r3, b.ConstantR0<float>(-2));
    623 
    624   auto expected = Literal::CreateR3<float>(
    625       {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
    626        {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
    627 
    628   ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
    629 }
    630 
    631 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
    632   ComputationBuilder b(client_, TestName());
    633   auto r1_0 = b.ConstantR1<float>({1000, 2000});
    634   auto r1_1 = b.ConstantR1<float>({100, 200});
    635   auto r1_2 = b.ConstantR1<float>({10, 20});
    636   auto r0 = b.ConstantR0<float>(3);
    637   auto r3 = b.Broadcast(r0, {2, 2, 2});
    638   for (int i = 0; i < 3; ++i) {
    639     r3 = b.Add(r1_0, r3, {0});
    640     r3 = b.Add(r3, r1_1, {1});
    641     r3 = b.Add(r1_2, r3, {2});
    642   }
    643   r3 = b.Mul(r3, b.ConstantR0<float>(-1));
    644 
    645   auto expected = Literal::CreateR3<float>(
    646       {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
    647        {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
    648 
    649   ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
    650 }
    651 
    652 XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
    653   // Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2])
    654   // results in a shape incompatible with the lhs [2, 3, 1].
    655   ComputationBuilder b(client_, TestName());
    656 
    657   b.Add(b.ConstantR2<float>({{1.0, 5.0}, {1.0, 5.0}}),
    658         b.ConstantLiteral(*Literal::CreateR3<float>(
    659             {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
    660         /*broadcast_dimensions=*/{1, 2});
    661 
    662   auto result_status = Execute(&b, {});
    663   EXPECT_FALSE(result_status.ok());
    664   EXPECT_THAT(result_status.status().error_message(),
    665               HasSubstr("broadcast dimension 0 mismatch"));
    666 }
    667 
    668 XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) {
    669   // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
    670   ComputationBuilder b(client_, TestName());
    671 
    672   b.Add(b.ConstantR2<float>({{1.0, 2.0}}),
    673         b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
    674 
    675   auto result_status = Execute(&b, {});
    676   EXPECT_FALSE(result_status.ok());
    677   EXPECT_THAT(result_status.status().error_message(),
    678               HasSubstr("binary op BINOP_ADD with incompatible shapes"));
    679 }
    680 
    681 XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) {
    682   // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
    683   ComputationBuilder b(client_, TestName());
    684 
    685   b.Add(b.ConstantR2<float>({{1.0, 2.0}}),
    686         b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
    687 
    688   auto result_status = Execute(&b, {});
    689   EXPECT_FALSE(result_status.ok());
    690   EXPECT_THAT(result_status.status().error_message(),
    691               HasSubstr("binary op BINOP_ADD with incompatible shapes"));
    692 }
    693 
    694 }  // namespace
    695 }  // namespace xla
    696