Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      7     http://www.apache.org/licenses/LICENSE-2.0
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     16 // Tests the reduce-window XLA operation.
     18 #include <limits>
     19 #include <memory>
     21 #include "absl/memory/memory.h"
     22 #include "absl/strings/str_cat.h"
     23 #include "absl/strings/str_join.h"
     24 #include "absl/types/span.h"
     25 #include "tensorflow/compiler/xla/array2d.h"
     26 #include "tensorflow/compiler/xla/array3d.h"
     27 #include "tensorflow/compiler/xla/array4d.h"
     28 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
     29 #include "tensorflow/compiler/xla/client/local_client.h"
     30 #include "tensorflow/compiler/xla/client/padding.h"
     31 #include "tensorflow/compiler/xla/client/xla_builder.h"
     32 #include "tensorflow/compiler/xla/client/xla_computation.h"
     33 #include "tensorflow/compiler/xla/reference_util.h"
     34 #include "tensorflow/compiler/xla/shape_util.h"
     35 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     36 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     37 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     38 #include "tensorflow/compiler/xla/tests/test_macros.h"
     39 #include "tensorflow/compiler/xla/xla_data.pb.h"
     40 #include "tensorflow/core/lib/core/status.h"
     41 #include "tensorflow/core/lib/core/status_test_util.h"
     42 #include "tensorflow/core/platform/test.h"
     43 #include "tensorflow/core/platform/types.h"
     45 namespace xla {
     46 namespace {
     49 // Tests both F32 and BF16.
     50 static std::array<bool, 2> use_bfloat16_params{false, true};
     51 #else
     52 // Only tests F32.
     53 static std::array<bool, 1> use_bfloat16_params{false};
     54 #endif
     56 class ReduceWindowTestBase : public ClientLibraryTestBase {
     57  public:
     58   ErrorSpec DefaultErrorSpec() const {
     59     if (use_bfloat16()) {
     60       return ErrorSpec(2e-1, 6e-2);
     61     } else {
     62       return ErrorSpec(1e-3, 1e-3);
     63     }
     64   }
     65 };
     67 class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
     68                          public ReduceWindowTestBase {
     69  public:
     70   ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); }
     72   void ReduceWindowAdd(const XlaOp& input,
     73                        absl::Span<const int64> window_dimensions,
     74                        absl::Span<const int64> window_strides,
     75                        Padding padding) {
     76     auto init = CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f),
     77                                           &builder_);
     78     ReduceWindow(input, init,
     79                  CreateScalarAddComputation(FloatType(), &builder_),
     80                  window_dimensions, window_strides, padding);
     81   }
     83   void ReduceWindowMax(const XlaOp& input,
     84                        absl::Span<const int64> window_dimensions,
     85                        absl::Span<const int64> window_strides,
     86                        Padding padding) {
     87     auto init =
     88         CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_);
     89     ReduceWindow(input, init,
     90                  CreateScalarMaxComputation(FloatType(), &builder_),
     91                  window_dimensions, window_strides, padding);
     92   }
     94   void ReduceWindowMin(const XlaOp& input,
     95                        absl::Span<const int64> window_dimensions,
     96                        absl::Span<const int64> window_strides,
     97                        Padding padding) {
     98     auto init =
     99         CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_);
    100     ReduceWindow(input, init,
    101                  CreateScalarMinComputation(FloatType(), &builder_),
    102                  window_dimensions, window_strides, padding);
    103   }
    105   XlaBuilder builder_;
    106 };
    108 TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
    109   const auto input = CreateConstantFromLiteral(
    110       LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
    111   const auto init_value =
    112       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0), &builder_);
    113   TF_ASSERT_OK(builder_.first_error());
    114   ReduceWindow(input, init_value,
    115                CreateScalarAddComputation(FloatType(), &builder_),
    116                /*window_dimensions=*/{1, 2},
    117                /*window_strides=*/{1}, Padding::kValid);
    118   ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT)
    119       << builder_.first_error();
    120   ASSERT_THAT(builder_.first_error().error_message(),
    121               ::testing::HasSubstr("Want input dimensions size"));
    122 }
    124 // Regression test for b/68964348.
    125 TEST_P(ReduceWindowTest, R0ReduceWindow) {
    126   const auto input =
    127       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(42.0), &builder_);
    128   const auto init =
    129       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(1.0), &builder_);
    130   ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_),
    131                /*window_dimensions=*/{},
    132                /*window_strides=*/{}, Padding::kSame);
    133   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR0<float>(43.0), {},
    134                            ErrorSpec(0.00001));
    135 }
    137 TEST_P(ReduceWindowTest, Min3In5Stride2) {
    138   const auto input = CreateConstantFromLiteral(
    139       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
    140   ReduceWindowMin(input, {3}, {2}, Padding::kValid);
    141   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({100, 1}),
    142                            {}, ErrorSpec(0.00001));
    143 }
    145 TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) {
    146   const auto input = CreateConstantFromLiteral(
    147       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
    148   ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1},
    149                   Padding::kSame);
    150   ComputeAndCompareLiteral(&builder_,
    151                            LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
    152                            {}, ErrorSpec(0.00001));
    153 }
    155 XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) {
    156   Array4D<float> input_array(1, 0, 2, 1);
    157   const auto input = CreateConstantFromArray(input_array, &builder_);
    158   Padding padding = Padding::kSame;
    159   ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
    161   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
    162                                               {1, 1, 1, 1}, padding);
    164   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
    165                            DefaultErrorSpec());
    166 }
    168 TEST_P(ReduceWindowTest, NonSquareSmall) {
    169   Array4D<float> input_array(1, 2, 2, 1);
    170   input_array.FillRandom(2.f, 2.f);
    171   const auto input = CreateConstantFromArray(input_array, &builder_);
    173   Padding padding = Padding::kSame;
    174   ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
    176   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
    177                                               {1, 1, 1, 1}, padding);
    179   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
    180                            DefaultErrorSpec());
    181 }
    183 TEST_P(ReduceWindowTest, MiddleDimsSmall) {
    184   Array4D<float> input_array(1, 3, 3, 1);
    185   input_array.FillRandom(2.f, 2.f);
    186   const auto input = CreateConstantFromArray(input_array, &builder_);
    187   Padding padding = Padding::kSame;
    188   ReduceWindowAdd(input, {1, 1, 1, 1}, {1, 2, 2, 1}, padding);
    190   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1},
    191                                               {1, 2, 2, 1}, padding);
    193   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
    194                            DefaultErrorSpec());
    195 }
    197 TEST_P(ReduceWindowTest, Along2ndMinorDim) {
    198   Array4D<float> input_array(3, 6, 7, 32);
    199   input_array.FillRandom(2.f, 2.f);
    200   const auto input = CreateConstantFromArray(input_array, &builder_);
    202   // The parameters of this reduction mimic feature norm (e.g. LRN).
    203   int lrn_diameter = 7;  // diameter = 2*radius + 1 --> must be odd
    204   Padding padding = Padding::kSame;
    205   ReduceWindowAdd(input, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
    207   auto res = ReferenceUtil::ReduceWindow4DAdd(
    208       input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
    210   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
    211                            DefaultErrorSpec());
    212 }
    214 TEST_P(ReduceWindowTest, AmongMajor2Dims) {
    215   Array4D<float> input_array(4, 4, 6, 8);
    216   input_array.FillWithMinorDimNum();
    217   const auto input_data_handle =
    218       CreateConstantFromArray(input_array, &builder_);
    220   int win_len = 3;
    221   int win_stride = 1;
    223   Padding padding = Padding::kSame;
    224   // Reduce only along the x and y dimensions, according to the win_len.
    225   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
    226                   {win_stride, win_stride, 1, 1}, padding);
    228   auto result = ReferenceUtil::ReduceWindow4DAdd(
    229       input_array, 0.0f, {win_len, win_len, 1, 1},
    230       {win_stride, win_stride, 1, 1}, padding);
    232   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
    233                            DefaultErrorSpec());
    234 }
    236 TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
    237   Array4D<float> input_array(9, 12, 4, 89);
    238   input_array.FillRandom(2.f, 2.f);
    240   int win_len = 3;
    241   int win_stride = 2;
    243   const auto input_data_handle =
    244       CreateConstantFromArray(input_array, &builder_);
    246   Padding padding = Padding::kSame;
    247   // Reduce only along the x and y dimensions, according to the win_len.
    248   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
    249                   {win_stride, win_stride, 1, 1}, padding);
    251   auto result = ReferenceUtil::ReduceWindow4DAdd(
    252       input_array, 0.0f, {win_len, win_len, 1, 1},
    253       {win_stride, win_stride, 1, 1}, padding);
    255   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
    256                            DefaultErrorSpec());
    257 }
    259 // Tests the super windowing logic w.r.t handling prime number of windows in a
    260 // major dimension with reduction.
    261 TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) {
    262   Array4D<float> input_array(15, 15, 4, 128);
    263   input_array.FillRandom(2.f, 4.f);
    265   int win_len = 3;
    266   int win_stride = 2;
    268   const auto input_data_handle =
    269       CreateConstantFromArray(input_array, &builder_);
    271   Padding padding = Padding::kSame;
    272   // Reduce only along the x and y dimensions, according to the win_len.
    273   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
    274                   {win_stride, win_stride, 1, 1}, padding);
    276   auto result = ReferenceUtil::ReduceWindow4DAdd(
    277       input_array, 0.0f, {win_len, win_len, 1, 1},
    278       {win_stride, win_stride, 1, 1}, padding);
    280   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
    281                            DefaultErrorSpec());
    282 }
    284 TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
    285   Array4D<float> input_array(19, 17, 8, 256);
    286   input_array.FillWithMinorDimNum();
    288   const auto input_data_handle =
    289       CreateConstantFromArray(input_array, &builder_);
    291   Padding padding = Padding::kSame;
    292   ReduceWindowAdd(input_data_handle, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
    294   auto result = ReferenceUtil::ReduceWindow4DAdd(
    295       input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
    297   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
    298                            DefaultErrorSpec());
    299 }
    301 // Tests a reduction function that is not a simple add/min/max/etc.
    302 XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
    303   Array4D<float> input_array(1, 2, 2, 1);
    304   input_array(0, 0, 0, 0) = 1;
    305   input_array(0, 0, 1, 0) = 2;
    306   input_array(0, 1, 0, 0) = 3;
    307   input_array(0, 1, 1, 0) = 4;
    308   const auto input = CreateConstantFromArray(input_array, &builder_);
    310   Padding padding = Padding::kValid;
    311   const Shape scalar = ShapeUtil::MakeShape(FloatType(), {});
    312   auto b = builder_.CreateSubBuilder("unusual");
    313   auto lhs = Parameter(b.get(), 0, scalar, "lhs");
    314   auto rhs = Parameter(b.get(), 1, scalar, "rhs");
    315   Min(Add(lhs, rhs),
    316       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(8.0f), b.get()));
    317   XlaComputation reduce_fn = b->BuildAndNoteError();
    319   ReduceWindow(
    320       input,
    321       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f), &builder_),
    322       reduce_fn,
    323       /*window_dimensions=*/{1, 1, 2, 1},
    324       /*window_strides=*/{1, 1, 1, 1}, padding);
    326   const auto reduce_func = [](float arg1, float arg2) {
    327     return std::min<float>(arg1 + arg2, 8.0f);
    328   };
    330   auto expected =
    331       ReferenceUtil::ReduceWindow4DGeneric(input_array, 0.0f, reduce_func,
    332                                            /*window=*/{1, 1, 2, 1},
    333                                            /*stride=*/{1, 1, 1, 1}, padding);
    335   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*expected),
    336                            {}, DefaultErrorSpec());
    337 }
    339 TEST_P(ReduceWindowTest, R4UnitWindow) {
    340   Array4D<float> input_array(13, 12, 8, 15);
    341   input_array.FillRandom(2.f, 2.f);
    342   Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
    343       input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
    344   XlaOp input;
    345   auto input_data = CreateParameterAndTransferLiteral(
    346       0, input_literal, "parameter", &builder_, &input);
    348   Padding padding = Padding::kSame;
    349   ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding);
    351   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1},
    352                                               {1, 4, 1, 1}, padding);
    354   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
    355                            {input_data.get()}, DefaultErrorSpec());
    356 }
    358 XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
    359   std::vector<int64> input_dims(6, 8);
    360   auto shape = ShapeUtil::MakeShape(F32, input_dims);
    362   Literal arg_literal(shape);
    363   arg_literal.PopulateWithValue(1.0f);
    364   const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
    366   Padding padding = Padding::kValid;
    367   ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
    369   std::vector<int64> output_layout = {1, 5, 3, 2, 0, 4};
    370   std::vector<int64> output_dims = {6, 8, 6, 6, 8, 8};
    371   Shape result_shape =
    372       ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout);
    373   Literal expected(result_shape);
    374   expected.PopulateWithValue(27.0f);
    375   ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
    376 }
    378 XLA_TEST_P(ReduceWindowTest, R6Add) {
    379   std::vector<int64> input_dims(6, 8);
    380   auto shape = ShapeUtil::MakeShape(F32, input_dims);
    382   Literal arg_literal =
    383       LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
    385   const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
    387   Padding padding = Padding::kValid;
    388   ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
    390   std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8};
    391   Literal expected =
    392       LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
    394   ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
    395 }
    397 XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
    398   Array4D<float> input_array(2, 1, 27, 119);
    399   input_array.FillRandom(2.0f);
    400   Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
    401       input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
    402   XlaOp input;
    403   auto input_data = CreateParameterAndTransferLiteral(
    404       0, input_literal, "parameter", &builder_, &input);
    406   int win_len = 1;
    407   int stride = 8;
    408   Padding padding = Padding::kSame;
    409   ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    411   auto res = ReferenceUtil::ReduceWindow4DAdd(
    412       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    414   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
    415                            {input_data.get()}, DefaultErrorSpec());
    416 }
    418 XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
    419   Array4D<float> input_array(3, 2, 4, 64);
    420   input_array.FillRandom(2.0f);
    421   Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
    422       input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
    423   XlaOp input;
    424   auto input_data = CreateParameterAndTransferLiteral(
    425       0, input_literal, "parameter", &builder_, &input);
    427   int win_len = 3;
    428   int stride = 1;
    429   Padding padding = Padding::kSame;
    430   ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    432   auto res = ReferenceUtil::ReduceWindow4DAdd(
    433       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    435   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
    436                            {input_data.get()}, DefaultErrorSpec());
    437 }
    439 XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
    440   Array4D<float> input_array(1, 3, 12, 200);
    441   input_array.FillRandom(2.0f);
    442   Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
    443       input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
    444   XlaOp input;
    445   auto input_data = CreateParameterAndTransferLiteral(
    446       0, input_literal, "parameter", &builder_, &input);
    448   int win_len = 8;
    449   int stride = 5;
    450   Padding padding = Padding::kSame;
    451   ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    453   auto res = ReferenceUtil::ReduceWindow4DAdd(
    454       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
    456   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
    457                            {input_data.get()}, DefaultErrorSpec());
    458 }
    460 TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) {
    461   Array4D<float> input_array(6, 4, 10, 130);
    462   input_array.FillRandom(2.0f);
    464   int win_len = 3;
    465   int win_stride = 2;
    467   Padding padding = Padding::kSame;
    468   const auto input_data_handle =
    469       CreateConstantFromArray(input_array, &builder_);
    470   // Reduce only along the x and y dimensions, according to the win_len.
    471   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
    472                   {win_stride, win_stride, 1, 1}, padding);
    474   auto result = ReferenceUtil::ReduceWindow4DAdd(
    475       input_array, 0.0f, {win_len, win_len, 1, 1},
    476       {win_stride, win_stride, 1, 1}, padding);
    477   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
    478                            DefaultErrorSpec());
    479 }
    481 XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) {
    482   std::vector<float> input_vector(128 * 9, 1);
    483   const auto input = CreateConstantFromLiteral(
    484       LiteralUtil::CreateR1<float>(input_vector), &builder_);
    485   ReduceWindowAdd(input, {32}, {128}, Padding::kValid);
    486   ComputeAndCompareLiteral(
    487       &builder_,
    488       LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
    489       DefaultErrorSpec());
    490 }
    492 XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) {
    493   std::vector<float> input_vector{
    494       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    495       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    496       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    497       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    498       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    499       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    500       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    501       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
    502   const auto input = CreateConstantFromLiteral(
    503       LiteralUtil::CreateR1<float>(input_vector), &builder_);
    504   ReduceWindowAdd(input, {128}, {128}, Padding::kValid);
    505   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
    506                            DefaultErrorSpec());
    507 }
    509 XLA_TEST_P(ReduceWindowTest, Add128In128) {
    510   std::vector<float> input_vector{
    511       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    512       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    513       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    514       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    515       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    516       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    517       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    518       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
    519   const auto input = CreateConstantFromLiteral(
    520       LiteralUtil::CreateR1<float>(input_vector), &builder_);
    521   ReduceWindowAdd(input, {128}, {1}, Padding::kValid);
    522   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
    523                            DefaultErrorSpec());
    524 }
    526 // Regression test for a bug that appeared in Inception (b/34784899).
    527 TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
    528   Array2D<float> input_array(14, 14, 1.0f);
    529   const auto input = CreateConstantFromArray(input_array, &builder_);
    531   int win_len = 3;
    532   int stride = 1;
    533   Padding padding = Padding::kSame;
    534   ReduceWindowAdd(input, {win_len, win_len}, {stride, stride}, padding);
    536   auto res = ReferenceUtil::ReduceWindow2DAdd(
    537       input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding);
    539   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
    540                            {}, DefaultErrorSpec());
    541 }
    543 TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
    544   Array2D<float> input_array(6, 4, 1.0f);
    545   XlaOp input = Broadcast(
    546       CreateConstantFromLiteral(LiteralUtil::One(F32), &builder_), {6, 4});
    548   Padding padding = Padding::kSame;
    549   ReduceWindowAdd(input, {4, 2}, {3, 3}, padding);
    551   auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3},
    552                                               padding);
    554   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
    555                            {}, DefaultErrorSpec());
    556 }
    558 INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest,
    559                         ::testing::ValuesIn(use_bfloat16_params));
    561 enum Reducer { kAdd, kMax };
    563 struct R4ReduceWindowTestData {
    564   int64 base_bounds[4];
    565   int64 window_bounds[4];
    566   int64 strides[4];
    567   int64 pad_low[4];
    568   int64 pad_high[4];
    569   int64 layout[4];
    571   Reducer reducer;
    572 };
    574 string R4ReduceWindowTestDataToString(
    575     const ::testing::TestParamInfo<
    576         ::testing::tuple<R4ReduceWindowTestData, bool>>& data) {
    577   const auto& param = ::testing::get<0>(data.param);
    578   string str = absl::StrCat(
    579       "base_bounds_", absl::StrJoin(param.base_bounds, "x"),        //
    580       "__window_bounds_", absl::StrJoin(param.window_bounds, "x"),  //
    581       "__strides_", absl::StrJoin(param.strides, "x"),              //
    582       "__pad_low_", absl::StrJoin(param.pad_low, "x"),              //
    583       "__pad_high_", absl::StrJoin(param.pad_high, "x"),            //
    584       "__layout_", absl::StrJoin(param.layout, "_"),                //
    585       (param.reducer == kAdd) ? "_add" : "_max");
    586   CHECK(param.reducer == kAdd || param.reducer == kMax);
    588   // Test names are not allowed to contain the '-' character.
    589   std::replace(str.begin(), str.end(), '-', 'n');
    590   if (::testing::get<1>(data.param)) {
    591     absl::StrAppend(&str, "_bfloat16");
    592   }
    593   return str;
    594 }
    596 class R4ReduceWindowTest : public ReduceWindowTestBase,
    597                            public ::testing::WithParamInterface<
    598                                ::testing::tuple<R4ReduceWindowTestData, bool>> {
    599  protected:
    600   R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
    602   void DoIt() {
    603     XlaBuilder b(TestName());
    604     const auto& param = ::testing::get<0>(GetParam());
    606     const float kInitValue = 0.0f;
    608     Array4D<float> input(param.base_bounds[0], param.base_bounds[1],
    609                          param.base_bounds[2], param.base_bounds[3]);
    610     // Choose a prime iota length so that each window sees a unique set of
    611     // values. (Technically, the requirement is that the iota length is
    612     // relatively prime to all of the dimensions involved in the reduce-window.)
    613     input.FillRepeatedIota(0, 137);
    614     // Floating point sum reduction requires higher localized precision. We need
    615     // the following normalization in order to enable testing of kAdd on large
    616     // windows.
    617     input.Each([&](absl::Span<const int64> /*indices*/, float* value) {
    618       *value = *value / 10000000000.f;
    619     });
    620     Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
    621         input, LayoutUtil::MakeLayout(param.layout));
    622     XlaOp parameter;
    623     auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
    624                                                        &b, &parameter);
    626     std::vector<std::pair<int64, int64>> padding(4);
    627     for (int i = 0; i < 4; ++i) {
    628       padding[i] = {param.pad_low[i], param.pad_high[i]};
    629     }
    631     auto init_value =
    632         CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
    633     CHECK(param.reducer == kAdd || param.reducer == kMax);
    634     auto reducer = param.reducer;
    635     auto computation = reducer == kAdd
    636                            ? CreateScalarAddComputation(FloatType(), &b)
    637                            : CreateScalarMaxComputation(FloatType(), &b);
    638     ReduceWindowWithGeneralPadding(
    639         /*operand=*/parameter,
    640         /*init_value=*/init_value,
    641         /*computation=*/computation,
    642         /*window_dimensions=*/param.window_bounds,
    643         /*window_strides=*/param.strides,
    644         /*base_dilations=*/{},
    645         /*window_dilations=*/{},
    646         /*padding=*/padding);
    648     CHECK(reducer == kAdd || reducer == kMax);
    649     auto reduce_func = reducer == kAdd
    650                            ? +[](float a, float b) { return a + b; }
    651                            : +[](float a, float b) { return std::max(a, b); };
    652     std::unique_ptr<Array4D<float>> expected =
    653         ReferenceUtil::ReduceWindow4DGeneric(
    654             /*operand=*/input,
    655             /*init=*/kInitValue,
    656             /*reduce_func=*/reduce_func,
    657             /*window=*/param.window_bounds,
    658             /*stride=*/param.strides,
    659             /*padding=*/padding);
    660     Literal expected_literal = LiteralUtil::CreateFromArray(*expected);
    661     const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
    662         input_literal.shape().element_type(),
    663         AsInt64Slice(expected_literal.shape().dimensions()), param.layout);
    664     ComputeAndCompareLiteral(&b, expected_literal, {input_arg.get()},
    665                              DefaultErrorSpec(), &expected_shape_with_layout);
    666   }
    667 };
    669 TEST_P(R4ReduceWindowTest, DoIt) { DoIt(); }
    671 // base_bounds, window_bounds, strides, pad_low, pad_high
    672 const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
    673     // Minimal edge case.
    674     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 1, 1},
    675                            /*window_bounds=*/{1, 1, 1, 1},
    676                            /*strides=*/{1, 1, 1, 1},
    677                            /*pad_low=*/{0, 0, 0, 0},
    678                            /*pad_high=*/{0, 0, 0, 0},
    679                            /*layout=*/{3, 2, 1, 0},
    680                            /*reducer=*/kAdd},
    682     // Arbitrary padding (not kSame or kValid).
    683     R4ReduceWindowTestData{/*base_bounds=*/{9, 12, 4, 89},
    684                            /*window_bounds=*/{3, 3, 1, 1},
    685                            /*strides=*/{2, 2, 1, 1},
    686                            /*pad_low=*/{4, 4, 0, 0},
    687                            /*pad_high=*/{4, 4, 0, 0},
    688                            /*layout=*/{3, 2, 1, 0},
    689                            /*reducer=*/kAdd},
    691     // Zero base bound edge case.
    692     R4ReduceWindowTestData{/*base_bounds=*/{1, 0, 1, 1},
    693                            /*window_bounds=*/{1, 1, 1, 1},
    694                            /*strides=*/{1, 1, 1, 1},
    695                            /*pad_low=*/{0, 0, 0, 0},
    696                            /*pad_high=*/{0, 0, 0, 0},
    697                            /*layout=*/{3, 2, 1, 0},
    698                            /*reducer=*/kAdd},
    700     // With max instead of add.
    701     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
    702                            /*window_bounds=*/{2, 3, 1, 1},
    703                            /*strides=*/{1, 1, 1, 1},
    704                            /*pad_low=*/{0, 0, 0, 0},
    705                            /*pad_high=*/{0, 0, 0, 0},
    706                            /*layout=*/{3, 2, 1, 0},
    707                            /*reducer=*/kMax},
    709     // With stride.
    710     R4ReduceWindowTestData{/*base_bounds=*/{4, 10, 17, 140},
    711                            /*window_bounds=*/{3, 2, 1, 1},
    712                            /*strides=*/{2, 4, 1, 1},
    713                            /*pad_low=*/{0, 0, 0, 0},
    714                            /*pad_high=*/{0, 0, 0, 0},
    715                            /*layout=*/{3, 2, 1, 0},
    716                            /*reducer=*/kAdd},
    718     // With low padding.
    719     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
    720                            /*window_bounds=*/{3, 2, 1, 1},
    721                            /*strides=*/{2, 2, 1, 1},
    722                            /*pad_low=*/{3, 2, 0, 0},
    723                            /*pad_high=*/{0, 0, 0, 0},
    724                            /*layout=*/{3, 2, 1, 0},
    725                            /*reducer=*/kAdd},
    727     // With high padding.
    728     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
    729                            /*window_bounds=*/{3, 2, 1, 1},
    730                            /*strides=*/{2, 2, 1, 1},
    731                            /*pad_low=*/{0, 0, 0, 0},
    732                            /*pad_high=*/{2, 3, 0, 0},
    733                            /*layout=*/{3, 2, 1, 0},
    734                            /*reducer=*/kAdd},
    736     // Window touches both sides of the padding simultaneously.
    737     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140},
    738                            /*window_bounds=*/{3, 3, 1, 1},
    739                            /*strides=*/{1, 1, 1, 1},
    740                            /*pad_low=*/{1, 1, 0, 0},
    741                            /*pad_high=*/{1, 1, 0, 0},
    742                            /*layout=*/{3, 2, 1, 0},
    743                            /*reducer=*/kAdd},
    745     // Window is entirely in the padding for some positions.
    746     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140},
    747                            /*window_bounds=*/{3, 3, 1, 1},
    748                            /*strides=*/{1, 1, 1, 1},
    749                            /*pad_low=*/{4, 4, 0, 0},
    750                            /*pad_high=*/{4, 4, 0, 0},
    751                            /*layout=*/{3, 2, 1, 0},
    752                            /*reducer=*/kAdd},
    754     // Zero base bound with padding edge case.
    755     R4ReduceWindowTestData{/*base_bounds=*/{2, 0, 3, 4},
    756                            /*window_bounds=*/{1, 1, 1, 1},
    757                            /*strides=*/{1, 1, 1, 1},
    758                            /*pad_low=*/{0, 1, 0, 0},
    759                            /*pad_high=*/{0, 0, 0, 0},
    760                            /*layout=*/{3, 2, 1, 0},
    761                            /*reducer=*/kAdd},
    763     // With stride, low padding and high padding.
    764     R4ReduceWindowTestData{/*base_bounds=*/{4, 3, 17, 140},
    765                            /*window_bounds=*/{3, 4, 1, 1},
    766                            /*strides=*/{3, 1, 1, 1},
    767                            /*pad_low=*/{10, 1, 0, 0},
    768                            /*pad_high=*/{2, 3, 0, 0},
    769                            /*layout=*/{3, 2, 1, 0},
    770                            /*reducer=*/kAdd},
    772     // With minor dimension == 129.
    773     R4ReduceWindowTestData{/*base_bounds=*/{3, 2, 7, 129},
    774                            /*window_bounds=*/{1, 1, 1, 1},
    775                            /*strides=*/{1, 1, 1, 1},
    776                            /*pad_low=*/{0, 0, 0, 0},
    777                            /*pad_high=*/{0, 0, 0, 0},
    778                            /*layout=*/{3, 2, 1, 0},
    779                            /*reducer=*/kAdd},
    781     // With minor dims reduction and non-overlapped stride.
    782     R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16},
    783                            /*window_bounds=*/{1, 1, 2, 2},
    784                            /*strides=*/{1, 1, 2, 2},
    785                            /*pad_low=*/{0, 0, 0, 0},
    786                            /*pad_high=*/{0, 0, 0, 0},
    787                            /*layout=*/{3, 2, 1, 0},
    788                            /*reducer=*/kAdd},
    790     // With minor dims reduction and overlapped stride.
    791     R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16},
    792                            /*window_bounds=*/{1, 1, 4, 4},
    793                            /*strides=*/{1, 1, 2, 2},
    794                            /*pad_low=*/{0, 0, 0, 0},
    795                            /*pad_high=*/{1, 0, 0, 0},
    796                            /*layout=*/{3, 2, 1, 0},
    797                            /*reducer=*/kAdd},
    799     R4ReduceWindowTestData{/*base_bounds=*/{8, 100, 100, 3},
    800                            /*window_bounds=*/{1, 64, 64, 1},
    801                            /*strides=*/{1, 64, 64, 1},
    802                            /*pad_low=*/{0, 0, 0, 0},
    803                            /*pad_high=*/{0, 0, 0, 0},
    804                            /*layout=*/{3, 0, 2, 1},
    805                            /*reducer=*/kAdd},
    807     R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 8, 64},
    808                            /*window_bounds=*/{112, 112, 1, 8},
    809                            /*strides=*/{112, 112, 1, 8},
    810                            /*pad_low=*/{0, 0, 0, 0},
    811                            /*pad_high=*/{0, 0, 0, 0},
    812                            /*layout=*/{3, 2, 1, 0},
    813                            /*reducer=*/kMax},
    815     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
    816                            /*window_bounds=*/{2, 3, 4, 5},
    817                            /*strides=*/{1, 1, 1, 1},
    818                            /*pad_low=*/{0, 0, 0, 0},
    819                            /*pad_high=*/{0, 0, 0, 0},
    820                            /*layout=*/{3, 2, 1, 0},
    821                            /*reducer=*/kAdd},
    823     // With 0321 layout.
    824     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
    825                            /*window_bounds=*/{2, 3, 4, 5},
    826                            /*strides=*/{1, 2, 3, 4},
    827                            /*pad_low=*/{0, 0, 0, 0},
    828                            /*pad_high=*/{0, 0, 0, 0},
    829                            /*layout=*/{0, 3, 2, 1},
    830                            /*reducer=*/kAdd},
    832     // With 0123 layout.
    833     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 13, 17},
    834                            /*window_bounds=*/{2, 3, 7, 9},
    835                            /*strides=*/{1, 2, 5, 8},
    836                            /*pad_low=*/{0, 0, 0, 0},
    837                            /*pad_high=*/{0, 0, 0, 0},
    838                            /*layout=*/{0, 1, 2, 3},
    839                            /*reducer=*/kAdd},
    840 };
    843     R4ReduceWindowTestInstantiation, R4ReduceWindowTest,
    844     ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues),
    845                        ::testing::ValuesIn(use_bfloat16_params)),
    846     R4ReduceWindowTestDataToString);
    848 class R4ReduceWindowLargeTest : public R4ReduceWindowTest {};
    850 XLA_TEST_P(R4ReduceWindowLargeTest, DISABLED_ON_INTERPRETER(DoIt)) { DoIt(); }
    852 // Test cases that are large/slow/failed.
    853 const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = {
    854     R4ReduceWindowTestData{/*base_bounds=*/{28, 28, 256, 128},
    855                            /*window_bounds=*/{3, 3, 1, 5},
    856                            /*strides=*/{1, 1, 1, 5},
    857                            /*pad_low=*/{1, 1, 0, 0},
    858                            /*pad_high=*/{1, 1, 0, 0},
    859                            /*layout=*/{3, 2, 1, 0},
    860                            /*reducer=*/kMax},
    862     R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 64, 128},
    863                            /*window_bounds=*/{3, 3, 1, 1},
    864                            /*strides=*/{2, 2, 1, 1},
    865                            /*pad_low=*/{0, 0, 0, 0},
    866                            /*pad_high=*/{1, 1, 0, 0},
    867                            /*layout=*/{3, 2, 1, 0},
    868                            /*reducer=*/kAdd},
    870     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 32768 - 3, 2},
    871                            /*window_bounds=*/{1, 1, 4, 1},
    872                            /*strides=*/{1, 1, 4, 1},
    873                            /*pad_low=*/{0, 0, 1, 0},
    874                            /*pad_high=*/{0, 0, 2, 0},
    875                            /*layout=*/{3, 2, 1, 0},
    876                            /*reducer=*/kMax},
    878     // Patterns generated by cumsum/cumprod.
    879     R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
    880                            /*window_bounds=*/{1021, 1, 1, 1},
    881                            /*strides=*/{1, 1, 1, 1},
    882                            /*pad_low=*/{1020, 0, 0, 0},
    883                            /*pad_high=*/{0, 0, 0, 0},
    884                            /*layout=*/{3, 2, 1, 0},
    885                            /*reducer=*/kAdd},
    887     R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
    888                            /*window_bounds=*/{1, 1, 1021, 1},
    889                            /*strides=*/{1, 1, 1, 1},
    890                            /*pad_low=*/{0, 0, 1020, 0},
    891                            /*pad_high=*/{0, 0, 0, 0},
    892                            /*layout=*/{3, 2, 1, 0},
    893                            /*reducer=*/kAdd},
    895     R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021},
    896                            /*window_bounds=*/{1, 1, 1, 1021},
    897                            /*strides=*/{1, 1, 1, 1},
    898                            /*pad_low=*/{0, 0, 0, 1020},
    899                            /*pad_high=*/{0, 0, 0, 0},
    900                            /*layout=*/{3, 2, 1, 0},
    901                            /*reducer=*/kAdd},
    903     R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
    904                            /*window_bounds=*/{1021, 1, 1, 1},
    905                            /*strides=*/{1, 1, 1, 1},
    906                            /*pad_low=*/{1021, 0, 0, 0},
    907                            /*pad_high=*/{0, 0, 0, 0},
    908                            /*layout=*/{3, 2, 1, 0},
    909                            /*reducer=*/kAdd},
    911     R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 1021, 16},
    912                            /*window_bounds=*/{1, 1, 1021, 1},
    913                            /*strides=*/{1, 1, 1, 1},
    914                            /*pad_low=*/{0, 0, 1021, 0},
    915                            /*pad_high=*/{0, 0, 0, 0},
    916                            /*layout=*/{3, 2, 1, 0},
    917                            /*reducer=*/kAdd},
    919     R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021},
    920                            /*window_bounds=*/{1, 1, 1, 1021},
    921                            /*strides=*/{1, 1, 1, 1},
    922                            /*pad_low=*/{0, 0, 0, 1021},
    923                            /*pad_high=*/{0, 0, 0, 0},
    924                            /*layout=*/{3, 2, 1, 0},
    925                            /*reducer=*/kAdd},
    926 };
    929     R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest,
    930     ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues),
    931                        ::testing::ValuesIn(use_bfloat16_params)),
    932     R4ReduceWindowTestDataToString);
    934 struct R3ReduceWindowTestData {
    935   int64 base_bounds[3];
    936   int64 window_bounds[3];
    937   int64 strides[3];
    938   int64 layout[3];
    939   Padding padding;
    940   Reducer reducer;
    941 } kR3TestCases[] = {
    942     {/*base_bounds=*/{2, 1, 2}, /*window_bounds=*/{1, 1, 2},
    943      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
    944      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    945     {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2},
    946      /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
    947      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
    948     {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2},
    949      /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
    950      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    951     {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
    952      /*strides=*/{1, 2, 2}, /*layout=*/{2, 1, 0},
    953      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    954     {/*base_bounds=*/{10, 21, 129}, /*window_bounds=*/{2, 9, 1},
    955      /*strides=*/{5, 2, 1}, /*layout=*/{2, 1, 0},
    956      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
    957     {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
    958      /*strides=*/{1, 2, 2}, /*layout=*/{0, 1, 2},
    959      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    960     {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
    961      /*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2},
    962      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    963     {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251},
    964      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
    965      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
    966     {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3},
    967      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
    968      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    969     {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64},
    970      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
    971      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
    972     {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257},
    973      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
    974      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
    975     {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
    976      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
    977      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    978     {/*base_bounds=*/{9999, 1, 1}, /*window_bounds=*/{9999, 1, 1},
    979      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
    980      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    981     {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
    982      /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
    983      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
    984 };
    986 string R3ReduceWindowTestDataToString(
    987     const ::testing::TestParamInfo<
    988         ::testing::tuple<R3ReduceWindowTestData, bool>>& data) {
    989   const auto& param = ::testing::get<0>(data.param);
    990   string str = absl::StrCat(
    991       "base_bounds_", absl::StrJoin(param.base_bounds, "x"), "__window_bounds_",
    992       absl::StrJoin(param.window_bounds, "x"), "__strides_",
    993       absl::StrJoin(param.strides, "x"), "__padding_",
    994       param.padding == Padding::kSame ? "same" : "valid", "__layout_",
    995       param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_",
    996       param.reducer == kAdd ? "add" : "max");
    997   if (::testing::get<1>(data.param)) {
    998     absl::StrAppend(&str, "_bfloat16");
    999   }
   1000   return str;
   1001 }
   1003 class R3ReduceWindowTest : public ReduceWindowTestBase,
   1004                            public ::testing::WithParamInterface<
   1005                                ::testing::tuple<R3ReduceWindowTestData, bool>> {
   1006  protected:
   1007   R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
   1008 };
   1010 TEST_P(R3ReduceWindowTest, DoIt) {
   1011   XlaBuilder b(TestName());
   1012   const auto& param = ::testing::get<0>(GetParam());
   1014   const float kInitValue = 0.0f;
   1015   Array3D<float> input(param.base_bounds[0], param.base_bounds[1],
   1016                        param.base_bounds[2]);
   1017   // Choose a prime iota length so that each window sees a unique set of values.
   1018   // (Technically, the requirement is that the iota length is relatively prime
   1019   // to all of the dimensions involved in the reduce-window.)
   1020   input.FillRepeatedIota(0, 137);
   1021   Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout(
   1022       input, LayoutUtil::MakeLayout(param.layout));
   1023   auto reducer = param.reducer;
   1024   if (use_bfloat16()) {
   1025     input_literal = LiteralUtil::ConvertF32ToBF16(input_literal);
   1027     // To avoid numerical issues, force the reducer to be kMax for bf16
   1028     // inputs.
   1029     reducer = kMax;
   1030   }
   1032   XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input");
   1033   auto init_value =
   1034       CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
   1036   auto computation = reducer == kAdd
   1037                          ? CreateScalarAddComputation(FloatType(), &b)
   1038                          : CreateScalarMaxComputation(FloatType(), &b);
   1040   ReduceWindow(/*operand=*/parameter,
   1041                /*init_value=*/init_value,
   1042                /*computation=*/computation,
   1043                /*window_dimensions=*/param.window_bounds,
   1044                /*window_strides=*/param.strides, /*padding=*/param.padding);
   1046   ComputeAndCompare(&b, {std::move(input_literal)}, DefaultErrorSpec());
   1047 }
   1050     R3ReduceWindowTestInstantiation, R3ReduceWindowTest,
   1051     ::testing::Combine(::testing::ValuesIn(kR3TestCases),
   1052                        ::testing::ValuesIn(use_bfloat16_params)),
   1053     R3ReduceWindowTestDataToString);
   1055 struct R2ReduceWindowTestData {
   1056   int64 base_bounds[2];
   1057   int64 window_bounds[2];
   1058   int64 strides[2];
   1059   int64 pad_low[2];
   1060   int64 pad_high[2];
   1061   int64 layout[2];
   1062   Reducer reducer;
   1063 } kR2TestCases[] = {
   1064     {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4},
   1065      /*strides=*/{1, 2}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1},
   1066      /*layout=*/{0, 1},
   1067      /*reducer=*/Reducer::kAdd},
   1068     {/*base_bounds=*/{2, 5}, /*window_bounds=*/{2, 4},
   1069      /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 2},
   1070      /*layout=*/{0, 1},
   1071      /*reducer=*/Reducer::kAdd},
   1072     {/*base_bounds=*/{1, 3}, /*window_bounds=*/{2, 3},
   1073      /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1},
   1074      /*layout=*/{0, 1},
   1075      /*reducer=*/Reducer::kAdd},
   1076     {/*base_bounds=*/{3, 129}, /*window_bounds=*/{1, 100},
   1077      /*strides=*/{2, 99}, /*pad_low=*/{0, 0}, /*pad_high=*/{35, 35},
   1078      /*layout=*/{0, 1},
   1079      /*reducer=*/Reducer::kAdd},
   1080 // TODO(b/74260408): This test last failed on GPU on 2018-03-08, likely due to a
   1081 // ptxas bug.
   1082 #ifndef XLA_TEST_BACKEND_GPU
   1083     {/*base_bounds=*/{6, 152}, /*window_bounds=*/{2, 25},
   1084      /*strides=*/{5, 4}, /*pad_low=*/{0, 1}, /*pad_high=*/{10, 11},
   1085      /*layout=*/{0, 1},
   1086      /*reducer=*/Reducer::kAdd},
   1087 #endif
   1088     {/*base_bounds=*/{6, 4}, /*window_bounds=*/{4, 2},
   1089      /*strides=*/{3, 3}, /*pad_low=*/{0, 1}, /*pad_high=*/{0, 1},
   1090      /*layout=*/{0, 1},
   1091      /*reducer=*/Reducer::kAdd},
   1092     {/*base_bounds=*/{5, 147}, /*window_bounds=*/{1, 36},
   1093      /*strides=*/{4, 5}, /*pad_low=*/{0, 0}, /*pad_high=*/{17, 17},
   1094      /*layout=*/{1, 0},
   1095      /*reducer=*/Reducer::kAdd},
   1096     {/*base_bounds=*/{4, 153}, /*window_bounds=*/{2, 93},
   1097      /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{46, 46},
   1098      /*layout=*/{1, 0},
   1099      /*reducer=*/Reducer::kAdd},
   1100     // Regression test for a bug that appeared in Inception (b/34784899).
   1101     {/*base_bounds=*/{28, 28}, /*window_bounds=*/{3, 3},
   1102      /*strides=*/{1, 1}, /*pad_low=*/{1, 1}, /*pad_high=*/{1, 1},
   1103      /*layout=*/{1, 0},
   1104      /*reducer=*/Reducer::kAdd},
   1105     {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
   1106      /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
   1107      /*layout=*/{1, 0},
   1108      /*reducer=*/Reducer::kAdd},
   1109     // Regression test for a bug that appeared in Inception (b/34784899).
   1110     {/*base_bounds=*/{4, 32}, /*window_bounds=*/{2, 2},
   1111      /*strides=*/{2, 2}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
   1112      /*layout=*/{1, 0},
   1113      /*reducer=*/Reducer::kAdd},
   1114     // Regression test for b/73903312: bf16 lacks precision to store result of
   1115     // very large windows. Testing with a reasonable window larger than 128.
   1116     {/*base_bounds=*/{8, 130}, /*window_bounds=*/{1, 130},
   1117      /*strides=*/{1, 1}, /*pad_low=*/{0, 130}, /*pad_high=*/{0, 0},
   1118      /*layout=*/{1, 0},
   1119      /*reducer=*/Reducer::kAdd},
   1120     {/*base_bounds=*/{8, 256}, /*window_bounds=*/{1, 4},
   1121      /*strides=*/{1, 64}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
   1122      /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
   1123     {/*base_bounds=*/{4096, 4096}, /*window_bounds=*/{1, 4},
   1124      /*strides=*/{1, 1024}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0},
   1125      /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
   1126     // Regression test for b/72234705: bf16 lacks precision to store incremental
   1127     // results on very large windows. Using smaller window with minor dim 128.
   1128     {/*base_bounds=*/{8, 128}, /*window_bounds=*/{2, 128},
   1129      /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0},
   1130      /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
   1131 };
   1133 string R2ReduceWindowTestDataToString(
   1134     const ::testing::TestParamInfo<
   1135         ::testing::tuple<R2ReduceWindowTestData, bool>>& data) {
   1136   const auto& param = ::testing::get<0>(data.param);
   1137   string str = absl::StrCat(
   1138       "base_bounds_", absl::StrJoin(param.base_bounds, "x"),        //
   1139       "__window_bounds_", absl::StrJoin(param.window_bounds, "x"),  //
   1140       "__strides_", absl::StrJoin(param.strides, "x"),              //
   1141       "__pad_low_", absl::StrJoin(param.pad_low, "x"), "__pad_high_",
   1142       absl::StrJoin(param.pad_high, "x"), "__layout_", param.layout[0], "_",
   1143       param.layout[1],  //
   1144       "__reducer_", param.reducer == kAdd ? "add" : "max");
   1145   if (::testing::get<1>(data.param)) {
   1146     absl::StrAppend(&str, "_bfloat16");
   1147   }
   1148   return str;
   1149 }
   1151 class R2ReduceWindowTest : public ReduceWindowTestBase,
   1152                            public ::testing::WithParamInterface<
   1153                                ::testing::tuple<R2ReduceWindowTestData, bool>> {
   1154  protected:
   1155   R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
   1157   void DoIt() {
   1158     XlaBuilder b(TestName());
   1159     const auto& param = ::testing::get<0>(GetParam());
   1161     const float kInitValue = 0.0f;
   1162     Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
   1163     Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout(
   1164         input, LayoutUtil::MakeLayout(param.layout));
   1166     XlaOp parameter;
   1167     auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
   1168                                                        &b, &parameter);
   1169     std::vector<std::pair<int64, int64>> padding(2);
   1170     for (int i = 0; i < 2; ++i) {
   1171       padding[i] = {param.pad_low[i], param.pad_high[i]};
   1172     }
   1173     auto computation = param.reducer == kAdd
   1174                            ? CreateScalarAddComputation(FloatType(), &b)
   1175                            : CreateScalarMaxComputation(FloatType(), &b);
   1176     auto init_value =
   1177         CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
   1178     ReduceWindowWithGeneralPadding(
   1179         /*operand=*/parameter,
   1180         /*init_value=*/init_value,
   1181         /*computation=*/computation,
   1182         /*window_dimensions=*/param.window_bounds,
   1183         /*window_strides=*/param.strides,
   1184         /*base_dilations=*/{},
   1185         /*window_dilations=*/{},
   1186         /*padding=*/padding);
   1188     auto reduce_func = param.reducer == kAdd
   1189                            ? +[](float a, float b) { return a + b; }
   1190                            : +[](float a, float b) { return std::max(a, b); };
   1191     auto expected = ReferenceUtil::ReduceWindow2DGeneric(
   1192         /*operand=*/input, /*init=*/kInitValue, /*reduce_func=*/reduce_func,
   1193         /*window=*/param.window_bounds,
   1194         /*stride=*/param.strides, /*padding=*/padding);
   1196     ComputeAndCompareLiteral(&b, LiteralUtil::CreateFromArray(*expected),
   1197                              {input_arg.get()}, DefaultErrorSpec());
   1198   }
   1199 };
   1201 TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); }
   1204     R2ReduceWindowTestInstantiation, R2ReduceWindowTest,
   1205     ::testing::Combine(::testing::ValuesIn(kR2TestCases),
   1206                        ::testing::ValuesIn(use_bfloat16_params)),
   1207     R2ReduceWindowTestDataToString);
   1209 struct R1ReduceWindowTestData {
   1210   int64 base_bounds[1];
   1211   int64 window_bounds[1];
   1212   int64 strides[1];
   1213   int64 pad_low[1];
   1214   int64 pad_high[1];
   1215   Reducer reducer;
   1216 } kR1TestCases[] = {
   1217     {/*base_bounds=*/{1}, /*window_bounds=*/{1},
   1218      /*strides=*/{1},
   1219      /*pad_low=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].first},
   1220      /*pad_high=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].second},
   1221      /*reducer=*/Reducer::kAdd},
   1223     {/*base_bounds=*/{3}, /*window_bounds=*/{3},
   1224      /*strides=*/{1},
   1225      /*pad_low=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].first},
   1226      /*pad_high=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].second},
   1227      /*reducer=*/Reducer::kAdd},
   1229     {/*base_bounds=*/{3}, /*window_bounds=*/{2},
   1230      /*strides=*/{1},
   1231      /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].first},
   1232      /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].second},
   1233      /*reducer=*/Reducer::kAdd},
   1235     {/*base_bounds=*/{5}, /*window_bounds=*/{1},
   1236      /*strides=*/{1},
   1237      /*pad_low=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].first},
   1238      /*pad_high=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].second},
   1239      /*reducer=*/Reducer::kMax},
   1241     {/*base_bounds=*/{16}, /*window_bounds=*/{4},
   1242      /*strides=*/{4},
   1243      /*pad_low=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].first},
   1244      /*pad_high=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].second},
   1245      /*reducer=*/Reducer::kMax},
   1247     {/*base_bounds=*/{16}, /*window_bounds=*/{4},
   1248      /*strides=*/{3},
   1249      /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].first},
   1250      /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].second},
   1251      /*reducer=*/Reducer::kAdd},
   1253     {/*base_bounds=*/{128 * 2},
   1254      /*window_bounds=*/{30},
   1255      /*strides=*/{27},
   1256      /*pad_low=*/
   1257      {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].first},
   1258      /*pad_high=*/
   1259      {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].second},
   1260      /*reducer=*/Reducer::kAdd},
   1262     {/*base_bounds=*/{128 * 17},
   1263      /*window_bounds=*/{7},
   1264      /*strides=*/{64},
   1265      /*pad_low=*/
   1266      {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].first},
   1267      /*pad_high=*/
   1268      {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].second},
   1269      /*reducer=*/Reducer::kAdd},
   1271     {/*base_bounds=*/{128 * 2},
   1272      /*window_bounds=*/{32},
   1273      /*strides=*/{56},
   1274      /*pad_low=*/
   1275      {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].first},
   1276      /*pad_high=*/
   1277      {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].second},
   1278      /*reducer=*/Reducer::kAdd},
   1280     {/*base_bounds=*/{3}, /*window_bounds=*/{2},
   1281      /*strides=*/{1},
   1282      /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].first},
   1283      /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].second},
   1284      /*reducer=*/Reducer::kAdd},
   1286     {/*base_bounds=*/{5}, /*window_bounds=*/{3},
   1287      /*strides=*/{2},
   1288      /*pad_low=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].first},
   1289      /*pad_high=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].second},
   1290      /*reducer=*/Reducer::kAdd},
   1292     {/*base_bounds=*/{16}, /*window_bounds=*/{4},
   1293      /*strides=*/{3},
   1294      /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].first},
   1295      /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].second},
   1296      /*reducer=*/Reducer::kAdd},
   1298     {/*base_bounds=*/{5}, /*window_bounds=*/{5},
   1299      /*strides=*/{1},
   1300      /*pad_low=*/{0},
   1301      /*pad_high=*/{5},
   1302      /*reducer=*/Reducer::kAdd},
   1304     {/*base_bounds=*/{5}, /*window_bounds=*/{5},
   1305      /*strides=*/{1},
   1306      /*pad_low=*/{5},
   1307      /*pad_high=*/{0},
   1308      /*reducer=*/Reducer::kAdd},
   1310     // The pattern generated by inclusive scan (cumsum/cumprod).
   1311     {/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
   1312      /*strides=*/{1},
   1313      /*pad_low=*/{4095},
   1314      /*pad_high=*/{0},
   1315      /*reducer=*/Reducer::kMax},
   1317     // The pattern generated by exclusive scan (cumsum/cumprod).
   1318     {/*base_bounds=*/{4095}, /*window_bounds=*/{4095},
   1319      /*strides=*/{1},
   1320      /*pad_low=*/{4095},
   1321      /*pad_high=*/{0},
   1322      /*reducer=*/Reducer::kMax},
   1323 };
   1325 string R1ReduceWindowTestDataToString(
   1326     const ::testing::TestParamInfo<
   1327         ::testing::tuple<R1ReduceWindowTestData, bool>>& data) {
   1328   const auto& param = ::testing::get<0>(data.param);
   1329   string str =
   1330       absl::StrCat("base_bounds_", absl::StrJoin(param.base_bounds, "x"),
   1331                    "__window_bounds_", absl::StrJoin(param.window_bounds, "x"),
   1332                    "__strides_", absl::StrJoin(param.strides, "x"),
   1333                    "__pad_low_", absl::StrJoin(param.pad_low, "x"),
   1334                    "__pad_high_", absl::StrJoin(param.pad_high, "x"),
   1335                    "__reducer_", param.reducer == kAdd ? "add" : "max");
   1336   if (::testing::get<1>(data.param)) {
   1337     absl::StrAppend(&str, "_bfloat16");
   1338   }
   1339   return str;
   1340 }
   1342 class R1ReduceWindowTest : public ReduceWindowTestBase,
   1343                            public ::testing::WithParamInterface<
   1344                                ::testing::tuple<R1ReduceWindowTestData, bool>> {
   1345  protected:
   1346   R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
   1347 };
   1349 TEST_P(R1ReduceWindowTest, DoIt) {
   1350   XlaBuilder b(TestName());
   1351   const auto& param = ::testing::get<0>(GetParam());
   1352   CHECK(param.reducer == kAdd || param.reducer == kMax);
   1354   const float kInitValue = 0.0f;
   1355   std::vector<float> input_vector(param.base_bounds[0]);
   1356   std::iota(std::begin(input_vector), std::end(input_vector), 0);
   1357   Literal input_literal =
   1358       LiteralUtil::CreateR1(absl::Span<const float>(input_vector));
   1359   XlaOp parameter;
   1360   auto input_arg =
   1361       CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, &parameter);
   1363   std::vector<std::pair<int64, int64>> padding(1);
   1364   padding[0] = {param.pad_low[0], param.pad_high[0]};
   1366   auto computation = param.reducer == kAdd
   1367                          ? CreateScalarAddComputation(FloatType(), &b)
   1368                          : CreateScalarMaxComputation(FloatType(), &b);
   1369   auto init_value =
   1370       CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
   1371   ReduceWindowWithGeneralPadding(
   1372       /*operand=*/parameter,
   1373       /*init_value=*/init_value,
   1374       /*computation=*/computation,
   1375       /*window_dimensions=*/param.window_bounds,
   1376       /*window_strides=*/param.strides,
   1377       /*base_dilations=*/{},
   1378       /*window_dilations=*/{},
   1379       /*padding=*/padding);
   1381   auto reduce_func = param.reducer == kAdd
   1382                          ? +[](float a, float b) { return a + b; }
   1383                          : +[](float a, float b) { return std::max(a, b); };
   1384   auto expected = ReferenceUtil::ReduceWindow1DGeneric(
   1385       /*operand=*/absl::Span<const float>(input_vector),
   1386       /*init=*/kInitValue,
   1387       /*reduce_func=*/reduce_func,
   1388       /*window=*/param.window_bounds,
   1389       /*stride=*/param.strides,
   1390       /*padding=*/padding);
   1392   ComputeAndCompareLiteral(&b, LiteralUtil::CreateR1<float>(*expected),
   1393                            {input_arg.get()}, DefaultErrorSpec());
   1394 }
   1397     R1ReduceWindowTestInstantiation, R1ReduceWindowTest,
   1398     ::testing::Combine(::testing::ValuesIn(kR1TestCases),
   1399                        ::testing::ValuesIn(use_bfloat16_params)),
   1400     R1ReduceWindowTestDataToString);
   1402 // Test class for text-based test cases. Note that this compares with the
   1403 // results on the interpreter backend.
   1404 class ReduceWindowTextTest : public HloTestBase {};
   1406 XLA_TEST_F(ReduceWindowTextTest, R2General256x384) {
   1407   const string hlo_string = R"(
   1408 HloModule R2Window
   1409 mul {
   1410   lhs = f32[] parameter(0)
   1411   rhs = f32[] parameter(1)
   1412   ROOT mul = f32[] multiply(lhs, rhs)
   1413 }
   1414 ENTRY R2Window {
   1415   operand = f32[256,384]{1,0} parameter(0)
   1416   constant = f32[] constant(1)
   1417   ROOT reduce-window = f32[256,384]{1,0} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
   1418 }
   1419 )";
   1420   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
   1421 }
   1423 XLA_TEST_F(ReduceWindowTextTest, R2General256x384Layout01) {
   1424   const string hlo_string = R"(
   1425 HloModule R2Window
   1426 mul {
   1427 lhs = f32[] parameter(0)
   1428 rhs = f32[] parameter(1)
   1429 ROOT mul = f32[] multiply(lhs, rhs)
   1430 }
   1431 ENTRY R2Window {
   1432 operand = f32[256,384]{0,1} parameter(0)
   1433 constant = f32[] constant(1)
   1434 ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
   1435 }
   1436 )";
   1437   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
   1438 }
   1440 XLA_TEST_F(ReduceWindowTextTest, R2General2x5) {
   1441   const string hlo_string = R"(
   1442 HloModule R2Window
   1443 mul {
   1444   lhs = f32[] parameter(0)
   1445   rhs = f32[] parameter(1)
   1446   ROOT mul = f32[] multiply(lhs, rhs)
   1447 }
   1448 ENTRY R2Window {
   1449   operand = f32[2,5]{1,0} parameter(0)
   1450   constant = f32[] constant(1)
   1451   ROOT reduce-window = f32[3,5]{1,0} reduce-window(operand, constant), window={size=2x1 pad=0_2x0_0}, to_apply=mul
   1452 }
   1453 )";
   1454   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
   1455 }
   1457 XLA_TEST_F(ReduceWindowTextTest, R2EffectiveScalar) {
   1458   const string hlo_string = R"(
   1459 HloModule R2Window
   1460 mul {
   1461   lhs = f32[] parameter(0)
   1462   rhs = f32[] parameter(1)
   1463   ROOT mul = f32[] multiply(lhs, rhs)
   1464 }
   1465 ENTRY R2Window {
   1466   operand = f32[1,1]{1,0} parameter(0)
   1467   negate = f32[1,1]{1,0} negate(operand)
   1468   constant = f32[] constant(1)
   1469   ROOT reduce-window = f32[1,1]{1,0} reduce-window(negate, constant), window={size=1x1 pad=0_0x0_0}, to_apply=mul
   1470 }
   1471 )";
   1472   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
   1473 }
   1475 XLA_TEST_F(ReduceWindowTextTest, R3EffectiveScalar) {
   1476   const string hlo_string = R"(
   1477 HloModule R3Window
   1478 mul {
   1479   lhs = f32[] parameter(0)
   1480   rhs = f32[] parameter(1)
   1481   ROOT mul = f32[] multiply(lhs, rhs)
   1482 }
   1483 ENTRY R3Window {
   1484   operand = f32[1,1,1]{2,1,0} parameter(0)
   1485   negate = f32[1,1,1]{2,1,0} negate(operand)
   1486   constant = f32[] constant(1)
   1487   ROOT reduce-window = f32[1,1,1]{2,1,0} reduce-window(negate, constant), window={size=1x1x1 pad=0_0x0_0x0_0}, to_apply=mul
   1488 }
   1489 )";
   1490   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
   1491 }
   1493 XLA_TEST_F(HloTestBase, ReduceWindowIdentity) {
   1494   const string hlo_string = R"(
   1495 HloModule ReduceWindowIdentity
   1496 identity.pad_to_reduce_window {
   1497   param0 = f32[] parameter(0)
   1498   ROOT param1 = f32[] parameter(1)
   1499 }
   1500 ENTRY reduce-window-identity {
   1501   operand = f32[1,32,64]{2,1,0} parameter(0)
   1502   constant.4466 = f32[] constant(0)
   1503   ROOT reduce-window = f32[1,33,64]{2,1,0} reduce-window(operand, constant.4466), window={size=1x1x1 pad=0_0x1_0x0_0}, to_apply=identity.pad_to_reduce_window
   1504 }
   1506 )";
   1507   EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
   1508 }
   1510 XLA_TEST_F(HloTestBase, ReduceWindowS32) {
   1511   const string hlo_string = R"(
   1512 HloModule reduce-window
   1514 %identity.pad_to_reduce_window (param0: s32[], param1: s32[]) -> s32[] {
   1515   %param0 = s32[] parameter(0)
   1516   ROOT %param1 = s32[] parameter(1)
   1517 }
   1519 ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] {
   1520   %parameter.0 = s32[81,8]{1,0} parameter(0)
   1521   %parameter.1 = s32[] parameter(1)
   1522   ROOT %reduce-window = s32[82,8]{1,0} reduce-window(s32[81,8]{1,0} %parameter.0, s32[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
   1523 }
   1525 )";
   1526   EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
   1527 }
   1529 XLA_TEST_F(HloTestBase, ReduceWindowS64) {
   1530   const string hlo_string = R"(
   1531 HloModule reduce-window
   1533 %identity.pad_to_reduce_window (param0: s64[], param1: s64[]) -> s64[] {
   1534   %param0 = s64[] parameter(0)
   1535   ROOT %param1 = s64[] parameter(1)
   1536 }
   1538 ENTRY %reduce-window (parameter.0: s64[81,8], parameter.1: s64[]) -> s64[82,8] {
   1539   %parameter.0 = s64[81,8]{1,0} parameter(0)
   1540   %parameter.1 = s64[] parameter(1)
   1541   ROOT %reduce-window = s64[82,8]{1,0} reduce-window(s64[81,8]{1,0} %parameter.0, s64[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
   1542 }
   1544 )";
   1545   EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
   1546 }
   1548 XLA_TEST_F(HloTestBase, ReduceWindowF16) {
   1549   const string hlo_string = R"(
   1550 HloModule reduce-window
   1552 %identity.pad_to_reduce_window (param0: f16[], param1: f16[]) -> f16[] {
   1553   %param0 = f16[] parameter(0)
   1554   ROOT %param1 = f16[] parameter(1)
   1555 }
   1557 ENTRY %reduce-window (parameter.0: f16[81,8], parameter.1: f16[]) -> f16[82,8] {
   1558   %parameter.0 = f16[81,8]{1,0} parameter(0)
   1559   %parameter.1 = f16[] parameter(1)
   1560   ROOT %reduce-window = f16[82,8]{1,0} reduce-window(f16[81,8]{1,0} %parameter.0, f16[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
   1561 }
   1563 )";
   1564   EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
   1565 }
   1567 }  // namespace
   1568 }  // namespace xla