Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Tests the reduce-window XLA operation.
     17 
     18 #include <limits>
     19 #include <memory>
     20 
     21 #include "absl/memory/memory.h"
     22 #include "absl/strings/str_cat.h"
     23 #include "absl/strings/str_join.h"
     24 #include "absl/types/span.h"
     25 #include "tensorflow/compiler/xla/array2d.h"
     26 #include "tensorflow/compiler/xla/array3d.h"
     27 #include "tensorflow/compiler/xla/array4d.h"
     28 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
     29 #include "tensorflow/compiler/xla/client/local_client.h"
     30 #include "tensorflow/compiler/xla/client/padding.h"
     31 #include "tensorflow/compiler/xla/client/xla_builder.h"
     32 #include "tensorflow/compiler/xla/client/xla_computation.h"
     33 #include "tensorflow/compiler/xla/reference_util.h"
     34 #include "tensorflow/compiler/xla/shape_util.h"
     35 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     36 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     37 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     38 #include "tensorflow/compiler/xla/tests/test_macros.h"
     39 #include "tensorflow/compiler/xla/xla_data.pb.h"
     40 #include "tensorflow/core/lib/core/status.h"
     41 #include "tensorflow/core/lib/core/status_test_util.h"
     42 #include "tensorflow/core/platform/test.h"
     43 #include "tensorflow/core/platform/types.h"
     44 
     45 namespace xla {
     46 namespace {
     47 
     48 #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16
     49 // Tests both F32 and BF16.
     50 static std::array<bool, 2> use_bfloat16_params{false, true};
     51 #else
     52 // Only tests F32.
     53 static std::array<bool, 1> use_bfloat16_params{false};
     54 #endif
     55 
     56 class ReduceWindowTestBase : public ClientLibraryTestBase {
     57  public:
     58   ErrorSpec DefaultErrorSpec() const {
     59     if (use_bfloat16()) {
     60       return ErrorSpec(2e-1, 6e-2);
     61     } else {
     62       return ErrorSpec(1e-3, 1e-3);
     63     }
     64   }
     65 };
     66 
     67 class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
     68                          public ReduceWindowTestBase {
     69  public:
     70   ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); }
     71 
     72   void ReduceWindowAdd(const XlaOp& input,
     73                        absl::Span<const int64> window_dimensions,
     74                        absl::Span<const int64> window_strides,
     75                        Padding padding) {
     76     auto init = CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f),
     77                                           &builder_);
     78     ReduceWindow(input, init,
     79                  CreateScalarAddComputation(FloatType(), &builder_),
     80                  window_dimensions, window_strides, padding);
     81   }
     82 
     83   void ReduceWindowMax(const XlaOp& input,
     84                        absl::Span<const int64> window_dimensions,
     85                        absl::Span<const int64> window_strides,
     86                        Padding padding) {
     87     auto init =
     88         CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_);
     89     ReduceWindow(input, init,
     90                  CreateScalarMaxComputation(FloatType(), &builder_),
     91                  window_dimensions, window_strides, padding);
     92   }
     93 
     94   void ReduceWindowMin(const XlaOp& input,
     95                        absl::Span<const int64> window_dimensions,
     96                        absl::Span<const int64> window_strides,
     97                        Padding padding) {
     98     auto init =
     99         CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_);
    100     ReduceWindow(input, init,
    101                  CreateScalarMinComputation(FloatType(), &builder_),
    102                  window_dimensions, window_strides, padding);
    103   }
    104 
    105   XlaBuilder builder_;
    106 };
    107 
    108 TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
    109   const auto input = CreateConstantFromLiteral(
    110       LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
    111   const auto init_value =
    112       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0), &builder_);
    113   TF_ASSERT_OK(builder_.first_error());
    114   ReduceWindow(input, init_value,
    115                CreateScalarAddComputation(FloatType(), &builder_),
    116                /*window_dimensions=*/{1, 2},
    117                /*window_strides=*/{1}, Padding::kValid);
    118   ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT)
    119       << builder_.first_error();
    120   ASSERT_THAT(builder_.first_error().error_message(),
    121               ::testing::HasSubstr("Want input dimensions size"));
    122 }
    123 
    124 // Regression test for b/68964348.
    125 TEST_P(ReduceWindowTest, R0ReduceWindow) {
    126   const auto input =
    127       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(42.0), &builder_);
    128   const auto init =
    129       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(1.0), &builder_);
    130   ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_),
    131                /*window_dimensions=*/{},
    132                /*window_strides=*/{}, Padding::kSame);
    133   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR0<float>(43.0), {},
    134                            ErrorSpec(0.00001));
    135 }
    136 
    137 TEST_P(ReduceWindowTest, Min3In5Stride2) {
    138   const auto input = CreateConstantFromLiteral(
    139       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
    140   ReduceWindowMin(input, {3}, {2}, Padding::kValid);
    141   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({100, 1}),
    142                            {}, ErrorSpec(0.00001));
    143 }
    144 
    145 TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) {
    146   const auto input = CreateConstantFromLiteral(
    147       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
    148   ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1},
    149                   Padding::kSame);
    150   ComputeAndCompareLiteral(&builder_,
    151                            LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
    152                            {}, ErrorSpec(0.00001));
    153 }
    154 
    155 XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) {
    156   Array4D<float> input_array(1, 0, 2, 1);
    157   const auto input = CreateConstantFromArray(input_array, &builder_);
    158   Padding padding = Padding::kSame;
    159   ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
    160 
    161   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
    162                                               {1, 1, 1, 1}, padding);
    163 
    164   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
    165                            DefaultErrorSpec());
    166 }
    167 
    168 TEST_P(ReduceWindowTest, NonSquareSmall) {
    169   Array4D<float> input_array(1, 2, 2, 1);
    170   input_array.FillRandom(2.f, 2.f);
    171   const auto input = CreateConstantFromArray(input_array, &builder_);
    172 
    173   Padding padding = Padding::kSame;
    174   ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
    175 
    176   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
    177                                               {1, 1, 1, 1}, padding);
    178 
    179   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
    180                            DefaultErrorSpec());
    181 }
    182 
    183 TEST_P(ReduceWindowTest, MiddleDimsSmall) {
    184   Array4D<float> input_array(1, 3, 3, 1);
    185   input_array.FillRandom(2.f, 2.f);
    186   const auto input = CreateConstantFromArray(input_array, &builder_);
    187   Padding padding = Padding::kSame;
    188   ReduceWindowAdd(input, {1, 1, 1, 1}, {1, 2, 2, 1}, padding);
    189 
    190   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1},
    191                                               {1, 2, 2, 1}, padding);
    192 
    193   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
    194                            DefaultErrorSpec());
    195 }
    196 
    197 TEST_P(ReduceWindowTest, Along2ndMinorDim) {
    198   Array4D<float> input_array(3, 6, 7, 32);
    199   input_array.FillRandom(2.f, 2.f);
    200   const auto input = CreateConstantFromArray(input_array, &builder_);
    201 
    202   // The parameters of this reduction mimic feature norm (e.g. LRN).
    203   int lrn_diameter = 7;  // diameter = 2*radius + 1 --> must be odd
    204   Padding padding = Padding::kSame;
    205   ReduceWindowAdd(input, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
    206 
    207   auto res = ReferenceUtil::ReduceWindow4DAdd(
    208       input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
    209 
    210   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
    211                            DefaultErrorSpec());
    212 }
    213 
    214 TEST_P(ReduceWindowTest, AmongMajor2Dims) {
    215   Array4D<float> input_array(4, 4, 6, 8);
    216   input_array.FillWithMinorDimNum();
    217   const auto input_data_handle =
    218       CreateConstantFromArray(input_array, &builder_);
    219 
    220   int win_len = 3;
    221   int win_stride = 1;
    222 
    223   Padding padding = Padding::kSame;
    224   // Reduce only along the x and y dimensions, according to the win_len.
    225   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
    226                   {win_stride, win_stride, 1, 1}, padding);
    227 
    228   auto result = ReferenceUtil::ReduceWindow4DAdd(
    229       input_array, 0.0f, {win_len, win_len, 1, 1},
    230       {win_stride, win_stride, 1, 1}, padding);
    231 
    232   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
    233                            DefaultErrorSpec());
    234 }
    235 
    236 TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
    237   Array4D<float> input_array(9, 12, 4, 89);
    238   input_array.FillRandom(2.f, 2.f);
    239 
    240   int win_len = 3;
    241   int win_stride = 2;
    242 
    243   const auto input_data_handle =
    244       CreateConstantFromArray(input_array, &builder_);
    245 
    246   Padding padding = Padding::kSame;
    247   // Reduce only along the x and y dimensions, according to the win_len.
    248   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
    249                   {win_stride, win_stride, 1, 1}, padding);
    250 
    251   auto result = ReferenceUtil::ReduceWindow4DAdd(
    252       input_array, 0.0f, {win_len, win_len, 1, 1},
    253       {win_stride, win_stride, 1, 1}, padding);
    254 
    255   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
    256                            DefaultErrorSpec());
    257 }
    258 
    259 // Tests the super windowing logic w.r.t handling prime number of windows in a
    260 // major dimension with reduction.
    261 TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) {
    262   Array4D<float> input_array(15, 15, 4, 128);
    263   input_array.FillRandom(2.f, 4.f);
    264 
    265   int win_len = 3;
    266   int win_stride = 2;
    267 
    268   const auto input_data_handle =
    269       CreateConstantFromArray(input_array, &builder_);
    270 
    271   Padding padding = Padding::kSame;
    272   // Reduce only along the x and y dimensions, according to the win_len.
    273   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
    274                   {win_stride, win_stride, 1, 1}, padding);
    275 
    276   auto result = ReferenceUtil::ReduceWindow4DAdd(
    277       input_array, 0.0f, {win_len, win_len, 1, 1},
    278       {win_stride, win_stride, 1, 1}, padding);
    279 
    280   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
    281                            DefaultErrorSpec());
    282 }
    283 
    284 TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
    285   Array4D<float> input_array(19, 17, 8, 256);
    286   input_array.FillWithMinorDimNum();
    287 
    288   const auto input_data_handle =
    289       CreateConstantFromArray(input_array, &builder_);
    290 
    291   Padding padding = Padding::kSame;
    292   ReduceWindowAdd(input_data_handle, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
    293 
    294   auto result = ReferenceUtil::ReduceWindow4DAdd(
    295       input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
    296 
    297   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
    298                            DefaultErrorSpec());
    299 }
    300 
    301 // Tests a reduction function that is not a simple add/min/max/etc.
    302 XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
    303   Array4D<float> input_array(1, 2, 2, 1);
    304   input_array(0, 0, 0, 0) = 1;
    305   input_array(0, 0, 1, 0) = 2;
    306   input_array(0, 1, 0, 0) = 3;
    307   input_array(0, 1, 1, 0) = 4;
    308   const auto input = CreateConstantFromArray(input_array, &builder_);
    309 
    310   Padding padding = Padding::kValid;
    311   const Shape scalar = ShapeUtil::MakeShape(FloatType(), {});
    312   auto b = builder_.CreateSubBuilder("unusual");
    313   auto lhs = Parameter(b.get(), 0, scalar, "lhs");
    314   auto rhs = Parameter(b.get(), 1, scalar, "rhs");
    315   Min(Add(lhs, rhs),
    316       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(8.0f), b.get()));
    317   XlaComputation reduce_fn = b->BuildAndNoteError();
    318 
    319   ReduceWindow(
    320       input,
    321       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f), &builder_),
    322       reduce_fn,
    323       /*window_dimensions=*/{1, 1, 2, 1},
    324       /*window_strides=*/{1, 1, 1, 1}, padding);
    325 
    326   const auto reduce_func = [](float arg1, float arg2) {
    327     return std::min<float>(arg1 + arg2, 8.0f);
    328   };
    329 
    330   auto expected =
    331       ReferenceUtil::ReduceWindow4DGeneric(input_array, 0.0f, reduce_func,
    332                                            /*window=*/{1, 1, 2, 1},
    333                                            /*stride=*/{1, 1, 1, 1}, padding);
    334 
    335   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*expected),
    336                            {}, DefaultErrorSpec());
    337 }
    338 
    339 TEST_P(ReduceWindowTest, R4UnitWindow) {
    340   Array4D<float> input_array(13, 12, 8, 15);
    341   input_array.FillRandom(2.f, 2.f);
    342   Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
    343       input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
    344   XlaOp input;
    345   auto input_data = CreateParameterAndTransferLiteral(
    346       0, input_literal, "parameter", &builder_, &input);
    347 
    348   Padding padding = Padding::kSame;
    349   ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding);
    350 
    351   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1},
    352                                               {1, 4, 1, 1}, padding);
    353 
    354   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
    355                            {input_data.get()}, DefaultErrorSpec());
    356 }
    357 
    358 XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
    359   std::vector<int64> input_dims(6, 8);
    360   auto shape = ShapeUtil::MakeShape(F32, input_dims);
    361 
    362   Literal arg_literal(shape);
    363   arg_literal.PopulateWithValue(1.0f);
    364   const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
    365 
    366   Padding padding = Padding::kValid;
    367   ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
    368 
    369   std::vector<int64> output_layout = {1, 5, 3, 2, 0, 4};
    370   std::vector<int64> output_dims = {6, 8, 6, 6, 8, 8};
    371   Shape result_shape =
    372       ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout);
    373   Literal expected(result_shape);
    374   expected.PopulateWithValue(27.0f);
    375   ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
    376 }
    377 
    378 XLA_TEST_P(ReduceWindowTest, R6Add) {
    379   std::vector<int64> input_dims(6, 8);
    380   auto shape = ShapeUtil::MakeShape(F32, input_dims);
    381 
    382   Literal arg_literal =
    383       LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
    384 
    385   const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
    386 
    387   Padding padding = Padding::kValid;
    388   ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
    389 
    390   std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8};
    391   Literal expected =
    392       LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
    393 
    394   ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
    395 }
    396 
    397 XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
    398   Array4D<float> input_array(2, 1, 27, 119);
    399   input_array.FillRandom(2.0f);
    400   Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
    401       input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
    402   XlaOp input;
    403   auto input_data = CreateParameterAndTransferLiteral(
    404       0, input_literal, "parameter", &builder_, &input);
    405 
    406   int win_len = 1;
    407   int stride = 8;
    408   Padding padding = Padding::kSame;
    409   ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    410 
    411   auto res = ReferenceUtil::ReduceWindow4DAdd(
    412       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    413 
    414   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
    415                            {input_data.get()}, DefaultErrorSpec());
    416 }
    417 
    418 XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
    419   Array4D<float> input_array(3, 2, 4, 64);
    420   input_array.FillRandom(2.0f);
    421   Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
    422       input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
    423   XlaOp input;
    424   auto input_data = CreateParameterAndTransferLiteral(
    425       0, input_literal, "parameter", &builder_, &input);
    426 
    427   int win_len = 3;
    428   int stride = 1;
    429   Padding padding = Padding::kSame;
    430   ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    431 
    432   auto res = ReferenceUtil::ReduceWindow4DAdd(
    433       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    434 
    435   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
    436                            {input_data.get()}, DefaultErrorSpec());
    437 }
    438 
    439 XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
    440   Array4D<float> input_array(1, 3, 12, 200);
    441   input_array.FillRandom(2.0f);
    442   Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
    443       input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
    444   XlaOp input;
    445   auto input_data = CreateParameterAndTransferLiteral(
    446       0, input_literal, "parameter", &builder_, &input);
    447 
    448   int win_len = 8;
    449   int stride = 5;
    450   Padding padding = Padding::kSame;
    451   ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    452 
    453   auto res = ReferenceUtil::ReduceWindow4DAdd(
    454       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    455 
    456   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
    457                            {input_data.get()}, DefaultErrorSpec());
    458 }
    459 
    460 TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) {
    461   Array4D<float> input_array(6, 4, 10, 130);
    462   input_array.FillRandom(2.0f);
    463 
    464   int win_len = 3;
    465   int win_stride = 2;
    466 
    467   Padding padding = Padding::kSame;
    468   const auto input_data_handle =
    469       CreateConstantFromArray(input_array, &builder_);
    470   // Reduce only along the x and y dimensions, according to the win_len.
    471   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
    472                   {win_stride, win_stride, 1, 1}, padding);
    473 
    474   auto result = ReferenceUtil::ReduceWindow4DAdd(
    475       input_array, 0.0f, {win_len, win_len, 1, 1},
    476       {win_stride, win_stride, 1, 1}, padding);
    477   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
    478                            DefaultErrorSpec());
    479 }
    480 
    481 XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) {
    482   std::vector<float> input_vector(128 * 9, 1);
    483   const auto input = CreateConstantFromLiteral(
    484       LiteralUtil::CreateR1<float>(input_vector), &builder_);
    485   ReduceWindowAdd(input, {32}, {128}, Padding::kValid);
    486   ComputeAndCompareLiteral(
    487       &builder_,
    488       LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
    489       DefaultErrorSpec());
    490 }
    491 
    492 XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) {
    493   std::vector<float> input_vector{
    494       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    495       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    496       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    497       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    498       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    499       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    500       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    501       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
    502   const auto input = CreateConstantFromLiteral(
    503       LiteralUtil::CreateR1<float>(input_vector), &builder_);
    504   ReduceWindowAdd(input, {128}, {128}, Padding::kValid);
    505   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
    506                            DefaultErrorSpec());
    507 }
    508 
    509 XLA_TEST_P(ReduceWindowTest, Add128In128) {
    510   std::vector<float> input_vector{
    511       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    512       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    513       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    514       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    515       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    516       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    517       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    518       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
    519   const auto input = CreateConstantFromLiteral(
    520       LiteralUtil::CreateR1<float>(input_vector), &builder_);
    521   ReduceWindowAdd(input, {128}, {1}, Padding::kValid);
    522   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
    523                            DefaultErrorSpec());
    524 }
    525 
    526 // Regression test for a bug that appeared in Inception (b/34784899).
    527 TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
    528   Array2D<float> input_array(14, 14, 1.0f);
    529   const auto input = CreateConstantFromArray(input_array, &builder_);
    530 
    531   int win_len = 3;
    532   int stride = 1;
    533   Padding padding = Padding::kSame;
    534   ReduceWindowAdd(input, {win_len, win_len}, {stride, stride}, padding);
    535 
    536   auto res = ReferenceUtil::ReduceWindow2DAdd(
    537       input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding);
    538 
    539   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
    540                            {}, DefaultErrorSpec());
    541 }
    542 
    543 TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
    544   Array2D<float> input_array(6, 4, 1.0f);
    545   XlaOp input = Broadcast(
    546       CreateConstantFromLiteral(LiteralUtil::One(F32), &builder_), {6, 4});
    547 
    548   Padding padding = Padding::kSame;
    549   ReduceWindowAdd(input, {4, 2}, {3, 3}, padding);
    550 
    551   auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3},
    552                                               padding);
    553 
    554   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
    555                            {}, DefaultErrorSpec());
    556 }
    557 
    558 INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest,
    559                         ::testing::ValuesIn(use_bfloat16_params));
    560 
    561 enum Reducer { kAdd, kMax };
    562 
    563 struct R4ReduceWindowTestData {
    564   int64 base_bounds[4];
    565   int64 window_bounds[4];
    566   int64 strides[4];
    567   int64 pad_low[4];
    568   int64 pad_high[4];
    569   int64 layout[4];
    570 
    571   Reducer reducer;
    572 };
    573 
    574 string R4ReduceWindowTestDataToString(
    575     const ::testing::TestParamInfo<
    576         ::testing::tuple<R4ReduceWindowTestData, bool>>& data) {
    577   const auto& param = ::testing::get<0>(data.param);
    578   string str = absl::StrCat(
    579       "base_bounds_", absl::StrJoin(param.base_bounds, "x"),        //
    580       "__window_bounds_", absl::StrJoin(param.window_bounds, "x"),  //
    581       "__strides_", absl::StrJoin(param.strides, "x"),              //
    582       "__pad_low_", absl::StrJoin(param.pad_low, "x"),              //
    583       "__pad_high_", absl::StrJoin(param.pad_high, "x"),            //
    584       "__layout_", absl::StrJoin(param.layout, "_"),                //
    585       (param.reducer == kAdd) ? "_add" : "_max");
    586   CHECK(param.reducer == kAdd || param.reducer == kMax);
    587 
    588   // Test names are not allowed to contain the '-' character.
    589   std::replace(str.begin(), str.end(), '-', 'n');
    590   if (::testing::get<1>(data.param)) {
    591     absl::StrAppend(&str, "_bfloat16");
    592   }
    593   return str;
    594 }
    595 
    596 class R4ReduceWindowTest : public ReduceWindowTestBase,
    597                            public ::testing::WithParamInterface<
    598                                ::testing::tuple<R4ReduceWindowTestData, bool>> {
    599  protected:
    600   R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
    601 
    602   void DoIt() {
    603     XlaBuilder b(TestName());
    604     const auto& param = ::testing::get<0>(GetParam());
    605 
    606     const float kInitValue = 0.0f;
    607 
    608     Array4D<float> input(param.base_bounds[0], param.base_bounds[1],
    609                          param.base_bounds[2], param.base_bounds[3]);
    610     // Choose a prime iota length so that each window sees a unique set of
    611     // values. (Technically, the requirement is that the iota length is
    612     // relatively prime to all of the dimensions involved in the reduce-window.)
    613     input.FillRepeatedIota(0, 137);
    614     // Floating point sum reduction requires higher localized precision. We need
    615     // the following normalization in order to enable testing of kAdd on large
    616     // windows.
    617     input.Each([&](absl::Span<const int64> /*indices*/, float* value) {
    618       *value = *value / 10000000000.f;
    619     });
    620     Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
    621         input, LayoutUtil::MakeLayout(param.layout));
    622     XlaOp parameter;
    623     auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
    624                                                        &b, &parameter);
    625 
    626     std::vector<std::pair<int64, int64>> padding(4);
    627     for (int i = 0; i < 4; ++i) {
    628       padding[i] = {param.pad_low[i], param.pad_high[i]};
    629     }
    630 
    631     auto init_value =
    632         CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
    633     CHECK(param.reducer == kAdd || param.reducer == kMax);
    634     auto reducer = param.reducer;
    635     auto computation = reducer == kAdd
    636                            ? CreateScalarAddComputation(FloatType(), &b)
    637                            : CreateScalarMaxComputation(FloatType(), &b);
    638     ReduceWindowWithGeneralPadding(
    639         /*operand=*/parameter,
    640         /*init_value=*/init_value,
    641         /*computation=*/computation,
    642         /*window_dimensions=*/param.window_bounds,
    643         /*window_strides=*/param.strides,
    644         /*base_dilations=*/{},
    645         /*window_dilations=*/{},
    646         /*padding=*/padding);
    647 
    648     CHECK(reducer == kAdd || reducer == kMax);
    649     auto reduce_func = reducer == kAdd
    650                            ? +[](float a, float b) { return a + b; }
    651                            : +[](float a, float b) { return std::max(a, b); };
    652     std::unique_ptr<Array4D<float>> expected =
    653         ReferenceUtil::ReduceWindow4DGeneric(
    654             /*operand=*/input,
    655             /*init=*/kInitValue,
    656             /*reduce_func=*/reduce_func,
    657             /*window=*/param.window_bounds,
    658             /*stride=*/param.strides,
    659             /*padding=*/padding);
    660     Literal expected_literal = LiteralUtil::CreateFromArray(*expected);
    661     const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
    662         input_literal.shape().element_type(),
    663         AsInt64Slice(expected_literal.shape().dimensions()), param.layout);
    664     ComputeAndCompareLiteral(&b, expected_literal, {input_arg.get()},
    665                              DefaultErrorSpec(), &expected_shape_with_layout);
    666   }
    667 };
    668 
    669 TEST_P(R4ReduceWindowTest, DoIt) { DoIt(); }
    670 
    671 // base_bounds, window_bounds, strides, pad_low, pad_high
    672 const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
    673     // Minimal edge case.
    674     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 1, 1},
    675                            /*window_bounds=*/{1, 1, 1, 1},
    676                            /*strides=*/{1, 1, 1, 1},
    677                            /*pad_low=*/{0, 0, 0, 0},
    678                            /*pad_high=*/{0, 0, 0, 0},
    679                            /*layout=*/{3, 2, 1, 0},
    680                            /*reducer=*/kAdd},
    681 
    682     // Arbitrary padding (not kSame or kValid).
    683     R4ReduceWindowTestData{/*base_bounds=*/{9, 12, 4, 89},
    684                            /*window_bounds=*/{3, 3, 1, 1},
    685                            /*strides=*/{2, 2, 1, 1},
    686                            /*pad_low=*/{4, 4, 0, 0},
    687                            /*pad_high=*/{4, 4, 0, 0},
    688                            /*layout=*/{3, 2, 1, 0},
    689                            /*reducer=*/kAdd},
    690 
    691     // Zero base bound edge case.
    692     R4ReduceWindowTestData{/*base_bounds=*/{1, 0, 1, 1},
    693                            /*window_bounds=*/{1, 1, 1, 1},
    694                            /*strides=*/{1, 1, 1, 1},
    695                            /*pad_low=*/{0, 0, 0, 0},
    696                            /*pad_high=*/{0, 0, 0, 0},
    697                            /*layout=*/{3, 2, 1, 0},
    698                            /*reducer=*/kAdd},
    699 
    700     // With max instead of add.
    701     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
    702                            /*window_bounds=*/{2, 3, 1, 1},
    703                            /*strides=*/{1, 1, 1, 1},
    704                            /*pad_low=*/{0, 0, 0, 0},
    705                            /*pad_high=*/{0, 0, 0, 0},
    706                            /*layout=*/{3, 2, 1, 0},
    707                            /*reducer=*/kMax},
    708 
    709     // With stride.
    710     R4ReduceWindowTestData{/*base_bounds=*/{4, 10, 17, 140},
    711                            /*window_bounds=*/{3, 2, 1, 1},
    712                            /*strides=*/{2, 4, 1, 1},
    713                            /*pad_low=*/{0, 0, 0, 0},
    714                            /*pad_high=*/{0, 0, 0, 0},
    715                            /*layout=*/{3, 2, 1, 0},
    716                            /*reducer=*/kAdd},
    717 
    718     // With low padding.
    719     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
    720                            /*window_bounds=*/{3, 2, 1, 1},
    721                            /*strides=*/{2, 2, 1, 1},
    722                            /*pad_low=*/{3, 2, 0, 0},
    723                            /*pad_high=*/{0, 0, 0, 0},
    724                            /*layout=*/{3, 2, 1, 0},
    725                            /*reducer=*/kAdd},
    726 
    727     // With high padding.
    728     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
    729                            /*window_bounds=*/{3, 2, 1, 1},
    730                            /*strides=*/{2, 2, 1, 1},
    731                            /*pad_low=*/{0, 0, 0, 0},
    732                            /*pad_high=*/{2, 3, 0, 0},
    733                            /*layout=*/{3, 2, 1, 0},
    734                            /*reducer=*/kAdd},
    735 
    736     // Window touches both sides of the padding simultaneously.
    737     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140},
    738                            /*window_bounds=*/{3, 3, 1, 1},
    739                            /*strides=*/{1, 1, 1, 1},
    740                            /*pad_low=*/{1, 1, 0, 0},
    741                            /*pad_high=*/{1, 1, 0, 0},
    742                            /*layout=*/{3, 2, 1, 0},
    743                            /*reducer=*/kAdd},
    744 
    745     // Window is entirely in the padding for some positions.
    746     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140},
    747                            /*window_bounds=*/{3, 3, 1, 1},
    748                            /*strides=*/{1, 1, 1, 1},
    749                            /*pad_low=*/{4, 4, 0, 0},
    750                            /*pad_high=*/{4, 4, 0, 0},
    751                            /*layout=*/{3, 2, 1, 0},
    752                            /*reducer=*/kAdd},
    753 
    754     // Zero base bound with padding edge case.
    755     R4ReduceWindowTestData{/*base_bounds=*/{2, 0, 3, 4},
    756                            /*window_bounds=*/{1, 1, 1, 1},
    757                            /*strides=*/{1, 1, 1, 1},
    758                            /*pad_low=*/{0, 1, 0, 0},
    759                            /*pad_high=*/{0, 0, 0, 0},
    760                            /*layout=*/{3, 2, 1, 0},
    761                            /*reducer=*/kAdd},
    762 
    763     // With stride, low padding and high padding.
    764     R4ReduceWindowTestData{/*base_bounds=*/{4, 3, 17, 140},
    765                            /*window_bounds=*/{3, 4, 1, 1},
    766                            /*strides=*/{3, 1, 1, 1},
    767                            /*pad_low=*/{10, 1, 0, 0},
    768                            /*pad_high=*/{2, 3, 0, 0},
    769                            /*layout=*/{3, 2, 1, 0},
    770                            /*reducer=*/kAdd},
    771 
    772     // With minor dimension == 129.
    773     R4ReduceWindowTestData{/*base_bounds=*/{3, 2, 7, 129},
    774                            /*window_bounds=*/{1, 1, 1, 1},
    775                            /*strides=*/{1, 1, 1, 1},
    776                            /*pad_low=*/{0, 0, 0, 0},
    777                            /*pad_high=*/{0, 0, 0, 0},
    778                            /*layout=*/{3, 2, 1, 0},
    779                            /*reducer=*/kAdd},
    780 
    781     // With minor dims reduction and non-overlapped stride.
    782     R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16},
    783                            /*window_bounds=*/{1, 1, 2, 2},
    784                            /*strides=*/{1, 1, 2, 2},
    785                            /*pad_low=*/{0, 0, 0, 0},
    786                            /*pad_high=*/{0, 0, 0, 0},
    787                            /*layout=*/{3, 2, 1, 0},
    788                            /*reducer=*/kAdd},
    789 
    790     // With minor dims reduction and overlapped stride.
    791     R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16},
    792                            /*window_bounds=*/{1, 1, 4, 4},
    793                            /*strides=*/{1, 1, 2, 2},
    794                            /*pad_low=*/{0, 0, 0, 0},
    795                            /*pad_high=*/{1, 0, 0, 0},
    796                            /*layout=*/{3, 2, 1, 0},
    797                            /*reducer=*/kAdd},
    798 
    799     R4ReduceWindowTestData{/*base_bounds=*/{8, 100, 100, 3},
    800                            /*window_bounds=*/{1, 64, 64, 1},
    801                            /*strides=*/{1, 64, 64, 1},
    802                            /*pad_low=*/{0, 0, 0, 0},
    803                            /*pad_high=*/{0, 0, 0, 0},
    804                            /*layout=*/{3, 0, 2, 1},
    805                            /*reducer=*/kAdd},
    806 
    807     R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 8, 64},
    808                            /*window_bounds=*/{112, 112, 1, 8},
    809                            /*strides=*/{112, 112, 1, 8},
    810                            /*pad_low=*/{0, 0, 0, 0},
    811                            /*pad_high=*/{0, 0, 0, 0},
    812                            /*layout=*/{3, 2, 1, 0},
    813                            /*reducer=*/kMax},
    814 
    815     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
    816                            /*window_bounds=*/{2, 3, 4, 5},
    817                            /*strides=*/{1, 1, 1, 1},
    818                            /*pad_low=*/{0, 0, 0, 0},
    819                            /*pad_high=*/{0, 0, 0, 0},
    820                            /*layout=*/{3, 2, 1, 0},
    821                            /*reducer=*/kAdd},
    822 
    823     // With 0321 layout.
    824     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
    825                            /*window_bounds=*/{2, 3, 4, 5},
    826                            /*strides=*/{1, 2, 3, 4},
    827                            /*pad_low=*/{0, 0, 0, 0},
    828                            /*pad_high=*/{0, 0, 0, 0},
    829                            /*layout=*/{0, 3, 2, 1},
    830                            /*reducer=*/kAdd},
    831 
    832     // With 0123 layout.
    833     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 13, 17},
    834                            /*window_bounds=*/{2, 3, 7, 9},
    835                            /*strides=*/{1, 2, 5, 8},
    836                            /*pad_low=*/{0, 0, 0, 0},
    837                            /*pad_high=*/{0, 0, 0, 0},
    838                            /*layout=*/{0, 1, 2, 3},
    839                            /*reducer=*/kAdd},
    840 };
    841 
    842 INSTANTIATE_TEST_CASE_P(
    843     R4ReduceWindowTestInstantiation, R4ReduceWindowTest,
    844     ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues),
    845                        ::testing::ValuesIn(use_bfloat16_params)),
    846     R4ReduceWindowTestDataToString);
    847 
    848 class R4ReduceWindowLargeTest : public R4ReduceWindowTest {};
    849 
    850 XLA_TEST_P(R4ReduceWindowLargeTest, DISABLED_ON_INTERPRETER(DoIt)) { DoIt(); }
    851 
    852 // Test cases that are large/slow/failed.
    853 const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = {
    854     R4ReduceWindowTestData{/*base_bounds=*/{28, 28, 256, 128},
    855                            /*window_bounds=*/{3, 3, 1, 5},
    856                            /*strides=*/{1, 1, 1, 5},
    857                            /*pad_low=*/{1, 1, 0, 0},
    858                            /*pad_high=*/{1, 1, 0, 0},
    859                            /*layout=*/{3, 2, 1, 0},
    860                            /*reducer=*/kMax},
    861 
    862     R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 64, 128},
    863                            /*window_bounds=*/{3, 3, 1, 1},
    864                            /*strides=*/{2, 2, 1, 1},
    865                            /*pad_low=*/{0, 0, 0, 0},
    866                            /*pad_high=*/{1, 1, 0, 0},
    867                            /*layout=*/{3, 2, 1, 0},
    868                            /*reducer=*/kAdd},
    869 
    870     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 32768 - 3, 2},
    871                            /*window_bounds=*/{1, 1, 4, 1},
    872                            /*strides=*/{1, 1, 4, 1},
    873                            /*pad_low=*/{0, 0, 1, 0},
    874                            /*pad_high=*/{0, 0, 2, 0},
    875                            /*layout=*/{3, 2, 1, 0},
    876                            /*reducer=*/kMax},
    877 
    878     // Patterns generated by cumsum/cumprod.
    879     R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
    880                            /*window_bounds=*/{1021, 1, 1, 1},
    881                            /*strides=*/{1, 1, 1, 1},
    882                            /*pad_low=*/{1020, 0, 0, 0},
    883                            /*pad_high=*/{0, 0, 0, 0},
    884                            /*layout=*/{3, 2, 1, 0},
    885                            /*reducer=*/kAdd},
    886 
    887     R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
    888                            /*window_bounds=*/{1, 1, 1021, 1},
    889                            /*strides=*/{1, 1, 1, 1},
    890                            /*pad_low=*/{0, 0, 1020, 0},
    891                            /*pad_high=*/{0, 0, 0, 0},
    892                            /*layout=*/{3, 2, 1, 0},
    893                            /*reducer=*/kAdd},
    894 
    895     R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021},
    896                            /*window_bounds=*/{1, 1, 1, 1021},
    897                            /*strides=*/{1, 1, 1, 1},
    898                            /*pad_low=*/{0, 0, 0, 1020},
    899                            /*pad_high=*/{0, 0, 0, 0},
    900                            /*layout=*/{3, 2, 1, 0},
    901                            /*reducer=*/kAdd},
    902 
    903     R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
    904                            /*window_bounds=*/{1021, 1, 1, 1},
    905                            /*strides=*/{1, 1, 1, 1},
    906                            /*pad_low=*/{1021, 0, 0, 0},
    907                            /*pad_high=*/{0, 0, 0, 0},
    908                            /*layout=*/{3, 2, 1, 0},
    909                            /*reducer=*/kAdd},
    910 
    911     R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 1021, 16},
    912                            /*window_bounds=*/{1, 1, 1021, 1},
    913                            /*strides=*/{1, 1, 1, 1},
    914                            /*pad_low=*/{0, 0, 1021, 0},
    915                            /*pad_high=*/{0, 0, 0, 0},
    916                            /*layout=*/{3, 2, 1, 0},
    917                            /*reducer=*/kAdd},
    918 
    919     R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021},
    920                            /*window_bounds=*/{1, 1, 1, 1021},
    921                            /*strides=*/{1, 1, 1, 1},
    922                            /*pad_low=*/{0, 0, 0, 1021},
    923                            /*pad_high=*/{0, 0, 0, 0},
    924                            /*layout=*/{3, 2, 1, 0},
    925                            /*reducer=*/kAdd},
    926 };
    927 
    928 INSTANTIATE_TEST_CASE_P(
    929     R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest,
    930     ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues),
    931                        ::testing::ValuesIn(use_bfloat16_params)),
    932     R4ReduceWindowTestDataToString);
    933 
    934 struct R3ReduceWindowTestData {
    935   int64 base_bounds[3];
    936   int64 window_bounds[3];
    937   int64 strides[3];
    938   int64 layout[3];
    939   Padding padding;
    940   Reducer reducer;
    941 } kR3TestCases[] = {
    942     {/*base_bounds=*/{2, 1, 2}, /*window_bounds=*/{1, 1, 2},
    943      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
    944      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    945     {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2},
    946      /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
    947      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
    948     {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2},
    949      /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
    950      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    951     {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
    952      /*strides=*/{1, 2, 2}, /*layout=*/{2, 1, 0},
    953      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    954     {/*base_bounds=*/{10, 21, 129}, /*window_bounds=*/{2, 9, 1},
    955      /*strides=*/{5, 2, 1}, /*layout=*/{2, 1, 0},
    956      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
    957     {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
    958      /*strides=*/{1, 2, 2}, /*layout=*/{0, 1, 2},
    959      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    960     {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
    961      /*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2},
    962      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    963     {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251},
    964      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
    965      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
    966     {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3},
    967      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
    968      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    969     {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64},
    970      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
    971      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
    972     {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257},
    973      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
    974      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
    975     {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
    976      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
    977      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    978     {/*base_bounds=*/{9999, 1, 1}, /*window_bounds=*/{9999, 1, 1},
    979      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
    980      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    981     {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
    982      /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
    983      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    984 };
    985 
    986 string R3ReduceWindowTestDataToString(
    987     const ::testing::TestParamInfo<
    988         ::testing::tuple<R3ReduceWindowTestData, bool>>& data) {
    989   const auto& param = ::testing::get<0>(data.param);
    990   string str = absl::StrCat(
    991       "base_bounds_", absl::StrJoin(param.base_bounds, "x"), "__window_bounds_",
    992       absl::StrJoin(param.window_bounds, "x"), "__strides_",
    993       absl::StrJoin(param.strides, "x"), "__padding_",
    994       param.padding == Padding::kSame ? "same" : "valid", "__layout_",
    995       param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_",
    996       param.reducer == kAdd ? "add" : "max");
    997   if (::testing::get<1>(data.param)) {
    998     absl::StrAppend(&str, "_bfloat16");
    999   }
   1000   return str;
   1001 }
   1002 
   1003 class R3ReduceWindowTest : public ReduceWindowTestBase,
   1004                            public ::testing::WithParamInterface<
   1005                                ::testing::tuple<R3ReduceWindowTestData, bool>> {
   1006  protected:
   1007   R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
   1008 };
   1009 
   1010 TEST_P(R3ReduceWindowTest, DoIt) {
   1011   XlaBuilder b(TestName());
   1012   const auto& param = ::testing::get<0>(GetParam());
   1013 
   1014   const float kInitValue = 0.0f;
   1015   Array3D<float> input(param.base_bounds[0], param.base_bounds[1],
   1016                        param.base_bounds[2]);
   1017   // Choose a prime iota length so that each window sees a unique set of values.
   1018   // (Technically, the requirement is that the iota length is relatively prime
   1019   // to all of the dimensions involved in the reduce-window.)
   1020   input.FillRepeatedIota(0, 137);
   1021   Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout(
   1022       input, LayoutUtil::MakeLayout(param.layout));
   1023   auto reducer = param.reducer;
   1024   if (use_bfloat16()) {
   1025     input_literal = LiteralUtil::ConvertF32ToBF16(input_literal);
   1026 
   1027     // To avoid numerical issues, force the reducer to be kMax for bf16
   1028     // inputs.
   1029     reducer = kMax;
   1030   }
   1031 
   1032   XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input");
   1033   auto init_value =
   1034       CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
   1035 
   1036   auto computation = reducer == kAdd
   1037                          ? CreateScalarAddComputation(FloatType(), &b)
   1038                          : CreateScalarMaxComputation(FloatType(), &b);
   1039 
   1040   ReduceWindow(/*operand=*/parameter,
   1041                /*init_value=*/init_value,
   1042                /*computation=*/computation,
   1043                /*window_dimensions=*/param.window_bounds,
   1044                /*window_strides=*/param.strides, /*padding=*/param.padding);
   1045 
   1046   ComputeAndCompare(&b, {std::move(input_literal)}, DefaultErrorSpec());
   1047 }
   1048 
   1049 INSTANTIATE_TEST_CASE_P(
   1050     R3ReduceWindowTestInstantiation, R3ReduceWindowTest,
   1051     ::testing::Combine(::testing::ValuesIn(kR3TestCases),
   1052                        ::testing::ValuesIn(use_bfloat16_params)),
   1053     R3ReduceWindowTestDataToString);
   1054 
   1055 struct R2ReduceWindowTestData {
   1056   int64 base_bounds[2];
   1057   int64 window_bounds[2];
   1058   int64 strides[2];
   1059   int64 pad_low[2];
   1060   int64 pad_high[2];
   1061   int64 layout[2];
   1062   Reducer reducer;
   1063 } kR2TestCases[] = {
   1064     {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4},
   1065      /*strides=*/{1, 2}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1},
   1066      /*layout=*/{0, 1},
   1067      /*reducer=*/Reducer::kAdd},
   1068     {/*base_bounds=*/{2, 5}, /*window_bounds=*/{2, 4},
   1069      /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 2},
   1070      /*layout=*/{0, 1},
   1071      /*reducer=*/Reducer::kAdd},
   1072     {/*base_bounds=*/{1, 3}, /*window_bounds=*/{2, 3},
   1073      /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1},
   1074      /*layout=*/{0, 1},
   1075      /*reducer=*/Reducer::kAdd},
   1076     {/*base_bounds=*/{3, 129}, /*window_bounds=*/{1, 100},
   1077      /*strides=*/{2, 99}, /*pad_low=*/{0, 0}, /*pad_high=*/{35, 35},
   1078      /*layout=*/{0, 1},
   1079      /*reducer=*/Reducer::kAdd},
   1080 // TODO(b/74260408): This test last failed on GPU on 2018-03-08, likely due to a
   1081 // ptxas bug.
   1082 #ifndef XLA_TEST_BACKEND_GPU
   1083     {/*base_bounds=*/{6, 152}, /*window_bounds=*/{2, 25},
   1084      /*strides=*/{5, 4}, /*pad_low=*/{0, 1}, /*pad_high=*/{10, 11},
   1085      /*layout=*/{0, 1},
   1086      /*reducer=*/Reducer::kAdd},
   1087 #endif
   1088     {/*base_bounds=*/{6, 4}, /*window_bounds=*/{4, 2},
   1089      /*strides=*/{3, 3}, /*pad_low=*/{0, 1}, /*pad_high=*/{0, 1},
   1090      /*layout=*/{0, 1},
   1091      /*reducer=*/Reducer::kAdd},
   1092     {/*base_bounds=*/{5, 147}, /*window_bounds=*/{1, 36},
   1093      /*strides=*/{4, 5}, /*pad_low=*/{0, 0}, /*pad_high=*/{17, 17},
   1094      /*layout=*/{1, 0},
   1095      /*reducer=*/Reducer::kAdd},
   1096     {/*base_bounds=*/{4, 153}, /*window_bounds=*/{2, 93},
   1097      /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{46, 46},
   1098      /*layout=*/{1, 0},
   1099      /*reducer=*/Reducer::kAdd},
   1100     // Regression test for a bug that appeared in Inception (b/34784899).
   1101     {/*base_bounds=*/{28, 28}, /*window_bounds=*/{3, 3},
   1102      /*strides=*/{1, 1}, /*pad_low=*/{1, 1}, /*pad_high=*/{1, 1},
   1103      /*layout=*/{1, 0},
   1104      /*reducer=*/Reducer::kAdd},
   1105     {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
   1106      /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
   1107      /*layout=*/{1, 0},
   1108      /*reducer=*/Reducer::kAdd},
   1109     // Regression test for a bug that appeared in Inception (b/34784899).
   1110     {/*base_bounds=*/{4, 32}, /*window_bounds=*/{2, 2},
   1111      /*strides=*/{2, 2}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
   1112      /*layout=*/{1, 0},
   1113      /*reducer=*/Reducer::kAdd},
   1114     // Regression test for b/73903312: bf16 lacks precision to store result of
   1115     // very large windows. Testing with a reasonable window larger than 128.
   1116     {/*base_bounds=*/{8, 130}, /*window_bounds=*/{1, 130},
   1117      /*strides=*/{1, 1}, /*pad_low=*/{0, 130}, /*pad_high=*/{0, 0},
   1118      /*layout=*/{1, 0},
   1119      /*reducer=*/Reducer::kAdd},
   1120     {/*base_bounds=*/{8, 256}, /*window_bounds=*/{1, 4},
   1121      /*strides=*/{1, 64}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
   1122      /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
   1123     {/*base_bounds=*/{4096, 4096}, /*window_bounds=*/{1, 4},
   1124      /*strides=*/{1, 1024}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0},
   1125      /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
   1126     // Regression test for b/72234705: bf16 lacks precision to store incremental
   1127     // results on very large windows. Using smaller window with minor dim 128.
   1128     {/*base_bounds=*/{8, 128}, /*window_bounds=*/{2, 128},
   1129      /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0},
   1130      /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
   1131 };
   1132 
   1133 string R2ReduceWindowTestDataToString(
   1134     const ::testing::TestParamInfo<
   1135         ::testing::tuple<R2ReduceWindowTestData, bool>>& data) {
   1136   const auto& param = ::testing::get<0>(data.param);
   1137   string str = absl::StrCat(
   1138       "base_bounds_", absl::StrJoin(param.base_bounds, "x"),        //
   1139       "__window_bounds_", absl::StrJoin(param.window_bounds, "x"),  //
   1140       "__strides_", absl::StrJoin(param.strides, "x"),              //
   1141       "__pad_low_", absl::StrJoin(param.pad_low, "x"), "__pad_high_",
   1142       absl::StrJoin(param.pad_high, "x"), "__layout_", param.layout[0], "_",
   1143       param.layout[1],  //
   1144       "__reducer_", param.reducer == kAdd ? "add" : "max");
   1145   if (::testing::get<1>(data.param)) {
   1146     absl::StrAppend(&str, "_bfloat16");
   1147   }
   1148   return str;
   1149 }
   1150 
   1151 class R2ReduceWindowTest : public ReduceWindowTestBase,
   1152                            public ::testing::WithParamInterface<
   1153                                ::testing::tuple<R2ReduceWindowTestData, bool>> {
   1154  protected:
   1155   R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
   1156 
   1157   void DoIt() {
   1158     XlaBuilder b(TestName());
   1159     const auto& param = ::testing::get<0>(GetParam());
   1160 
   1161     const float kInitValue = 0.0f;
   1162     Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
   1163     Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout(
   1164         input, LayoutUtil::MakeLayout(param.layout));
   1165 
   1166     XlaOp parameter;
   1167     auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
   1168                                                        &b, &parameter);
   1169     std::vector<std::pair<int64, int64>> padding(2);
   1170     for (int i = 0; i < 2; ++i) {
   1171       padding[i] = {param.pad_low[i], param.pad_high[i]};
   1172     }
   1173     auto computation = param.reducer == kAdd
   1174                            ? CreateScalarAddComputation(FloatType(), &b)
   1175                            : CreateScalarMaxComputation(FloatType(), &b);
   1176     auto init_value =
   1177         CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
   1178     ReduceWindowWithGeneralPadding(
   1179         /*operand=*/parameter,
   1180         /*init_value=*/init_value,
   1181         /*computation=*/computation,
   1182         /*window_dimensions=*/param.window_bounds,
   1183         /*window_strides=*/param.strides,
   1184         /*base_dilations=*/{},
   1185         /*window_dilations=*/{},
   1186         /*padding=*/padding);
   1187 
   1188     auto reduce_func = param.reducer == kAdd
   1189                            ? +[](float a, float b) { return a + b; }
   1190                            : +[](float a, float b) { return std::max(a, b); };
   1191     auto expected = ReferenceUtil::ReduceWindow2DGeneric(
   1192         /*operand=*/input, /*init=*/kInitValue, /*reduce_func=*/reduce_func,
   1193         /*window=*/param.window_bounds,
   1194         /*stride=*/param.strides, /*padding=*/padding);
   1195 
   1196     ComputeAndCompareLiteral(&b, LiteralUtil::CreateFromArray(*expected),
   1197                              {input_arg.get()}, DefaultErrorSpec());
   1198   }
   1199 };
   1200 
   1201 TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); }
   1202 
   1203 INSTANTIATE_TEST_CASE_P(
   1204     R2ReduceWindowTestInstantiation, R2ReduceWindowTest,
   1205     ::testing::Combine(::testing::ValuesIn(kR2TestCases),
   1206                        ::testing::ValuesIn(use_bfloat16_params)),
   1207     R2ReduceWindowTestDataToString);
   1208 
   1209 struct R1ReduceWindowTestData {
   1210   int64 base_bounds[1];
   1211   int64 window_bounds[1];
   1212   int64 strides[1];
   1213   int64 pad_low[1];
   1214   int64 pad_high[1];
   1215   Reducer reducer;
   1216 } kR1TestCases[] = {
   1217     {/*base_bounds=*/{1}, /*window_bounds=*/{1},
   1218      /*strides=*/{1},
   1219      /*pad_low=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].first},
   1220      /*pad_high=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].second},
   1221      /*reducer=*/Reducer::kAdd},
   1222 
   1223     {/*base_bounds=*/{3}, /*window_bounds=*/{3},
   1224      /*strides=*/{1},
   1225      /*pad_low=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].first},
   1226      /*pad_high=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].second},
   1227      /*reducer=*/Reducer::kAdd},
   1228 
   1229     {/*base_bounds=*/{3}, /*window_bounds=*/{2},
   1230      /*strides=*/{1},
   1231      /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].first},
   1232      /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].second},
   1233      /*reducer=*/Reducer::kAdd},
   1234 
   1235     {/*base_bounds=*/{5}, /*window_bounds=*/{1},
   1236      /*strides=*/{1},
   1237      /*pad_low=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].first},
   1238      /*pad_high=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].second},
   1239      /*reducer=*/Reducer::kMax},
   1240 
   1241     {/*base_bounds=*/{16}, /*window_bounds=*/{4},
   1242      /*strides=*/{4},
   1243      /*pad_low=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].first},
   1244      /*pad_high=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].second},
   1245      /*reducer=*/Reducer::kMax},
   1246 
   1247     {/*base_bounds=*/{16}, /*window_bounds=*/{4},
   1248      /*strides=*/{3},
   1249      /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].first},
   1250      /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].second},
   1251      /*reducer=*/Reducer::kAdd},
   1252 
   1253     {/*base_bounds=*/{128 * 2},
   1254      /*window_bounds=*/{30},
   1255      /*strides=*/{27},
   1256      /*pad_low=*/
   1257      {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].first},
   1258      /*pad_high=*/
   1259      {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].second},
   1260      /*reducer=*/Reducer::kAdd},
   1261 
   1262     {/*base_bounds=*/{128 * 17},
   1263      /*window_bounds=*/{7},
   1264      /*strides=*/{64},
   1265      /*pad_low=*/
   1266      {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].first},
   1267      /*pad_high=*/
   1268      {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].second},
   1269      /*reducer=*/Reducer::kAdd},
   1270 
   1271     {/*base_bounds=*/{128 * 2},
   1272      /*window_bounds=*/{32},
   1273      /*strides=*/{56},
   1274      /*pad_low=*/
   1275      {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].first},
   1276      /*pad_high=*/
   1277      {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].second},
   1278      /*reducer=*/Reducer::kAdd},
   1279 
   1280     {/*base_bounds=*/{3}, /*window_bounds=*/{2},
   1281      /*strides=*/{1},
   1282      /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].first},
   1283      /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].second},
   1284      /*reducer=*/Reducer::kAdd},
   1285 
   1286     {/*base_bounds=*/{5}, /*window_bounds=*/{3},
   1287      /*strides=*/{2},
   1288      /*pad_low=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].first},
   1289      /*pad_high=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].second},
   1290      /*reducer=*/Reducer::kAdd},
   1291 
   1292     {/*base_bounds=*/{16}, /*window_bounds=*/{4},
   1293      /*strides=*/{3},
   1294      /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].first},
   1295      /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].second},
   1296      /*reducer=*/Reducer::kAdd},
   1297 
   1298     {/*base_bounds=*/{5}, /*window_bounds=*/{5},
   1299      /*strides=*/{1},
   1300      /*pad_low=*/{0},
   1301      /*pad_high=*/{5},
   1302      /*reducer=*/Reducer::kAdd},
   1303 
   1304     {/*base_bounds=*/{5}, /*window_bounds=*/{5},
   1305      /*strides=*/{1},
   1306      /*pad_low=*/{5},
   1307      /*pad_high=*/{0},
   1308      /*reducer=*/Reducer::kAdd},
   1309 
   1310     // The pattern generated by inclusive scan (cumsum/cumprod).
   1311     {/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
   1312      /*strides=*/{1},
   1313      /*pad_low=*/{4095},
   1314      /*pad_high=*/{0},
   1315      /*reducer=*/Reducer::kMax},
   1316 
   1317     // The pattern generated by exclusive scan (cumsum/cumprod).
   1318     {/*base_bounds=*/{4095}, /*window_bounds=*/{4095},
   1319      /*strides=*/{1},
   1320      /*pad_low=*/{4095},
   1321      /*pad_high=*/{0},
   1322      /*reducer=*/Reducer::kMax},
   1323 };
   1324 
   1325 string R1ReduceWindowTestDataToString(
   1326     const ::testing::TestParamInfo<
   1327         ::testing::tuple<R1ReduceWindowTestData, bool>>& data) {
   1328   const auto& param = ::testing::get<0>(data.param);
   1329   string str =
   1330       absl::StrCat("base_bounds_", absl::StrJoin(param.base_bounds, "x"),
   1331                    "__window_bounds_", absl::StrJoin(param.window_bounds, "x"),
   1332                    "__strides_", absl::StrJoin(param.strides, "x"),
   1333                    "__pad_low_", absl::StrJoin(param.pad_low, "x"),
   1334                    "__pad_high_", absl::StrJoin(param.pad_high, "x"),
   1335                    "__reducer_", param.reducer == kAdd ? "add" : "max");
   1336   if (::testing::get<1>(data.param)) {
   1337     absl::StrAppend(&str, "_bfloat16");
   1338   }
   1339   return str;
   1340 }
   1341 
   1342 class R1ReduceWindowTest : public ReduceWindowTestBase,
   1343                            public ::testing::WithParamInterface<
   1344                                ::testing::tuple<R1ReduceWindowTestData, bool>> {
   1345  protected:
   1346   R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
   1347 };
   1348 
   1349 TEST_P(R1ReduceWindowTest, DoIt) {
   1350   XlaBuilder b(TestName());
   1351   const auto& param = ::testing::get<0>(GetParam());
   1352   CHECK(param.reducer == kAdd || param.reducer == kMax);
   1353 
   1354   const float kInitValue = 0.0f;
   1355   std::vector<float> input_vector(param.base_bounds[0]);
   1356   std::iota(std::begin(input_vector), std::end(input_vector), 0);
   1357   Literal input_literal =
   1358       LiteralUtil::CreateR1(absl::Span<const float>(input_vector));
   1359   XlaOp parameter;
   1360   auto input_arg =
   1361       CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, &parameter);
   1362 
   1363   std::vector<std::pair<int64, int64>> padding(1);
   1364   padding[0] = {param.pad_low[0], param.pad_high[0]};
   1365 
   1366   auto computation = param.reducer == kAdd
   1367                          ? CreateScalarAddComputation(FloatType(), &b)
   1368                          : CreateScalarMaxComputation(FloatType(), &b);
   1369   auto init_value =
   1370       CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
   1371   ReduceWindowWithGeneralPadding(
   1372       /*operand=*/parameter,
   1373       /*init_value=*/init_value,
   1374       /*computation=*/computation,
   1375       /*window_dimensions=*/param.window_bounds,
   1376       /*window_strides=*/param.strides,
   1377       /*base_dilations=*/{},
   1378       /*window_dilations=*/{},
   1379       /*padding=*/padding);
   1380 
   1381   auto reduce_func = param.reducer == kAdd
   1382                          ? +[](float a, float b) { return a + b; }
   1383                          : +[](float a, float b) { return std::max(a, b); };
   1384   auto expected = ReferenceUtil::ReduceWindow1DGeneric(
   1385       /*operand=*/absl::Span<const float>(input_vector),
   1386       /*init=*/kInitValue,
   1387       /*reduce_func=*/reduce_func,
   1388       /*window=*/param.window_bounds,
   1389       /*stride=*/param.strides,
   1390       /*padding=*/padding);
   1391 
   1392   ComputeAndCompareLiteral(&b, LiteralUtil::CreateR1<float>(*expected),
   1393                            {input_arg.get()}, DefaultErrorSpec());
   1394 }
   1395 
   1396 INSTANTIATE_TEST_CASE_P(
   1397     R1ReduceWindowTestInstantiation, R1ReduceWindowTest,
   1398     ::testing::Combine(::testing::ValuesIn(kR1TestCases),
   1399                        ::testing::ValuesIn(use_bfloat16_params)),
   1400     R1ReduceWindowTestDataToString);
   1401 
   1402 // Test class for text-based test cases. Note that this compares with the
   1403 // results on the interpreter backend.
   1404 class ReduceWindowTextTest : public HloTestBase {};
   1405 
   1406 XLA_TEST_F(ReduceWindowTextTest, R2General256x384) {
   1407   const string hlo_string = R"(
   1408 HloModule R2Window
   1409 mul {
   1410   lhs = f32[] parameter(0)
   1411   rhs = f32[] parameter(1)
   1412   ROOT mul = f32[] multiply(lhs, rhs)
   1413 }
   1414 ENTRY R2Window {
   1415   operand = f32[256,384]{1,0} parameter(0)
   1416   constant = f32[] constant(1)
   1417   ROOT reduce-window = f32[256,384]{1,0} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
   1418 }
   1419 )";
   1420   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
   1421 }
   1422 
   1423 XLA_TEST_F(ReduceWindowTextTest, R2General256x384Layout01) {
   1424   const string hlo_string = R"(
   1425 HloModule R2Window
   1426 mul {
   1427 lhs = f32[] parameter(0)
   1428 rhs = f32[] parameter(1)
   1429 ROOT mul = f32[] multiply(lhs, rhs)
   1430 }
   1431 ENTRY R2Window {
   1432 operand = f32[256,384]{0,1} parameter(0)
   1433 constant = f32[] constant(1)
   1434 ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
   1435 }
   1436 )";
   1437   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
   1438 }
   1439 
   1440 XLA_TEST_F(ReduceWindowTextTest, R2General2x5) {
   1441   const string hlo_string = R"(
   1442 HloModule R2Window
   1443 mul {
   1444   lhs = f32[] parameter(0)
   1445   rhs = f32[] parameter(1)
   1446   ROOT mul = f32[] multiply(lhs, rhs)
   1447 }
   1448 ENTRY R2Window {
   1449   operand = f32[2,5]{1,0} parameter(0)
   1450   constant = f32[] constant(1)
   1451   ROOT reduce-window = f32[3,5]{1,0} reduce-window(operand, constant), window={size=2x1 pad=0_2x0_0}, to_apply=mul
   1452 }
   1453 )";
   1454   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
   1455 }
   1456 
   1457 XLA_TEST_F(ReduceWindowTextTest, R2EffectiveScalar) {
   1458   const string hlo_string = R"(
   1459 HloModule R2Window
   1460 mul {
   1461   lhs = f32[] parameter(0)
   1462   rhs = f32[] parameter(1)
   1463   ROOT mul = f32[] multiply(lhs, rhs)
   1464 }
   1465 ENTRY R2Window {
   1466   operand = f32[1,1]{1,0} parameter(0)
   1467   negate = f32[1,1]{1,0} negate(operand)
   1468   constant = f32[] constant(1)
   1469   ROOT reduce-window = f32[1,1]{1,0} reduce-window(negate, constant), window={size=1x1 pad=0_0x0_0}, to_apply=mul
   1470 }
   1471 )";
   1472   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
   1473 }
   1474 
   1475 XLA_TEST_F(ReduceWindowTextTest, R3EffectiveScalar) {
   1476   const string hlo_string = R"(
   1477 HloModule R3Window
   1478 mul {
   1479   lhs = f32[] parameter(0)
   1480   rhs = f32[] parameter(1)
   1481   ROOT mul = f32[] multiply(lhs, rhs)
   1482 }
   1483 ENTRY R3Window {
   1484   operand = f32[1,1,1]{2,1,0} parameter(0)
   1485   negate = f32[1,1,1]{2,1,0} negate(operand)
   1486   constant = f32[] constant(1)
   1487   ROOT reduce-window = f32[1,1,1]{2,1,0} reduce-window(negate, constant), window={size=1x1x1 pad=0_0x0_0x0_0}, to_apply=mul
   1488 }
   1489 )";
   1490   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
   1491 }
   1492 
   1493 XLA_TEST_F(HloTestBase, ReduceWindowIdentity) {
   1494   const string hlo_string = R"(
   1495 HloModule ReduceWindowIdentity
   1496 identity.pad_to_reduce_window {
   1497   param0 = f32[] parameter(0)
   1498   ROOT param1 = f32[] parameter(1)
   1499 }
   1500 ENTRY reduce-window-identity {
   1501   operand = f32[1,32,64]{2,1,0} parameter(0)
   1502   constant.4466 = f32[] constant(0)
   1503   ROOT reduce-window = f32[1,33,64]{2,1,0} reduce-window(operand, constant.4466), window={size=1x1x1 pad=0_0x1_0x0_0}, to_apply=identity.pad_to_reduce_window
   1504 }
   1505 
   1506 )";
   1507   EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
   1508 }
   1509 
   1510 XLA_TEST_F(HloTestBase, ReduceWindowS32) {
   1511   const string hlo_string = R"(
   1512 HloModule reduce-window
   1513 
   1514 %identity.pad_to_reduce_window (param0: s32[], param1: s32[]) -> s32[] {
   1515   %param0 = s32[] parameter(0)
   1516   ROOT %param1 = s32[] parameter(1)
   1517 }
   1518 
   1519 ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] {
   1520   %parameter.0 = s32[81,8]{1,0} parameter(0)
   1521   %parameter.1 = s32[] parameter(1)
   1522   ROOT %reduce-window = s32[82,8]{1,0} reduce-window(s32[81,8]{1,0} %parameter.0, s32[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
   1523 }
   1524 
   1525 )";
   1526   EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
   1527 }
   1528 
   1529 XLA_TEST_F(HloTestBase, ReduceWindowS64) {
   1530   const string hlo_string = R"(
   1531 HloModule reduce-window
   1532 
   1533 %identity.pad_to_reduce_window (param0: s64[], param1: s64[]) -> s64[] {
   1534   %param0 = s64[] parameter(0)
   1535   ROOT %param1 = s64[] parameter(1)
   1536 }
   1537 
   1538 ENTRY %reduce-window (parameter.0: s64[81,8], parameter.1: s64[]) -> s64[82,8] {
   1539   %parameter.0 = s64[81,8]{1,0} parameter(0)
   1540   %parameter.1 = s64[] parameter(1)
   1541   ROOT %reduce-window = s64[82,8]{1,0} reduce-window(s64[81,8]{1,0} %parameter.0, s64[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
   1542 }
   1543 
   1544 )";
   1545   EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
   1546 }
   1547 
   1548 XLA_TEST_F(HloTestBase, ReduceWindowF16) {
   1549   const string hlo_string = R"(
   1550 HloModule reduce-window
   1551 
   1552 %identity.pad_to_reduce_window (param0: f16[], param1: f16[]) -> f16[] {
   1553   %param0 = f16[] parameter(0)
   1554   ROOT %param1 = f16[] parameter(1)
   1555 }
   1556 
   1557 ENTRY %reduce-window (parameter.0: f16[81,8], parameter.1: f16[]) -> f16[82,8] {
   1558   %parameter.0 = f16[81,8]{1,0} parameter(0)
   1559   %parameter.1 = f16[] parameter(1)
   1560   ROOT %reduce-window = f16[82,8]{1,0} reduce-window(f16[81,8]{1,0} %parameter.0, f16[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
   1561 }
   1562 
   1563 )";
   1564   EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
   1565 }
   1566 
   1567 }  // namespace
   1568 }  // namespace xla
   1569