Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <memory>
     17 #include <vector>
     18 
     19 #include "tensorflow/compiler/xla/array2d.h"
     20 #include "tensorflow/compiler/xla/array3d.h"
     21 #include "tensorflow/compiler/xla/client/computation.h"
     22 #include "tensorflow/compiler/xla/client/computation_builder.h"
     23 #include "tensorflow/compiler/xla/client/local_client.h"
     24 #include "tensorflow/compiler/xla/reference_util.h"
     25 #include "tensorflow/compiler/xla/statusor.h"
     26 #include "tensorflow/compiler/xla/test.h"
     27 #include "tensorflow/compiler/xla/test_helpers.h"
     28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     29 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     30 #include "tensorflow/compiler/xla/tests/test_macros.h"
     31 #include "tensorflow/core/platform/test.h"
     32 
     33 namespace xla {
     34 namespace {
     35 
     36 using ConcatTest = ClientLibraryTestBase;
     37 using ::testing::HasSubstr;
     38 
     39 // Concatenate expects at least one argument.
     40 XLA_TEST_F(ConcatTest, Concat_Nothing) {
     41   ComputationBuilder builder(client_, TestName());
     42   auto concatenated = builder.ConcatInDim({}, 0);
     43   StatusOr<Computation> computation_status = builder.Build();
     44   ASSERT_FALSE(computation_status.ok());
     45   EXPECT_THAT(computation_status.status().ToString(),
     46               HasSubstr("Concatenate expects at least one argument"));
     47 }
     48 
     49 // Concatenate with one argument works.
     50 XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) {
     51   ComputationBuilder builder(client_, TestName());
     52   auto a = builder.ConstantR1<float>({42.0, 64.0});
     53   auto concatenated = builder.ConcatInDim({a}, 0);
     54 
     55   std::vector<float> expected = {42, 64};
     56   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
     57 }
     58 
     59 XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) {
     60   ComputationBuilder builder(client_, TestName());
     61   auto a = builder.ConstantR1<float>({});
     62   auto concatenated = builder.ConcatInDim({a}, 0);
     63 
     64   std::vector<float> expected = {};
     65   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
     66 }
     67 
     68 // Show that we can't concatenate R0 with R0 because we can't name the dimension
     69 // to concatenate on.
     70 XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) {
     71   ComputationBuilder builder(client_, TestName());
     72   auto a = builder.ConstantR0<float>(42.0);
     73   auto b = builder.ConstantR0<float>(64.0);
     74   auto concatenated = builder.ConcatInDim({a, b}, 0);
     75   StatusOr<Computation> computation_status = builder.Build();
     76   ASSERT_FALSE(computation_status.ok());
     77   EXPECT_THAT(computation_status.status().ToString(),
     78               HasSubstr("dimension to concatenate along out of bounds: 0"));
     79 }
     80 
     81 XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) {
     82   ComputationBuilder builder(client_, TestName());
     83   auto a = builder.ConstantR1<float>({});
     84   auto b = builder.ConstantR1<float>({});
     85   auto concatenated = builder.ConcatInDim({a, b}, 0);
     86 
     87   std::vector<float> expected = {};
     88   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
     89 }
     90 
     91 XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) {
     92   ComputationBuilder builder(client_, TestName());
     93   auto a = builder.ConstantR1<float>({});
     94   auto b = builder.ConstantR1<float>({256.0});
     95   auto concatenated = builder.ConcatInDim({a, b}, 0);
     96 
     97   std::vector<float> expected = {256};
     98   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
     99 }
    100 
    101 XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) {
    102   ComputationBuilder builder(client_, TestName());
    103   auto a = builder.ConstantR1<float>({42.0, 64.0});
    104   auto b = builder.ConstantR1<float>({});
    105   auto concatenated = builder.ConcatInDim({a, b}, 0);
    106 
    107   std::vector<float> expected = {42, 64};
    108   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
    109 }
    110 
    111 XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) {
    112   ComputationBuilder builder(client_, TestName());
    113   auto a = builder.ConstantR1<float>({42.0, 64.0});
    114   auto b = builder.ConstantR1<float>({256.0});
    115   auto concatenated = builder.ConcatInDim({a, b}, 0);
    116 
    117   std::vector<float> expected = {42, 64, 256};
    118   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
    119 }
    120 
    121 XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) {
    122   std::vector<float> lhs(253);
    123   std::vector<float> rhs(7);
    124   std::vector<float> expected(253 + 7);
    125   for (int i = 0; i < 253; ++i) {
    126     expected[i] = lhs[i] = i + 1;
    127   }
    128   for (int i = 0; i < 7; ++i) {
    129     expected[253 + i] = rhs[i] = 253 + i + 1;
    130   }
    131 
    132   ComputationBuilder builder(client_, TestName());
    133   auto a = builder.ConstantR1<float>(lhs);
    134   auto b = builder.ConstantR1<float>(rhs);
    135   auto concatenated = builder.ConcatInDim({a, b}, 0);
    136 
    137   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
    138 }
    139 
    140 XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) {
    141   for (int dim : {0, 1}) {
    142     ComputationBuilder builder(client_, TestName());
    143     auto a = builder.ConstantR2FromArray2D(Array2D<float>(0, 0));
    144     auto b = builder.ConstantR2FromArray2D(Array2D<float>(0, 0));
    145     auto concatenated = builder.ConcatInDim({a, b}, dim);
    146 
    147     ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {},
    148                                ErrorSpec(0.0001));
    149   }
    150 }
    151 
    152 XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) {
    153   ComputationBuilder builder(client_, TestName());
    154   auto a_array = CreatePatternedMatrix(1, 1);
    155   auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
    156   auto a = builder.ConstantR2FromArray2D(*a_array);
    157   auto b = builder.ConstantR2FromArray2D(*b_array);
    158   auto concatenated = builder.ConcatInDim({a, b}, 0);
    159 
    160   Array2D<float> expected({
    161       {0}, {64},
    162   });
    163   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
    164 }
    165 
    166 XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) {
    167   ComputationBuilder builder(client_, TestName());
    168   auto a_array = CreatePatternedMatrix(1, 1);
    169   auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
    170   auto a = builder.ConstantR2FromArray2D(*a_array);
    171   auto b = builder.ConstantR2FromArray2D(*b_array);
    172   auto concatenated = builder.ConcatInDim({a, b}, 1);
    173 
    174   Array2D<float> expected({
    175       {0, 64},
    176   });
    177   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
    178 }
    179 
    180 XLA_TEST_F(ConcatTest, Concat2x0With2x5) {
    181   ComputationBuilder builder(client_, TestName());
    182   auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
    183   auto a = builder.ConstantR2FromArray2D(Array2D<float>(2, 0));
    184   auto b = builder.ConstantR2FromArray2D(*b_array);
    185   auto concatenated = builder.ConcatInDim({a, b}, 1);
    186 
    187   ComputeAndCompareR2<float>(&builder, *b_array, {}, ErrorSpec(0.0001));
    188 }
    189 
    190 XLA_TEST_F(ConcatTest, Concat2x3With2x5) {
    191   ComputationBuilder builder(client_, TestName());
    192   auto a_array = CreatePatternedMatrix(2, 3);
    193   auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
    194   auto a = builder.ConstantR2FromArray2D(*a_array);
    195   auto b = builder.ConstantR2FromArray2D(*b_array);
    196   auto concatenated = builder.ConcatInDim({a, b}, 1);
    197 
    198   Array2D<float> expected({
    199       {0, 1, 2, 64, 65, 66, 67, 68},
    200       {1000, 1001, 1002, 1064, 1065, 1066, 1067, 1068},
    201   });
    202   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
    203 }
    204 
    205 XLA_TEST_F(ConcatTest, Concat3x2With0x2) {
    206   ComputationBuilder builder(client_, TestName());
    207   auto a_array = CreatePatternedMatrix(3, 2);
    208   auto a = builder.ConstantR2FromArray2D(*a_array);
    209   auto b = builder.ConstantR2FromArray2D(Array2D<float>(0, 2));
    210   auto concatenated = builder.ConcatInDim({a, b}, 0);
    211 
    212   ComputeAndCompareR2<float>(&builder, *a_array, {}, ErrorSpec(0.0001));
    213 }
    214 
    215 XLA_TEST_F(ConcatTest, Concat3x2With5x2) {
    216   ComputationBuilder builder(client_, TestName());
    217   auto a_array = CreatePatternedMatrix(3, 2);
    218   auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0);
    219   auto a = builder.ConstantR2FromArray2D(*a_array);
    220   auto b = builder.ConstantR2FromArray2D(*b_array);
    221   auto concatenated = builder.ConcatInDim({a, b}, 0);
    222 
    223   Array2D<float> expected({
    224       {0, 1},
    225       {1000, 1001},
    226       {2000, 2001},
    227       {64, 65},
    228       {1064, 1065},
    229       {2064, 2065},
    230       {3064, 3065},
    231       {4064, 4065},
    232   });
    233   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
    234 }
    235 
    236 XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) {
    237   ComputationBuilder builder(client_, TestName());
    238   auto a = builder.ConstantR3FromArray3D(Array3D<float>(3, 0, 2));
    239   auto b = builder.ConstantR3FromArray3D(Array3D<float>(3, 0, 1));
    240   auto concatenated = builder.ConcatInDim({a, b}, 2);
    241   ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 3), {},
    242                              ErrorSpec(0.0001));
    243 }
    244 
    245 XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) {
    246   ComputationBuilder builder(client_, TestName());
    247   Array3D<float> a_array({
    248       // 3x1x2
    249       {{0, 1}},
    250       {{2, 3}},
    251       {{4, 5}},
    252   });
    253   Array3D<float> b_array({
    254       // 3x1x1
    255       {{6}},
    256       {{7}},
    257       {{8}},
    258   });
    259   auto a = builder.ConstantR3FromArray3D(a_array);
    260   auto b = builder.ConstantR3FromArray3D(b_array);
    261   auto concatenated = builder.ConcatInDim({a, b}, 2);
    262 
    263   Array3D<float> expected({
    264       {{0, 1, 6}}, {{2, 3, 7}}, {{4, 5, 8}},
    265   });
    266   ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
    267 }
    268 
    269 XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) {
    270   ComputationBuilder builder(client_, TestName());
    271   auto a = builder.ConstantR1<float>({42.0});
    272   auto b = builder.ConstantR1<float>({64.0});
    273   auto c = builder.ConstantR1<float>({256.0});
    274   auto concatenated = builder.ConcatInDim({a, b, c}, 0);
    275 
    276   std::vector<float> expected = {42, 64, 256};
    277   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
    278 }
    279 
    280 XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) {
    281   ComputationBuilder builder(client_, TestName());
    282   Array3D<float> a_array({
    283       // 3x1x2
    284       {{0, 1}},
    285       {{4, 5}},
    286       {{8, 9}},
    287   });
    288   Array3D<float> b_array({
    289       // 3x1x1
    290       {{2}},
    291       {{6}},
    292       {{10}},
    293   });
    294   Array3D<float> c_array({
    295       // 3x1x1
    296       {{3}},
    297       {{7}},
    298       {{11}},
    299   });
    300   auto a = builder.ConstantR3FromArray3D(a_array);
    301   auto b = builder.ConstantR3FromArray3D(b_array);
    302   auto c = builder.ConstantR3FromArray3D(c_array);
    303   auto concatenated = builder.ConcatInDim({a, b, c}, 2);
    304 
    305   Array3D<float> expected({
    306       {{0, 1, 2, 3}}, {{4, 5, 6, 7}}, {{8, 9, 10, 11}},
    307   });
    308   ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
    309 }
    310 
    311 XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) {
    312   ComputationBuilder builder(client_, TestName());
    313   auto a = builder.ConstantR1<float>({42.0});
    314   auto b = builder.ConstantR1<float>({64.0});
    315   auto c = builder.ConstantR1<float>({256.0});
    316   // concatenated = (a concat b) concat c
    317   auto concatenated =
    318       builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0);
    319 
    320   std::vector<float> expected = {42, 64, 256};
    321   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
    322 }
    323 
    324 XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) {
    325   ComputationBuilder builder(client_, TestName());
    326   auto a = builder.ConstantR1<float>({42.0});
    327   auto b = builder.ConstantR1<float>({64.0});
    328   auto c = builder.ConstantR1<float>({256.0});
    329   // concatenated = a concat (b concat c)
    330   auto concatenated =
    331       builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0);
    332 
    333   std::vector<float> expected = {42, 64, 256};
    334   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
    335 }
    336 
    337 XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) {
    338   Array2D<float> lhs(1, 1024);
    339   Array2D<float> rhs(1, 1024);
    340   for (int i = 0; i < 1024; ++i) {
    341     lhs(0, i) = i;
    342     rhs(0, i) = i + 1024;
    343   }
    344 
    345   ComputationBuilder builder(client_, TestName());
    346   auto a = builder.ConstantR2FromArray2D<float>(lhs);
    347   auto b = builder.ConstantR2FromArray2D<float>(rhs);
    348   builder.ConcatInDim({a, b}, 0);
    349 
    350   Array2D<float> expected(2, 1024);
    351   for (int i = 0; i < 1024; ++i) {
    352     expected(0, i) = i;
    353     expected(1, i) = i + 1024;
    354   }
    355   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
    356 }
    357 
    358 XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) {
    359   Array2D<float> lhs(1, 1024);
    360   Array2D<float> rhs(1, 1024);
    361   for (int i = 0; i < 1024; ++i) {
    362     lhs(0, i) = i;
    363     rhs(0, i) = i + 1024;
    364   }
    365 
    366   ComputationBuilder builder(client_, TestName());
    367   auto a = builder.ConstantR2FromArray2D<float>(lhs);
    368   auto b = builder.ConstantR2FromArray2D<float>(rhs);
    369   builder.ConcatInDim({a, b}, 1);
    370 
    371   Array2D<float> expected(1, 2048);
    372   for (int i = 0; i < 1024; ++i) {
    373     expected(0, i) = i;
    374     expected(0, i + 1024) = i + 1024;
    375   }
    376   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
    377 }
    378 
    379 XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) {
    380   Array2D<float> lhs(64, 64);
    381   Array2D<float> rhs(64, 2);
    382   for (int i0 = 0; i0 < 64; ++i0) {
    383     for (int i1 = 0; i1 < 64; ++i1) {
    384       lhs(i0, i1) = (i0 << 10) | i1;
    385     }
    386     for (int i1 = 0; i1 < 2; ++i1) {
    387       rhs(i0, i1) = (i0 << 10) | (i1 + 64);
    388     }
    389   }
    390 
    391   ComputationBuilder builder(client_, TestName());
    392   auto a = builder.ConstantR2FromArray2D<float>(lhs);
    393   auto b = builder.ConstantR2FromArray2D<float>(rhs);
    394   builder.ConcatInDim({a, b}, 1);
    395 
    396   Array2D<float> expected(64, 66);
    397   for (int i0 = 0; i0 < 64; ++i0) {
    398     for (int i1 = 0; i1 < 66; ++i1) {
    399       expected(i0, i1) = (i0 << 10) | i1;
    400     }
    401   }
    402   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
    403 }
    404 
    405 // Show that we can't concatenate with an opaques.
    406 XLA_TEST_F(ConcatTest, CannotConcatOpaques) {
    407   ComputationBuilder builder(client_, TestName());
    408   auto opaque_shape = ShapeUtil::MakeOpaqueShape();
    409   auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
    410   auto x = builder.Parameter(0, r1f32, "x");
    411   auto y = builder.Parameter(1, opaque_shape, "y");
    412   auto concatenated = builder.ConcatInDim({x, y}, 0);
    413   StatusOr<Computation> computation_status = builder.Build();
    414   ASSERT_FALSE(computation_status.ok());
    415   EXPECT_THAT(
    416       computation_status.status().ToString(),
    417       HasSubstr("Expected non-opaque argument for operand of concatenation"));
    418 }
    419 
    420 XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) {
    421   ComputationBuilder builder(client_, TestName());
    422   auto p0 = builder.ConstantR1<bool>({true});
    423   auto p1 = builder.ConstantR1<bool>({false});
    424   auto p2 = builder.ConstantR1<bool>({true});
    425   auto concatenated = builder.ConcatInDim({p0, p1, p2}, 0);
    426 
    427   bool expected[] = {true, false, true};
    428   ComputeAndCompareR1<bool>(&builder, expected, {});
    429 }
    430 
    431 XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
    432   ComputationBuilder builder(client_, TestName());
    433   auto a0 = builder.ConstantR1<int32>({1});
    434   auto a1 = builder.ConstantR1<int32>({2, 3});
    435   auto a2 = builder.ConstantR1<int32>({4, 5, 6});
    436   auto a3 = builder.ConstantR1<int32>({7, 8, 9, 10});
    437   auto concatenated = builder.ConcatInDim({a0, a1, a2, a3}, 0);
    438 
    439   std::vector<int32> expected(10);
    440   std::iota(expected.begin(), expected.end(), 1);
    441   ComputeAndCompareR1<int32>(&builder, expected, {});
    442 }
    443 
    444 XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) {
    445   ComputationBuilder builder(client_, TestName());
    446 
    447   Array3D<float> arr0(9, 17, 1);
    448   arr0.Fill(1);
    449 
    450   Array3D<float> arr1(9, 17, 256);
    451   arr1.Fill(2);
    452 
    453   Array3D<float> expected(9, 17, arr0.n3() + arr1.n3());
    454   for (int64 i = 0; i < expected.n1(); ++i) {
    455     for (int64 j = 0; j < expected.n2(); ++j) {
    456       int64 kk = 0;
    457       for (const Array3D<float>& arr : {arr0, arr1}) {
    458         for (int64 k = 0; k < arr.n3(); ++k, ++kk) {
    459           expected(i, j, kk) = arr(i, j, k);
    460         }
    461       }
    462     }
    463   }
    464 
    465   ComputationDataHandle h0;
    466   auto p0 = CreateR3Parameter<float>(arr0, /*parameter_number=*/0, "p0",
    467                                      &builder, &h0);
    468   ComputationDataHandle h1;
    469   auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1",
    470                                      &builder, &h1);
    471 
    472   auto concatenated = builder.ConcatInDim({h0, h1}, 2);
    473 
    474   ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()});
    475 }
    476 
    477 // Describes a binary rank-2 concatenation test.
    478 struct R2BinarySpec {
    479   int64 lhs_dim0;
    480   int64 lhs_dim1;
    481   int64 rhs_dim0;
    482   int64 rhs_dim1;
    483   int64 concat_dimension;
    484 };
    485 
    486 // TEST_P harness for binary rank-2 concatenation.
    487 class ConcatR2BinaryTest : public ClientLibraryTestBase,
    488                            public ::testing::WithParamInterface<R2BinarySpec> {
    489 };
    490 
    491 TEST_P(ConcatR2BinaryTest, DoIt) {
    492   const R2BinarySpec& spec = GetParam();
    493   Array2D<int32> lhs(spec.lhs_dim0, spec.lhs_dim1);
    494   lhs.FillUnique();
    495   Array2D<int32> rhs(spec.rhs_dim0, spec.rhs_dim1);
    496   rhs.FillUnique(1000);
    497 
    498   ComputationBuilder builder(client_, TestName());
    499   auto a0 = builder.ConstantR2FromArray2D<int32>(lhs);
    500   auto a1 = builder.ConstantR2FromArray2D<int32>(rhs);
    501   builder.ConcatInDim({a0, a1}, spec.concat_dimension);
    502 
    503   std::unique_ptr<Array2D<int32>> expected =
    504       ReferenceUtil::Concat2D(lhs, rhs, spec.concat_dimension);
    505   ComputeAndCompareR2<int32>(&builder, *expected, {});
    506 }
    507 
    508 // Regression test for b/31944287. x*y is used (at the same index) by all
    509 // operands of the concat. We should emit x*y in three incoming basic blocks of
    510 // the concat because these basic blocks are not control-equivalent.
    511 //
    512 //      x*y
    513 //    /  |   \
    514 // add1 add2 add3
    515 //    \  |   /
    516 //     concat
    517 XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
    518   auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
    519   auto x_literal = Literal::CreateR0<float>(2.f);
    520   auto y_literal = Literal::CreateR0<float>(3.f);
    521   auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
    522   auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
    523 
    524   ComputationBuilder builder(client_, TestName());
    525   auto x = builder.Parameter(0, f32_scalar, "x");
    526   auto y = builder.Parameter(1, f32_scalar, "y");
    527   auto mul = builder.Mul(x, y);
    528   auto add1 = builder.Add(mul, builder.ConstantR1<float>({1.f, 2.f}));
    529   auto add2 = builder.Add(mul, builder.ConstantR1<float>({3.f, 4.f}));
    530   auto add3 = builder.Add(mul, builder.ConstantR1<float>({5.f, 6.f}));
    531   builder.ConcatInDim({add1, add2, add3}, /*dimension=*/0);
    532 
    533   ComputeAndCompareR1<float>(&builder, {7., 8., 9., 10., 11., 12.},
    534                              {x_data.get(), y_data.get()}, ErrorSpec(1e-4));
    535 }
    536 
    537 // Test that the HLO optimization to replace a concat of a bradcasted scalar
    538 // produces the correct result in rank 1.
    539 XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
    540   auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
    541   auto x_literal = Literal::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
    542   auto y_literal = Literal::CreateR0<float>(1.5f);
    543   auto z_literal = Literal::CreateR0<float>(5.5f);
    544   auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
    545   auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
    546   auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
    547 
    548   ComputationBuilder builder(client_, TestName());
    549   auto x = builder.Parameter(0, x_literal->shape(), "x");
    550   auto y = builder.Parameter(1, f32_scalar, "y");
    551   auto z = builder.Parameter(2, f32_scalar, "z");
    552   auto bcast = builder.Broadcast(y, {5});
    553   auto bcast2 = builder.Broadcast(z, {3});
    554   auto concat = builder.ConcatInDim({bcast, x}, /*dimension=*/0);
    555   builder.ConcatInDim({concat, bcast2}, /*dimension=*/0);
    556 
    557   ComputeAndCompareR1<float>(
    558       &builder,
    559       {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 2.0f, 3.0f, 5.0f, 6.0f, 5.5f, 5.5f, 5.5f},
    560       {x_data.get(), y_data.get(), z_data.get()}, ErrorSpec(1e-4));
    561 }
    562 
    563 // Test that the HLO optimization to replace a concat of a bradcasted scalar
    564 // produces the correct result in rank 3 with both high and low padding in
    565 // different dimensions.
    566 XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) {
    567   auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
    568   Array3D<float> x3d(3, 5, 7, 3.14f);
    569   auto x_literal = Literal::CreateR3FromArray3D<float>(x3d);
    570   auto y_literal = Literal::CreateR0<float>(1.5f);
    571   auto z_literal = Literal::CreateR0<float>(5.5f);
    572   auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
    573   auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
    574   auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
    575 
    576   ComputationBuilder builder(client_, TestName());
    577   auto x = builder.Parameter(0, x_literal->shape(), "x");
    578   auto y = builder.Parameter(1, f32_scalar, "y");
    579   auto z = builder.Parameter(2, f32_scalar, "y");
    580   auto y_bcast = builder.Broadcast(y, {1, 5, 7});
    581   auto z_bcast = builder.Broadcast(z, {4, 1, 7});
    582   auto concat = builder.ConcatInDim({y_bcast, x}, /*dimension=*/0);
    583   builder.ConcatInDim({concat, z_bcast}, /*dimension=*/1);
    584   Array3D<float> y_bcast3d(1, 5, 7, 1.5f);
    585   Array3D<float> z_bcast3d(4, 1, 7, 5.5f);
    586   auto concat0 = ReferenceUtil::Concat3D(y_bcast3d, x3d, 0);
    587   auto concat1 = ReferenceUtil::Concat3D(*concat0, z_bcast3d, 1);
    588 
    589   ComputeAndCompareR3<float>(&builder, *concat1,
    590                              {x_data.get(), y_data.get(), z_data.get()},
    591                              ErrorSpec(1e-4));
    592 }
    593 
    594 INSTANTIATE_TEST_CASE_P(ConcatR2BinaryTestInstantiation, ConcatR2BinaryTest,
    595                         ::testing::Values(R2BinarySpec{1, 1, 1, 1, 0},
    596                                           R2BinarySpec{1, 1, 1, 1, 1},
    597                                           R2BinarySpec{4, 3, 4, 3, 0},
    598                                           R2BinarySpec{4, 3, 4, 3, 1},
    599                                           R2BinarySpec{7, 128, 1, 128, 0},
    600                                           R2BinarySpec{8, 127, 8, 1, 1}));
    601 
    602 }  // namespace
    603 }  // namespace xla
    604