1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Tests the reduce-window XLA operation. 17 18 #include <limits> 19 #include <memory> 20 21 #include "absl/memory/memory.h" 22 #include "absl/strings/str_cat.h" 23 #include "absl/strings/str_join.h" 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/array2d.h" 26 #include "tensorflow/compiler/xla/array3d.h" 27 #include "tensorflow/compiler/xla/array4d.h" 28 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 29 #include "tensorflow/compiler/xla/client/local_client.h" 30 #include "tensorflow/compiler/xla/client/padding.h" 31 #include "tensorflow/compiler/xla/client/xla_builder.h" 32 #include "tensorflow/compiler/xla/client/xla_computation.h" 33 #include "tensorflow/compiler/xla/reference_util.h" 34 #include "tensorflow/compiler/xla/shape_util.h" 35 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 36 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 37 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 38 #include "tensorflow/compiler/xla/tests/test_macros.h" 39 #include "tensorflow/compiler/xla/xla_data.pb.h" 40 #include "tensorflow/core/lib/core/status.h" 41 #include "tensorflow/core/lib/core/status_test_util.h" 42 #include "tensorflow/core/platform/test.h" 43 #include "tensorflow/core/platform/types.h" 44 45 namespace xla { 46 namespace { 47 48 #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 49 // Tests both F32 and BF16. 50 static std::array<bool, 2> use_bfloat16_params{false, true}; 51 #else 52 // Only tests F32. 53 static std::array<bool, 1> use_bfloat16_params{false}; 54 #endif 55 56 class ReduceWindowTestBase : public ClientLibraryTestBase { 57 public: 58 ErrorSpec DefaultErrorSpec() const { 59 if (use_bfloat16()) { 60 return ErrorSpec(2e-1, 6e-2); 61 } else { 62 return ErrorSpec(1e-3, 1e-3); 63 } 64 } 65 }; 66 67 class ReduceWindowTest : public ::testing::WithParamInterface<bool>, 68 public ReduceWindowTestBase { 69 public: 70 ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); } 71 72 void ReduceWindowAdd(const XlaOp& input, 73 absl::Span<const int64> window_dimensions, 74 absl::Span<const int64> window_strides, 75 Padding padding) { 76 auto init = CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f), 77 &builder_); 78 ReduceWindow(input, init, 79 CreateScalarAddComputation(FloatType(), &builder_), 80 window_dimensions, window_strides, padding); 81 } 82 83 void ReduceWindowMax(const XlaOp& input, 84 absl::Span<const int64> window_dimensions, 85 absl::Span<const int64> window_strides, 86 Padding padding) { 87 auto init = 88 CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_); 89 ReduceWindow(input, init, 90 CreateScalarMaxComputation(FloatType(), &builder_), 91 window_dimensions, window_strides, padding); 92 } 93 94 void ReduceWindowMin(const XlaOp& input, 95 absl::Span<const int64> window_dimensions, 96 absl::Span<const int64> window_strides, 97 Padding padding) { 98 auto init = 99 CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_); 100 ReduceWindow(input, init, 101 CreateScalarMinComputation(FloatType(), &builder_), 102 window_dimensions, window_strides, padding); 103 } 104 105 XlaBuilder builder_; 106 }; 107 108 TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { 109 const auto input = CreateConstantFromLiteral( 110 LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_); 111 const auto init_value = 112 CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0), &builder_); 113 TF_ASSERT_OK(builder_.first_error()); 114 ReduceWindow(input, init_value, 115 CreateScalarAddComputation(FloatType(), &builder_), 116 /*window_dimensions=*/{1, 2}, 117 /*window_strides=*/{1}, Padding::kValid); 118 ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT) 119 << builder_.first_error(); 120 ASSERT_THAT(builder_.first_error().error_message(), 121 ::testing::HasSubstr("Want input dimensions size")); 122 } 123 124 // Regression test for b/68964348. 125 TEST_P(ReduceWindowTest, R0ReduceWindow) { 126 const auto input = 127 CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(42.0), &builder_); 128 const auto init = 129 CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(1.0), &builder_); 130 ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), 131 /*window_dimensions=*/{}, 132 /*window_strides=*/{}, Padding::kSame); 133 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR0<float>(43.0), {}, 134 ErrorSpec(0.00001)); 135 } 136 137 TEST_P(ReduceWindowTest, Min3In5Stride2) { 138 const auto input = CreateConstantFromLiteral( 139 LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_); 140 ReduceWindowMin(input, {3}, {2}, Padding::kValid); 141 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({100, 1}), 142 {}, ErrorSpec(0.00001)); 143 } 144 145 TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) { 146 const auto input = CreateConstantFromLiteral( 147 LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_); 148 ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1}, 149 Padding::kSame); 150 ComputeAndCompareLiteral(&builder_, 151 LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}), 152 {}, ErrorSpec(0.00001)); 153 } 154 155 XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { 156 Array4D<float> input_array(1, 0, 2, 1); 157 const auto input = CreateConstantFromArray(input_array, &builder_); 158 Padding padding = Padding::kSame; 159 ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); 160 161 auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, 162 {1, 1, 1, 1}, padding); 163 164 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, 165 DefaultErrorSpec()); 166 } 167 168 TEST_P(ReduceWindowTest, NonSquareSmall) { 169 Array4D<float> input_array(1, 2, 2, 1); 170 input_array.FillRandom(2.f, 2.f); 171 const auto input = CreateConstantFromArray(input_array, &builder_); 172 173 Padding padding = Padding::kSame; 174 ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); 175 176 auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, 177 {1, 1, 1, 1}, padding); 178 179 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, 180 DefaultErrorSpec()); 181 } 182 183 TEST_P(ReduceWindowTest, MiddleDimsSmall) { 184 Array4D<float> input_array(1, 3, 3, 1); 185 input_array.FillRandom(2.f, 2.f); 186 const auto input = CreateConstantFromArray(input_array, &builder_); 187 Padding padding = Padding::kSame; 188 ReduceWindowAdd(input, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); 189 190 auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1}, 191 {1, 2, 2, 1}, padding); 192 193 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, 194 DefaultErrorSpec()); 195 } 196 197 TEST_P(ReduceWindowTest, Along2ndMinorDim) { 198 Array4D<float> input_array(3, 6, 7, 32); 199 input_array.FillRandom(2.f, 2.f); 200 const auto input = CreateConstantFromArray(input_array, &builder_); 201 202 // The parameters of this reduction mimic feature norm (e.g. LRN). 203 int lrn_diameter = 7; // diameter = 2*radius + 1 --> must be odd 204 Padding padding = Padding::kSame; 205 ReduceWindowAdd(input, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); 206 207 auto res = ReferenceUtil::ReduceWindow4DAdd( 208 input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); 209 210 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, 211 DefaultErrorSpec()); 212 } 213 214 TEST_P(ReduceWindowTest, AmongMajor2Dims) { 215 Array4D<float> input_array(4, 4, 6, 8); 216 input_array.FillWithMinorDimNum(); 217 const auto input_data_handle = 218 CreateConstantFromArray(input_array, &builder_); 219 220 int win_len = 3; 221 int win_stride = 1; 222 223 Padding padding = Padding::kSame; 224 // Reduce only along the x and y dimensions, according to the win_len. 225 ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, 226 {win_stride, win_stride, 1, 1}, padding); 227 228 auto result = ReferenceUtil::ReduceWindow4DAdd( 229 input_array, 0.0f, {win_len, win_len, 1, 1}, 230 {win_stride, win_stride, 1, 1}, padding); 231 232 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, 233 DefaultErrorSpec()); 234 } 235 236 TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { 237 Array4D<float> input_array(9, 12, 4, 89); 238 input_array.FillRandom(2.f, 2.f); 239 240 int win_len = 3; 241 int win_stride = 2; 242 243 const auto input_data_handle = 244 CreateConstantFromArray(input_array, &builder_); 245 246 Padding padding = Padding::kSame; 247 // Reduce only along the x and y dimensions, according to the win_len. 248 ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, 249 {win_stride, win_stride, 1, 1}, padding); 250 251 auto result = ReferenceUtil::ReduceWindow4DAdd( 252 input_array, 0.0f, {win_len, win_len, 1, 1}, 253 {win_stride, win_stride, 1, 1}, padding); 254 255 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, 256 DefaultErrorSpec()); 257 } 258 259 // Tests the super windowing logic w.r.t handling prime number of windows in a 260 // major dimension with reduction. 261 TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) { 262 Array4D<float> input_array(15, 15, 4, 128); 263 input_array.FillRandom(2.f, 4.f); 264 265 int win_len = 3; 266 int win_stride = 2; 267 268 const auto input_data_handle = 269 CreateConstantFromArray(input_array, &builder_); 270 271 Padding padding = Padding::kSame; 272 // Reduce only along the x and y dimensions, according to the win_len. 273 ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, 274 {win_stride, win_stride, 1, 1}, padding); 275 276 auto result = ReferenceUtil::ReduceWindow4DAdd( 277 input_array, 0.0f, {win_len, win_len, 1, 1}, 278 {win_stride, win_stride, 1, 1}, padding); 279 280 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, 281 DefaultErrorSpec()); 282 } 283 284 TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { 285 Array4D<float> input_array(19, 17, 8, 256); 286 input_array.FillWithMinorDimNum(); 287 288 const auto input_data_handle = 289 CreateConstantFromArray(input_array, &builder_); 290 291 Padding padding = Padding::kSame; 292 ReduceWindowAdd(input_data_handle, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); 293 294 auto result = ReferenceUtil::ReduceWindow4DAdd( 295 input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); 296 297 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, 298 DefaultErrorSpec()); 299 } 300 301 // Tests a reduction function that is not a simple add/min/max/etc. 302 XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { 303 Array4D<float> input_array(1, 2, 2, 1); 304 input_array(0, 0, 0, 0) = 1; 305 input_array(0, 0, 1, 0) = 2; 306 input_array(0, 1, 0, 0) = 3; 307 input_array(0, 1, 1, 0) = 4; 308 const auto input = CreateConstantFromArray(input_array, &builder_); 309 310 Padding padding = Padding::kValid; 311 const Shape scalar = ShapeUtil::MakeShape(FloatType(), {}); 312 auto b = builder_.CreateSubBuilder("unusual"); 313 auto lhs = Parameter(b.get(), 0, scalar, "lhs"); 314 auto rhs = Parameter(b.get(), 1, scalar, "rhs"); 315 Min(Add(lhs, rhs), 316 CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(8.0f), b.get())); 317 XlaComputation reduce_fn = b->BuildAndNoteError(); 318 319 ReduceWindow( 320 input, 321 CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f), &builder_), 322 reduce_fn, 323 /*window_dimensions=*/{1, 1, 2, 1}, 324 /*window_strides=*/{1, 1, 1, 1}, padding); 325 326 const auto reduce_func = [](float arg1, float arg2) { 327 return std::min<float>(arg1 + arg2, 8.0f); 328 }; 329 330 auto expected = 331 ReferenceUtil::ReduceWindow4DGeneric(input_array, 0.0f, reduce_func, 332 /*window=*/{1, 1, 2, 1}, 333 /*stride=*/{1, 1, 1, 1}, padding); 334 335 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*expected), 336 {}, DefaultErrorSpec()); 337 } 338 339 TEST_P(ReduceWindowTest, R4UnitWindow) { 340 Array4D<float> input_array(13, 12, 8, 15); 341 input_array.FillRandom(2.f, 2.f); 342 Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( 343 input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); 344 XlaOp input; 345 auto input_data = CreateParameterAndTransferLiteral( 346 0, input_literal, "parameter", &builder_, &input); 347 348 Padding padding = Padding::kSame; 349 ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); 350 351 auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, 352 {1, 4, 1, 1}, padding); 353 354 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), 355 {input_data.get()}, DefaultErrorSpec()); 356 } 357 358 XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { 359 std::vector<int64> input_dims(6, 8); 360 auto shape = ShapeUtil::MakeShape(F32, input_dims); 361 362 Literal arg_literal(shape); 363 arg_literal.PopulateWithValue(1.0f); 364 const auto input = CreateConstantFromLiteral(arg_literal, &builder_); 365 366 Padding padding = Padding::kValid; 367 ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); 368 369 std::vector<int64> output_layout = {1, 5, 3, 2, 0, 4}; 370 std::vector<int64> output_dims = {6, 8, 6, 6, 8, 8}; 371 Shape result_shape = 372 ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); 373 Literal expected(result_shape); 374 expected.PopulateWithValue(27.0f); 375 ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec()); 376 } 377 378 XLA_TEST_P(ReduceWindowTest, R6Add) { 379 std::vector<int64> input_dims(6, 8); 380 auto shape = ShapeUtil::MakeShape(F32, input_dims); 381 382 Literal arg_literal = 383 LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f); 384 385 const auto input = CreateConstantFromLiteral(arg_literal, &builder_); 386 387 Padding padding = Padding::kValid; 388 ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); 389 390 std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8}; 391 Literal expected = 392 LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 9.0f); 393 394 ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec()); 395 } 396 397 XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { 398 Array4D<float> input_array(2, 1, 27, 119); 399 input_array.FillRandom(2.0f); 400 Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( 401 input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); 402 XlaOp input; 403 auto input_data = CreateParameterAndTransferLiteral( 404 0, input_literal, "parameter", &builder_, &input); 405 406 int win_len = 1; 407 int stride = 8; 408 Padding padding = Padding::kSame; 409 ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); 410 411 auto res = ReferenceUtil::ReduceWindow4DAdd( 412 input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); 413 414 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), 415 {input_data.get()}, DefaultErrorSpec()); 416 } 417 418 XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { 419 Array4D<float> input_array(3, 2, 4, 64); 420 input_array.FillRandom(2.0f); 421 Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( 422 input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); 423 XlaOp input; 424 auto input_data = CreateParameterAndTransferLiteral( 425 0, input_literal, "parameter", &builder_, &input); 426 427 int win_len = 3; 428 int stride = 1; 429 Padding padding = Padding::kSame; 430 ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); 431 432 auto res = ReferenceUtil::ReduceWindow4DAdd( 433 input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); 434 435 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), 436 {input_data.get()}, DefaultErrorSpec()); 437 } 438 439 XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { 440 Array4D<float> input_array(1, 3, 12, 200); 441 input_array.FillRandom(2.0f); 442 Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( 443 input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); 444 XlaOp input; 445 auto input_data = CreateParameterAndTransferLiteral( 446 0, input_literal, "parameter", &builder_, &input); 447 448 int win_len = 8; 449 int stride = 5; 450 Padding padding = Padding::kSame; 451 ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); 452 453 auto res = ReferenceUtil::ReduceWindow4DAdd( 454 input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); 455 456 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), 457 {input_data.get()}, DefaultErrorSpec()); 458 } 459 460 TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { 461 Array4D<float> input_array(6, 4, 10, 130); 462 input_array.FillRandom(2.0f); 463 464 int win_len = 3; 465 int win_stride = 2; 466 467 Padding padding = Padding::kSame; 468 const auto input_data_handle = 469 CreateConstantFromArray(input_array, &builder_); 470 // Reduce only along the x and y dimensions, according to the win_len. 471 ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, 472 {win_stride, win_stride, 1, 1}, padding); 473 474 auto result = ReferenceUtil::ReduceWindow4DAdd( 475 input_array, 0.0f, {win_len, win_len, 1, 1}, 476 {win_stride, win_stride, 1, 1}, padding); 477 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, 478 DefaultErrorSpec()); 479 } 480 481 XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) { 482 std::vector<float> input_vector(128 * 9, 1); 483 const auto input = CreateConstantFromLiteral( 484 LiteralUtil::CreateR1<float>(input_vector), &builder_); 485 ReduceWindowAdd(input, {32}, {128}, Padding::kValid); 486 ComputeAndCompareLiteral( 487 &builder_, 488 LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, 489 DefaultErrorSpec()); 490 } 491 492 XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) { 493 std::vector<float> input_vector{ 494 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 495 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 496 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 497 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 498 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 499 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 500 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 501 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; 502 const auto input = CreateConstantFromLiteral( 503 LiteralUtil::CreateR1<float>(input_vector), &builder_); 504 ReduceWindowAdd(input, {128}, {128}, Padding::kValid); 505 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {}, 506 DefaultErrorSpec()); 507 } 508 509 XLA_TEST_P(ReduceWindowTest, Add128In128) { 510 std::vector<float> input_vector{ 511 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 512 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 513 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 514 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 515 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 516 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 517 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 518 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; 519 const auto input = CreateConstantFromLiteral( 520 LiteralUtil::CreateR1<float>(input_vector), &builder_); 521 ReduceWindowAdd(input, {128}, {1}, Padding::kValid); 522 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {}, 523 DefaultErrorSpec()); 524 } 525 526 // Regression test for a bug that appeared in Inception (b/34784899). 527 TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { 528 Array2D<float> input_array(14, 14, 1.0f); 529 const auto input = CreateConstantFromArray(input_array, &builder_); 530 531 int win_len = 3; 532 int stride = 1; 533 Padding padding = Padding::kSame; 534 ReduceWindowAdd(input, {win_len, win_len}, {stride, stride}, padding); 535 536 auto res = ReferenceUtil::ReduceWindow2DAdd( 537 input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); 538 539 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res), 540 {}, DefaultErrorSpec()); 541 } 542 543 TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { 544 Array2D<float> input_array(6, 4, 1.0f); 545 XlaOp input = Broadcast( 546 CreateConstantFromLiteral(LiteralUtil::One(F32), &builder_), {6, 4}); 547 548 Padding padding = Padding::kSame; 549 ReduceWindowAdd(input, {4, 2}, {3, 3}, padding); 550 551 auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, 552 padding); 553 554 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res), 555 {}, DefaultErrorSpec()); 556 } 557 558 INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, 559 ::testing::ValuesIn(use_bfloat16_params)); 560 561 enum Reducer { kAdd, kMax }; 562 563 struct R4ReduceWindowTestData { 564 int64 base_bounds[4]; 565 int64 window_bounds[4]; 566 int64 strides[4]; 567 int64 pad_low[4]; 568 int64 pad_high[4]; 569 int64 layout[4]; 570 571 Reducer reducer; 572 }; 573 574 string R4ReduceWindowTestDataToString( 575 const ::testing::TestParamInfo< 576 ::testing::tuple<R4ReduceWindowTestData, bool>>& data) { 577 const auto& param = ::testing::get<0>(data.param); 578 string str = absl::StrCat( 579 "base_bounds_", absl::StrJoin(param.base_bounds, "x"), // 580 "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), // 581 "__strides_", absl::StrJoin(param.strides, "x"), // 582 "__pad_low_", absl::StrJoin(param.pad_low, "x"), // 583 "__pad_high_", absl::StrJoin(param.pad_high, "x"), // 584 "__layout_", absl::StrJoin(param.layout, "_"), // 585 (param.reducer == kAdd) ? "_add" : "_max"); 586 CHECK(param.reducer == kAdd || param.reducer == kMax); 587 588 // Test names are not allowed to contain the '-' character. 589 std::replace(str.begin(), str.end(), '-', 'n'); 590 if (::testing::get<1>(data.param)) { 591 absl::StrAppend(&str, "_bfloat16"); 592 } 593 return str; 594 } 595 596 class R4ReduceWindowTest : public ReduceWindowTestBase, 597 public ::testing::WithParamInterface< 598 ::testing::tuple<R4ReduceWindowTestData, bool>> { 599 protected: 600 R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } 601 602 void DoIt() { 603 XlaBuilder b(TestName()); 604 const auto& param = ::testing::get<0>(GetParam()); 605 606 const float kInitValue = 0.0f; 607 608 Array4D<float> input(param.base_bounds[0], param.base_bounds[1], 609 param.base_bounds[2], param.base_bounds[3]); 610 // Choose a prime iota length so that each window sees a unique set of 611 // values. (Technically, the requirement is that the iota length is 612 // relatively prime to all of the dimensions involved in the reduce-window.) 613 input.FillRepeatedIota(0, 137); 614 // Floating point sum reduction requires higher localized precision. We need 615 // the following normalization in order to enable testing of kAdd on large 616 // windows. 617 input.Each([&](absl::Span<const int64> /*indices*/, float* value) { 618 *value = *value / 10000000000.f; 619 }); 620 Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( 621 input, LayoutUtil::MakeLayout(param.layout)); 622 XlaOp parameter; 623 auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0", 624 &b, ¶meter); 625 626 std::vector<std::pair<int64, int64>> padding(4); 627 for (int i = 0; i < 4; ++i) { 628 padding[i] = {param.pad_low[i], param.pad_high[i]}; 629 } 630 631 auto init_value = 632 CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); 633 CHECK(param.reducer == kAdd || param.reducer == kMax); 634 auto reducer = param.reducer; 635 auto computation = reducer == kAdd 636 ? CreateScalarAddComputation(FloatType(), &b) 637 : CreateScalarMaxComputation(FloatType(), &b); 638 ReduceWindowWithGeneralPadding( 639 /*operand=*/parameter, 640 /*init_value=*/init_value, 641 /*computation=*/computation, 642 /*window_dimensions=*/param.window_bounds, 643 /*window_strides=*/param.strides, 644 /*base_dilations=*/{}, 645 /*window_dilations=*/{}, 646 /*padding=*/padding); 647 648 CHECK(reducer == kAdd || reducer == kMax); 649 auto reduce_func = reducer == kAdd 650 ? +[](float a, float b) { return a + b; } 651 : +[](float a, float b) { return std::max(a, b); }; 652 std::unique_ptr<Array4D<float>> expected = 653 ReferenceUtil::ReduceWindow4DGeneric( 654 /*operand=*/input, 655 /*init=*/kInitValue, 656 /*reduce_func=*/reduce_func, 657 /*window=*/param.window_bounds, 658 /*stride=*/param.strides, 659 /*padding=*/padding); 660 Literal expected_literal = LiteralUtil::CreateFromArray(*expected); 661 const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout( 662 input_literal.shape().element_type(), 663 AsInt64Slice(expected_literal.shape().dimensions()), param.layout); 664 ComputeAndCompareLiteral(&b, expected_literal, {input_arg.get()}, 665 DefaultErrorSpec(), &expected_shape_with_layout); 666 } 667 }; 668 669 TEST_P(R4ReduceWindowTest, DoIt) { DoIt(); } 670 671 // base_bounds, window_bounds, strides, pad_low, pad_high 672 const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { 673 // Minimal edge case. 674 R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 1, 1}, 675 /*window_bounds=*/{1, 1, 1, 1}, 676 /*strides=*/{1, 1, 1, 1}, 677 /*pad_low=*/{0, 0, 0, 0}, 678 /*pad_high=*/{0, 0, 0, 0}, 679 /*layout=*/{3, 2, 1, 0}, 680 /*reducer=*/kAdd}, 681 682 // Arbitrary padding (not kSame or kValid). 683 R4ReduceWindowTestData{/*base_bounds=*/{9, 12, 4, 89}, 684 /*window_bounds=*/{3, 3, 1, 1}, 685 /*strides=*/{2, 2, 1, 1}, 686 /*pad_low=*/{4, 4, 0, 0}, 687 /*pad_high=*/{4, 4, 0, 0}, 688 /*layout=*/{3, 2, 1, 0}, 689 /*reducer=*/kAdd}, 690 691 // Zero base bound edge case. 692 R4ReduceWindowTestData{/*base_bounds=*/{1, 0, 1, 1}, 693 /*window_bounds=*/{1, 1, 1, 1}, 694 /*strides=*/{1, 1, 1, 1}, 695 /*pad_low=*/{0, 0, 0, 0}, 696 /*pad_high=*/{0, 0, 0, 0}, 697 /*layout=*/{3, 2, 1, 0}, 698 /*reducer=*/kAdd}, 699 700 // With max instead of add. 701 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, 702 /*window_bounds=*/{2, 3, 1, 1}, 703 /*strides=*/{1, 1, 1, 1}, 704 /*pad_low=*/{0, 0, 0, 0}, 705 /*pad_high=*/{0, 0, 0, 0}, 706 /*layout=*/{3, 2, 1, 0}, 707 /*reducer=*/kMax}, 708 709 // With stride. 710 R4ReduceWindowTestData{/*base_bounds=*/{4, 10, 17, 140}, 711 /*window_bounds=*/{3, 2, 1, 1}, 712 /*strides=*/{2, 4, 1, 1}, 713 /*pad_low=*/{0, 0, 0, 0}, 714 /*pad_high=*/{0, 0, 0, 0}, 715 /*layout=*/{3, 2, 1, 0}, 716 /*reducer=*/kAdd}, 717 718 // With low padding. 719 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, 720 /*window_bounds=*/{3, 2, 1, 1}, 721 /*strides=*/{2, 2, 1, 1}, 722 /*pad_low=*/{3, 2, 0, 0}, 723 /*pad_high=*/{0, 0, 0, 0}, 724 /*layout=*/{3, 2, 1, 0}, 725 /*reducer=*/kAdd}, 726 727 // With high padding. 728 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, 729 /*window_bounds=*/{3, 2, 1, 1}, 730 /*strides=*/{2, 2, 1, 1}, 731 /*pad_low=*/{0, 0, 0, 0}, 732 /*pad_high=*/{2, 3, 0, 0}, 733 /*layout=*/{3, 2, 1, 0}, 734 /*reducer=*/kAdd}, 735 736 // Window touches both sides of the padding simultaneously. 737 R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140}, 738 /*window_bounds=*/{3, 3, 1, 1}, 739 /*strides=*/{1, 1, 1, 1}, 740 /*pad_low=*/{1, 1, 0, 0}, 741 /*pad_high=*/{1, 1, 0, 0}, 742 /*layout=*/{3, 2, 1, 0}, 743 /*reducer=*/kAdd}, 744 745 // Window is entirely in the padding for some positions. 746 R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140}, 747 /*window_bounds=*/{3, 3, 1, 1}, 748 /*strides=*/{1, 1, 1, 1}, 749 /*pad_low=*/{4, 4, 0, 0}, 750 /*pad_high=*/{4, 4, 0, 0}, 751 /*layout=*/{3, 2, 1, 0}, 752 /*reducer=*/kAdd}, 753 754 // Zero base bound with padding edge case. 755 R4ReduceWindowTestData{/*base_bounds=*/{2, 0, 3, 4}, 756 /*window_bounds=*/{1, 1, 1, 1}, 757 /*strides=*/{1, 1, 1, 1}, 758 /*pad_low=*/{0, 1, 0, 0}, 759 /*pad_high=*/{0, 0, 0, 0}, 760 /*layout=*/{3, 2, 1, 0}, 761 /*reducer=*/kAdd}, 762 763 // With stride, low padding and high padding. 764 R4ReduceWindowTestData{/*base_bounds=*/{4, 3, 17, 140}, 765 /*window_bounds=*/{3, 4, 1, 1}, 766 /*strides=*/{3, 1, 1, 1}, 767 /*pad_low=*/{10, 1, 0, 0}, 768 /*pad_high=*/{2, 3, 0, 0}, 769 /*layout=*/{3, 2, 1, 0}, 770 /*reducer=*/kAdd}, 771 772 // With minor dimension == 129. 773 R4ReduceWindowTestData{/*base_bounds=*/{3, 2, 7, 129}, 774 /*window_bounds=*/{1, 1, 1, 1}, 775 /*strides=*/{1, 1, 1, 1}, 776 /*pad_low=*/{0, 0, 0, 0}, 777 /*pad_high=*/{0, 0, 0, 0}, 778 /*layout=*/{3, 2, 1, 0}, 779 /*reducer=*/kAdd}, 780 781 // With minor dims reduction and non-overlapped stride. 782 R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16}, 783 /*window_bounds=*/{1, 1, 2, 2}, 784 /*strides=*/{1, 1, 2, 2}, 785 /*pad_low=*/{0, 0, 0, 0}, 786 /*pad_high=*/{0, 0, 0, 0}, 787 /*layout=*/{3, 2, 1, 0}, 788 /*reducer=*/kAdd}, 789 790 // With minor dims reduction and overlapped stride. 791 R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16}, 792 /*window_bounds=*/{1, 1, 4, 4}, 793 /*strides=*/{1, 1, 2, 2}, 794 /*pad_low=*/{0, 0, 0, 0}, 795 /*pad_high=*/{1, 0, 0, 0}, 796 /*layout=*/{3, 2, 1, 0}, 797 /*reducer=*/kAdd}, 798 799 R4ReduceWindowTestData{/*base_bounds=*/{8, 100, 100, 3}, 800 /*window_bounds=*/{1, 64, 64, 1}, 801 /*strides=*/{1, 64, 64, 1}, 802 /*pad_low=*/{0, 0, 0, 0}, 803 /*pad_high=*/{0, 0, 0, 0}, 804 /*layout=*/{3, 0, 2, 1}, 805 /*reducer=*/kAdd}, 806 807 R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 8, 64}, 808 /*window_bounds=*/{112, 112, 1, 8}, 809 /*strides=*/{112, 112, 1, 8}, 810 /*pad_low=*/{0, 0, 0, 0}, 811 /*pad_high=*/{0, 0, 0, 0}, 812 /*layout=*/{3, 2, 1, 0}, 813 /*reducer=*/kMax}, 814 815 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, 816 /*window_bounds=*/{2, 3, 4, 5}, 817 /*strides=*/{1, 1, 1, 1}, 818 /*pad_low=*/{0, 0, 0, 0}, 819 /*pad_high=*/{0, 0, 0, 0}, 820 /*layout=*/{3, 2, 1, 0}, 821 /*reducer=*/kAdd}, 822 823 // With 0321 layout. 824 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, 825 /*window_bounds=*/{2, 3, 4, 5}, 826 /*strides=*/{1, 2, 3, 4}, 827 /*pad_low=*/{0, 0, 0, 0}, 828 /*pad_high=*/{0, 0, 0, 0}, 829 /*layout=*/{0, 3, 2, 1}, 830 /*reducer=*/kAdd}, 831 832 // With 0123 layout. 833 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 13, 17}, 834 /*window_bounds=*/{2, 3, 7, 9}, 835 /*strides=*/{1, 2, 5, 8}, 836 /*pad_low=*/{0, 0, 0, 0}, 837 /*pad_high=*/{0, 0, 0, 0}, 838 /*layout=*/{0, 1, 2, 3}, 839 /*reducer=*/kAdd}, 840 }; 841 842 INSTANTIATE_TEST_CASE_P( 843 R4ReduceWindowTestInstantiation, R4ReduceWindowTest, 844 ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues), 845 ::testing::ValuesIn(use_bfloat16_params)), 846 R4ReduceWindowTestDataToString); 847 848 class R4ReduceWindowLargeTest : public R4ReduceWindowTest {}; 849 850 XLA_TEST_P(R4ReduceWindowLargeTest, DISABLED_ON_INTERPRETER(DoIt)) { DoIt(); } 851 852 // Test cases that are large/slow/failed. 853 const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = { 854 R4ReduceWindowTestData{/*base_bounds=*/{28, 28, 256, 128}, 855 /*window_bounds=*/{3, 3, 1, 5}, 856 /*strides=*/{1, 1, 1, 5}, 857 /*pad_low=*/{1, 1, 0, 0}, 858 /*pad_high=*/{1, 1, 0, 0}, 859 /*layout=*/{3, 2, 1, 0}, 860 /*reducer=*/kMax}, 861 862 R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 64, 128}, 863 /*window_bounds=*/{3, 3, 1, 1}, 864 /*strides=*/{2, 2, 1, 1}, 865 /*pad_low=*/{0, 0, 0, 0}, 866 /*pad_high=*/{1, 1, 0, 0}, 867 /*layout=*/{3, 2, 1, 0}, 868 /*reducer=*/kAdd}, 869 870 R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 32768 - 3, 2}, 871 /*window_bounds=*/{1, 1, 4, 1}, 872 /*strides=*/{1, 1, 4, 1}, 873 /*pad_low=*/{0, 0, 1, 0}, 874 /*pad_high=*/{0, 0, 2, 0}, 875 /*layout=*/{3, 2, 1, 0}, 876 /*reducer=*/kMax}, 877 878 // Patterns generated by cumsum/cumprod. 879 R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16}, 880 /*window_bounds=*/{1021, 1, 1, 1}, 881 /*strides=*/{1, 1, 1, 1}, 882 /*pad_low=*/{1020, 0, 0, 0}, 883 /*pad_high=*/{0, 0, 0, 0}, 884 /*layout=*/{3, 2, 1, 0}, 885 /*reducer=*/kAdd}, 886 887 R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16}, 888 /*window_bounds=*/{1, 1, 1021, 1}, 889 /*strides=*/{1, 1, 1, 1}, 890 /*pad_low=*/{0, 0, 1020, 0}, 891 /*pad_high=*/{0, 0, 0, 0}, 892 /*layout=*/{3, 2, 1, 0}, 893 /*reducer=*/kAdd}, 894 895 R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021}, 896 /*window_bounds=*/{1, 1, 1, 1021}, 897 /*strides=*/{1, 1, 1, 1}, 898 /*pad_low=*/{0, 0, 0, 1020}, 899 /*pad_high=*/{0, 0, 0, 0}, 900 /*layout=*/{3, 2, 1, 0}, 901 /*reducer=*/kAdd}, 902 903 R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16}, 904 /*window_bounds=*/{1021, 1, 1, 1}, 905 /*strides=*/{1, 1, 1, 1}, 906 /*pad_low=*/{1021, 0, 0, 0}, 907 /*pad_high=*/{0, 0, 0, 0}, 908 /*layout=*/{3, 2, 1, 0}, 909 /*reducer=*/kAdd}, 910 911 R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 1021, 16}, 912 /*window_bounds=*/{1, 1, 1021, 1}, 913 /*strides=*/{1, 1, 1, 1}, 914 /*pad_low=*/{0, 0, 1021, 0}, 915 /*pad_high=*/{0, 0, 0, 0}, 916 /*layout=*/{3, 2, 1, 0}, 917 /*reducer=*/kAdd}, 918 919 R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021}, 920 /*window_bounds=*/{1, 1, 1, 1021}, 921 /*strides=*/{1, 1, 1, 1}, 922 /*pad_low=*/{0, 0, 0, 1021}, 923 /*pad_high=*/{0, 0, 0, 0}, 924 /*layout=*/{3, 2, 1, 0}, 925 /*reducer=*/kAdd}, 926 }; 927 928 INSTANTIATE_TEST_CASE_P( 929 R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest, 930 ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues), 931 ::testing::ValuesIn(use_bfloat16_params)), 932 R4ReduceWindowTestDataToString); 933 934 struct R3ReduceWindowTestData { 935 int64 base_bounds[3]; 936 int64 window_bounds[3]; 937 int64 strides[3]; 938 int64 layout[3]; 939 Padding padding; 940 Reducer reducer; 941 } kR3TestCases[] = { 942 {/*base_bounds=*/{2, 1, 2}, /*window_bounds=*/{1, 1, 2}, 943 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, 944 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, 945 {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2}, 946 /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0}, 947 /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, 948 {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2}, 949 /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0}, 950 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, 951 {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2}, 952 /*strides=*/{1, 2, 2}, /*layout=*/{2, 1, 0}, 953 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, 954 {/*base_bounds=*/{10, 21, 129}, /*window_bounds=*/{2, 9, 1}, 955 /*strides=*/{5, 2, 1}, /*layout=*/{2, 1, 0}, 956 /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, 957 {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2}, 958 /*strides=*/{1, 2, 2}, /*layout=*/{0, 1, 2}, 959 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, 960 {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2}, 961 /*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2}, 962 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, 963 {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251}, 964 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, 965 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, 966 {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3}, 967 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, 968 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, 969 {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64}, 970 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, 971 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, 972 {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257}, 973 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, 974 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, 975 {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3}, 976 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, 977 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, 978 {/*base_bounds=*/{9999, 1, 1}, /*window_bounds=*/{9999, 1, 1}, 979 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, 980 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, 981 {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3}, 982 /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0}, 983 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, 984 }; 985 986 string R3ReduceWindowTestDataToString( 987 const ::testing::TestParamInfo< 988 ::testing::tuple<R3ReduceWindowTestData, bool>>& data) { 989 const auto& param = ::testing::get<0>(data.param); 990 string str = absl::StrCat( 991 "base_bounds_", absl::StrJoin(param.base_bounds, "x"), "__window_bounds_", 992 absl::StrJoin(param.window_bounds, "x"), "__strides_", 993 absl::StrJoin(param.strides, "x"), "__padding_", 994 param.padding == Padding::kSame ? "same" : "valid", "__layout_", 995 param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_", 996 param.reducer == kAdd ? "add" : "max"); 997 if (::testing::get<1>(data.param)) { 998 absl::StrAppend(&str, "_bfloat16"); 999 } 1000 return str; 1001 } 1002 1003 class R3ReduceWindowTest : public ReduceWindowTestBase, 1004 public ::testing::WithParamInterface< 1005 ::testing::tuple<R3ReduceWindowTestData, bool>> { 1006 protected: 1007 R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } 1008 }; 1009 1010 TEST_P(R3ReduceWindowTest, DoIt) { 1011 XlaBuilder b(TestName()); 1012 const auto& param = ::testing::get<0>(GetParam()); 1013 1014 const float kInitValue = 0.0f; 1015 Array3D<float> input(param.base_bounds[0], param.base_bounds[1], 1016 param.base_bounds[2]); 1017 // Choose a prime iota length so that each window sees a unique set of values. 1018 // (Technically, the requirement is that the iota length is relatively prime 1019 // to all of the dimensions involved in the reduce-window.) 1020 input.FillRepeatedIota(0, 137); 1021 Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout( 1022 input, LayoutUtil::MakeLayout(param.layout)); 1023 auto reducer = param.reducer; 1024 if (use_bfloat16()) { 1025 input_literal = LiteralUtil::ConvertF32ToBF16(input_literal); 1026 1027 // To avoid numerical issues, force the reducer to be kMax for bf16 1028 // inputs. 1029 reducer = kMax; 1030 } 1031 1032 XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input"); 1033 auto init_value = 1034 CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); 1035 1036 auto computation = reducer == kAdd 1037 ? CreateScalarAddComputation(FloatType(), &b) 1038 : CreateScalarMaxComputation(FloatType(), &b); 1039 1040 ReduceWindow(/*operand=*/parameter, 1041 /*init_value=*/init_value, 1042 /*computation=*/computation, 1043 /*window_dimensions=*/param.window_bounds, 1044 /*window_strides=*/param.strides, /*padding=*/param.padding); 1045 1046 ComputeAndCompare(&b, {std::move(input_literal)}, DefaultErrorSpec()); 1047 } 1048 1049 INSTANTIATE_TEST_CASE_P( 1050 R3ReduceWindowTestInstantiation, R3ReduceWindowTest, 1051 ::testing::Combine(::testing::ValuesIn(kR3TestCases), 1052 ::testing::ValuesIn(use_bfloat16_params)), 1053 R3ReduceWindowTestDataToString); 1054 1055 struct R2ReduceWindowTestData { 1056 int64 base_bounds[2]; 1057 int64 window_bounds[2]; 1058 int64 strides[2]; 1059 int64 pad_low[2]; 1060 int64 pad_high[2]; 1061 int64 layout[2]; 1062 Reducer reducer; 1063 } kR2TestCases[] = { 1064 {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4}, 1065 /*strides=*/{1, 2}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1}, 1066 /*layout=*/{0, 1}, 1067 /*reducer=*/Reducer::kAdd}, 1068 {/*base_bounds=*/{2, 5}, /*window_bounds=*/{2, 4}, 1069 /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 2}, 1070 /*layout=*/{0, 1}, 1071 /*reducer=*/Reducer::kAdd}, 1072 {/*base_bounds=*/{1, 3}, /*window_bounds=*/{2, 3}, 1073 /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1}, 1074 /*layout=*/{0, 1}, 1075 /*reducer=*/Reducer::kAdd}, 1076 {/*base_bounds=*/{3, 129}, /*window_bounds=*/{1, 100}, 1077 /*strides=*/{2, 99}, /*pad_low=*/{0, 0}, /*pad_high=*/{35, 35}, 1078 /*layout=*/{0, 1}, 1079 /*reducer=*/Reducer::kAdd}, 1080 // TODO(b/74260408): This test last failed on GPU on 2018-03-08, likely due to a 1081 // ptxas bug. 1082 #ifndef XLA_TEST_BACKEND_GPU 1083 {/*base_bounds=*/{6, 152}, /*window_bounds=*/{2, 25}, 1084 /*strides=*/{5, 4}, /*pad_low=*/{0, 1}, /*pad_high=*/{10, 11}, 1085 /*layout=*/{0, 1}, 1086 /*reducer=*/Reducer::kAdd}, 1087 #endif 1088 {/*base_bounds=*/{6, 4}, /*window_bounds=*/{4, 2}, 1089 /*strides=*/{3, 3}, /*pad_low=*/{0, 1}, /*pad_high=*/{0, 1}, 1090 /*layout=*/{0, 1}, 1091 /*reducer=*/Reducer::kAdd}, 1092 {/*base_bounds=*/{5, 147}, /*window_bounds=*/{1, 36}, 1093 /*strides=*/{4, 5}, /*pad_low=*/{0, 0}, /*pad_high=*/{17, 17}, 1094 /*layout=*/{1, 0}, 1095 /*reducer=*/Reducer::kAdd}, 1096 {/*base_bounds=*/{4, 153}, /*window_bounds=*/{2, 93}, 1097 /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{46, 46}, 1098 /*layout=*/{1, 0}, 1099 /*reducer=*/Reducer::kAdd}, 1100 // Regression test for a bug that appeared in Inception (b/34784899). 1101 {/*base_bounds=*/{28, 28}, /*window_bounds=*/{3, 3}, 1102 /*strides=*/{1, 1}, /*pad_low=*/{1, 1}, /*pad_high=*/{1, 1}, 1103 /*layout=*/{1, 0}, 1104 /*reducer=*/Reducer::kAdd}, 1105 {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2}, 1106 /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, 1107 /*layout=*/{1, 0}, 1108 /*reducer=*/Reducer::kAdd}, 1109 // Regression test for a bug that appeared in Inception (b/34784899). 1110 {/*base_bounds=*/{4, 32}, /*window_bounds=*/{2, 2}, 1111 /*strides=*/{2, 2}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, 1112 /*layout=*/{1, 0}, 1113 /*reducer=*/Reducer::kAdd}, 1114 // Regression test for b/73903312: bf16 lacks precision to store result of 1115 // very large windows. Testing with a reasonable window larger than 128. 1116 {/*base_bounds=*/{8, 130}, /*window_bounds=*/{1, 130}, 1117 /*strides=*/{1, 1}, /*pad_low=*/{0, 130}, /*pad_high=*/{0, 0}, 1118 /*layout=*/{1, 0}, 1119 /*reducer=*/Reducer::kAdd}, 1120 {/*base_bounds=*/{8, 256}, /*window_bounds=*/{1, 4}, 1121 /*strides=*/{1, 64}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, 1122 /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, 1123 {/*base_bounds=*/{4096, 4096}, /*window_bounds=*/{1, 4}, 1124 /*strides=*/{1, 1024}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0}, 1125 /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, 1126 // Regression test for b/72234705: bf16 lacks precision to store incremental 1127 // results on very large windows. Using smaller window with minor dim 128. 1128 {/*base_bounds=*/{8, 128}, /*window_bounds=*/{2, 128}, 1129 /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0}, 1130 /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, 1131 }; 1132 1133 string R2ReduceWindowTestDataToString( 1134 const ::testing::TestParamInfo< 1135 ::testing::tuple<R2ReduceWindowTestData, bool>>& data) { 1136 const auto& param = ::testing::get<0>(data.param); 1137 string str = absl::StrCat( 1138 "base_bounds_", absl::StrJoin(param.base_bounds, "x"), // 1139 "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), // 1140 "__strides_", absl::StrJoin(param.strides, "x"), // 1141 "__pad_low_", absl::StrJoin(param.pad_low, "x"), "__pad_high_", 1142 absl::StrJoin(param.pad_high, "x"), "__layout_", param.layout[0], "_", 1143 param.layout[1], // 1144 "__reducer_", param.reducer == kAdd ? "add" : "max"); 1145 if (::testing::get<1>(data.param)) { 1146 absl::StrAppend(&str, "_bfloat16"); 1147 } 1148 return str; 1149 } 1150 1151 class R2ReduceWindowTest : public ReduceWindowTestBase, 1152 public ::testing::WithParamInterface< 1153 ::testing::tuple<R2ReduceWindowTestData, bool>> { 1154 protected: 1155 R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } 1156 1157 void DoIt() { 1158 XlaBuilder b(TestName()); 1159 const auto& param = ::testing::get<0>(GetParam()); 1160 1161 const float kInitValue = 0.0f; 1162 Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f); 1163 Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout( 1164 input, LayoutUtil::MakeLayout(param.layout)); 1165 1166 XlaOp parameter; 1167 auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0", 1168 &b, ¶meter); 1169 std::vector<std::pair<int64, int64>> padding(2); 1170 for (int i = 0; i < 2; ++i) { 1171 padding[i] = {param.pad_low[i], param.pad_high[i]}; 1172 } 1173 auto computation = param.reducer == kAdd 1174 ? CreateScalarAddComputation(FloatType(), &b) 1175 : CreateScalarMaxComputation(FloatType(), &b); 1176 auto init_value = 1177 CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); 1178 ReduceWindowWithGeneralPadding( 1179 /*operand=*/parameter, 1180 /*init_value=*/init_value, 1181 /*computation=*/computation, 1182 /*window_dimensions=*/param.window_bounds, 1183 /*window_strides=*/param.strides, 1184 /*base_dilations=*/{}, 1185 /*window_dilations=*/{}, 1186 /*padding=*/padding); 1187 1188 auto reduce_func = param.reducer == kAdd 1189 ? +[](float a, float b) { return a + b; } 1190 : +[](float a, float b) { return std::max(a, b); }; 1191 auto expected = ReferenceUtil::ReduceWindow2DGeneric( 1192 /*operand=*/input, /*init=*/kInitValue, /*reduce_func=*/reduce_func, 1193 /*window=*/param.window_bounds, 1194 /*stride=*/param.strides, /*padding=*/padding); 1195 1196 ComputeAndCompareLiteral(&b, LiteralUtil::CreateFromArray(*expected), 1197 {input_arg.get()}, DefaultErrorSpec()); 1198 } 1199 }; 1200 1201 TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); } 1202 1203 INSTANTIATE_TEST_CASE_P( 1204 R2ReduceWindowTestInstantiation, R2ReduceWindowTest, 1205 ::testing::Combine(::testing::ValuesIn(kR2TestCases), 1206 ::testing::ValuesIn(use_bfloat16_params)), 1207 R2ReduceWindowTestDataToString); 1208 1209 struct R1ReduceWindowTestData { 1210 int64 base_bounds[1]; 1211 int64 window_bounds[1]; 1212 int64 strides[1]; 1213 int64 pad_low[1]; 1214 int64 pad_high[1]; 1215 Reducer reducer; 1216 } kR1TestCases[] = { 1217 {/*base_bounds=*/{1}, /*window_bounds=*/{1}, 1218 /*strides=*/{1}, 1219 /*pad_low=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].first}, 1220 /*pad_high=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].second}, 1221 /*reducer=*/Reducer::kAdd}, 1222 1223 {/*base_bounds=*/{3}, /*window_bounds=*/{3}, 1224 /*strides=*/{1}, 1225 /*pad_low=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].first}, 1226 /*pad_high=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].second}, 1227 /*reducer=*/Reducer::kAdd}, 1228 1229 {/*base_bounds=*/{3}, /*window_bounds=*/{2}, 1230 /*strides=*/{1}, 1231 /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].first}, 1232 /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].second}, 1233 /*reducer=*/Reducer::kAdd}, 1234 1235 {/*base_bounds=*/{5}, /*window_bounds=*/{1}, 1236 /*strides=*/{1}, 1237 /*pad_low=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].first}, 1238 /*pad_high=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].second}, 1239 /*reducer=*/Reducer::kMax}, 1240 1241 {/*base_bounds=*/{16}, /*window_bounds=*/{4}, 1242 /*strides=*/{4}, 1243 /*pad_low=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].first}, 1244 /*pad_high=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].second}, 1245 /*reducer=*/Reducer::kMax}, 1246 1247 {/*base_bounds=*/{16}, /*window_bounds=*/{4}, 1248 /*strides=*/{3}, 1249 /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].first}, 1250 /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].second}, 1251 /*reducer=*/Reducer::kAdd}, 1252 1253 {/*base_bounds=*/{128 * 2}, 1254 /*window_bounds=*/{30}, 1255 /*strides=*/{27}, 1256 /*pad_low=*/ 1257 {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].first}, 1258 /*pad_high=*/ 1259 {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].second}, 1260 /*reducer=*/Reducer::kAdd}, 1261 1262 {/*base_bounds=*/{128 * 17}, 1263 /*window_bounds=*/{7}, 1264 /*strides=*/{64}, 1265 /*pad_low=*/ 1266 {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].first}, 1267 /*pad_high=*/ 1268 {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].second}, 1269 /*reducer=*/Reducer::kAdd}, 1270 1271 {/*base_bounds=*/{128 * 2}, 1272 /*window_bounds=*/{32}, 1273 /*strides=*/{56}, 1274 /*pad_low=*/ 1275 {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].first}, 1276 /*pad_high=*/ 1277 {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].second}, 1278 /*reducer=*/Reducer::kAdd}, 1279 1280 {/*base_bounds=*/{3}, /*window_bounds=*/{2}, 1281 /*strides=*/{1}, 1282 /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].first}, 1283 /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].second}, 1284 /*reducer=*/Reducer::kAdd}, 1285 1286 {/*base_bounds=*/{5}, /*window_bounds=*/{3}, 1287 /*strides=*/{2}, 1288 /*pad_low=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].first}, 1289 /*pad_high=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].second}, 1290 /*reducer=*/Reducer::kAdd}, 1291 1292 {/*base_bounds=*/{16}, /*window_bounds=*/{4}, 1293 /*strides=*/{3}, 1294 /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].first}, 1295 /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].second}, 1296 /*reducer=*/Reducer::kAdd}, 1297 1298 {/*base_bounds=*/{5}, /*window_bounds=*/{5}, 1299 /*strides=*/{1}, 1300 /*pad_low=*/{0}, 1301 /*pad_high=*/{5}, 1302 /*reducer=*/Reducer::kAdd}, 1303 1304 {/*base_bounds=*/{5}, /*window_bounds=*/{5}, 1305 /*strides=*/{1}, 1306 /*pad_low=*/{5}, 1307 /*pad_high=*/{0}, 1308 /*reducer=*/Reducer::kAdd}, 1309 1310 // The pattern generated by inclusive scan (cumsum/cumprod). 1311 {/*base_bounds=*/{4096}, /*window_bounds=*/{4096}, 1312 /*strides=*/{1}, 1313 /*pad_low=*/{4095}, 1314 /*pad_high=*/{0}, 1315 /*reducer=*/Reducer::kMax}, 1316 1317 // The pattern generated by exclusive scan (cumsum/cumprod). 1318 {/*base_bounds=*/{4095}, /*window_bounds=*/{4095}, 1319 /*strides=*/{1}, 1320 /*pad_low=*/{4095}, 1321 /*pad_high=*/{0}, 1322 /*reducer=*/Reducer::kMax}, 1323 }; 1324 1325 string R1ReduceWindowTestDataToString( 1326 const ::testing::TestParamInfo< 1327 ::testing::tuple<R1ReduceWindowTestData, bool>>& data) { 1328 const auto& param = ::testing::get<0>(data.param); 1329 string str = 1330 absl::StrCat("base_bounds_", absl::StrJoin(param.base_bounds, "x"), 1331 "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), 1332 "__strides_", absl::StrJoin(param.strides, "x"), 1333 "__pad_low_", absl::StrJoin(param.pad_low, "x"), 1334 "__pad_high_", absl::StrJoin(param.pad_high, "x"), 1335 "__reducer_", param.reducer == kAdd ? "add" : "max"); 1336 if (::testing::get<1>(data.param)) { 1337 absl::StrAppend(&str, "_bfloat16"); 1338 } 1339 return str; 1340 } 1341 1342 class R1ReduceWindowTest : public ReduceWindowTestBase, 1343 public ::testing::WithParamInterface< 1344 ::testing::tuple<R1ReduceWindowTestData, bool>> { 1345 protected: 1346 R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } 1347 }; 1348 1349 TEST_P(R1ReduceWindowTest, DoIt) { 1350 XlaBuilder b(TestName()); 1351 const auto& param = ::testing::get<0>(GetParam()); 1352 CHECK(param.reducer == kAdd || param.reducer == kMax); 1353 1354 const float kInitValue = 0.0f; 1355 std::vector<float> input_vector(param.base_bounds[0]); 1356 std::iota(std::begin(input_vector), std::end(input_vector), 0); 1357 Literal input_literal = 1358 LiteralUtil::CreateR1(absl::Span<const float>(input_vector)); 1359 XlaOp parameter; 1360 auto input_arg = 1361 CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); 1362 1363 std::vector<std::pair<int64, int64>> padding(1); 1364 padding[0] = {param.pad_low[0], param.pad_high[0]}; 1365 1366 auto computation = param.reducer == kAdd 1367 ? CreateScalarAddComputation(FloatType(), &b) 1368 : CreateScalarMaxComputation(FloatType(), &b); 1369 auto init_value = 1370 CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); 1371 ReduceWindowWithGeneralPadding( 1372 /*operand=*/parameter, 1373 /*init_value=*/init_value, 1374 /*computation=*/computation, 1375 /*window_dimensions=*/param.window_bounds, 1376 /*window_strides=*/param.strides, 1377 /*base_dilations=*/{}, 1378 /*window_dilations=*/{}, 1379 /*padding=*/padding); 1380 1381 auto reduce_func = param.reducer == kAdd 1382 ? +[](float a, float b) { return a + b; } 1383 : +[](float a, float b) { return std::max(a, b); }; 1384 auto expected = ReferenceUtil::ReduceWindow1DGeneric( 1385 /*operand=*/absl::Span<const float>(input_vector), 1386 /*init=*/kInitValue, 1387 /*reduce_func=*/reduce_func, 1388 /*window=*/param.window_bounds, 1389 /*stride=*/param.strides, 1390 /*padding=*/padding); 1391 1392 ComputeAndCompareLiteral(&b, LiteralUtil::CreateR1<float>(*expected), 1393 {input_arg.get()}, DefaultErrorSpec()); 1394 } 1395 1396 INSTANTIATE_TEST_CASE_P( 1397 R1ReduceWindowTestInstantiation, R1ReduceWindowTest, 1398 ::testing::Combine(::testing::ValuesIn(kR1TestCases), 1399 ::testing::ValuesIn(use_bfloat16_params)), 1400 R1ReduceWindowTestDataToString); 1401 1402 // Test class for text-based test cases. Note that this compares with the 1403 // results on the interpreter backend. 1404 class ReduceWindowTextTest : public HloTestBase {}; 1405 1406 XLA_TEST_F(ReduceWindowTextTest, R2General256x384) { 1407 const string hlo_string = R"( 1408 HloModule R2Window 1409 mul { 1410 lhs = f32[] parameter(0) 1411 rhs = f32[] parameter(1) 1412 ROOT mul = f32[] multiply(lhs, rhs) 1413 } 1414 ENTRY R2Window { 1415 operand = f32[256,384]{1,0} parameter(0) 1416 constant = f32[] constant(1) 1417 ROOT reduce-window = f32[256,384]{1,0} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul 1418 } 1419 )"; 1420 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); 1421 } 1422 1423 XLA_TEST_F(ReduceWindowTextTest, R2General256x384Layout01) { 1424 const string hlo_string = R"( 1425 HloModule R2Window 1426 mul { 1427 lhs = f32[] parameter(0) 1428 rhs = f32[] parameter(1) 1429 ROOT mul = f32[] multiply(lhs, rhs) 1430 } 1431 ENTRY R2Window { 1432 operand = f32[256,384]{0,1} parameter(0) 1433 constant = f32[] constant(1) 1434 ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul 1435 } 1436 )"; 1437 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); 1438 } 1439 1440 XLA_TEST_F(ReduceWindowTextTest, R2General2x5) { 1441 const string hlo_string = R"( 1442 HloModule R2Window 1443 mul { 1444 lhs = f32[] parameter(0) 1445 rhs = f32[] parameter(1) 1446 ROOT mul = f32[] multiply(lhs, rhs) 1447 } 1448 ENTRY R2Window { 1449 operand = f32[2,5]{1,0} parameter(0) 1450 constant = f32[] constant(1) 1451 ROOT reduce-window = f32[3,5]{1,0} reduce-window(operand, constant), window={size=2x1 pad=0_2x0_0}, to_apply=mul 1452 } 1453 )"; 1454 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); 1455 } 1456 1457 XLA_TEST_F(ReduceWindowTextTest, R2EffectiveScalar) { 1458 const string hlo_string = R"( 1459 HloModule R2Window 1460 mul { 1461 lhs = f32[] parameter(0) 1462 rhs = f32[] parameter(1) 1463 ROOT mul = f32[] multiply(lhs, rhs) 1464 } 1465 ENTRY R2Window { 1466 operand = f32[1,1]{1,0} parameter(0) 1467 negate = f32[1,1]{1,0} negate(operand) 1468 constant = f32[] constant(1) 1469 ROOT reduce-window = f32[1,1]{1,0} reduce-window(negate, constant), window={size=1x1 pad=0_0x0_0}, to_apply=mul 1470 } 1471 )"; 1472 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); 1473 } 1474 1475 XLA_TEST_F(ReduceWindowTextTest, R3EffectiveScalar) { 1476 const string hlo_string = R"( 1477 HloModule R3Window 1478 mul { 1479 lhs = f32[] parameter(0) 1480 rhs = f32[] parameter(1) 1481 ROOT mul = f32[] multiply(lhs, rhs) 1482 } 1483 ENTRY R3Window { 1484 operand = f32[1,1,1]{2,1,0} parameter(0) 1485 negate = f32[1,1,1]{2,1,0} negate(operand) 1486 constant = f32[] constant(1) 1487 ROOT reduce-window = f32[1,1,1]{2,1,0} reduce-window(negate, constant), window={size=1x1x1 pad=0_0x0_0x0_0}, to_apply=mul 1488 } 1489 )"; 1490 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); 1491 } 1492 1493 XLA_TEST_F(HloTestBase, ReduceWindowIdentity) { 1494 const string hlo_string = R"( 1495 HloModule ReduceWindowIdentity 1496 identity.pad_to_reduce_window { 1497 param0 = f32[] parameter(0) 1498 ROOT param1 = f32[] parameter(1) 1499 } 1500 ENTRY reduce-window-identity { 1501 operand = f32[1,32,64]{2,1,0} parameter(0) 1502 constant.4466 = f32[] constant(0) 1503 ROOT reduce-window = f32[1,33,64]{2,1,0} reduce-window(operand, constant.4466), window={size=1x1x1 pad=0_0x1_0x0_0}, to_apply=identity.pad_to_reduce_window 1504 } 1505 1506 )"; 1507 EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt)); 1508 } 1509 1510 XLA_TEST_F(HloTestBase, ReduceWindowS32) { 1511 const string hlo_string = R"( 1512 HloModule reduce-window 1513 1514 %identity.pad_to_reduce_window (param0: s32[], param1: s32[]) -> s32[] { 1515 %param0 = s32[] parameter(0) 1516 ROOT %param1 = s32[] parameter(1) 1517 } 1518 1519 ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] { 1520 %parameter.0 = s32[81,8]{1,0} parameter(0) 1521 %parameter.1 = s32[] parameter(1) 1522 ROOT %reduce-window = s32[82,8]{1,0} reduce-window(s32[81,8]{1,0} %parameter.0, s32[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window 1523 } 1524 1525 )"; 1526 EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt)); 1527 } 1528 1529 XLA_TEST_F(HloTestBase, ReduceWindowS64) { 1530 const string hlo_string = R"( 1531 HloModule reduce-window 1532 1533 %identity.pad_to_reduce_window (param0: s64[], param1: s64[]) -> s64[] { 1534 %param0 = s64[] parameter(0) 1535 ROOT %param1 = s64[] parameter(1) 1536 } 1537 1538 ENTRY %reduce-window (parameter.0: s64[81,8], parameter.1: s64[]) -> s64[82,8] { 1539 %parameter.0 = s64[81,8]{1,0} parameter(0) 1540 %parameter.1 = s64[] parameter(1) 1541 ROOT %reduce-window = s64[82,8]{1,0} reduce-window(s64[81,8]{1,0} %parameter.0, s64[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window 1542 } 1543 1544 )"; 1545 EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt)); 1546 } 1547 1548 XLA_TEST_F(HloTestBase, ReduceWindowF16) { 1549 const string hlo_string = R"( 1550 HloModule reduce-window 1551 1552 %identity.pad_to_reduce_window (param0: f16[], param1: f16[]) -> f16[] { 1553 %param0 = f16[] parameter(0) 1554 ROOT %param1 = f16[] parameter(1) 1555 } 1556 1557 ENTRY %reduce-window (parameter.0: f16[81,8], parameter.1: f16[]) -> f16[82,8] { 1558 %parameter.0 = f16[81,8]{1,0} parameter(0) 1559 %parameter.1 = f16[] parameter(1) 1560 ROOT %reduce-window = f16[82,8]{1,0} reduce-window(f16[81,8]{1,0} %parameter.0, f16[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window 1561 } 1562 1563 )"; 1564 EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt)); 1565 } 1566 1567 } // namespace 1568 } // namespace xla 1569