Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Tests that slice operations can be performed.
     17 
     18 #include <numeric>
     19 #include <vector>
     20 
     21 #include "absl/container/inlined_vector.h"
     22 #include "absl/strings/str_cat.h"
     23 #include "absl/strings/str_format.h"
     24 #include "absl/strings/str_join.h"
     25 #include "absl/types/span.h"
     26 #include "tensorflow/compiler/xla/array2d.h"
     27 #include "tensorflow/compiler/xla/client/local_client.h"
     28 #include "tensorflow/compiler/xla/client/xla_builder.h"
     29 #include "tensorflow/compiler/xla/reference_util.h"
     30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     32 #include "tensorflow/compiler/xla/tests/test_macros.h"
     33 #include "tensorflow/core/platform/test.h"
     34 #include "tensorflow/core/platform/types.h"
     35 
     36 namespace xla {
     37 namespace {
     38 
     39 class SliceTest : public ClientLibraryTestBase {};
     40 
     41 TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) {
     42   Array3D<float> values(3, 3, 3);
     43   values.FillIota(0);
     44 
     45   XlaBuilder builder(TestName());
     46   auto original = ConstantR3FromArray3D<float>(&builder, values);
     47   Slice(original, {0, 0, 0}, {3, 3, 1}, {1, 1, 1});
     48 
     49   Array3D<float> expected{
     50       {{0.0}, {3.0}, {6.0}}, {{9.0}, {12.0}, {15.0}}, {{18.0}, {21.0}, {24.0}}};
     51   ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.000001));
     52 }
     53 
     54 TEST_F(SliceTest, Slice3x3x3_To_3x1x3_F32) {
     55   Array3D<float> values(3, 3, 3);
     56   values.FillIota(0);
     57 
     58   XlaBuilder builder(TestName());
     59   auto original = ConstantR3FromArray3D<float>(&builder, values);
     60   Slice(original, {0, 0, 0}, {3, 1, 3}, {1, 1, 1});
     61 
     62   Array3D<float> expected{
     63       {{0.0, 1.0, 2.0}}, {{9.0, 10.0, 11.0}}, {{18.0, 19.0, 20.0}}};
     64   ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.000001));
     65 }
     66 
     67 TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) {
     68   Array3D<float> values(3, 3, 3);
     69   values.FillIota(0);
     70 
     71   XlaBuilder builder(TestName());
     72   auto original = ConstantR3FromArray3D<float>(&builder, values);
     73   Slice(original, {0, 0, 0}, {1, 3, 3}, {1, 1, 1});
     74 
     75   Array3D<float> expected{
     76       {{{0.0, 1.0, 2.0}, {3.0, 4.0, 5.0}, {6.0, 7.0, 8.0}}}};
     77   ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.000001));
     78 }
     79 
     80 XLA_TEST_F(SliceTest, Slice0x0to0x0F32) {
     81   XlaBuilder builder(TestName());
     82   auto original = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 0));
     83   Slice(original, {0, 0}, {0, 0}, {1, 1});
     84 
     85   ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {});
     86 }
     87 
     88 XLA_TEST_F(SliceTest, Slice0x20to0x5F32) {
     89   XlaBuilder builder(TestName());
     90   auto original = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 20));
     91   Slice(original, {0, 15}, {0, 20}, {1, 1});
     92 
     93   ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 5), {});
     94 }
     95 
     96 XLA_TEST_F(SliceTest, Slice3x0to2x0F32) {
     97   XlaBuilder builder(TestName());
     98   auto original = ConstantR2FromArray2D<float>(&builder, Array2D<float>(3, 0));
     99   Slice(original, {1, 0}, {3, 0}, {1, 1});
    100 
    101   ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 0), {});
    102 }
    103 
    104 XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) {
    105   Array2D<float> values(256, 256);
    106   for (int row = 0; row < 256; ++row) {
    107     for (int col = 0; col < 256; ++col) {
    108       values(row, col) = (row << 10) | col;
    109     }
    110   }
    111 
    112   XlaBuilder builder(TestName());
    113   auto original = ConstantR2FromArray2D<float>(&builder, values);
    114   Slice(original, {128, 128}, {256, 256}, {1, 1});
    115 
    116   Array2D<float> expected(128, 128);
    117   for (int row = 0; row < 128; ++row) {
    118     for (int col = 0; col < 128; ++col) {
    119       expected(row, col) = ((row + 128) << 10) | (col + 128);
    120     }
    121   }
    122   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
    123 }
    124 
    125 // Tests: (f32[1,4096], starts={0, 3072}, limits={1, 4096}) -> f32[1,1024])
    126 TEST_F(SliceTest, Slice_1x4096_To_1x1024) {
    127   Array2D<float> values(1, 4096);
    128   std::iota(values.data(), values.data() + 4096, 0.0);
    129 
    130   XlaBuilder builder(TestName());
    131   auto original = ConstantR2FromArray2D<float>(&builder, values);
    132   Slice(original, {0, 3072}, {1, 4096}, {1, 1});
    133 
    134   Array2D<float> expected(1, 1024);
    135   std::iota(expected.data(), expected.data() + 1024, 3072.0);
    136   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
    137 }
    138 
    139 // Tests slice: (f32[16,4], starts={0, 0}, limits={16, 2}) -> f32[16,2]
    140 TEST_F(SliceTest, Slice_16x4_To_16x2) {
    141   Array2D<float> values(16, 4);
    142   Array2D<float> expected(16, 2);
    143   for (int row = 0; row < 16; ++row) {
    144     for (int col = 0; col < 4; ++col) {
    145       values(row, col) = (row << 10) | col;
    146       if (col < 2) {
    147         expected(row, col) = (row << 10) | col;
    148       }
    149     }
    150   }
    151   XlaBuilder builder(TestName());
    152   auto original = ConstantR2FromArray2D<float>(&builder, values);
    153   Slice(original, {0, 0}, {16, 2}, {1, 1});
    154   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
    155 }
    156 
    157 // Tests: (f32[2, 2, 24, 256], starts = {1, 0, 8, 0}, ends = {2, 2, 16, 128}
    158 TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) {
    159   Array4D<float> values(2, 2, 24, 256);
    160   values.FillRandom(3.14f);
    161   auto expected = ReferenceUtil::Slice4D(
    162       values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}, /*strides=*/{{1, 1, 1, 1}});
    163   XlaBuilder builder(TestName());
    164   auto original = ConstantR4FromArray4D(&builder, values);
    165   Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1});
    166   ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001));
    167 }
    168 
    169 TEST_F(SliceTest, SliceOfReshape) {
    170   Array2D<int> values(2 * 3 * 24, 7);
    171   values.FillIota(1);
    172   XlaBuilder builder(TestName());
    173   auto original = ConstantR2FromArray2D(&builder, values);
    174   auto reshape = Reshape(original, {24, 3, 2, 7});
    175   Slice(reshape, {0, 0, 0, 0}, {11, 3, 2, 7}, {1, 1, 1, 1});
    176   ComputeAndCompare(&builder, {});
    177 }
    178 
    179 TEST_F(SliceTest, SliceOfCollapsingReshape) {
    180   Array4D<int> values(2, 3, 5, 7);
    181   values.FillIota(1);
    182   XlaBuilder builder(TestName());
    183   auto original = ConstantR4FromArray4D(&builder, values);
    184   auto reshape = Reshape(original, {2 * 3 * 5, 7});
    185   Slice(reshape, {0, 0}, {4, 7}, {1, 1});
    186   ComputeAndCompare(&builder, {});
    187 }
    188 
    189 XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) {
    190   Array4D<float> values(2, 4, 6, 8);
    191   values.FillRandom(3.14f);
    192   auto expected = ReferenceUtil::Slice4D(values, {{0, 0, 0, 0}}, {{2, 4, 6, 8}},
    193                                          /*strides=*/{{1, 1, 2, 1}});
    194   auto expected_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
    195       *expected, LayoutUtil::MakeLayout({0, 1, 2, 3}));
    196   XlaBuilder builder(TestName());
    197   auto original = ConstantR4FromArray4D(&builder, values);
    198   Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1});
    199   ComputeAndCompareLiteral(&builder, expected_literal, {}, ErrorSpec(0.000001),
    200                            &expected_literal.shape());
    201 }
    202 
    203 struct R1Spec {
    204   int64 input_dim0;
    205   int64 slice_start;
    206   int64 slice_limit;
    207   int64 slice_stride;
    208 };
    209 
    210 // Parameterized test that generates R1 values, slices them according
    211 // to the R1Spec, and compares the result with a computed version.
    212 class SliceR1Test : public ClientLibraryTestBase,
    213                     public ::testing::WithParamInterface<R1Spec> {
    214  protected:
    215   template <typename NativeT>
    216   void Run(const R1Spec& spec) {
    217     // This can't be an std::vector, since you can't grab a Span of a
    218     // vector<bool>.
    219     absl::InlinedVector<NativeT, 1> input(spec.input_dim0);
    220     std::iota(input.begin(), input.end(), NativeT());
    221     auto literal = LiteralUtil::CreateR1<NativeT>(input);
    222 
    223     XlaBuilder builder(TestName());
    224     auto original = Parameter(&builder, 0, literal.shape(), "p0");
    225     Slice(original, {spec.slice_start}, {spec.slice_limit},
    226           {spec.slice_stride});
    227 
    228     // Ditto.
    229     absl::InlinedVector<NativeT, 1> expected;
    230     for (int i = spec.slice_start; i < spec.slice_limit;
    231          i += spec.slice_stride) {
    232       expected.push_back(i);
    233     }
    234 
    235     TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
    236                             client_->TransferToServer(literal));
    237     ComputeAndCompareR1<NativeT>(&builder, expected, {arg.get()});
    238   }
    239 };
    240 
    241 // A version of SliceR1Test used to label and disable 'large' tests
    242 class SliceR1LargeTest : public SliceR1Test {};
    243 
    244 string SliceR1TestDataToString(const ::testing::TestParamInfo<R1Spec>& data) {
    245   const R1Spec& spec = data.param;
    246   return absl::StrFormat("%d_%d_%d_%d", spec.input_dim0, spec.slice_start,
    247                          spec.slice_limit, spec.slice_stride);
    248 }
    249 
    250 XLA_TEST_P(SliceR1Test, DoIt_F32) { Run<float>(GetParam()); }
    251 
    252 XLA_TEST_P(SliceR1Test, DoIt_F64) { Run<double>(GetParam()); }
    253 
    254 XLA_TEST_P(SliceR1Test, DoIt_U32) { Run<uint32>(GetParam()); }
    255 
    256 XLA_TEST_P(SliceR1Test, DoIt_S32) { Run<int32>(GetParam()); }
    257 
    258 XLA_TEST_P(SliceR1Test, DoIt_U64) { Run<uint64>(GetParam()); }
    259 
    260 XLA_TEST_P(SliceR1Test, DoIt_S64) { Run<int64>(GetParam()); }
    261 
    262 XLA_TEST_P(SliceR1LargeTest, DoIt_F32) { Run<float>(GetParam()); }
    263 
    264 XLA_TEST_P(SliceR1LargeTest, DoIt_F64) { Run<double>(GetParam()); }
    265 
    266 XLA_TEST_P(SliceR1LargeTest, DoIt_U32) { Run<uint32>(GetParam()); }
    267 
    268 XLA_TEST_P(SliceR1LargeTest, DoIt_S32) { Run<int32>(GetParam()); }
    269 
    270 XLA_TEST_P(SliceR1LargeTest, DoIt_U64) { Run<uint64>(GetParam()); }
    271 
    272 XLA_TEST_P(SliceR1LargeTest, DoIt_S64) { Run<int64>(GetParam()); }
    273 
    274 XLA_TEST_P(SliceR1Test, DoIt_PRED) { Run<bool>(GetParam()); }
    275 
    276 // Tests for R1 slice ops.
    277 // The format for each testcase is {input size, start, limit, stride}.
    278 // clang-format off
    279 INSTANTIATE_TEST_CASE_P(
    280     SliceR1TestInstantiation,
    281     SliceR1Test,
    282     ::testing::Values(
    283         R1Spec{10, 0, 0, 1},
    284         R1Spec{10, 7, 7, 1},
    285         R1Spec{10, 0, 5, 1},
    286         R1Spec{10, 3, 5, 1},
    287         R1Spec{10, 0, 10, 1},
    288         R1Spec{1024, 0, 5, 1},
    289         R1Spec{1024, 3, 5, 1},
    290         R1Spec{1024 + 17, 0, 5, 1},
    291         R1Spec{1024 + 17, 3, 5, 1},
    292         R1Spec{1024 + 17, 1024, 1024 + 6, 1},
    293         R1Spec{1024 + 17, 1024 + 1, 1024 + 6, 1},
    294         R1Spec{1024, 1024 - 4, 1024, 1},
    295         R1Spec{4 * 1024, 7, 7 + 1024, 1},
    296         R1Spec{4 * 1024, 0, 4 * 1024, 1},
    297         R1Spec{4 * 1024, 1, 4 * 1024 - 1, 1},
    298         R1Spec{4 * 1024, 1024, 3 * 1024, 1},
    299         R1Spec{4 * 1024, 1024 + 1, 3 * 1024 - 1, 1},
    300         R1Spec{16 * 1024, 0, 5, 1},
    301         R1Spec{16 * 1024, 3, 5, 1},
    302         R1Spec{16 * 1024 + 17, 0, 5, 1},
    303         R1Spec{16 * 1024 + 17, 3, 5, 1},
    304         R1Spec{16 * 1024 + 17, 16 * 1024, 16 * 1024 + 6, 1},
    305         R1Spec{16 * 1024 + 17, 16 * 1024 + 1, 16 * 1024 + 6, 1},
    306         R1Spec{16 * 1024, 4 * 1024 - 17, 8 * 1024 - 18, 1},
    307         R1Spec{64 * 1024, 0, 64 * 1024, 1},
    308         R1Spec{64 * 1024, 1, 64 * 1024 - 1, 1},
    309         R1Spec{64 * 1024, 1024, 63 * 1024, 1},
    310         R1Spec{64 * 1024, 1024 + 1, 63 * 1024 - 1, 1},
    311         R1Spec{64 * 1024, 32 * 1024, 33 * 1024, 1},
    312         R1Spec{64 * 1024, 32 * 1024 + 1, 33 * 1024 - 1, 1},
    313         R1Spec{64 * 1024, 32 * 1024 - 17, 36 * 1024 - 18, 1}
    314     ),
    315     SliceR1TestDataToString
    316 );
    317 
    318 // TODO(b/69425338): This uses too much memory on GPU.
    319 #ifndef XLA_TEST_BACKEND_GPU
    320 INSTANTIATE_TEST_CASE_P(
    321     SliceR1TestBigSlicesInstantiation,
    322     SliceR1LargeTest,
    323     ::testing::Values(
    324           R1Spec{
    325               16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1},
    326           R1Spec{
    327               16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1},
    328           R1Spec{
    329               16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1}
    330     ),
    331     SliceR1TestDataToString
    332 );
    333 #endif
    334 
    335 INSTANTIATE_TEST_CASE_P(
    336     SliceStridedR1TestInstantiation,
    337     SliceR1Test,
    338     ::testing::Values(
    339         R1Spec{10, 2, 4, 2},
    340         R1Spec{10, 0, 10, 2},
    341         R1Spec{10, 0, 10, 3},
    342         R1Spec{10, 0, 10, 4},
    343         R1Spec{10, 0, 10, 5},
    344         R1Spec{10, 0, 10, 10},
    345         R1Spec{500, 200, 400, 7},
    346         R1Spec{4096, 1, 4095, 3},
    347         R1Spec{2047, 1024 - 24, 1024 + 160, 31},
    348         R1Spec{2047, 1, 2046, 3 * 128},
    349         R1Spec{4096, 1024 + 3, 4095, 500},
    350         R1Spec{8192, 0, 8192, 1024 * 3 + 400},
    351         R1Spec{1024 * 1024, 0, 1024 * 1024, 2},
    352         R1Spec{1024 * 1024, 0, 1024 * 1024, 8},
    353         R1Spec{1024 * 1024, 0, 1024 * 1024, 7},
    354         R1Spec{1024 * 1024, 0, 1024 * 1024, 125},
    355         R1Spec{1024 * 1024, 3, 1024 - 9, 2},
    356         R1Spec{1024 * 1024, 3, 1024 - 9, 8},
    357         R1Spec{1024 * 1024, 3, 1024 - 9, 7},
    358         R1Spec{1024 * 1024, 3, 1024 - 9, 125},
    359         R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 2},
    360         R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 8},
    361         R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 7},
    362         R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 125},
    363         R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 2},
    364         R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 8},
    365         R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 7},
    366         R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 125},
    367         R1Spec{16 * 1024 * 1024, 0, 16 * 1024 * 1024, 4097},
    368         R1Spec{16 * 1024 * 1024, 0, 16 * 1024 * 1024, 4093},
    369         R1Spec{16 * 1024 * 1024, 12 * 1024 + 17, 16 * 1024 * 1024 - 231, 4097},
    370         R1Spec{16 * 1024 * 1024, 12 * 1024 + 17, 16 * 1024 * 1024 - 231, 4093}
    371     ),
    372     SliceR1TestDataToString
    373 );
    374 // clang-format on
    375 
    376 struct R2Spec {
    377   int64 input_dim0;
    378   int64 input_dim1;
    379   std::array<int64, 2> slice_starts;
    380   std::array<int64, 2> slice_limits;
    381   std::array<int64, 2> slice_strides;
    382   std::array<int64, 2> layout;
    383 };
    384 
    385 // Parameterized test that generates patterned R2 values, slices them according
    386 // to the R2Spec, and compares the results with the ReferenceUtil version.
    387 class SliceR2Test : public ClientLibraryTestBase,
    388                     public ::testing::WithParamInterface<R2Spec> {};
    389 
    390 XLA_TEST_P(SliceR2Test, DoIt) {
    391   const R2Spec& spec = GetParam();
    392   Array2D<int32> input(spec.input_dim0, spec.input_dim1);
    393   input.FillUnique();
    394   auto literal = LiteralUtil::CreateR2FromArray2DWithLayout(
    395       input, LayoutUtil::MakeLayout(spec.layout));
    396 
    397   XlaBuilder builder(TestName());
    398   auto a = Parameter(&builder, 0, literal.shape(), "p0");
    399   Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
    400 
    401   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
    402                           client_->TransferToServer(literal));
    403   std::unique_ptr<Array2D<int32>> expected = ReferenceUtil::Slice2D(
    404       input, spec.slice_starts, spec.slice_limits, spec.slice_strides);
    405   ComputeAndCompareR2<int32>(&builder, *expected, {arg.get()});
    406 }
    407 
    408 INSTANTIATE_TEST_CASE_P(
    409     SliceR2TestInstantiation, SliceR2Test,
    410     ::testing::Values(
    411         R2Spec{4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}}, {{0, 1}}},              //
    412         R2Spec{4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}}, {{1, 0}}},              //
    413         R2Spec{16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}}, {{0, 1}}},             //
    414         R2Spec{16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}}, {{1, 0}}},             //
    415         R2Spec{256, 400, {{0, 300}}, {{256, 400}}, {{1, 1}}, {{1, 0}}},     //
    416         R2Spec{500, 400, {{111, 123}}, {{300, 257}}, {{1, 1}}, {{1, 0}}},   //
    417         R2Spec{500, 400, {{111, 123}}, {{300, 400}}, {{1, 1}}, {{1, 0}}},   //
    418         R2Spec{384, 512, {{128, 256}}, {{256, 384}}, {{1, 1}}, {{1, 0}}},   //
    419         R2Spec{357, 512, {{111, 256}}, {{301, 384}}, {{1, 1}}, {{1, 0}}},   //
    420         R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{1, 2}}, {{0, 1}}},           //
    421         R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{1, 2}}, {{1, 0}}},           //
    422         R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{2, 1}}, {{0, 1}}},           //
    423         R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{2, 1}}, {{1, 0}}},           //
    424         R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{2, 2}}, {{0, 1}}},           //
    425         R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{2, 2}}, {{1, 0}}},           //
    426         R2Spec{256, 400, {{100, 129}}, {{256, 400}}, {{3, 5}}, {{1, 0}}},   //
    427         R2Spec{256, 400, {{100, 129}}, {{256, 400}}, {{3, 5}}, {{0, 1}}},   //
    428         R2Spec{256, 400, {{100, 129}}, {{256, 400}}, {{5, 3}}, {{1, 0}}},   //
    429         R2Spec{256, 400, {{100, 129}}, {{256, 400}}, {{5, 3}}, {{0, 1}}},   //
    430         R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{7, 11}}, {{1, 0}}},  //
    431         R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{7, 11}}, {{0, 1}}},  //
    432         R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{1, 0}}},  //
    433         R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{0, 1}}},  //
    434         R2Spec{8672, 512, {{8, 0}}, {{8672, 512}}, {{542, 1}}, {{1, 0}}},   //
    435         R2Spec{
    436             511, 513, {{129, 300}}, {{400, 500}}, {{101, 129}}, {{1, 0}}},  //
    437         R2Spec{
    438             511, 513, {{129, 300}}, {{400, 500}}, {{101, 129}}, {{0, 1}}},  //
    439         R2Spec{
    440             511, 513, {{129, 300}}, {{400, 500}}, {{129, 101}}, {{1, 0}}},  //
    441         R2Spec{
    442             511, 513, {{129, 300}}, {{400, 500}}, {{129, 101}}, {{0, 1}}},  //
    443         R2Spec{
    444             511, 1023, {{129, 257}}, {{500, 1000}}, {{129, 255}}, {{1, 0}}},  //
    445         R2Spec{
    446             511, 1023, {{129, 257}}, {{500, 1000}}, {{129, 255}}, {{0, 1}}},  //
    447         R2Spec{511,
    448                513,
    449                {{129, 255}},
    450                {{511 - 129, 513 - 140}},
    451                {{13, 19}},
    452                {{1, 0}}},  //
    453         R2Spec{511,
    454                513,
    455                {{129, 255}},
    456                {{511 - 129, 513 - 140}},
    457                {{13, 19}},
    458                {{0, 1}}}  //
    459         ));
    460 
    461 struct R4Spec {
    462   std::array<int64, 4> input_dims;
    463   std::array<int64, 4> input_layout;  // minor-to-major
    464   std::array<int64, 4> slice_starts;
    465   std::array<int64, 4> slice_limits;
    466   std::array<int64, 4> slice_strides;
    467 };
    468 
    469 string R4SpecToString(const ::testing::TestParamInfo<R4Spec>& data) {
    470   const R4Spec& spec = data.param;
    471   return absl::StrCat("input_", absl::StrJoin(spec.input_dims, "x"),
    472                       "__layout_", absl::StrJoin(spec.input_layout, ""),
    473                       "__starts_", absl::StrJoin(spec.slice_starts, "x"),
    474                       "__limits_", absl::StrJoin(spec.slice_limits, "x"),
    475                       "__strides_", absl::StrJoin(spec.slice_strides, "x"));
    476 }
    477 
    478 class SliceR4Test : public ClientLibraryTestBase,
    479                     public ::testing::WithParamInterface<R4Spec> {
    480  protected:
    481   void Run(const R4Spec& spec) {
    482     Array4D<float> values(spec.input_dims[0], spec.input_dims[1],
    483                           spec.input_dims[2], spec.input_dims[3]);
    484     values.FillIota(3.14159);
    485     auto expected = ReferenceUtil::Slice4D(
    486         values, spec.slice_starts, spec.slice_limits, spec.slice_strides);
    487     XlaBuilder builder(TestName());
    488     auto literal = LiteralUtil::CreateR4FromArray4DWithLayout(
    489         values, LayoutUtil::MakeLayout(spec.input_layout));
    490     auto parameter = Parameter(&builder, 0, literal.shape(), "p0");
    491     TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
    492                             client_->TransferToServer(literal));
    493     Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides);
    494     ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001));
    495   }
    496 };
    497 
    498 XLA_TEST_P(SliceR4Test, DoIt) { Run(GetParam()); }
    499 
    500 const R4Spec kR4SpecValues[] = {
    501     R4Spec{{{2, 2, 2, 2}},
    502            {{3, 2, 1, 0}},
    503            {{0, 0, 0, 0}},
    504            {{0, 0, 0, 0}},
    505            {{1, 1, 1, 1}}},  //
    506     R4Spec{{{3, 3, 4, 4}},
    507            {{3, 2, 1, 0}},
    508            {{0, 0, 0, 0}},
    509            {{3, 3, 4, 4}},
    510            {{1, 1, 2, 1}}},  //
    511     R4Spec{{{2, 3, 16, 4}},
    512            {{3, 2, 1, 0}},
    513            {{0, 0, 0, 0}},
    514            {{2, 3, 16, 4}},
    515            {{1, 1, 3, 1}}},  //
    516     R4Spec{{{4, 16, 3, 2}},
    517            {{0, 1, 2, 3}},
    518            {{1, 4, 1, 0}},
    519            {{3, 12, 3, 2}},
    520            {{1, 1, 3, 2}}},  //
    521     R4Spec{{{2, 2, 257, 129}},
    522            {{3, 2, 1, 0}},
    523            {{1, 1, 62, 64}},
    524            {{2, 2, 195, 129}},
    525            {{1, 1, 3, 1}}},  //
    526     R4Spec{{{3, 5, 257, 129}},
    527            {{3, 2, 1, 0}},
    528            {{1, 2, 61, 64}},
    529            {{3, 5, 199, 129}},
    530            {{1, 1, 3, 1}}},  //
    531     R4Spec{{{5, 8, 257, 129}},
    532            {{3, 2, 1, 0}},
    533            {{2, 3, 60, 64}},
    534            {{3, 5, 200, 68}},
    535            {{1, 1, 1, 1}}},  //
    536     R4Spec{{{8, 10, 256, 130}},
    537            {{3, 2, 1, 0}},
    538            {{1, 2, 60, 127}},
    539            {{7, 9, 166, 129}},
    540            {{4, 2, 3, 1}}},  //
    541     R4Spec{{{2, 4, 8, 4}},
    542            {{3, 2, 1, 0}},
    543            {{1, 2, 0, 1}},
    544            {{2, 4, 8, 3}},
    545            {{1, 1, 7, 1}}},  //
    546     R4Spec{{{10, 21, 256, 150}},
    547            {{3, 2, 1, 0}},
    548            {{1, 2, 9, 127}},
    549            {{9, 16, 82, 133}},
    550            {{3, 5, 7, 2}}},  //
    551     R4Spec{{{15, 25, 256, 150}},
    552            {{3, 2, 1, 0}},
    553            {{4, 6, 19, 126}},
    554            {{15, 25, 89, 135}},
    555            {{5, 7, 7, 3}}},  //
    556     R4Spec{{{2, 4, 256, 150}},
    557            {{3, 2, 1, 0}},
    558            {{1, 2, 29, 125}},
    559            {{2, 4, 159, 145}},
    560            {{1, 1, 7, 7}}},  //
    561     R4Spec{{{2, 4, 256, 150}},
    562            {{3, 2, 1, 0}},
    563            {{1, 2, 39, 119}},
    564            {{2, 4, 158, 145}},
    565            {{1, 1, 7, 11}}},  //
    566     R4Spec{{{1, 1, 5, 512}},
    567            {{3, 2, 1, 0}},
    568            {{0, 0, 0, 0}},
    569            {{1, 1, 5, 512}},
    570            {{1, 1, 4, 1}}},  //
    571     R4Spec{{{1, 1, 513, 513}},
    572            {{3, 2, 1, 0}},
    573            {{0, 0, 0, 0}},
    574            {{1, 1, 513, 513}},
    575            {{1, 1, 512, 512}}},  //
    576     R4Spec{{{1, 1, 1024, 4}},
    577            {{3, 2, 1, 0}},
    578            {{0, 0, 15, 0}},
    579            {{1, 1, 1022, 4}},
    580            {{1, 1, 23, 1}}},  //
    581     R4Spec{{{1, 1, 1024, 4}},
    582            {{3, 2, 1, 0}},
    583            {{0, 0, 14, 0}},
    584            {{1, 1, 1023, 4}},
    585            {{1, 1, 101, 1}}},  //
    586     R4Spec{{{1, 1, 4, 1024}},
    587            {{3, 2, 1, 0}},
    588            {{0, 0, 1, 20}},
    589            {{1, 1, 4, 1023}},
    590            {{1, 1, 1, 129}}},  //
    591     R4Spec{{{5, 5, 512, 1024}},
    592            {{3, 2, 1, 0}},
    593            {{1, 1, 0, 0}},
    594            {{4, 4, 512, 1024}},
    595            {{2, 2, 2, 1}}},  //
    596     R4Spec{{{5, 5, 512, 1024}},
    597            {{3, 2, 1, 0}},
    598            {{1, 1, 0, 0}},
    599            {{4, 4, 512, 1024}},
    600            {{2, 1, 1, 400}}},  //
    601     R4Spec{{{32, 64, 128, 256}},
    602            {{3, 2, 1, 0}},
    603            {{10, 20, 30, 40}},
    604            {{30, 60, 100, 200}},
    605            {{11, 21, 31, 41}}},  //
    606     R4Spec{{{1, 1, 14, 2048}},
    607            {{3, 2, 1, 0}},
    608            {{0, 0, 2, 0}},
    609            {{1, 1, 14, 2}},
    610            {{1, 1, 1, 1}}},  //
    611 };
    612 
    613 INSTANTIATE_TEST_CASE_P(SliceR4TestInstantiation, SliceR4Test,
    614                         ::testing::ValuesIn(kR4SpecValues), R4SpecToString);
    615 
    616 }  // namespace
    617 }  // namespace xla
    618