Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Tests the reduce-window XLA operation.
     17 
     18 #include <limits>
     19 #include <memory>
     20 
     21 #include "tensorflow/compiler/xla/array2d.h"
     22 #include "tensorflow/compiler/xla/array3d.h"
     23 #include "tensorflow/compiler/xla/array4d.h"
     24 #include "tensorflow/compiler/xla/client/computation_builder.h"
     25 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
     26 #include "tensorflow/compiler/xla/client/local_client.h"
     27 #include "tensorflow/compiler/xla/client/padding.h"
     28 #include "tensorflow/compiler/xla/reference_util.h"
     29 #include "tensorflow/compiler/xla/shape_util.h"
     30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     32 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     33 #include "tensorflow/compiler/xla/tests/test_macros.h"
     34 #include "tensorflow/compiler/xla/xla_data.pb.h"
     35 #include "tensorflow/core/lib/core/status.h"
     36 #include "tensorflow/core/lib/core/status_test_util.h"
     37 #include "tensorflow/core/lib/gtl/array_slice.h"
     38 #include "tensorflow/core/platform/test.h"
     39 #include "tensorflow/core/platform/types.h"
     40 
     41 namespace xla {
     42 namespace {
     43 
     44 #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16
     45 // Tests both F32 and BF16.
     46 static std::array<bool, 2> use_bfloat16_params{false, true};
     47 #else
     48 // Only tests F32.
     49 static std::array<bool, 1> use_bfloat16_params{false};
     50 #endif
     51 
     52 class ReduceWindowTestBase : public ClientLibraryTestBase {
     53  public:
     54   ErrorSpec DefaultErrorSpec() const {
     55     if (use_bfloat16()) {
     56       return ErrorSpec(1e-1, 5e-2);
     57     } else {
     58       return ErrorSpec(1e-3, 1e-3);
     59     }
     60   }
     61 };
     62 
     63 class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
     64                          public ReduceWindowTestBase {
     65  public:
     66   ReduceWindowTest() : builder_(client_, TestName()) {
     67     set_use_bfloat16(GetParam());
     68   }
     69 
     70   void ReduceWindowAdd(const ComputationDataHandle& input,
     71                        tensorflow::gtl::ArraySlice<int64> window_dimensions,
     72                        tensorflow::gtl::ArraySlice<int64> window_strides,
     73                        Padding padding) {
     74     auto init =
     75         CreateConstantFromLiteral(*Literal::CreateR0<float>(0.0f), &builder_);
     76     builder_.ReduceWindow(input, init,
     77                           CreateScalarAddComputation(FloatType(), &builder_),
     78                           window_dimensions, window_strides, padding);
     79   }
     80 
     81   void ReduceWindowMax(const ComputationDataHandle& input,
     82                        tensorflow::gtl::ArraySlice<int64> window_dimensions,
     83                        tensorflow::gtl::ArraySlice<int64> window_strides,
     84                        Padding padding) {
     85     auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_);
     86     builder_.ReduceWindow(input, init, CreateScalarMax(), window_dimensions,
     87                           window_strides, padding);
     88   }
     89 
     90   void ReduceWindowMin(const ComputationDataHandle& input,
     91                        tensorflow::gtl::ArraySlice<int64> window_dimensions,
     92                        tensorflow::gtl::ArraySlice<int64> window_strides,
     93                        Padding padding) {
     94     auto init = CreateConstantFromLiteral(Literal::MaxValue(F32), &builder_);
     95     builder_.ReduceWindow(input, init,
     96                           CreateScalarMinComputation(FloatType(), &builder_),
     97                           window_dimensions, window_strides, padding);
     98   }
     99 
    100   ComputationBuilder builder_;
    101 };
    102 
    103 TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
    104   const auto input = CreateConstantFromLiteral(
    105       *Literal::CreateR1<float>({1, 1, 1, 1}), &builder_);
    106   const auto init_value =
    107       CreateConstantFromLiteral(*Literal::CreateR0<float>(0), &builder_);
    108   TF_ASSERT_OK(builder_.first_error());
    109   builder_.ReduceWindow(input, init_value,
    110                         CreateScalarAddComputation(FloatType(), &builder_),
    111                         /*window_dimensions=*/{1, 2},
    112                         /*window_strides=*/{1}, Padding::kValid);
    113   ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT)
    114       << builder_.first_error();
    115   ASSERT_THAT(builder_.first_error().error_message(),
    116               ::testing::HasSubstr("Want input dimensions size"));
    117 }
    118 
    119 // Regression test for b/68964348.
    120 TEST_P(ReduceWindowTest, R0ReduceWindow) {
    121   const auto input =
    122       CreateConstantFromLiteral(*Literal::CreateR0<float>(42.0), &builder_);
    123   const auto init =
    124       CreateConstantFromLiteral(*Literal::CreateR0<float>(1.0), &builder_);
    125   builder_.ReduceWindow(input, init,
    126                         CreateScalarAddComputation(FloatType(), &builder_),
    127                         /*window_dimensions=*/{},
    128                         /*window_strides=*/{}, Padding::kSame);
    129   ComputeAndCompareLiteral(&builder_, *Literal::CreateR0<float>(43.0), {},
    130                            ErrorSpec(0.00001));
    131 }
    132 
    133 TEST_P(ReduceWindowTest, Min3In5Stride2) {
    134   const auto input = CreateConstantFromLiteral(
    135       *Literal::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
    136   ReduceWindowMin(input, {3}, {2}, Padding::kValid);
    137   ComputeAndCompareLiteral(&builder_, *Literal::CreateR1<float>({100, 1}), {},
    138                            ErrorSpec(0.00001));
    139 }
    140 
    141 TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) {
    142   const auto input = CreateConstantFromLiteral(
    143       *Literal::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
    144   ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1},
    145                   Padding::kSame);
    146   ComputeAndCompareLiteral(&builder_,
    147                            *Literal::CreateR1<float>({1000, 100, 10, 1, 1}), {},
    148                            ErrorSpec(0.00001));
    149 }
    150 
    151 XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) {
    152   Array4D<float> input_array(1, 0, 2, 1);
    153   const auto input = CreateConstantFromArray(input_array, &builder_);
    154   Padding padding = Padding::kSame;
    155   ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
    156 
    157   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
    158                                               {1, 1, 1, 1}, padding);
    159 
    160   ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {},
    161                            DefaultErrorSpec());
    162 }
    163 
    164 TEST_P(ReduceWindowTest, NonSquareSmall) {
    165   Array4D<float> input_array(1, 2, 2, 1);
    166   input_array.FillRandom(2.f, 2.f);
    167   const auto input = CreateConstantFromArray(input_array, &builder_);
    168 
    169   Padding padding = Padding::kSame;
    170   ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
    171 
    172   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
    173                                               {1, 1, 1, 1}, padding);
    174 
    175   ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {},
    176                            DefaultErrorSpec());
    177 }
    178 
    179 TEST_P(ReduceWindowTest, MiddleDimsSmall) {
    180   Array4D<float> input_array(1, 3, 3, 1);
    181   input_array.FillRandom(2.f, 2.f);
    182   const auto input = CreateConstantFromArray(input_array, &builder_);
    183   Padding padding = Padding::kSame;
    184   ReduceWindowAdd(input, {1, 1, 1, 1}, {1, 2, 2, 1}, padding);
    185 
    186   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1},
    187                                               {1, 2, 2, 1}, padding);
    188 
    189   ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {},
    190                            DefaultErrorSpec());
    191 }
    192 
    193 TEST_P(ReduceWindowTest, Along2ndMinorDim) {
    194   Array4D<float> input_array(3, 6, 7, 32);
    195   input_array.FillRandom(2.f, 2.f);
    196   const auto input = CreateConstantFromArray(input_array, &builder_);
    197 
    198   // The parameters of this reduction mimic feature norm (e.g. LRN).
    199   int lrn_diameter = 7;  // diameter = 2*radius + 1 --> must be odd
    200   Padding padding = Padding::kSame;
    201   ReduceWindowAdd(input, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
    202 
    203   auto res = ReferenceUtil::ReduceWindow4DAdd(
    204       input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
    205 
    206   ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {},
    207                            DefaultErrorSpec());
    208 }
    209 
    210 TEST_P(ReduceWindowTest, AmongMajor2Dims) {
    211   Array4D<float> input_array(4, 4, 6, 8);
    212   input_array.FillWithMinorDimNum();
    213   const auto input_data_handle =
    214       CreateConstantFromArray(input_array, &builder_);
    215 
    216   int win_len = 3;
    217   int win_stride = 1;
    218 
    219   Padding padding = Padding::kSame;
    220   // Reduce only along the x and y dimensions, according to the win_len.
    221   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
    222                   {win_stride, win_stride, 1, 1}, padding);
    223 
    224   auto result = ReferenceUtil::ReduceWindow4DAdd(
    225       input_array, 0.0f, {win_len, win_len, 1, 1},
    226       {win_stride, win_stride, 1, 1}, padding);
    227 
    228   ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
    229                            DefaultErrorSpec());
    230 }
    231 
    232 TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
    233   Array4D<float> input_array(9, 12, 4, 89);
    234   input_array.FillRandom(2.f, 2.f);
    235 
    236   int win_len = 3;
    237   int win_stride = 2;
    238 
    239   const auto input_data_handle =
    240       CreateConstantFromArray(input_array, &builder_);
    241 
    242   Padding padding = Padding::kSame;
    243   // Reduce only along the x and y dimensions, according to the win_len.
    244   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
    245                   {win_stride, win_stride, 1, 1}, padding);
    246 
    247   auto result = ReferenceUtil::ReduceWindow4DAdd(
    248       input_array, 0.0f, {win_len, win_len, 1, 1},
    249       {win_stride, win_stride, 1, 1}, padding);
    250 
    251   ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
    252                            DefaultErrorSpec());
    253 }
    254 
    255 // Tests a reduction function that is not a simple add/min/max/etc.
    256 XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
    257   Array4D<float> input_array(1, 2, 2, 1);
    258   input_array(0, 0, 0, 0) = 1;
    259   input_array(0, 0, 1, 0) = 2;
    260   input_array(0, 1, 0, 0) = 3;
    261   input_array(0, 1, 1, 0) = 4;
    262   const auto input = CreateConstantFromArray(input_array, &builder_);
    263 
    264   Padding padding = Padding::kValid;
    265   const Shape scalar = ShapeUtil::MakeShape(FloatType(), {});
    266   auto b = builder_.CreateSubBuilder("unusual");
    267   auto lhs = b->Parameter(0, scalar, "lhs");
    268   auto rhs = b->Parameter(1, scalar, "rhs");
    269   b->Min(b->Add(lhs, rhs),
    270          CreateConstantFromLiteral(*Literal::CreateR0<float>(8.0f), b.get()));
    271   Computation reduce_fn = b->BuildAndNoteError();
    272 
    273   builder_.ReduceWindow(
    274       input,
    275       CreateConstantFromLiteral(*Literal::CreateR0<float>(0.0f), &builder_),
    276       reduce_fn,
    277       /*window_dimensions=*/{1, 1, 2, 1},
    278       /*window_strides=*/{1, 1, 1, 1}, padding);
    279 
    280   const auto reduce_func = [](float arg1, float arg2) {
    281     return std::min<float>(arg1 + arg2, 8.0f);
    282   };
    283 
    284   auto expected =
    285       ReferenceUtil::ReduceWindow4DGeneric(input_array, 0.0f, reduce_func,
    286                                            /*window=*/{1, 1, 2, 1},
    287                                            /*stride=*/{1, 1, 1, 1}, padding);
    288 
    289   ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*expected), {},
    290                            DefaultErrorSpec());
    291 }
    292 
    293 TEST_P(ReduceWindowTest, R4UnitWindow) {
    294   Array4D<float> input_array(13, 12, 8, 15);
    295   input_array.FillRandom(2.f, 2.f);
    296   std::unique_ptr<Literal> input_literal =
    297       Literal::CreateR4FromArray4DWithLayout(
    298           input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
    299   ComputationDataHandle input;
    300   auto input_data = CreateParameterAndTransferLiteral(
    301       0, *input_literal, "parameter", &builder_, &input);
    302 
    303   Padding padding = Padding::kSame;
    304   ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding);
    305 
    306   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1},
    307                                               {1, 4, 1, 1}, padding);
    308 
    309   ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res),
    310                            {input_data.get()}, DefaultErrorSpec());
    311 }
    312 
    313 XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
    314   std::vector<int64> input_dims(6, 8);
    315   auto shape = ShapeUtil::MakeShape(F32, input_dims);
    316 
    317   std::unique_ptr<Literal> arg_literal = Literal::CreateFromShape(shape);
    318   auto generator = [&](tensorflow::gtl::ArraySlice<int64> indexes) -> float {
    319     return 1.0f;
    320   };
    321   TF_EXPECT_OK(arg_literal->Populate<float>(generator));
    322 
    323   const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
    324 
    325   Padding padding = Padding::kValid;
    326   ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
    327 
    328   std::vector<int64> output_layout = {1, 5, 3, 2, 0, 4};
    329   std::vector<int64> output_dims = {6, 8, 6, 6, 8, 8};
    330   Shape result_shape =
    331       ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout);
    332   std::unique_ptr<Literal> expected = Literal::CreateFromShape(result_shape);
    333   auto out_generator =
    334       [&](tensorflow::gtl::ArraySlice<int64> indexes) -> float {
    335     return 27.0f;
    336   };
    337   TF_EXPECT_OK(expected->Populate<float>(out_generator));
    338 
    339   ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
    340 }
    341 
    342 XLA_TEST_P(ReduceWindowTest, R6Add) {
    343   std::vector<int64> input_dims(6, 8);
    344   auto shape = ShapeUtil::MakeShape(F32, input_dims);
    345 
    346   std::unique_ptr<Literal> arg_literal =
    347       Literal::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
    348 
    349   const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
    350 
    351   Padding padding = Padding::kValid;
    352   ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
    353 
    354   std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8};
    355   std::unique_ptr<Literal> expected =
    356       Literal::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
    357 
    358   ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
    359 }
    360 
    361 XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
    362   Array4D<float> input_array(2, 1, 27, 119);
    363   input_array.FillRandom(2.0f);
    364   std::unique_ptr<Literal> input_literal =
    365       Literal::CreateR4FromArray4DWithLayout(
    366           input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
    367   ComputationDataHandle input;
    368   auto input_data = CreateParameterAndTransferLiteral(
    369       0, *input_literal, "parameter", &builder_, &input);
    370 
    371   int win_len = 1;
    372   int stride = 8;
    373   Padding padding = Padding::kSame;
    374   ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    375 
    376   auto res = ReferenceUtil::ReduceWindow4DAdd(
    377       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    378 
    379   ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res),
    380                            {input_data.get()}, DefaultErrorSpec());
    381 }
    382 
    383 XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
    384   Array4D<float> input_array(3, 2, 4, 64);
    385   input_array.FillRandom(2.0f);
    386   std::unique_ptr<Literal> input_literal =
    387       Literal::CreateR4FromArray4DWithLayout(
    388           input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
    389   ComputationDataHandle input;
    390   auto input_data = CreateParameterAndTransferLiteral(
    391       0, *input_literal, "parameter", &builder_, &input);
    392 
    393   int win_len = 3;
    394   int stride = 1;
    395   Padding padding = Padding::kSame;
    396   ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    397 
    398   auto res = ReferenceUtil::ReduceWindow4DAdd(
    399       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    400 
    401   ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res),
    402                            {input_data.get()}, DefaultErrorSpec());
    403 }
    404 
    405 XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
    406   Array4D<float> input_array(1, 3, 12, 200);
    407   input_array.FillRandom(2.0f);
    408   std::unique_ptr<Literal> input_literal =
    409       Literal::CreateR4FromArray4DWithLayout(
    410           input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
    411   ComputationDataHandle input;
    412   auto input_data = CreateParameterAndTransferLiteral(
    413       0, *input_literal, "parameter", &builder_, &input);
    414 
    415   int win_len = 8;
    416   int stride = 5;
    417   Padding padding = Padding::kSame;
    418   ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    419 
    420   auto res = ReferenceUtil::ReduceWindow4DAdd(
    421       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    422 
    423   ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res),
    424                            {input_data.get()}, DefaultErrorSpec());
    425 }
    426 
    427 TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) {
    428   Array4D<float> input_array(6, 4, 10, 130);
    429   input_array.FillRandom(2.0f);
    430 
    431   int win_len = 3;
    432   int win_stride = 2;
    433 
    434   Padding padding = Padding::kSame;
    435   const auto input_data_handle =
    436       CreateConstantFromArray(input_array, &builder_);
    437   // Reduce only along the x and y dimensions, according to the win_len.
    438   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
    439                   {win_stride, win_stride, 1, 1}, padding);
    440 
    441   auto result = ReferenceUtil::ReduceWindow4DAdd(
    442       input_array, 0.0f, {win_len, win_len, 1, 1},
    443       {win_stride, win_stride, 1, 1}, padding);
    444   ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
    445                            DefaultErrorSpec());
    446 }
    447 
    448 XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) {
    449   std::vector<float> input_vector(128 * 9, 1);
    450   const auto input = CreateConstantFromLiteral(
    451       *Literal::CreateR1<float>(input_vector), &builder_);
    452   ReduceWindowAdd(input, {32}, {128}, Padding::kValid);
    453   ComputeAndCompareLiteral(
    454       &builder_,
    455       *Literal::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
    456       DefaultErrorSpec());
    457 }
    458 
    459 XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) {
    460   std::vector<float> input_vector{
    461       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    462       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    463       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    464       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    465       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    466       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    467       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    468       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
    469   const auto input = CreateConstantFromLiteral(
    470       *Literal::CreateR1<float>(input_vector), &builder_);
    471   ReduceWindowAdd(input, {128}, {128}, Padding::kValid);
    472   ComputeAndCompareLiteral(&builder_, *Literal::CreateR1<float>({1088}), {},
    473                            DefaultErrorSpec());
    474 }
    475 
    476 XLA_TEST_P(ReduceWindowTest, Add128In128) {
    477   std::vector<float> input_vector{
    478       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    479       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    480       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    481       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    482       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    483       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    484       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    485       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
    486   const auto input = CreateConstantFromLiteral(
    487       *Literal::CreateR1<float>(input_vector), &builder_);
    488   ReduceWindowAdd(input, {128}, {1}, Padding::kValid);
    489   ComputeAndCompareLiteral(&builder_, *Literal::CreateR1<float>({1088}), {},
    490                            DefaultErrorSpec());
    491 }
    492 
    493 // Regression test for a bug that appeared in Inception (b/34784899).
    494 TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
    495   Array2D<float> input_array(14, 14, 1.0f);
    496   const auto input = CreateConstantFromArray(input_array, &builder_);
    497 
    498   int win_len = 3;
    499   int stride = 1;
    500   Padding padding = Padding::kSame;
    501   ReduceWindowAdd(input, {win_len, win_len}, {stride, stride}, padding);
    502 
    503   auto res = ReferenceUtil::ReduceWindow2DAdd(
    504       input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding);
    505 
    506   ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray<float>(*res),
    507                            {}, DefaultErrorSpec());
    508 }
    509 
    510 TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
    511   Array2D<float> input_array(6, 4, 1.0f);
    512   ComputationDataHandle input = builder_.Broadcast(
    513       CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4});
    514 
    515   Padding padding = Padding::kSame;
    516   ReduceWindowAdd(input, {4, 2}, {3, 3}, padding);
    517 
    518   auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3},
    519                                               padding);
    520 
    521   ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray<float>(*res),
    522                            {}, DefaultErrorSpec());
    523 }
    524 
    525 INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest,
    526                         ::testing::ValuesIn(use_bfloat16_params));
    527 
    528 enum Reducer { kAdd, kMax };
    529 
    530 struct R4ReduceWindowTestData {
    531   int64 base_bounds[4];
    532   int64 window_bounds[4];
    533   int64 strides[4];
    534   int64 pad_low[4];
    535   int64 pad_high[4];
    536   int64 layout[4];
    537 
    538   Reducer reducer;
    539 };
    540 
    541 string R4ReduceWindowTestDataToString(
    542     const ::testing::TestParamInfo<
    543         ::testing::tuple<R4ReduceWindowTestData, bool>>& data) {
    544   const auto& param = ::testing::get<0>(data.param);
    545   string str = tensorflow::strings::StrCat(
    546       "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"),  //
    547       "__window_bounds_",
    548       tensorflow::str_util::Join(param.window_bounds, "x"),            //
    549       "__strides_", tensorflow::str_util::Join(param.strides, "x"),    //
    550       "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"),    //
    551       "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"),  //
    552       "__layout_", tensorflow::str_util::Join(param.layout, "_"),      //
    553       (param.reducer == kAdd) ? "_add" : "_max");
    554   CHECK(param.reducer == kAdd || param.reducer == kMax);
    555 
    556   // Test names are not allowed to contain the '-' character.
    557   std::replace(str.begin(), str.end(), '-', 'n');
    558   if (::testing::get<1>(data.param)) {
    559     str = tensorflow::strings::StrCat(str, "_bfloat16");
    560   }
    561   return str;
    562 }
    563 
    564 class R4ReduceWindowTest : public ReduceWindowTestBase,
    565                            public ::testing::WithParamInterface<
    566                                ::testing::tuple<R4ReduceWindowTestData, bool>> {
    567  protected:
    568   R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
    569 
    570   void DoIt() {
    571     ComputationBuilder b(client_, TestName());
    572     const auto& param = ::testing::get<0>(GetParam());
    573 
    574     const float kInitValue = 0.0f;
    575 
    576     Array4D<float> input(param.base_bounds[0], param.base_bounds[1],
    577                          param.base_bounds[2], param.base_bounds[3]);
    578     input.FillIota(1);
    579     std::unique_ptr<Literal> input_literal =
    580         Literal::CreateR4FromArray4DWithLayout(
    581             input, LayoutUtil::MakeLayout(param.layout));
    582     ComputationDataHandle parameter;
    583     auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
    584                                                        &b, &parameter);
    585 
    586     std::vector<std::pair<int64, int64>> padding(4);
    587     for (int i = 0; i < 4; ++i) {
    588       padding[i] = {param.pad_low[i], param.pad_high[i]};
    589     }
    590 
    591     auto init_value =
    592         CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
    593     CHECK(param.reducer == kAdd || param.reducer == kMax);
    594     auto computation = param.reducer == kAdd
    595                            ? CreateScalarAddComputation(FloatType(), &b)
    596                            : CreateScalarMaxComputation(FloatType(), &b);
    597     b.ReduceWindowWithGeneralPadding(
    598         /*operand=*/parameter,
    599         /*init_value=*/init_value,
    600         /*computation=*/computation,
    601         /*window_dimensions=*/param.window_bounds,
    602         /*window_strides=*/param.strides,
    603         /*padding=*/padding);
    604 
    605     CHECK(param.reducer == kAdd || param.reducer == kMax);
    606     auto reduce_func = param.reducer == kAdd
    607                            ? +[](float a, float b) { return a + b; }
    608                            : +[](float a, float b) { return std::max(a, b); };
    609     std::unique_ptr<Array4D<float>> expected =
    610         ReferenceUtil::ReduceWindow4DGeneric(
    611             /*operand=*/input,
    612             /*init=*/kInitValue,
    613             /*reduce_func=*/reduce_func,
    614             /*window=*/param.window_bounds,
    615             /*stride=*/param.strides,
    616             /*padding=*/padding);
    617     std::unique_ptr<Literal> expected_literal =
    618         Literal::CreateFromArray(*expected);
    619     const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
    620         input_literal->shape().element_type(),
    621         AsInt64Slice(expected_literal->shape().dimensions()), param.layout);
    622     ComputeAndCompareLiteral(&b, *expected_literal, {input_arg.get()},
    623                              DefaultErrorSpec(), &expected_shape_with_layout);
    624   }
    625 };
    626 
    627 TEST_P(R4ReduceWindowTest, DoIt) { DoIt(); }
    628 
    629 // base_bounds, window_bounds, strides, pad_low, pad_high
    630 const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
    631     // Minimal edge case.
    632     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 1, 1},
    633                            /*window_bounds=*/{1, 1, 1, 1},
    634                            /*strides=*/{1, 1, 1, 1},
    635                            /*pad_low=*/{0, 0, 0, 0},
    636                            /*pad_high=*/{0, 0, 0, 0},
    637                            /*layout=*/{3, 2, 1, 0},
    638                            /*reducer=*/kAdd},
    639 
    640     // Arbitrary padding (not kSame or kValid).
    641     R4ReduceWindowTestData{/*base_bounds=*/{9, 12, 4, 89},
    642                            /*window_bounds=*/{3, 3, 1, 1},
    643                            /*strides=*/{2, 2, 1, 1},
    644                            /*pad_low=*/{4, 4, 0, 0},
    645                            /*pad_high=*/{4, 4, 0, 0},
    646                            /*layout=*/{3, 2, 1, 0},
    647                            /*reducer=*/kAdd},
    648 
    649     // Zero base bound edge case.
    650     R4ReduceWindowTestData{/*base_bounds=*/{1, 0, 1, 1},
    651                            /*window_bounds=*/{1, 1, 1, 1},
    652                            /*strides=*/{1, 1, 1, 1},
    653                            /*pad_low=*/{0, 0, 0, 0},
    654                            /*pad_high=*/{0, 0, 0, 0},
    655                            /*layout=*/{3, 2, 1, 0},
    656                            /*reducer=*/kAdd},
    657 
    658     // With non-1x1 window.
    659     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
    660                            /*window_bounds=*/{2, 3, 1, 1},
    661                            /*strides=*/{1, 1, 1, 1},
    662                            /*pad_low=*/{0, 0, 0, 0},
    663                            /*pad_high=*/{0, 0, 0, 0},
    664                            /*layout=*/{3, 2, 1, 0},
    665                            /*reducer=*/kAdd},
    666 
    667     // With max instead of add.
    668     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
    669                            /*window_bounds=*/{2, 3, 1, 1},
    670                            /*strides=*/{1, 1, 1, 1},
    671                            /*pad_low=*/{0, 0, 0, 0},
    672                            /*pad_high=*/{0, 0, 0, 0},
    673                            /*layout=*/{3, 2, 1, 0},
    674                            /*reducer=*/kMax},
    675 
    676     // With stride.
    677     R4ReduceWindowTestData{/*base_bounds=*/{4, 10, 17, 140},
    678                            /*window_bounds=*/{3, 2, 1, 1},
    679                            /*strides=*/{2, 4, 1, 1},
    680                            /*pad_low=*/{0, 0, 0, 0},
    681                            /*pad_high=*/{0, 0, 0, 0},
    682                            /*layout=*/{3, 2, 1, 0},
    683                            /*reducer=*/kAdd},
    684 
    685     // With low padding.
    686     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
    687                            /*window_bounds=*/{3, 2, 1, 1},
    688                            /*strides=*/{2, 2, 1, 1},
    689                            /*pad_low=*/{3, 2, 0, 0},
    690                            /*pad_high=*/{0, 0, 0, 0},
    691                            /*layout=*/{3, 2, 1, 0},
    692                            /*reducer=*/kAdd},
    693 
    694     // With high padding.
    695     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
    696                            /*window_bounds=*/{3, 2, 1, 1},
    697                            /*strides=*/{2, 2, 1, 1},
    698                            /*pad_low=*/{0, 0, 0, 0},
    699                            /*pad_high=*/{2, 3, 0, 0},
    700                            /*layout=*/{3, 2, 1, 0},
    701                            /*reducer=*/kAdd},
    702 
    703     // Window touches both sides of the padding simultaneously.
    704     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140},
    705                            /*window_bounds=*/{3, 3, 1, 1},
    706                            /*strides=*/{1, 1, 1, 1},
    707                            /*pad_low=*/{1, 1, 0, 0},
    708                            /*pad_high=*/{1, 1, 0, 0},
    709                            /*layout=*/{3, 2, 1, 0},
    710                            /*reducer=*/kAdd},
    711 
    712     // Window is entirely in the padding for some positions.
    713     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140},
    714                            /*window_bounds=*/{3, 3, 1, 1},
    715                            /*strides=*/{1, 1, 1, 1},
    716                            /*pad_low=*/{4, 4, 0, 0},
    717                            /*pad_high=*/{4, 4, 0, 0},
    718                            /*layout=*/{3, 2, 1, 0},
    719                            /*reducer=*/kAdd},
    720 
    721     // Zero base bound with padding edge case.
    722     R4ReduceWindowTestData{/*base_bounds=*/{2, 0, 3, 4},
    723                            /*window_bounds=*/{1, 1, 1, 1},
    724                            /*strides=*/{1, 1, 1, 1},
    725                            /*pad_low=*/{0, 1, 0, 0},
    726                            /*pad_high=*/{0, 0, 0, 0},
    727                            /*layout=*/{3, 2, 1, 0},
    728                            /*reducer=*/kAdd},
    729 
    730     // With stride, low padding and high padding.
    731     R4ReduceWindowTestData{/*base_bounds=*/{4, 3, 17, 140},
    732                            /*window_bounds=*/{3, 4, 1, 1},
    733                            /*strides=*/{3, 1, 1, 1},
    734                            /*pad_low=*/{10, 1, 0, 0},
    735                            /*pad_high=*/{2, 3, 0, 0},
    736                            /*layout=*/{3, 2, 1, 0},
    737                            /*reducer=*/kAdd},
    738 
    739     // With second minor dimension == 9.
    740     R4ReduceWindowTestData{/*base_bounds=*/{2, 3, 9, 127},
    741                            /*window_bounds=*/{1, 1, 1, 1},
    742                            /*strides=*/{1, 1, 1, 1},
    743                            /*pad_low=*/{0, 0, 0, 0},
    744                            /*pad_high=*/{0, 0, 0, 0},
    745                            /*layout=*/{3, 2, 1, 0},
    746                            /*reducer=*/kAdd},
    747 
    748     // With minor dimension == 129.
    749     R4ReduceWindowTestData{/*base_bounds=*/{3, 2, 7, 129},
    750                            /*window_bounds=*/{1, 1, 1, 1},
    751                            /*strides=*/{1, 1, 1, 1},
    752                            /*pad_low=*/{0, 0, 0, 0},
    753                            /*pad_high=*/{0, 0, 0, 0},
    754                            /*layout=*/{3, 2, 1, 0},
    755                            /*reducer=*/kAdd},
    756 
    757     // With minor dims reduction and non-overlapped stride.
    758     R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16},
    759                            /*window_bounds=*/{1, 1, 2, 2},
    760                            /*strides=*/{1, 1, 2, 2},
    761                            /*pad_low=*/{0, 0, 0, 0},
    762                            /*pad_high=*/{0, 0, 0, 0},
    763                            /*layout=*/{3, 2, 1, 0},
    764                            /*reducer=*/kAdd},
    765 
    766     // With minor dims reduction and overlapped stride.
    767     R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16},
    768                            /*window_bounds=*/{1, 1, 4, 4},
    769                            /*strides=*/{1, 1, 2, 2},
    770                            /*pad_low=*/{0, 0, 0, 0},
    771                            /*pad_high=*/{1, 0, 0, 0},
    772                            /*layout=*/{3, 2, 1, 0},
    773                            /*reducer=*/kAdd},
    774 };
    775 
    776 INSTANTIATE_TEST_CASE_P(
    777     R4ReduceWindowTestInstantiation, R4ReduceWindowTest,
    778     ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues),
    779                        ::testing::ValuesIn(use_bfloat16_params)),
    780     R4ReduceWindowTestDataToString);
    781 
    782 class R4ReduceWindowLargeTest : public R4ReduceWindowTest {};
    783 
    784 XLA_TEST_P(R4ReduceWindowLargeTest, DISABLED_ON_INTERPRETER(DoIt)) { DoIt(); }
    785 
    786 // Test cases that are large/slow/failed.
    787 const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = {
    788     R4ReduceWindowTestData{/*base_bounds=*/{28, 28, 256, 128},
    789                            /*window_bounds=*/{3, 3, 1, 5},
    790                            /*strides=*/{1, 1, 1, 5},
    791                            /*pad_low=*/{1, 1, 0, 0},
    792                            /*pad_high=*/{1, 1, 0, 0},
    793                            /*layout=*/{3, 2, 1, 0},
    794                            /*reducer=*/kMax},
    795 
    796     R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 64, 128},
    797                            /*window_bounds=*/{3, 3, 1, 1},
    798                            /*strides=*/{2, 2, 1, 1},
    799                            /*pad_low=*/{0, 0, 0, 0},
    800                            /*pad_high=*/{1, 1, 0, 0},
    801                            /*layout=*/{3, 2, 1, 0},
    802                            /*reducer=*/kAdd},
    803 
    804     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 32768 - 3, 2},
    805                            /*window_bounds=*/{1, 1, 4, 1},
    806                            /*strides=*/{1, 1, 4, 1},
    807                            /*pad_low=*/{0, 0, 1, 0},
    808                            /*pad_high=*/{0, 0, 2, 0},
    809                            /*layout=*/{3, 2, 1, 0},
    810                            /*reducer=*/kMax},
    811 };
    812 
    813 INSTANTIATE_TEST_CASE_P(
    814     R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest,
    815     ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues),
    816                        ::testing::ValuesIn(use_bfloat16_params)),
    817     R4ReduceWindowTestDataToString);
    818 
    819 class R4ReduceWindowAnyDimsTest : public R4ReduceWindowTest {};
    820 
    821 // TODO(b/72234705): Fix the test cases failed on CPU and GPU.
    822 XLA_TEST_P(R4ReduceWindowAnyDimsTest,
    823            DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) {
    824   DoIt();
    825 }
    826 
    827 const R4ReduceWindowTestData kR4ReduceWindowAnyDimsTestValues[] = {
    828     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
    829                            /*window_bounds=*/{2, 3, 4, 5},
    830                            /*strides=*/{1, 1, 1, 1},
    831                            /*pad_low=*/{0, 0, 0, 0},
    832                            /*pad_high=*/{0, 0, 0, 0},
    833                            /*layout=*/{3, 2, 1, 0},
    834                            /*reducer=*/kAdd},
    835     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
    836                            /*window_bounds=*/{2, 3, 1, 1},
    837                            /*strides=*/{1, 1, 1, 1},
    838                            /*pad_low=*/{0, 0, 0, 0},
    839                            /*pad_high=*/{0, 0, 0, 0},
    840                            /*layout=*/{3, 2, 1, 0},
    841                            /*reducer=*/kMax},
    842     // With 0321 layout.
    843     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
    844                            /*window_bounds=*/{2, 3, 4, 5},
    845                            /*strides=*/{1, 2, 3, 4},
    846                            /*pad_low=*/{0, 0, 0, 0},
    847                            /*pad_high=*/{0, 0, 0, 0},
    848                            /*layout=*/{0, 3, 2, 1},
    849                            /*reducer=*/kAdd},
    850 
    851     // With 0123 layout.
    852     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 23},
    853                            /*window_bounds=*/{2, 3, 7, 9},
    854                            /*strides=*/{1, 2, 5, 8},
    855                            /*pad_low=*/{0, 0, 0, 0},
    856                            /*pad_high=*/{0, 0, 0, 0},
    857                            /*layout=*/{0, 1, 2, 3},
    858                            /*reducer=*/kAdd},
    859 };
    860 
    861 INSTANTIATE_TEST_CASE_P(
    862     R4ReduceWindowAnyDimsTestInstantiation, R4ReduceWindowAnyDimsTest,
    863     ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowAnyDimsTestValues),
    864                        ::testing::ValuesIn(use_bfloat16_params)),
    865     R4ReduceWindowTestDataToString);
    866 
    867 struct R3ReduceWindowTestData {
    868   int64 base_bounds[3];
    869   int64 window_bounds[3];
    870   int64 strides[3];
    871   int64 layout[3];
    872   Padding padding;
    873   Reducer reducer;
    874 } kR3TestCases[] = {
    875     {/*base_bounds=*/{2, 1, 2}, /*window_bounds=*/{1, 1, 2},
    876      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
    877      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    878     {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2},
    879      /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
    880      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
    881     {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2},
    882      /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
    883      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    884     {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
    885      /*strides=*/{1, 2, 2}, /*layout=*/{2, 1, 0},
    886      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    887     {/*base_bounds=*/{10, 21, 129}, /*window_bounds=*/{2, 9, 1},
    888      /*strides=*/{5, 2, 1}, /*layout=*/{2, 1, 0},
    889      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
    890     {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
    891      /*strides=*/{1, 2, 2}, /*layout=*/{0, 1, 2},
    892      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    893     {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
    894      /*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2},
    895      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    896 };
    897 
    898 string R3ReduceWindowTestDataToString(
    899     const ::testing::TestParamInfo<
    900         ::testing::tuple<R3ReduceWindowTestData, bool>>& data) {
    901   const auto& param = ::testing::get<0>(data.param);
    902   string str = tensorflow::strings::StrCat(
    903       "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"),
    904       "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"),
    905       "__strides_", tensorflow::str_util::Join(param.strides, "x"),
    906       "__padding_", param.padding == Padding::kSame ? "same" : "valid",
    907       "__layout_", param.layout[0], "_", param.layout[1], "_", param.layout[2],
    908       "__reducer_", param.reducer == kAdd ? "add" : "max");
    909   if (::testing::get<1>(data.param)) {
    910     str = tensorflow::strings::StrCat(str, "_bfloat16");
    911   }
    912   return str;
    913 }
    914 
    915 class R3ReduceWindowTest : public ReduceWindowTestBase,
    916                            public ::testing::WithParamInterface<
    917                                ::testing::tuple<R3ReduceWindowTestData, bool>> {
    918  protected:
    919   R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
    920 };
    921 
    922 TEST_P(R3ReduceWindowTest, Add) {
    923   ComputationBuilder b(client_, TestName());
    924   const auto& param = ::testing::get<0>(GetParam());
    925   CHECK(param.reducer == kAdd);
    926 
    927   const float kInitValue = 0.0f;
    928   Array3D<float> input(param.base_bounds[0], param.base_bounds[1],
    929                        param.base_bounds[2], 1.0f);
    930   std::unique_ptr<Literal> input_literal =
    931       Literal::CreateR3FromArray3DWithLayout(
    932           input, LayoutUtil::MakeLayout(param.layout));
    933 
    934   ComputationDataHandle parameter;
    935   auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
    936                                                      &b, &parameter);
    937   auto init_value =
    938       CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
    939   b.ReduceWindow(/*operand=*/parameter,
    940                  /*init_value=*/init_value,
    941                  /*computation=*/CreateScalarAddComputation(FloatType(), &b),
    942                  /*window_dimensions=*/param.window_bounds,
    943                  /*window_strides=*/param.strides, /*padding=*/param.padding);
    944 
    945   auto expected = ReferenceUtil::ReduceWindow3DAdd(
    946       /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds,
    947       /*stride=*/param.strides, /*padding=*/param.padding);
    948 
    949   ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected),
    950                            {input_arg.get()}, DefaultErrorSpec());
    951 }
    952 
    953 INSTANTIATE_TEST_CASE_P(
    954     R3ReduceWindowTestInstantiation, R3ReduceWindowTest,
    955     ::testing::Combine(::testing::ValuesIn(kR3TestCases),
    956                        ::testing::ValuesIn(use_bfloat16_params)),
    957     R3ReduceWindowTestDataToString);
    958 
    959 struct R2ReduceWindowTestData {
    960   int64 base_bounds[2];
    961   int64 window_bounds[2];
    962   int64 strides[2];
    963   int64 layout[2];
    964   Padding padding;
    965   Reducer reducer;
    966 } kR2TestCases[] = {
    967     {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4},
    968      /*strides=*/{1, 2}, /*layout=*/{0, 1},
    969      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
    970     {/*base_bounds=*/{2, 5}, /*window_bounds=*/{2, 4},
    971      /*strides=*/{1, 1}, /*layout=*/{0, 1},
    972      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
    973     {/*base_bounds=*/{1, 3}, /*window_bounds=*/{2, 3},
    974      /*strides=*/{1, 1}, /*layout=*/{0, 1},
    975      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
    976     {/*base_bounds=*/{3, 129}, /*window_bounds=*/{1, 100},
    977      /*strides=*/{2, 99}, /*layout=*/{0, 1},
    978      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
    979     {/*base_bounds=*/{6, 152}, /*window_bounds=*/{2, 25},
    980      /*strides=*/{5, 4}, /*layout=*/{0, 1},
    981      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
    982     {/*base_bounds=*/{6, 4}, /*window_bounds=*/{4, 2},
    983      /*strides=*/{3, 3}, /*layout=*/{0, 1},
    984      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
    985     {/*base_bounds=*/{5, 147}, /*window_bounds=*/{1, 36},
    986      /*strides=*/{4, 5}, /*layout=*/{1, 0},
    987      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
    988     {/*base_bounds=*/{4, 153}, /*window_bounds=*/{2, 93},
    989      /*strides=*/{1, 1}, /*layout=*/{1, 0},
    990      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
    991     // Regression test for a bug that appeared in Inception (b/34784899).
    992     {/*base_bounds=*/{28, 28}, /*window_bounds=*/{3, 3},
    993      /*strides=*/{1, 1}, /*layout=*/{1, 0},
    994      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
    995     // Regression test for a bug that appeared in Inception (b/34784899).
    996     {/*base_bounds=*/{4, 32}, /*window_bounds=*/{2, 2},
    997      /*strides=*/{2, 2}, /*layout=*/{1, 0},
    998      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    999     {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
   1000      /*strides=*/{1, 1}, /*layout=*/{1, 0},
   1001      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
   1002 };
   1003 
   1004 string R2ReduceWindowTestDataToString(
   1005     const ::testing::TestParamInfo<
   1006         ::testing::tuple<R2ReduceWindowTestData, bool>>& data) {
   1007   const auto& param = ::testing::get<0>(data.param);
   1008   string str = tensorflow::strings::StrCat(
   1009       "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"),  //
   1010       "__window_bounds_",
   1011       tensorflow::str_util::Join(param.window_bounds, "x"),              //
   1012       "__strides_", tensorflow::str_util::Join(param.strides, "x"),      //
   1013       "__padding_", param.padding == Padding::kSame ? "same" : "valid",  //
   1014       "__layout_", param.layout[0], "_", param.layout[1],                //
   1015       "__reducer_", param.reducer == kAdd ? "add" : "max");
   1016   if (::testing::get<1>(data.param)) {
   1017     str = tensorflow::strings::StrCat(str, "_bfloat16");
   1018   }
   1019   return str;
   1020 }
   1021 
   1022 class R2ReduceWindowTest : public ReduceWindowTestBase,
   1023                            public ::testing::WithParamInterface<
   1024                                ::testing::tuple<R2ReduceWindowTestData, bool>> {
   1025  protected:
   1026   R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
   1027 
   1028   void DoIt() {
   1029     ComputationBuilder b(client_, TestName());
   1030     const auto& param = ::testing::get<0>(GetParam());
   1031     CHECK(param.reducer == kAdd);
   1032 
   1033     const float kInitValue = 0.0f;
   1034     Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
   1035     std::unique_ptr<Literal> input_literal =
   1036         Literal::CreateR2FromArray2DWithLayout(
   1037             input, LayoutUtil::MakeLayout(param.layout));
   1038 
   1039     ComputationDataHandle parameter;
   1040     auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
   1041                                                        &b, &parameter);
   1042     auto init_value =
   1043         CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
   1044     b.ReduceWindow(/*operand=*/parameter,
   1045                    /*init_value=*/init_value,
   1046                    /*computation=*/CreateScalarAddComputation(FloatType(), &b),
   1047                    /*window_dimensions=*/param.window_bounds,
   1048                    /*window_strides=*/param.strides, /*padding=*/param.padding);
   1049 
   1050     auto expected = ReferenceUtil::ReduceWindow2DAdd(
   1051         /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds,
   1052         /*stride=*/param.strides, /*padding=*/param.padding);
   1053 
   1054     ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected),
   1055                              {input_arg.get()}, DefaultErrorSpec());
   1056   }
   1057 };
   1058 
   1059 TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); }
   1060 
   1061 INSTANTIATE_TEST_CASE_P(
   1062     R2ReduceWindowTestInstantiation, R2ReduceWindowTest,
   1063     ::testing::Combine(::testing::ValuesIn(kR2TestCases),
   1064                        ::testing::ValuesIn(use_bfloat16_params)),
   1065     R2ReduceWindowTestDataToString);
   1066 
   1067 class R2ReduceWindowFailingCpuGpuBf16Test : public R2ReduceWindowTest {};
   1068 
   1069 // TODO(b/72234705): Fix the test cases failed on CPU and GPU.
   1070 XLA_TEST_P(R2ReduceWindowFailingCpuGpuBf16Test,
   1071            DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) {
   1072   DoIt();
   1073 }
   1074 
   1075 const R2ReduceWindowTestData kR2FailingValuesCpuGpuBf16Test[] = {
   1076     {/*base_bounds=*/{8, 128}, /*window_bounds=*/{8, 128},
   1077      /*strides=*/{1, 1}, /*layout=*/{1, 0},
   1078      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
   1079 };
   1080 
   1081 INSTANTIATE_TEST_CASE_P(
   1082     R2ReduceWindowFailingInstantiation, R2ReduceWindowFailingCpuGpuBf16Test,
   1083     ::testing::Combine(::testing::ValuesIn(kR2FailingValuesCpuGpuBf16Test),
   1084                        ::testing::ValuesIn(use_bfloat16_params)),
   1085     R2ReduceWindowTestDataToString);
   1086 
   1087 struct R1ReduceWindowTestData {
   1088   int64 base_bounds[1];
   1089   int64 window_bounds[1];
   1090   int64 strides[1];
   1091   int64 pad_low[1];
   1092   int64 pad_high[1];
   1093   Reducer reducer;
   1094 } kR1TestCases[] = {
   1095     {/*base_bounds=*/{1}, /*window_bounds=*/{1},
   1096      /*strides=*/{1},
   1097      /*pad_low=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].first},
   1098      /*pad_high=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].second},
   1099      /*reducer=*/Reducer::kAdd},
   1100 
   1101     {/*base_bounds=*/{3}, /*window_bounds=*/{3},
   1102      /*strides=*/{1},
   1103      /*pad_low=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].first},
   1104      /*pad_high=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].second},
   1105      /*reducer=*/Reducer::kAdd},
   1106 
   1107     {/*base_bounds=*/{3}, /*window_bounds=*/{2},
   1108      /*strides=*/{1},
   1109      /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].first},
   1110      /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].second},
   1111      /*reducer=*/Reducer::kAdd},
   1112 
   1113     {/*base_bounds=*/{5}, /*window_bounds=*/{1},
   1114      /*strides=*/{1},
   1115      /*pad_low=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].first},
   1116      /*pad_high=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].second},
   1117      /*reducer=*/Reducer::kMax},
   1118 
   1119     {/*base_bounds=*/{16}, /*window_bounds=*/{4},
   1120      /*strides=*/{4},
   1121      /*pad_low=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].first},
   1122      /*pad_high=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].second},
   1123      /*reducer=*/Reducer::kMax},
   1124 
   1125     {/*base_bounds=*/{16}, /*window_bounds=*/{4},
   1126      /*strides=*/{3},
   1127      /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].first},
   1128      /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].second},
   1129      /*reducer=*/Reducer::kAdd},
   1130 
   1131     {/*base_bounds=*/{128 * 2},
   1132      /*window_bounds=*/{30},
   1133      /*strides=*/{27},
   1134      /*pad_low=*/
   1135      {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].first},
   1136      /*pad_high=*/
   1137      {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].second},
   1138      /*reducer=*/Reducer::kAdd},
   1139 
   1140     {/*base_bounds=*/{128 * 17},
   1141      /*window_bounds=*/{7},
   1142      /*strides=*/{64},
   1143      /*pad_low=*/
   1144      {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].first},
   1145      /*pad_high=*/
   1146      {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].second},
   1147      /*reducer=*/Reducer::kAdd},
   1148 
   1149     {/*base_bounds=*/{128 * 2},
   1150      /*window_bounds=*/{32},
   1151      /*strides=*/{56},
   1152      /*pad_low=*/
   1153      {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].first},
   1154      /*pad_high=*/
   1155      {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].second},
   1156      /*reducer=*/Reducer::kAdd},
   1157 
   1158     {/*base_bounds=*/{3}, /*window_bounds=*/{2},
   1159      /*strides=*/{1},
   1160      /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].first},
   1161      /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].second},
   1162      /*reducer=*/Reducer::kAdd},
   1163 
   1164     {/*base_bounds=*/{5}, /*window_bounds=*/{3},
   1165      /*strides=*/{2},
   1166      /*pad_low=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].first},
   1167      /*pad_high=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].second},
   1168      /*reducer=*/Reducer::kAdd},
   1169 
   1170     {/*base_bounds=*/{16}, /*window_bounds=*/{4},
   1171      /*strides=*/{3},
   1172      /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].first},
   1173      /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].second},
   1174      /*reducer=*/Reducer::kAdd},
   1175 
   1176     {/*base_bounds=*/{5}, /*window_bounds=*/{5},
   1177      /*strides=*/{1},
   1178      /*pad_low=*/{0},
   1179      /*pad_high=*/{5},
   1180      /*reducer=*/Reducer::kAdd},
   1181 
   1182     {/*base_bounds=*/{5}, /*window_bounds=*/{5},
   1183      /*strides=*/{1},
   1184      /*pad_low=*/{5},
   1185      /*pad_high=*/{0},
   1186      /*reducer=*/Reducer::kAdd},
   1187 };
   1188 
   1189 string R1ReduceWindowTestDataToString(
   1190     const ::testing::TestParamInfo<
   1191         ::testing::tuple<R1ReduceWindowTestData, bool>>& data) {
   1192   const auto& param = ::testing::get<0>(data.param);
   1193   string str = tensorflow::strings::StrCat(
   1194       "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"),
   1195       "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"),
   1196       "__strides_", tensorflow::str_util::Join(param.strides, "x"),
   1197       "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"),
   1198       "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"),
   1199       "__reducer_", param.reducer == kAdd ? "add" : "max");
   1200   if (::testing::get<1>(data.param)) {
   1201     str = tensorflow::strings::StrCat(str, "_bfloat16");
   1202   }
   1203   return str;
   1204 }
   1205 
   1206 class R1ReduceWindowTest : public ReduceWindowTestBase,
   1207                            public ::testing::WithParamInterface<
   1208                                ::testing::tuple<R1ReduceWindowTestData, bool>> {
   1209  protected:
   1210   R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
   1211 };
   1212 
   1213 TEST_P(R1ReduceWindowTest, DoIt) {
   1214   ComputationBuilder b(client_, TestName());
   1215   const auto& param = ::testing::get<0>(GetParam());
   1216   CHECK(param.reducer == kAdd || param.reducer == kMax);
   1217 
   1218   const float kInitValue = 0.0f;
   1219   std::vector<float> input_vector(param.base_bounds[0]);
   1220   std::iota(std::begin(input_vector), std::end(input_vector), 0);
   1221   std::unique_ptr<Literal> input_literal =
   1222       Literal::CreateR1(tensorflow::gtl::ArraySlice<float>(input_vector));
   1223   ComputationDataHandle parameter;
   1224   auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
   1225                                                      &b, &parameter);
   1226 
   1227   std::vector<std::pair<int64, int64>> padding(1);
   1228   padding[0] = {param.pad_low[0], param.pad_high[0]};
   1229 
   1230   auto computation = param.reducer == kAdd
   1231                          ? CreateScalarAddComputation(FloatType(), &b)
   1232                          : CreateScalarMaxComputation(FloatType(), &b);
   1233   auto init_value =
   1234       CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
   1235   b.ReduceWindowWithGeneralPadding(
   1236       /*operand=*/parameter,
   1237       /*init_value=*/init_value,
   1238       /*computation=*/computation,
   1239       /*window_dimensions=*/param.window_bounds,
   1240       /*window_strides=*/param.strides, /*padding=*/padding);
   1241 
   1242   auto reduce_func = param.reducer == kAdd
   1243                          ? +[](float a, float b) { return a + b; }
   1244                          : +[](float a, float b) { return std::max(a, b); };
   1245   auto expected = ReferenceUtil::ReduceWindow1DGeneric(
   1246       /*operand=*/tensorflow::gtl::ArraySlice<float>(input_vector),
   1247       /*init=*/kInitValue,
   1248       /*reduce_func=*/reduce_func,
   1249       /*window=*/param.window_bounds,
   1250       /*stride=*/param.strides,
   1251       /*padding=*/padding);
   1252 
   1253   ComputeAndCompareLiteral(&b, *Literal::CreateR1<float>(*expected),
   1254                            {input_arg.get()}, DefaultErrorSpec());
   1255 }
   1256 
   1257 INSTANTIATE_TEST_CASE_P(
   1258     R1ReduceWindowTestInstantiation, R1ReduceWindowTest,
   1259     ::testing::Combine(::testing::ValuesIn(kR1TestCases),
   1260                        ::testing::ValuesIn(use_bfloat16_params)),
   1261     R1ReduceWindowTestDataToString);
   1262 
   1263 // Test class for text-based test cases. Note that this compares with the
   1264 // results on the interpreter backend.
   1265 class ReduceWindowTextTest : public HloTestBase {};
   1266 
   1267 TEST_F(ReduceWindowTextTest, R2General256x384) {
   1268   const string& hlo_string = R"(
   1269 HloModule R2Window
   1270 mul {
   1271   lhs = f32[] parameter(0)
   1272   rhs = f32[] parameter(1)
   1273   ROOT mul = f32[] multiply(lhs, rhs)
   1274 }
   1275 ENTRY R2Window {
   1276   operand = f32[256,384]{1,0} parameter(0)
   1277   constant = f32[] constant(1)
   1278   ROOT reduce-window = f32[256,384]{1,0} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
   1279 }
   1280 )";
   1281   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
   1282 }
   1283 
   1284 TEST_F(ReduceWindowTextTest, R2General256x384Layout01) {
   1285   const string& hlo_string = R"(
   1286 HloModule R2Window
   1287 mul {
   1288 lhs = f32[] parameter(0)
   1289 rhs = f32[] parameter(1)
   1290 ROOT mul = f32[] multiply(lhs, rhs)
   1291 }
   1292 ENTRY R2Window {
   1293 operand = f32[256,384]{0,1} parameter(0)
   1294 constant = f32[] constant(1)
   1295 ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
   1296 }
   1297 )";
   1298   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
   1299 }
   1300 
   1301 TEST_F(ReduceWindowTextTest, R2General2x5) {
   1302   const string& hlo_string = R"(
   1303 HloModule R2Window
   1304 mul {
   1305   lhs = f32[] parameter(0)
   1306   rhs = f32[] parameter(1)
   1307   ROOT mul = f32[] multiply(lhs, rhs)
   1308 }
   1309 ENTRY R2Window {
   1310   operand = f32[2,5]{1,0} parameter(0)
   1311   constant = f32[] constant(1)
   1312   ROOT reduce-window = f32[3,5]{1,0} reduce-window(operand, constant), window={size=2x1 pad=0_2x0_0}, to_apply=mul
   1313 }
   1314 )";
   1315   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
   1316 }
   1317 
   1318 }  // namespace
   1319 }  // namespace xla
   1320