1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <cmath> 17 #include <limits> 18 #include <memory> 19 #include <numeric> 20 #include <vector> 21 22 #include "absl/base/casts.h" 23 #include "absl/types/span.h" 24 #include "tensorflow/compiler/xla/array2d.h" 25 #include "tensorflow/compiler/xla/array3d.h" 26 #include "tensorflow/compiler/xla/array4d.h" 27 #include "tensorflow/compiler/xla/client/global_data.h" 28 #include "tensorflow/compiler/xla/client/local_client.h" 29 #include "tensorflow/compiler/xla/client/xla_builder.h" 30 #include "tensorflow/compiler/xla/layout_util.h" 31 #include "tensorflow/compiler/xla/literal.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 #include "tensorflow/compiler/xla/test.h" 34 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 35 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 36 #include "tensorflow/compiler/xla/tests/test_macros.h" 37 #include "tensorflow/compiler/xla/types.h" 38 #include "tensorflow/core/platform/types.h" 39 40 namespace xla { 41 namespace { 42 43 class ArrayElementwiseOpTest : public ClientLibraryTestBase { 44 public: 45 ErrorSpec error_spec_{0.0001, 0.0001}; 46 }; 47 48 class ArrayElementwiseOpTestParamCount 49 : public ArrayElementwiseOpTest, 50 public ::testing::WithParamInterface<int> {}; 51 52 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) { 53 XlaBuilder builder(TestName()); 54 auto a = ConstantR1<float>(&builder, {}); 55 Neg(a); 56 57 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 58 } 59 60 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) { 61 XlaBuilder builder(TestName()); 62 auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); 63 Neg(a); 64 65 ComputeAndCompareR1<float>(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {}, 66 error_spec_); 67 } 68 69 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { 70 XlaBuilder builder(TestName()); 71 auto a = ConstantR1<int32>(&builder, 72 {-1, 0, 1, 324, std::numeric_limits<int32>::min(), 73 std::numeric_limits<int32>::max()}); 74 Neg(a); 75 76 // -min == min for int32 due to an overflow. In C++ it is undefined behavior 77 // to do this calculation. For XLA we have not specified that, so it 78 // ought to work. 79 ComputeAndCompareR1<int32>(&builder, 80 {1, 0, -1, -324, std::numeric_limits<int32>::min(), 81 -std::numeric_limits<int32>::max()}, 82 {}); 83 } 84 85 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) { 86 XlaBuilder builder(TestName()); 87 auto a = ConstantR1<complex64>(&builder, {}); 88 Neg(a); 89 90 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_); 91 } 92 93 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { 94 XlaBuilder builder(TestName()); 95 auto a = ConstantR1<complex64>( 96 &builder, {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}}); 97 Neg(a); 98 99 ComputeAndCompareR1<complex64>( 100 &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}}, 101 {}, error_spec_); 102 } 103 104 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) { 105 XlaBuilder builder(TestName()); 106 auto a = 107 ConstantR1<int64>(&builder, { 108 -1, 109 1, 110 0, 111 0x12345678, 112 static_cast<int64>(0xffffffff12345678l), 113 static_cast<int64>(0x8000000000000000LL), 114 static_cast<int64>(0x8000000000000001LL), 115 }); 116 Neg(a); 117 LOG(INFO) << -static_cast<int64>(0x7FFFFFFFFFFFFFFFLL); 118 119 ComputeAndCompareR1<int64>(&builder, 120 { 121 1, 122 -1, 123 0, 124 -0x12345678, 125 0xedcba988, 126 static_cast<int64>(0x8000000000000000LL), 127 -static_cast<int64>(0x8000000000000001LL), 128 }, 129 {}); 130 } 131 132 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { 133 XlaBuilder builder(TestName()); 134 auto a = ConstantR1<float>(&builder, {}); 135 IsFinite(a); 136 137 ComputeAndCompareR1<bool>(&builder, {}, {}); 138 } 139 140 // A non-canonical quiet NaN value. 141 static const float kNonCanonicalNaN = absl::bit_cast<float>(0x7FD01234); 142 143 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) { 144 XlaBuilder builder(TestName()); 145 IsFinite(ConstantR0<float>(&builder, NAN)); 146 ComputeAndCompareR0<bool>(&builder, false, {}); 147 148 EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); 149 IsFinite(ConstantR0<float>(&builder, kNonCanonicalNaN)); 150 ComputeAndCompareR0<bool>(&builder, false, {}); 151 152 const float inf = std::numeric_limits<float>::infinity(); 153 IsFinite(ConstantR0<float>(&builder, inf)); 154 ComputeAndCompareR0<bool>(&builder, false, {}); 155 156 IsFinite(ConstantR0<float>(&builder, -inf)); 157 ComputeAndCompareR0<bool>(&builder, false, {}); 158 159 IsFinite(ConstantR0<float>(&builder, 0.0f)); 160 ComputeAndCompareR0<bool>(&builder, true, {}); 161 } 162 163 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) { 164 XlaBuilder builder(TestName()); 165 const float inf = std::numeric_limits<float>::infinity(); 166 EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); 167 auto a = ConstantR1<float>(&builder, 168 {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}}); 169 IsFinite(a); 170 171 ComputeAndCompareR1<bool>(&builder, {false, true, false, true, false, false}, 172 {}); 173 } 174 175 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { 176 XlaBuilder builder(TestName()); 177 auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); 178 auto b = ConstantR1<float>(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); 179 Add(a, b); 180 181 ComputeAndCompareR1<float>(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {}, 182 error_spec_); 183 } 184 185 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) { 186 XlaBuilder builder(TestName()); 187 auto a = ConstantR1<float>(&builder, {}); 188 auto b = ConstantR1<float>(&builder, {}); 189 Add(a, b); 190 191 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 192 } 193 194 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { 195 XlaBuilder builder(TestName()); 196 auto a = ConstantR1<complex64>( 197 &builder, {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}}); 198 auto b = ConstantR1<complex64>( 199 &builder, {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}}); 200 Add(a, b); 201 202 ComputeAndCompareR1<complex64>( 203 &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {}, 204 error_spec_); 205 } 206 207 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) { 208 XlaBuilder builder(TestName()); 209 auto a = ConstantR1<complex64>(&builder, {}); 210 auto b = ConstantR1<complex64>(&builder, {}); 211 Add(a, b); 212 213 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_); 214 } 215 216 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 217 XlaBuilder b(TestName()); 218 219 std::vector<uint64> lhs{0xFFFFFFFF, 220 static_cast<uint64>(-1), 221 0, 222 0, 223 0x7FFFFFFFFFFFFFFFLL, 224 0x7FFFFFFFFFFFFFFLL, 225 0x8000000000000000LL, 226 0x8000000000000000LL, 227 1}; 228 Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs}); 229 auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); 230 std::unique_ptr<GlobalData> lhs_data = 231 client_->TransferToServer(lhs_literal).ConsumeValueOrDie(); 232 233 std::vector<uint64> rhs{1, 234 0x7FFFFFFFFFFFFFFLL, 235 0x7FFFFFFFFFFFFFFFLL, 236 0x8000000000000000LL, 237 0, 238 static_cast<uint64>(-1), 239 0, 240 1, 241 0x8000000000000000LL}; 242 Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs}); 243 auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); 244 std::unique_ptr<GlobalData> rhs_data = 245 client_->TransferToServer(rhs_literal).ConsumeValueOrDie(); 246 247 Add(lhs_param, rhs_param); 248 249 std::vector<uint64> expected(lhs.size()); 250 for (int64 i = 0; i < lhs.size(); ++i) { 251 expected[i] = lhs[i] + rhs[i]; 252 } 253 254 ComputeAndCompareR1<uint64>(&b, expected, {lhs_data.get(), rhs_data.get()}); 255 } 256 257 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 258 XlaBuilder b(TestName()); 259 260 std::vector<int64> lhs{static_cast<int64>(0x8000000000000000LL), 261 static_cast<int64>(0x8000000000000000LL), 262 -1, 263 0x7FFFFFFFFFFFFFFLL, 264 0x7FFFFFFFFFFFFFFFLL, 265 1, 266 0, 267 -1}; 268 Literal lhs_literal = LiteralUtil::CreateR1<int64>({lhs}); 269 auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); 270 std::unique_ptr<GlobalData> lhs_data = 271 client_->TransferToServer(lhs_literal).ConsumeValueOrDie(); 272 273 std::vector<int64> rhs{-1, 274 0, 275 static_cast<int64>(0x8000000000000000LL), 276 1, 277 0, 278 0x7FFFFFFFFFFFFFFLL, 279 0x7FFFFFFFFFFFFFFFLL, 280 0x7FFFFFFFFFFFFFFFLL}; 281 Literal rhs_literal = LiteralUtil::CreateR1<int64>({rhs}); 282 auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); 283 std::unique_ptr<GlobalData> rhs_data = 284 client_->TransferToServer(rhs_literal).ConsumeValueOrDie(); 285 286 Sub(lhs_param, rhs_param); 287 288 std::vector<int64> expected(lhs.size()); 289 for (int64 i = 0; i < lhs.size(); ++i) { 290 expected[i] = lhs[i] - rhs[i]; 291 } 292 293 ComputeAndCompareR1<int64>(&b, expected, {lhs_data.get(), rhs_data.get()}); 294 } 295 296 XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) { 297 XlaBuilder b(TestName()); 298 299 std::vector<uint64> lhs{static_cast<uint64>(0x8000000000000000ULL)}; 300 Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs}); 301 auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); 302 303 std::vector<uint64> rhs{static_cast<uint64>(0x7FFFFFFFFFFFFFFFULL)}; 304 Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs}); 305 auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); 306 307 Lt(lhs_param, rhs_param); 308 309 ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)}); 310 } 311 312 TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { 313 const int count = GetParam(); 314 XlaBuilder builder(TestName()); 315 std::vector<float> a_values; 316 std::vector<float> b_values; 317 for (int i = 0; i < count; ++i) { 318 a_values.push_back(i / static_cast<float>(count)); 319 b_values.push_back(2 * i / static_cast<float>(count + 2)); 320 } 321 322 Literal a_literal = LiteralUtil::CreateR1<float>({a_values}); 323 std::unique_ptr<GlobalData> a_data = 324 client_->TransferToServer(a_literal).ConsumeValueOrDie(); 325 auto a_constant = ConstantR1<float>(&builder, a_values); 326 auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param"); 327 328 Literal b_literal = LiteralUtil::CreateR1<float>({b_values}); 329 std::unique_ptr<GlobalData> b_data = 330 client_->TransferToServer(b_literal).ConsumeValueOrDie(); 331 auto b_param = Parameter(&builder, 1, a_literal.shape(), "b_param"); 332 auto b_constant = ConstantR1<float>(&builder, b_values); 333 334 auto sum1 = Add(a_constant, b_param); 335 auto sum2 = Add(a_constant, b_constant); 336 auto sum3 = Add(a_param, b_param); 337 auto sum4 = Add(a_param, b_constant); 338 339 auto sum = Add(sum1, sum2); 340 sum = Add(sum, sum3); 341 sum = Add(sum, sum4); 342 343 std::vector<float> expected; 344 for (int64 i = 0; i < count; ++i) { 345 expected.push_back(4 * (a_values[i] + b_values[i])); 346 } 347 348 ComputeAndCompareR1<float>(&builder, expected, {a_data.get(), b_data.get()}, 349 error_spec_); 350 } 351 352 XLA_TEST_F(ArrayElementwiseOpTest, DeeplyNestedAddWithSlices) { 353 XlaBuilder builder(TestName()); 354 std::vector<float> values(30, 0.0); 355 auto a_literal = LiteralUtil::CreateR1<float>(values); 356 auto a = Parameter(&builder, 0, a_literal.shape(), "x"); 357 auto b_literal = LiteralUtil::CreateR1<float>(values); 358 auto b = Parameter(&builder, 1, b_literal.shape(), "x"); 359 360 // Construct a sequence of diamond-shaped gadgets like this: 361 // 362 // add 363 // / \ 364 // slice slice 365 // \ / 366 // add 367 // 368 // Each 'left' slice removes the last element, each 'right' slice removes the 369 // first element. In this way, we index into the add with different 370 // multi-dimensional index arrays, which defeats the caching we use to avoid 371 // exponential compile time. 372 std::function<XlaOp(int64)> generate_recursive = 373 [&](int64 slice_size) -> XlaOp { 374 if (slice_size == values.size()) { 375 return Add(a, b); 376 } 377 XlaOp param = generate_recursive(slice_size + 1); 378 auto slice1 = Slice(param, {0}, {slice_size}, {1}); 379 auto slice2 = Slice(param, {1}, {slice_size + 1}, {1}); 380 return Add(slice1, slice2); 381 }; 382 generate_recursive(1); 383 auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); 384 auto b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie(); 385 ComputeAndCompareR1<float>(&builder, {0.0}, {a_data.get(), b_data.get()}); 386 } 387 388 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { 389 XlaBuilder builder(TestName()); 390 auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); 391 auto b = ConstantR1<float>(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); 392 Sub(a, b); 393 394 ComputeAndCompareR1<float>(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f}, 395 {}, error_spec_); 396 } 397 398 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) { 399 XlaBuilder builder(TestName()); 400 auto a = ConstantR1<float>(&builder, {}); 401 auto b = ConstantR1<float>(&builder, {}); 402 Sub(a, b); 403 404 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 405 } 406 407 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) { 408 XlaBuilder builder(TestName()); 409 auto a = ConstantR1<int32>(&builder, {-1, 0, 2, 1000000000}); 410 auto b = ConstantR1<int32>(&builder, {-1, 2, 1, -1}); 411 Sub(a, b); 412 413 ComputeAndCompareR1<int32>(&builder, {0, -2, 1, 1000000001}, {}); 414 } 415 416 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) { 417 XlaBuilder builder(TestName()); 418 auto a = ConstantR1<int32>(&builder, {}); 419 auto b = ConstantR1<int32>(&builder, {}); 420 Sub(a, b); 421 422 ComputeAndCompareR1<int32>(&builder, {}, {}); 423 } 424 425 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { 426 XlaBuilder builder(TestName()); 427 auto a = ConstantR1<complex64>(&builder, 428 {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}}); 429 auto b = ConstantR1<complex64>( 430 &builder, {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}}); 431 Sub(a, b); 432 433 ComputeAndCompareR1<complex64>( 434 &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {}, 435 error_spec_); 436 } 437 438 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) { 439 XlaBuilder builder(TestName()); 440 auto a = ConstantR1<complex64>(&builder, {}); 441 auto b = ConstantR1<complex64>(&builder, {}); 442 Sub(a, b); 443 444 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_); 445 } 446 447 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { 448 XlaBuilder builder(TestName()); 449 auto a = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); 450 auto b = ConstantR1<float>(&builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f}); 451 Div(a, b); 452 453 ComputeAndCompareR1<float>(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {}, 454 error_spec_); 455 } 456 457 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { 458 XlaBuilder builder(TestName()); 459 auto a = ConstantR1<float>(&builder, {}); 460 auto b = ConstantR1<float>(&builder, {}); 461 Div(a, b); 462 463 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 464 } 465 466 class IntegerDivideOpTest : public ArrayElementwiseOpTest { 467 protected: 468 template <typename T> 469 void TestDivRem(absl::Span<const T> dividends, absl::Span<const T> divisors, 470 absl::Span<const T> quotients, 471 absl::Span<const T> remainders) { 472 { 473 XlaBuilder builder(TestName()); 474 XlaOp dividend; 475 XlaOp divisor; 476 auto dividend_data = 477 CreateR1Parameter<T>(dividends, 0, "dividend", &builder, ÷nd); 478 auto divisor_data = 479 CreateR1Parameter<T>(divisors, 1, "divisor", &builder, &divisor); 480 Div(dividend, divisor); 481 482 ComputeAndCompareR1<T>(&builder, quotients, 483 {dividend_data.get(), divisor_data.get()}); 484 } 485 486 // Test with a compile-time constant divisor. 487 { 488 XlaBuilder builder(TestName()); 489 XlaOp dividend; 490 auto dividend_data = 491 CreateR1Parameter<T>(dividends, 0, "dividend", &builder, ÷nd); 492 Div(dividend, ConstantR1<T>(&builder, divisors)); 493 494 ComputeAndCompareR1<T>(&builder, quotients, {dividend_data.get()}); 495 } 496 497 { 498 XlaBuilder builder(TestName()); 499 XlaOp dividend; 500 XlaOp divisor; 501 auto dividend_data = 502 CreateR1Parameter<T>(dividends, 0, "dividend", &builder, ÷nd); 503 auto divisor_data = 504 CreateR1Parameter<T>(divisors, 1, "divisor", &builder, &divisor); 505 Rem(dividend, divisor); 506 507 ComputeAndCompareR1<T>(&builder, remainders, 508 {dividend_data.get(), divisor_data.get()}); 509 } 510 511 // Test with a compile-time constant divisor. 512 { 513 XlaBuilder builder(TestName()); 514 XlaOp dividend; 515 auto dividend_data = 516 CreateR1Parameter<T>(dividends, 0, "dividend", &builder, ÷nd); 517 Rem(dividend, ConstantR1<T>(&builder, divisors)); 518 519 ComputeAndCompareR1<T>(&builder, remainders, {dividend_data.get()}); 520 } 521 } 522 }; 523 524 XLA_TEST_F(IntegerDivideOpTest, DivS32s) { 525 // clang-format off 526 // Some interesting values to test. 527 std::vector<int32> vals = { 528 INT32_MIN, INT32_MIN + 1, INT32_MIN + 2, -0x40000000, -0x3fffffff, 529 -271181, -1309, -17, -10, -5, -3, -2, -1, 0, 1, 2, 3, 5, 10, 17, 26, 101, 530 7919, 0x40000000, INT32_MAX - 2, INT32_MAX - 1, INT32_MAX}; 531 // clang-format on 532 533 std::vector<int32> dividends, divisors, quotients, remainders; 534 for (int32 divisor : vals) { 535 if (divisor != 0) { 536 for (int32 dividend : vals) { 537 // Avoid integer overflow. 538 if (dividend != INT32_MIN || divisor != -1) { 539 dividends.push_back(dividend); 540 divisors.push_back(divisor); 541 quotients.push_back(dividend / divisor); 542 remainders.push_back(dividend % divisor); 543 } 544 } 545 } 546 } 547 548 TestDivRem<int32>(dividends, divisors, quotients, remainders); 549 } 550 551 XLA_TEST_F(IntegerDivideOpTest, SignedOverflow) { 552 std::vector<int32> dividends = {5, INT32_MIN}, divisors = {0, -1}, 553 quotients = {-1, INT32_MIN}, remainders = {5, 0}; 554 555 TestDivRem<int32>(dividends, divisors, quotients, remainders); 556 } 557 558 XLA_TEST_F(IntegerDivideOpTest, DivU32s) { 559 // clang-format off 560 // Some interesting values to test. 561 std::vector<uint32> vals = { 562 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0xABCDEF12, 0xCAFEBEEF, 0x80000000, 563 0x80000001, UINT32_MAX - 2, UINT32_MAX - 1, UINT32_MAX}; 564 // clang-format on 565 566 std::vector<uint32> dividends, divisors, quotients, remainders; 567 for (uint32 divisor : vals) { 568 if (divisor != 0) { 569 for (uint32 dividend : vals) { 570 dividends.push_back(dividend); 571 divisors.push_back(divisor); 572 quotients.push_back(dividend / divisor); 573 remainders.push_back(dividend % divisor); 574 } 575 } 576 } 577 578 TestDivRem<uint32>(dividends, divisors, quotients, remainders); 579 } 580 581 XLA_TEST_F(IntegerDivideOpTest, UnsignedOverflow) { 582 std::vector<int32> dividends = {5}, divisors = {0}, quotients = {-1}, 583 remainders = {5}; 584 585 TestDivRem<int32>(dividends, divisors, quotients, remainders); 586 } 587 588 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { 589 XlaBuilder builder(TestName()); 590 auto a = ConstantR1<complex64>( 591 &builder, {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}}); 592 auto b = ConstantR1<complex64>(&builder, 593 {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}}); 594 Div(a, b); 595 596 ComputeAndCompareR1<complex64>( 597 &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_); 598 } 599 600 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) { 601 XlaBuilder builder(TestName()); 602 auto a = ConstantR1<complex64>(&builder, {}); 603 auto b = ConstantR1<complex64>(&builder, {}); 604 Div(a, b); 605 606 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_); 607 } 608 609 XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { 610 XlaBuilder builder(TestName()); 611 auto a = ConstantR1<float>( 612 &builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f}); 613 auto b = ConstantR1<float>( 614 &builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f}); 615 Rem(a, b); 616 617 ComputeAndCompareR1<float>( 618 &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {}, 619 error_spec_); 620 } 621 622 XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) { 623 XlaBuilder builder(TestName()); 624 auto a = ConstantR1<float>(&builder, {}); 625 auto b = ConstantR1<float>(&builder, {}); 626 Rem(a, b); 627 628 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 629 } 630 631 XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { 632 XlaBuilder builder(TestName()); 633 auto a = ConstantR1<double>( 634 &builder, {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0}); 635 auto b = ConstantR1<double>( 636 &builder, {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0}); 637 Rem(a, b); 638 639 ComputeAndCompareR1<double>( 640 &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {}, 641 error_spec_); 642 } 643 644 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) { 645 XlaBuilder builder(TestName()); 646 auto a = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); 647 auto b = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); 648 Mul(a, b); 649 650 ComputeAndCompareR1<float>(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f}, 651 {}, error_spec_); 652 } 653 654 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) { 655 XlaBuilder builder(TestName()); 656 auto a = ConstantR1<float>(&builder, {}); 657 auto b = ConstantR1<float>(&builder, {}); 658 Mul(a, b); 659 660 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 661 } 662 663 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) { 664 std::vector<int32> data = {0, 665 1, 666 -1, 667 1234, 668 0x1a243514, 669 std::numeric_limits<int32>::max(), 670 std::numeric_limits<int32>::min()}; 671 // Form the test data set using all products of 'data' with itself. 672 std::vector<int32> a_data, b_data, expected; 673 for (int32 a : data) { 674 for (int32 b : data) { 675 a_data.push_back(a); 676 b_data.push_back(b); 677 expected.push_back(static_cast<uint32>(a) * static_cast<uint32>(b)); 678 } 679 } 680 681 XlaBuilder builder(TestName()); 682 auto a = ConstantR1<int32>(&builder, a_data); 683 auto b = ConstantR1<int32>(&builder, b_data); 684 Mul(a, b); 685 686 ComputeAndCompareR1<int32>(&builder, expected, {}); 687 } 688 689 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) { 690 XlaBuilder builder(TestName()); 691 auto a = ConstantR1<int32>(&builder, {}); 692 auto b = ConstantR1<int32>(&builder, {}); 693 Mul(a, b); 694 695 ComputeAndCompareR1<int32>(&builder, {}, {}); 696 } 697 698 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { 699 std::vector<uint32> data = {0, 1, 0xDEADBEEF, 1234, 700 0x1a243514, 0xFFFFFFFF, 0x80808080}; 701 702 // Form the test data set using all products of 'data' with itself. 703 std::vector<uint32> a_data, b_data, expected; 704 for (uint32 a : data) { 705 for (uint32 b : data) { 706 a_data.push_back(a); 707 b_data.push_back(b); 708 expected.push_back(a * b); 709 } 710 } 711 712 XlaBuilder builder(TestName()); 713 auto a = ConstantR1<uint32>(&builder, a_data); 714 auto b = ConstantR1<uint32>(&builder, b_data); 715 Mul(a, b); 716 717 ComputeAndCompareR1<uint32>(&builder, expected, {}); 718 } 719 720 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { 721 XlaBuilder builder(TestName()); 722 auto a = ConstantR1<complex64>( 723 &builder, {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}}); 724 auto b = ConstantR1<complex64>(&builder, 725 {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}}); 726 Mul(a, b); 727 728 ComputeAndCompareR1<complex64>( 729 &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {}, 730 error_spec_); 731 } 732 733 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) { 734 XlaBuilder builder(TestName()); 735 auto a = ConstantR1<complex64>(&builder, {}); 736 auto b = ConstantR1<complex64>(&builder, {}); 737 Mul(a, b); 738 739 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_); 740 } 741 742 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { 743 XlaBuilder builder(TestName()); 744 auto a = ConstantR1<bool>(&builder, {false, false, true, true}); 745 auto b = ConstantR1<bool>(&builder, {false, true, false, true}); 746 And(a, b); 747 748 ComputeAndCompareR1<bool>(&builder, {false, false, false, true}, {}); 749 } 750 751 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) { 752 XlaBuilder builder(TestName()); 753 auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}}); 754 auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}}); 755 And(a, b); 756 757 Array2D<bool> expected_array({{false, false}, {false, true}}); 758 ComputeAndCompareR2<bool>(&builder, expected_array, {}); 759 } 760 761 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) { 762 XlaBuilder builder(TestName()); 763 auto a = ConstantR1<bool>(&builder, {}); 764 auto b = ConstantR1<bool>(&builder, {}); 765 And(a, b); 766 767 ComputeAndCompareR1<bool>(&builder, {}, {}); 768 } 769 770 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) { 771 XlaBuilder builder(TestName()); 772 auto a = ConstantR1<int32>(&builder, {0, -1, -8}); 773 auto b = ConstantR1<int32>(&builder, {5, -7, 12}); 774 And(a, b); 775 776 ComputeAndCompareR1<int32>(&builder, {0, -7, 8}, {}); 777 } 778 779 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) { 780 XlaBuilder builder(TestName()); 781 auto a = ConstantR2<int32>(&builder, {{0, -5}, {-1, 5}}); 782 auto b = ConstantR2<int32>(&builder, {{1, -6}, {4, 5}}); 783 And(a, b); 784 785 Array2D<int32> expected_array({{0, -6}, {4, 5}}); 786 ComputeAndCompareR2<int32>(&builder, expected_array, {}); 787 } 788 789 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) { 790 XlaBuilder builder(TestName()); 791 auto a = ConstantR1<int32>(&builder, {}); 792 auto b = ConstantR1<int32>(&builder, {}); 793 And(a, b); 794 795 ComputeAndCompareR1<int32>(&builder, {}, {}); 796 } 797 798 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) { 799 XlaBuilder builder(TestName()); 800 auto a = ConstantR1<int32>(&builder, {0, 1, 8}); 801 auto b = ConstantR1<int32>(&builder, {5, 7, 12}); 802 And(a, b); 803 804 ComputeAndCompareR1<int32>(&builder, {0, 1, 8}, {}); 805 } 806 807 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) { 808 XlaBuilder builder(TestName()); 809 auto a = ConstantR2<uint32>(&builder, {{0, 1}, {3, 8}}); 810 auto b = ConstantR2<uint32>(&builder, {{1, 0}, {7, 6}}); 811 And(a, b); 812 813 Array2D<uint32> expected_array({{0, 0}, {3, 0}}); 814 ComputeAndCompareR2<uint32>(&builder, expected_array, {}); 815 } 816 817 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) { 818 XlaBuilder builder(TestName()); 819 auto a = ConstantR1<uint32>(&builder, {}); 820 auto b = ConstantR1<uint32>(&builder, {}); 821 And(a, b); 822 823 ComputeAndCompareR1<uint32>(&builder, {}, {}); 824 } 825 826 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) { 827 XlaBuilder builder(TestName()); 828 auto a = ConstantR1<bool>(&builder, {false, false, true, true}); 829 auto b = ConstantR1<bool>(&builder, {false, true, false, true}); 830 Or(a, b); 831 832 ComputeAndCompareR1<bool>(&builder, {false, true, true, true}, {}); 833 } 834 835 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) { 836 XlaBuilder builder(TestName()); 837 auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}}); 838 auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}}); 839 Or(a, b); 840 841 Array2D<bool> expected_array({{false, true}, {true, true}}); 842 ComputeAndCompareR2<bool>(&builder, expected_array, {}); 843 } 844 845 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) { 846 XlaBuilder builder(TestName()); 847 auto a = ConstantR1<bool>(&builder, {}); 848 auto b = ConstantR1<bool>(&builder, {}); 849 Or(a, b); 850 851 ComputeAndCompareR1<bool>(&builder, {}, {}); 852 } 853 854 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) { 855 XlaBuilder builder(TestName()); 856 auto a = ConstantR1<int32>(&builder, {0, -1, 8}); 857 auto b = ConstantR1<int32>(&builder, {5, -7, 4}); 858 Or(a, b); 859 860 ComputeAndCompareR1<int32>(&builder, {5, -1, 12}, {}); 861 } 862 863 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) { 864 XlaBuilder builder(TestName()); 865 auto a = ConstantR2<int32>(&builder, {{0, -1}, {8, 8}}); 866 auto b = ConstantR2<int32>(&builder, {{5, -7}, {4, 1}}); 867 Or(a, b); 868 869 Array2D<int32> expected_array({{5, -1}, {12, 9}}); 870 ComputeAndCompareR2<int32>(&builder, expected_array, {}); 871 } 872 873 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) { 874 XlaBuilder builder(TestName()); 875 auto a = ConstantR1<int32>(&builder, {}); 876 auto b = ConstantR1<int32>(&builder, {}); 877 Or(a, b); 878 879 ComputeAndCompareR1<int32>(&builder, {}, {}); 880 } 881 882 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) { 883 XlaBuilder builder(TestName()); 884 auto a = ConstantR1<uint32>(&builder, {0, 1, 8}); 885 auto b = ConstantR1<uint32>(&builder, {5, 7, 4}); 886 Or(a, b); 887 888 ComputeAndCompareR1<uint32>(&builder, {5, 7, 12}, {}); 889 } 890 891 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) { 892 XlaBuilder builder(TestName()); 893 auto a = ConstantR2<uint32>(&builder, {{0, 1}, {8, 8}}); 894 auto b = ConstantR2<uint32>(&builder, {{5, 7}, {4, 1}}); 895 Or(a, b); 896 897 Array2D<uint32> expected_array({{5, 7}, {12, 9}}); 898 ComputeAndCompareR2<uint32>(&builder, expected_array, {}); 899 } 900 901 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) { 902 XlaBuilder builder(TestName()); 903 auto a = ConstantR1<uint32>(&builder, {}); 904 auto b = ConstantR1<uint32>(&builder, {}); 905 Or(a, b); 906 907 ComputeAndCompareR1<uint32>(&builder, {}, {}); 908 } 909 910 XLA_TEST_F(ArrayElementwiseOpTest, XorPredR1) { 911 XlaBuilder builder(TestName()); 912 auto a = ConstantR1<bool>(&builder, {false, false, true, true}); 913 auto b = ConstantR1<bool>(&builder, {false, true, false, true}); 914 Xor(a, b); 915 916 ComputeAndCompareR1<bool>(&builder, {false, true, true, false}, {}); 917 } 918 919 XLA_TEST_F(ArrayElementwiseOpTest, XorPredR2) { 920 XlaBuilder builder(TestName()); 921 auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}}); 922 auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}}); 923 Xor(a, b); 924 925 Array2D<bool> expected_array({{false, true}, {true, false}}); 926 ComputeAndCompareR2<bool>(&builder, expected_array, {}); 927 } 928 929 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementPredR1) { 930 XlaBuilder builder(TestName()); 931 auto a = ConstantR1<bool>(&builder, {}); 932 auto b = ConstantR1<bool>(&builder, {}); 933 Xor(a, b); 934 935 ComputeAndCompareR1<bool>(&builder, {}, {}); 936 } 937 938 XLA_TEST_F(ArrayElementwiseOpTest, XorS32R1) { 939 XlaBuilder builder(TestName()); 940 auto a = ConstantR1<int32>(&builder, {0, -1, 8}); 941 auto b = ConstantR1<int32>(&builder, {5, -7, 4}); 942 Xor(a, b); 943 944 ComputeAndCompareR1<int32>(&builder, {5, 6, 12}, {}); 945 } 946 947 XLA_TEST_F(ArrayElementwiseOpTest, XorS32R2) { 948 XlaBuilder builder(TestName()); 949 auto a = ConstantR2<int32>(&builder, {{0, -1}, {8, 8}}); 950 auto b = ConstantR2<int32>(&builder, {{5, -7}, {4, 1}}); 951 Xor(a, b); 952 953 Array2D<int32> expected_array({{5, 6}, {12, 9}}); 954 ComputeAndCompareR2<int32>(&builder, expected_array, {}); 955 } 956 957 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementS32R1) { 958 XlaBuilder builder(TestName()); 959 auto a = ConstantR1<int32>(&builder, {}); 960 auto b = ConstantR1<int32>(&builder, {}); 961 Xor(a, b); 962 963 ComputeAndCompareR1<int32>(&builder, {}, {}); 964 } 965 966 XLA_TEST_F(ArrayElementwiseOpTest, XorU32R1) { 967 XlaBuilder builder(TestName()); 968 auto a = ConstantR1<uint32>(&builder, {0, 1, 8}); 969 auto b = ConstantR1<uint32>(&builder, {5, 7, 4}); 970 Xor(a, b); 971 972 ComputeAndCompareR1<uint32>(&builder, {5, 6, 12}, {}); 973 } 974 975 XLA_TEST_F(ArrayElementwiseOpTest, XorU32R2) { 976 XlaBuilder builder(TestName()); 977 auto a = ConstantR2<uint32>(&builder, {{0, 1}, {8, 8}}); 978 auto b = ConstantR2<uint32>(&builder, {{5, 7}, {4, 1}}); 979 Xor(a, b); 980 981 Array2D<uint32> expected_array({{5, 6}, {12, 9}}); 982 ComputeAndCompareR2<uint32>(&builder, expected_array, {}); 983 } 984 985 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementU32R1) { 986 XlaBuilder builder(TestName()); 987 auto a = ConstantR1<uint32>(&builder, {}); 988 auto b = ConstantR1<uint32>(&builder, {}); 989 Xor(a, b); 990 991 ComputeAndCompareR1<uint32>(&builder, {}, {}); 992 } 993 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) { 994 XlaBuilder builder(TestName()); 995 auto a = ConstantR1<bool>(&builder, {false, true, true, false}); 996 Not(a); 997 998 ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {}); 999 } 1000 1001 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) { 1002 XlaBuilder builder(TestName()); 1003 auto a = ConstantR2<bool>(&builder, {{false, true}, {true, false}}); 1004 Not(a); 1005 1006 Array2D<bool> expected_array({{true, false}, {false, true}}); 1007 ComputeAndCompareR2<bool>(&builder, expected_array, {}); 1008 } 1009 1010 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) { 1011 XlaBuilder builder(TestName()); 1012 auto a = ConstantR1<bool>(&builder, {}); 1013 Not(a); 1014 1015 ComputeAndCompareR1<bool>(&builder, {}, {}); 1016 } 1017 1018 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) { 1019 XlaBuilder builder(TestName()); 1020 auto a = ConstantR1<int32>(&builder, {-1, 0, 1}); 1021 Not(a); 1022 1023 ComputeAndCompareR1<int32>(&builder, {0, -1, -2}, {}); 1024 } 1025 1026 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) { 1027 XlaBuilder builder(TestName()); 1028 auto a = ConstantR2<int32>(&builder, {{-1, 0}, {1, 8}}); 1029 Not(a); 1030 1031 Array2D<int32> expected_array({{0, -1}, {-2, -9}}); 1032 ComputeAndCompareR2<int32>(&builder, expected_array, {}); 1033 } 1034 1035 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) { 1036 XlaBuilder builder(TestName()); 1037 auto a = ConstantR1<int32>(&builder, {}); 1038 Not(a); 1039 1040 ComputeAndCompareR1<int32>(&builder, {}, {}); 1041 } 1042 1043 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) { 1044 XlaBuilder builder(TestName()); 1045 auto a = ConstantR1<uint32>(&builder, {0, 4294967295}); 1046 Not(a); 1047 1048 ComputeAndCompareR1<uint32>(&builder, {4294967295, 0}, {}); 1049 } 1050 1051 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) { 1052 XlaBuilder builder(TestName()); 1053 auto a = ConstantR2<uint32>(&builder, {{0, 4294967295}, {1, 4294967294}}); 1054 Not(a); 1055 1056 Array2D<uint32> expected_array({{4294967295, 0}, {4294967294, 1}}); 1057 ComputeAndCompareR2<uint32>(&builder, expected_array, {}); 1058 } 1059 1060 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { 1061 XlaBuilder builder(TestName()); 1062 auto a = ConstantR1<uint32>(&builder, {}); 1063 Not(a); 1064 1065 ComputeAndCompareR1<uint32>(&builder, {}, {}); 1066 } 1067 1068 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { 1069 XlaBuilder builder(TestName()); 1070 auto a = ConstantR1<int32>( 1071 &builder, {static_cast<int32>(0x12345678), static_cast<int32>(0xF0001000), 1072 1, 3, 77, 1, -3, 77}); 1073 auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 15, 32, 100, -1}); 1074 ShiftLeft(a, b); 1075 1076 ComputeAndCompareR1<int32>(&builder, 1077 {static_cast<int32>(0x23456780), 0x00100000, 0x4, 1078 0x180, 2523136, 0, 0, 0}, 1079 {}); 1080 } 1081 1082 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { 1083 XlaBuilder builder(TestName()); 1084 auto a = ConstantR1<int32>( 1085 &builder, {static_cast<int32>(0x92345678), static_cast<int32>(0x10001000), 1086 1, 3, 77, 1, -3, 77}); 1087 auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 2, 32, 100, -1}); 1088 ShiftRightArithmetic(a, b); 1089 1090 ComputeAndCompareR1<int32>( 1091 &builder, 1092 {static_cast<int32>(0xF9234567), static_cast<int32>(0x00100010), 0, 0, 19, 1093 0, -1, 0}, 1094 {}); 1095 } 1096 1097 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { 1098 XlaBuilder builder(TestName()); 1099 auto a = ConstantR1<int32>( 1100 &builder, {static_cast<int32>(0x92345678), static_cast<int32>(0x10001000), 1101 1, 3, 77, 1, -3, 77}); 1102 auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 5, 32, 100, -1}); 1103 ShiftRightLogical(a, b); 1104 1105 ComputeAndCompareR1<int32>(&builder, 1106 {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); 1107 } 1108 1109 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { 1110 XlaBuilder builder(TestName()); 1111 auto a = ConstantR1<uint32>(&builder, 1112 {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77}); 1113 auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 15, 32, 100, ~0u}); 1114 ShiftLeft(a, b); 1115 1116 ComputeAndCompareR1<uint32>( 1117 &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {}); 1118 } 1119 1120 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { 1121 XlaBuilder builder(TestName()); 1122 auto a = ConstantR1<uint32>(&builder, 1123 {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); 1124 auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 2, 32, 100, ~0u}); 1125 ShiftRightArithmetic(a, b); 1126 1127 ComputeAndCompareR1<uint32>( 1128 &builder, {0xF9234567, 0x00100010, 0, 0, 19, 0, ~0u, 0}, {}); 1129 } 1130 1131 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { 1132 XlaBuilder builder(TestName()); 1133 auto a = ConstantR1<uint32>(&builder, 1134 {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); 1135 auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 5, 32, 100, ~0u}); 1136 ShiftRightLogical(a, b); 1137 1138 ComputeAndCompareR1<uint32>(&builder, 1139 {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); 1140 } 1141 1142 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { 1143 SetFastMathDisabled(true); 1144 XlaBuilder builder(TestName()); 1145 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); 1146 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 2.25f, 10.0f, NAN}); 1147 Eq(lhs, rhs); 1148 1149 ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {}); 1150 } 1151 1152 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { 1153 XlaBuilder builder(TestName()); 1154 auto lhs = ConstantR1<float>(&builder, {}); 1155 auto rhs = ConstantR1<float>(&builder, {}); 1156 Eq(lhs, rhs); 1157 1158 ComputeAndCompareR1<bool>(&builder, {}, {}); 1159 } 1160 1161 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { 1162 SetFastMathDisabled(true); 1163 XlaBuilder builder(TestName()); 1164 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); 1165 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); 1166 Ge(lhs, rhs); 1167 1168 ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {}); 1169 } 1170 1171 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { 1172 SetFastMathDisabled(true); 1173 XlaBuilder builder(TestName()); 1174 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); 1175 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); 1176 Gt(lhs, rhs); 1177 1178 ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {}); 1179 } 1180 1181 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) { 1182 SetFastMathDisabled(true); 1183 XlaBuilder builder(TestName()); 1184 auto lhs = ConstantR1<float>(&builder, {-2.5f, 5.0f, 2.25f, NAN, 6.0f}); 1185 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); 1186 Le(lhs, rhs); 1187 1188 ComputeAndCompareR1<bool>(&builder, {true, true, false, false, false}, {}); 1189 } 1190 1191 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) { 1192 SetFastMathDisabled(true); 1193 XlaBuilder builder(TestName()); 1194 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); 1195 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); 1196 Lt(lhs, rhs); 1197 1198 ComputeAndCompareR1<bool>(&builder, {true, false, false, false, false}, {}); 1199 } 1200 1201 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { 1202 const int32 min = std::numeric_limits<int32>::min(); 1203 const int32 max = std::numeric_limits<int32>::max(); 1204 XlaBuilder builder(TestName()); 1205 auto lhs = 1206 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max}); 1207 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); 1208 Eq(lhs, rhs); 1209 1210 ComputeAndCompareR1<bool>( 1211 &builder, {true, false, false, false, true, false, false, false, true}, 1212 {}); 1213 } 1214 1215 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { 1216 XlaBuilder builder(TestName()); 1217 auto lhs = ConstantR1<int32>(&builder, {}); 1218 auto rhs = ConstantR1<int32>(&builder, {}); 1219 Eq(lhs, rhs); 1220 1221 ComputeAndCompareR1<bool>(&builder, {}, {}); 1222 } 1223 1224 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { 1225 SetFastMathDisabled(true); 1226 XlaBuilder builder(TestName()); 1227 auto lhs = ConstantR1<complex64>(&builder, {{-2.5f, 10.0f}, 1228 {1.0f, 25.5f}, 1229 {2.25f, -3.0f}, 1230 {NAN, 0.0f}, 1231 {1.0f, 6.0f}}); 1232 auto rhs = ConstantR1<complex64>(&builder, {{0.0f, 10.0f}, 1233 {1.0f, 5.0f}, 1234 {2.25f, -3.0f}, 1235 {10.0f, 0.0f}, 1236 {1.0f, NAN}}); 1237 Eq(lhs, rhs); 1238 1239 ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {}); 1240 } 1241 1242 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) { 1243 XlaBuilder builder(TestName()); 1244 auto lhs = ConstantR1<complex64>(&builder, {}); 1245 auto rhs = ConstantR1<complex64>(&builder, {}); 1246 Eq(lhs, rhs); 1247 1248 ComputeAndCompareR1<bool>(&builder, {}, {}); 1249 } 1250 1251 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { 1252 // Disable fast-math because we're operating on NaNs. 1253 SetFastMathDisabled(true); 1254 1255 XlaBuilder builder(TestName()); 1256 auto lhs = ConstantR1<complex64>(&builder, {{-2.5f, 10.0f}, 1257 {1.0f, 25.5f}, 1258 {2.25f, -3.0f}, 1259 {NAN, 0.0f}, 1260 {1.0f, 6.0f}}); 1261 auto rhs = ConstantR1<complex64>(&builder, {{0.0f, 10.0f}, 1262 {1.0f, 5.0f}, 1263 {2.25f, -3.0f}, 1264 {10.0f, 0.0f}, 1265 {1.0f, NAN}}); 1266 Ne(lhs, rhs); 1267 1268 ComputeAndCompareR1<bool>(&builder, {true, true, false, true, true}, {}); 1269 } 1270 1271 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { 1272 // Disable fast-math because we're operating on NaNs. 1273 SetFastMathDisabled(true); 1274 1275 XlaBuilder builder(TestName()); 1276 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); 1277 auto rhs = ConstantR1<float>(&builder, {10.0f, 25.5f, 1.0f, 10.0f, NAN}); 1278 Ne(lhs, rhs); 1279 1280 ComputeAndCompareR1<bool>(&builder, {true, false, true, true, true}, {}); 1281 } 1282 1283 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { 1284 const int32 min = std::numeric_limits<int32>::min(); 1285 const int32 max = std::numeric_limits<int32>::max(); 1286 XlaBuilder builder(TestName()); 1287 auto lhs = 1288 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max}); 1289 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); 1290 Ne(lhs, rhs); 1291 1292 ComputeAndCompareR1<bool>( 1293 &builder, {false, true, true, true, false, true, true, true, false}, {}); 1294 } 1295 1296 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) { 1297 const int32 min = std::numeric_limits<int32>::min(); 1298 const int32 max = std::numeric_limits<int32>::max(); 1299 XlaBuilder builder(TestName()); 1300 auto lhs = 1301 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max}); 1302 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); 1303 Ge(lhs, rhs); 1304 1305 ComputeAndCompareR1<bool>( 1306 &builder, {true, false, false, true, true, false, true, true, true}, {}); 1307 } 1308 1309 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) { 1310 const int32 min = std::numeric_limits<int32>::min(); 1311 const int32 max = std::numeric_limits<int32>::max(); 1312 XlaBuilder builder(TestName()); 1313 auto lhs = 1314 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max}); 1315 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); 1316 Gt(lhs, rhs); 1317 1318 ComputeAndCompareR1<bool>( 1319 &builder, {false, false, false, true, false, false, true, true, false}, 1320 {}); 1321 } 1322 1323 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) { 1324 const int32 min = std::numeric_limits<int32>::min(); 1325 const int32 max = std::numeric_limits<int32>::max(); 1326 XlaBuilder builder(TestName()); 1327 auto lhs = 1328 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max}); 1329 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); 1330 Le(lhs, rhs); 1331 1332 ComputeAndCompareR1<bool>( 1333 &builder, {true, true, true, false, true, true, false, false, true}, {}); 1334 } 1335 1336 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { 1337 const int32 min = std::numeric_limits<int32>::min(); 1338 const int32 max = std::numeric_limits<int32>::max(); 1339 XlaBuilder builder(TestName()); 1340 auto lhs = 1341 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max}); 1342 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); 1343 Lt(lhs, rhs); 1344 1345 ComputeAndCompareR1<bool>( 1346 &builder, {false, true, true, false, false, true, false, false, false}, 1347 {}); 1348 } 1349 1350 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { 1351 const uint32 max = std::numeric_limits<uint32>::max(); 1352 XlaBuilder builder(TestName()); 1353 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); 1354 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); 1355 Eq(lhs, rhs); 1356 1357 ComputeAndCompareR1<bool>( 1358 &builder, {true, false, false, false, true, false, false, false, true}, 1359 {}); 1360 } 1361 1362 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { 1363 const uint32 max = std::numeric_limits<uint32>::max(); 1364 XlaBuilder builder(TestName()); 1365 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); 1366 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); 1367 Ne(lhs, rhs); 1368 1369 ComputeAndCompareR1<bool>( 1370 &builder, {false, true, true, true, false, true, true, true, false}, {}); 1371 } 1372 1373 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { 1374 const uint32 max = std::numeric_limits<uint32>::max(); 1375 XlaBuilder builder(TestName()); 1376 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); 1377 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); 1378 Ge(lhs, rhs); 1379 1380 ComputeAndCompareR1<bool>( 1381 &builder, {true, false, false, true, true, false, true, true, true}, {}); 1382 } 1383 1384 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { 1385 const uint32 max = std::numeric_limits<uint32>::max(); 1386 XlaBuilder builder(TestName()); 1387 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); 1388 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); 1389 Gt(lhs, rhs); 1390 1391 ComputeAndCompareR1<bool>( 1392 &builder, {false, false, false, true, false, false, true, true, false}, 1393 {}); 1394 } 1395 1396 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { 1397 const uint32 max = std::numeric_limits<uint32>::max(); 1398 XlaBuilder builder(TestName()); 1399 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); 1400 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); 1401 Le(lhs, rhs); 1402 1403 ComputeAndCompareR1<bool>( 1404 &builder, {true, true, true, false, true, true, false, false, true}, {}); 1405 } 1406 1407 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { 1408 const uint32 max = std::numeric_limits<uint32>::max(); 1409 XlaBuilder builder(TestName()); 1410 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); 1411 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); 1412 Lt(lhs, rhs); 1413 1414 ComputeAndCompareR1<bool>( 1415 &builder, {false, true, true, false, false, true, false, false, false}, 1416 {}); 1417 } 1418 1419 XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { 1420 SetFastMathDisabled(true); 1421 XlaBuilder builder(TestName()); 1422 auto lhs = 1423 ConstantR1<float>(&builder, {4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f}); 1424 auto rhs = 1425 ConstantR1<float>(&builder, {2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f}); 1426 Pow(lhs, rhs); 1427 1428 ComputeAndCompareR1<float>( 1429 &builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_); 1430 } 1431 1432 XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { 1433 SetFastMathDisabled(true); 1434 XlaBuilder builder(TestName()); 1435 auto lhs = ConstantR1<float>(&builder, {-2.0f, -0.6f, -0.6f, 0.0f}); 1436 auto rhs = ConstantR1<float>(&builder, {0.5f, 0.6f, -0.6f, -0.6f}); 1437 Pow(lhs, rhs); 1438 1439 ComputeAndCompareR1<float>(&builder, {NAN, NAN, NAN, INFINITY}, {}, 1440 error_spec_); 1441 } 1442 1443 XLA_TEST_F(ArrayElementwiseOpTest, PowC64s) { 1444 SetFastMathDisabled(true); 1445 XlaBuilder builder(TestName()); 1446 auto lhs = 1447 ConstantR1<complex64>(&builder, {-2.0f, -0.6f, -0.6f, 0.0f, 0.0f, 0.0f}); 1448 auto rhs = 1449 ConstantR1<complex64>(&builder, {0.5f, 0.6f, -0.6f, 0.5f, 0.6f, 0.0f}); 1450 Pow(lhs, rhs); 1451 1452 ComputeAndCompareR1<complex64>(&builder, 1453 { 1454 {0, 1.41421356}, 1455 {-2.27443288e-01, 0.69999846}, 1456 {-4.19847531e-01, -1.29215783}, 1457 {0, 0}, 1458 {0, 0}, 1459 {1, 0}, 1460 }, 1461 {}, error_spec_); 1462 } 1463 1464 XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { 1465 XlaBuilder builder(TestName()); 1466 auto lhs = ConstantR1<float>(&builder, {}); 1467 auto rhs = ConstantR1<float>(&builder, {}); 1468 Pow(lhs, rhs); 1469 1470 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 1471 } 1472 1473 // Some Pow cases that can be implemented more efficiently. 1474 XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { 1475 XlaBuilder b(TestName()); 1476 1477 std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f}; 1478 std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; 1479 1480 Literal param_literal = LiteralUtil::CreateR1<float>(values); 1481 std::unique_ptr<GlobalData> param_data = 1482 client_->TransferToServer(param_literal).ConsumeValueOrDie(); 1483 1484 auto sum = ConstantR0<float>(&b, 0.0f); 1485 auto param = Parameter(&b, 0, param_literal.shape(), "param"); 1486 for (float exponent : exponents) { 1487 sum = Add(sum, Pow(param, ConstantR0<float>(&b, exponent))); 1488 } 1489 1490 std::vector<float> expected; 1491 for (auto value : values) { 1492 float sum = 0.0f; 1493 for (float exponent : exponents) { 1494 sum += std::pow(value, exponent); 1495 } 1496 expected.push_back(sum); 1497 } 1498 1499 ComputeAndCompareR1<float>(&b, expected, {param_data.get()}, error_spec_); 1500 } 1501 1502 XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { 1503 XlaBuilder b(TestName()); 1504 1505 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; 1506 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; 1507 1508 Literal literal0 = LiteralUtil::CreateR1<float>(values0); 1509 std::unique_ptr<GlobalData> data0 = 1510 client_->TransferToServer(literal0).ConsumeValueOrDie(); 1511 Literal literal1 = LiteralUtil::CreateR1<float>(values1); 1512 std::unique_ptr<GlobalData> data1 = 1513 client_->TransferToServer(literal1).ConsumeValueOrDie(); 1514 auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); 1515 auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); 1516 Pow(Exp(param0), param1); 1517 1518 std::vector<float> expected(values0.size()); 1519 for (int64 i = 0; i < values0.size(); ++i) { 1520 expected[i] = std::pow(std::exp(values0[i]), values1[i]); 1521 } 1522 1523 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()}, 1524 error_spec_); 1525 } 1526 1527 XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { 1528 XlaBuilder b(TestName()); 1529 1530 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f}; 1531 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; 1532 1533 Literal literal0 = LiteralUtil::CreateR1<float>(values0); 1534 std::unique_ptr<GlobalData> data0 = 1535 client_->TransferToServer(literal0).ConsumeValueOrDie(); 1536 Literal literal1 = LiteralUtil::CreateR1<float>(values1); 1537 std::unique_ptr<GlobalData> data1 = 1538 client_->TransferToServer(literal1).ConsumeValueOrDie(); 1539 auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); 1540 auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); 1541 Log(Pow(param0, param1)); 1542 1543 std::vector<float> expected(values0.size()); 1544 for (int64 i = 0; i < values0.size(); ++i) { 1545 expected[i] = std::log(std::pow(values0[i], values1[i])); 1546 } 1547 1548 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()}, 1549 error_spec_); 1550 } 1551 1552 XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { 1553 XlaBuilder b(TestName()); 1554 1555 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; 1556 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; 1557 1558 Literal literal0 = LiteralUtil::CreateR1<float>(values0); 1559 std::unique_ptr<GlobalData> data0 = 1560 client_->TransferToServer(literal0).ConsumeValueOrDie(); 1561 Literal literal1 = LiteralUtil::CreateR1<float>(values1); 1562 std::unique_ptr<GlobalData> data1 = 1563 client_->TransferToServer(literal1).ConsumeValueOrDie(); 1564 auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); 1565 auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); 1566 Mul(Exp(param0), Exp(param1)); 1567 1568 std::vector<float> expected(values0.size()); 1569 for (int64 i = 0; i < values0.size(); ++i) { 1570 expected[i] = std::exp(values0[i]) * std::exp(values1[i]); 1571 } 1572 1573 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()}, 1574 error_spec_); 1575 } 1576 1577 XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { 1578 XlaBuilder b(TestName()); 1579 1580 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; 1581 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; 1582 1583 Literal literal0 = LiteralUtil::CreateR1<float>(values0); 1584 std::unique_ptr<GlobalData> data0 = 1585 client_->TransferToServer(literal0).ConsumeValueOrDie(); 1586 Literal literal1 = LiteralUtil::CreateR1<float>(values1); 1587 std::unique_ptr<GlobalData> data1 = 1588 client_->TransferToServer(literal1).ConsumeValueOrDie(); 1589 auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); 1590 auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); 1591 Div(param0, Exp(param1)); 1592 1593 std::vector<float> expected(values0.size()); 1594 for (int64 i = 0; i < values0.size(); ++i) { 1595 expected[i] = values0[i] / std::exp(values1[i]); 1596 } 1597 1598 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()}, 1599 error_spec_); 1600 } 1601 1602 XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { 1603 XlaBuilder b(TestName()); 1604 1605 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; 1606 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; 1607 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; 1608 1609 Literal literal0 = LiteralUtil::CreateR1<float>(values0); 1610 std::unique_ptr<GlobalData> data0 = 1611 client_->TransferToServer(literal0).ConsumeValueOrDie(); 1612 1613 Literal literal1 = LiteralUtil::CreateR1<float>(values1); 1614 std::unique_ptr<GlobalData> data1 = 1615 client_->TransferToServer(literal1).ConsumeValueOrDie(); 1616 1617 Literal literal2 = LiteralUtil::CreateR1<float>(values2); 1618 std::unique_ptr<GlobalData> data2 = 1619 client_->TransferToServer(literal2).ConsumeValueOrDie(); 1620 auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); 1621 auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); 1622 auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); 1623 Div(Div(param0, param1), param2); 1624 1625 std::vector<float> expected(values0.size()); 1626 for (int64 i = 0; i < values0.size(); ++i) { 1627 expected[i] = (values0[i] / values1[i]) / values2[i]; 1628 } 1629 1630 ComputeAndCompareR1<float>( 1631 &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); 1632 } 1633 1634 XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { 1635 XlaBuilder b(TestName()); 1636 1637 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; 1638 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; 1639 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; 1640 1641 Literal literal0 = LiteralUtil::CreateR1<float>(values0); 1642 std::unique_ptr<GlobalData> data0 = 1643 client_->TransferToServer(literal0).ConsumeValueOrDie(); 1644 1645 Literal literal1 = LiteralUtil::CreateR1<float>(values1); 1646 std::unique_ptr<GlobalData> data1 = 1647 client_->TransferToServer(literal1).ConsumeValueOrDie(); 1648 1649 Literal literal2 = LiteralUtil::CreateR1<float>(values2); 1650 std::unique_ptr<GlobalData> data2 = 1651 client_->TransferToServer(literal2).ConsumeValueOrDie(); 1652 1653 auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); 1654 auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); 1655 auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); 1656 Div(param0, Div(param1, param2)); 1657 1658 std::vector<float> expected(values0.size()); 1659 for (int64 i = 0; i < values0.size(); ++i) { 1660 expected[i] = values0[i] / (values1[i] / values2[i]); 1661 } 1662 1663 ComputeAndCompareR1<float>( 1664 &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); 1665 } 1666 1667 XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { 1668 XlaBuilder b(TestName()); 1669 1670 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; 1671 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f}; 1672 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f}; 1673 1674 Literal literal0 = LiteralUtil::CreateR1<float>(values0); 1675 std::unique_ptr<GlobalData> data0 = 1676 client_->TransferToServer(literal0).ConsumeValueOrDie(); 1677 1678 Literal literal1 = LiteralUtil::CreateR1<float>(values1); 1679 std::unique_ptr<GlobalData> data1 = 1680 client_->TransferToServer(literal1).ConsumeValueOrDie(); 1681 1682 Literal literal2 = LiteralUtil::CreateR1<float>(values2); 1683 std::unique_ptr<GlobalData> data2 = 1684 client_->TransferToServer(literal2).ConsumeValueOrDie(); 1685 1686 auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); 1687 auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); 1688 auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); 1689 Div(param0, Pow(param1, param2)); 1690 1691 std::vector<float> expected(values0.size()); 1692 for (int64 i = 0; i < values0.size(); ++i) { 1693 expected[i] = values0[i] / std::pow(values1[i], values2[i]); 1694 } 1695 1696 ComputeAndCompareR1<float>( 1697 &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); 1698 } 1699 1700 XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { 1701 XlaBuilder b(TestName()); 1702 1703 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; 1704 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; 1705 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; 1706 std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f}; 1707 1708 Literal literal0 = LiteralUtil::CreateR1<float>(values0); 1709 std::unique_ptr<GlobalData> data0 = 1710 client_->TransferToServer(literal0).ConsumeValueOrDie(); 1711 1712 Literal literal1 = LiteralUtil::CreateR1<float>(values1); 1713 std::unique_ptr<GlobalData> data1 = 1714 client_->TransferToServer(literal1).ConsumeValueOrDie(); 1715 1716 Literal literal2 = LiteralUtil::CreateR1<float>(values2); 1717 std::unique_ptr<GlobalData> data2 = 1718 client_->TransferToServer(literal2).ConsumeValueOrDie(); 1719 1720 Literal literal3 = LiteralUtil::CreateR1<float>(values3); 1721 std::unique_ptr<GlobalData> data3 = 1722 client_->TransferToServer(literal3).ConsumeValueOrDie(); 1723 1724 auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); 1725 auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); 1726 auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); 1727 auto param3 = Parameter(&b, 3, literal3.shape(), "param2"); 1728 Div(Div(param0, param1), Div(param2, param3)); 1729 1730 std::vector<float> expected(values0.size()); 1731 for (int64 i = 0; i < values0.size(); ++i) { 1732 expected[i] = (values0[i] / values1[i]) / (values2[i] / values3[i]); 1733 } 1734 1735 ComputeAndCompareR1<float>( 1736 &b, expected, {data0.get(), data1.get(), data2.get(), data3.get()}, 1737 error_spec_); 1738 } 1739 1740 TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { 1741 const int count = GetParam(); 1742 XlaBuilder builder(TestName()); 1743 std::vector<float> values; 1744 values.reserve(count); 1745 for (int i = 0; i < count; ++i) { 1746 values.push_back(i / static_cast<float>(count)); 1747 } 1748 auto x = ConstantR1<float>(&builder, values); 1749 Pow(x, ConstantR0<float>(&builder, 2.0f)); 1750 1751 std::vector<float> expected; 1752 expected.reserve(values.size()); 1753 for (float value : values) { 1754 expected.push_back(value * value); 1755 } 1756 1757 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); 1758 } 1759 1760 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) { 1761 XlaBuilder builder(TestName()); 1762 Array4D<float> values(2, 2, 2, 2); 1763 1764 std::vector<float> values_vector; 1765 std::vector<float> expected_vector; 1766 for (int i = 0; i < values.num_elements(); ++i) { 1767 values_vector.push_back(static_cast<float>(i) / values.num_elements()); 1768 expected_vector.push_back(values_vector.back() * values_vector.back()); 1769 } 1770 values.SetValues(values_vector); 1771 1772 Array4D<float> expected(2, 2, 2, 2, expected_vector); 1773 1774 auto x = ConstantR4FromArray4D<float>(&builder, values); 1775 Pow(x, ConstantR0<float>(&builder, 2.0f)); 1776 1777 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 1778 } 1779 1780 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) { 1781 XlaBuilder builder(TestName()); 1782 Array4D<float> values(2, 2, 0, 2); 1783 Array4D<float> expected(2, 2, 0, 2); 1784 1785 auto x = ConstantR4FromArray4D<float>(&builder, values); 1786 Pow(x, ConstantR0<float>(&builder, 2.0f)); 1787 1788 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 1789 } 1790 1791 XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) { 1792 XlaBuilder builder(TestName()); 1793 SetFastMathDisabled(true); 1794 auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f}); 1795 auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN}); 1796 Min(lhs, rhs); 1797 1798 ComputeAndCompareR1<float>(&builder, {1.0f, -5.0f, 1.0f, NAN, NAN}, {}, 1799 error_spec_); 1800 } 1801 1802 XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) { 1803 XlaBuilder builder(TestName()); 1804 auto lhs = ConstantR1<float>(&builder, {}); 1805 auto rhs = ConstantR1<float>(&builder, {}); 1806 Min(lhs, rhs); 1807 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 1808 } 1809 1810 XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) { 1811 XlaBuilder builder(TestName()); 1812 SetFastMathDisabled(true); 1813 auto lhs = ConstantR1<double>(&builder, {1.0, 1.0, 2.25, NAN, 6.0}); 1814 auto rhs = ConstantR1<double>(&builder, {2.0, -5.0, 1.0, 10.0, NAN}); 1815 Min(lhs, rhs); 1816 1817 ComputeAndCompareR1<double>(&builder, {1.0, -5.0, 1.0, NAN, NAN}, {}, 1818 error_spec_); 1819 } 1820 1821 XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) { 1822 XlaBuilder builder(TestName()); 1823 SetFastMathDisabled(true); 1824 auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f}); 1825 auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN}); 1826 Max(lhs, rhs); 1827 1828 ComputeAndCompareR1<float>(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {}, 1829 error_spec_); 1830 } 1831 1832 XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) { 1833 XlaBuilder builder(TestName()); 1834 auto lhs = ConstantR1<float>(&builder, {}); 1835 auto rhs = ConstantR1<float>(&builder, {}); 1836 Max(lhs, rhs); 1837 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 1838 } 1839 1840 XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) { 1841 XlaBuilder builder(TestName()); 1842 SetFastMathDisabled(true); 1843 auto lhs = ConstantR1<double>(&builder, {1.0, 1.0, 2.25, NAN, 6.0}); 1844 auto rhs = ConstantR1<double>(&builder, {2.0, -5.0, 1.0, 10.0, NAN}); 1845 Max(lhs, rhs); 1846 1847 ComputeAndCompareR1<double>(&builder, {2.0, 1.0, 2.25, NAN, NAN}, {}, 1848 error_spec_); 1849 } 1850 1851 XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) { 1852 const int32 min = std::numeric_limits<int32>::min(); 1853 const int32 max = std::numeric_limits<int32>::max(); 1854 XlaBuilder builder(TestName()); 1855 auto x = ConstantR1<int32>( 1856 &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); 1857 auto y = ConstantR1<int32>( 1858 &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); 1859 Max(x, y); 1860 1861 std::vector<int32> expected = {min, max, 0, -1, 0, 0, 0, 1862 1, 1, 10, max, max, max}; 1863 ComputeAndCompareR1<int32>(&builder, expected, {}); 1864 } 1865 1866 XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) { 1867 const int32 min = std::numeric_limits<int32>::min(); 1868 const int32 max = std::numeric_limits<int32>::max(); 1869 XlaBuilder builder(TestName()); 1870 auto x = ConstantR1<int32>( 1871 &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); 1872 auto y = ConstantR1<int32>( 1873 &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); 1874 Min(x, y); 1875 1876 std::vector<int32> expected = {min, min, min, -10, -1, -1, 0, 1877 0, 0, 1, 0, max, min}; 1878 ComputeAndCompareR1<int32>(&builder, expected, {}); 1879 } 1880 1881 XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) { 1882 const uint32 max = std::numeric_limits<uint32>::max(); 1883 XlaBuilder builder(TestName()); 1884 auto x = ConstantR1<uint32>(&builder, {0, 0, 1, 1, 1, max, max, max}); 1885 auto y = ConstantR1<uint32>(&builder, {0, 1, 0, 1, 10, 0, 234234, max}); 1886 Max(x, y); 1887 1888 std::vector<uint32> expected = {0, 1, 1, 1, 10, max, max, max}; 1889 ComputeAndCompareR1<uint32>(&builder, expected, {}); 1890 } 1891 1892 XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) { 1893 const uint32 max = std::numeric_limits<uint32>::max(); 1894 XlaBuilder builder(TestName()); 1895 auto x = ConstantR1<uint32>(&builder, {0, 0, 1, 1, 1, max, max, max}); 1896 auto y = ConstantR1<uint32>(&builder, {0, 1, 0, 1, 10, 0, 234234, max}); 1897 Min(x, y); 1898 1899 std::vector<uint32> expected = {0, 0, 0, 1, 1, 0, 234234, max}; 1900 ComputeAndCompareR1<uint32>(&builder, expected, {}); 1901 } 1902 1903 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { 1904 XlaBuilder builder(TestName()); 1905 auto x = ConstantR1<float>( 1906 &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); 1907 auto y = ConstantR1<float>( 1908 &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); 1909 Max(x, y); 1910 1911 std::vector<float> expected = {-0.0, 1.0, 2.0, 3.0, 4.0, 1912 5.0, 6.0, 7.0, 8.0, 9.0}; 1913 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); 1914 } 1915 1916 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { 1917 XlaBuilder builder(TestName()); 1918 auto u = ConstantR1<float>(&builder, {3.5}); 1919 auto v = ConstantR1<float>(&builder, {}); 1920 Max(u, v); 1921 1922 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 1923 } 1924 1925 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { 1926 for (int broadcast_dim : {0, 1}) { 1927 XlaBuilder builder(TestName()); 1928 auto u = ConstantR1<float>(&builder, {3.5}); 1929 auto v = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 2)); 1930 Max(u, v, /*broadcast_dimensions=*/{broadcast_dim}); 1931 1932 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_); 1933 } 1934 } 1935 1936 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { 1937 XlaBuilder builder(TestName()); 1938 auto v = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f}); 1939 auto m = ConstantR2<float>(&builder, 1940 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); 1941 Max(v, m, /*broadcast_dimensions=*/{1}); 1942 1943 Array2D<float> expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}}); 1944 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); 1945 } 1946 1947 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { 1948 XlaBuilder builder(TestName()); 1949 auto v = ConstantR1<float>(&builder, {}); 1950 auto m = ConstantR2<float>(&builder, {{}, {}}); 1951 Max(v, m, /*broadcast_dimensions=*/{1}); 1952 1953 Array2D<float> expected({{}, {}}); 1954 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); 1955 } 1956 1957 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { 1958 XlaBuilder builder(TestName()); 1959 auto scalar = ConstantR0<int32>(&builder, 2); 1960 Array3D<int32> a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}}); 1961 auto array = ConstantR3FromArray3D<int32>(&builder, a_3d); 1962 Max(array, scalar, /*broadcast_dimensions=*/{}); 1963 1964 Array3D<int32> expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}}); 1965 ComputeAndCompareR3<int32>(&builder, expected, {}); 1966 } 1967 1968 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { 1969 XlaBuilder builder(TestName()); 1970 auto scalar = ConstantR0<int32>(&builder, 2); 1971 Array3D<int32> a_3d(2, 0, 3); 1972 auto array = ConstantR3FromArray3D<int32>(&builder, a_3d); 1973 Max(array, scalar, /*broadcast_dimensions=*/{}); 1974 1975 Array3D<int32> expected(2, 0, 3); 1976 ComputeAndCompareR3<int32>(&builder, expected, {}); 1977 } 1978 1979 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { 1980 XlaBuilder builder(TestName()); 1981 auto m = ConstantR2<float>(&builder, 1982 {{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}}); 1983 auto v = ConstantR1<float>(&builder, {-10.2f, 16.4f}); 1984 Min(m, v, /*broadcast_dimensions=*/{0}); 1985 1986 Array2D<float> expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}}); 1987 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); 1988 } 1989 1990 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { 1991 XlaBuilder builder(TestName()); 1992 auto m = ConstantR2<float>(&builder, {{}, {}}); 1993 auto v = ConstantR1<float>(&builder, {-10.2f, 16.4f}); 1994 Min(m, v, /*broadcast_dimensions=*/{0}); 1995 1996 Array2D<float> expected({{}, {}}); 1997 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); 1998 } 1999 2000 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { 2001 XlaBuilder builder(TestName()); 2002 auto array2d = 2003 ConstantR2<float>(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); 2004 auto array4d = ConstantR4FromArray4D<float>( 2005 &builder, {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}}, 2006 {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}}); 2007 Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); 2008 2009 Array4D<float> expected( 2010 {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}}, 2011 {{{-12.2f, 64.29f, 6.1f}}, {{-0.01f, 32.2f, 2.5f}}}}); 2012 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 2013 } 2014 2015 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { 2016 XlaBuilder builder(TestName()); 2017 auto array2d = 2018 ConstantR2<float>(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); 2019 Array4D<float> arg(2, 2, 0, 3); 2020 auto array4d = ConstantR4FromArray4D<float>(&builder, arg); 2021 Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); 2022 2023 Array4D<float> expected(2, 2, 0, 3); 2024 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 2025 } 2026 2027 XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { 2028 XlaBuilder builder(TestName()); 2029 auto x = ConstantR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); 2030 auto y = ConstantR1<int32>(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); 2031 Min(x, y); 2032 2033 std::vector<int32> expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0}; 2034 ComputeAndCompareR1<int32>(&builder, expected, {}); 2035 } 2036 2037 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { 2038 XlaBuilder builder(TestName()); 2039 auto x = ConstantR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); 2040 auto y = ConstantR1<int32>(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); 2041 Max(x, y); 2042 2043 std::vector<int32> expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9}; 2044 ComputeAndCompareR1<int32>(&builder, expected, {}); 2045 } 2046 2047 XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) { 2048 XlaBuilder builder(TestName()); 2049 auto a = ConstantR1<int32>(&builder, {-3, 26, 2, -1, 1}); 2050 auto b = ConstantR1<int32>(&builder, {10, 5, 1, 10, -10}); 2051 Rem(a, b); 2052 2053 ComputeAndCompareR1<int32>(&builder, {-3, 1, 0, -1, 1}, {}); 2054 } 2055 2056 XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { 2057 XlaBuilder builder(TestName()); 2058 auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); 2059 auto argument = 2060 ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); 2061 auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0}); 2062 Clamp(minimum, argument, maximum); 2063 2064 ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {}, 2065 error_spec_); 2066 } 2067 2068 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32) { 2069 SetFastMathDisabled(true); 2070 XlaBuilder builder(TestName()); 2071 auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, NAN}); 2072 auto argument = 2073 ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); 2074 auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, NAN, 123.0f}); 2075 Clamp(minimum, argument, maximum); 2076 2077 ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, NAN, NAN}, {}, 2078 error_spec_); 2079 } 2080 2081 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { 2082 XlaBuilder builder(TestName()); 2083 auto minimum = ConstantR0<float>(&builder, 0.0f); 2084 auto argument = ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); 2085 auto maximum = ConstantR0<float>(&builder, 5.0f); 2086 Clamp(minimum, argument, maximum); 2087 2088 ComputeAndCompareR1<float>(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {}, 2089 error_spec_); 2090 } 2091 2092 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { 2093 XlaBuilder builder(TestName()); 2094 auto min_scalar = ConstantR0<float>(&builder, 0.0f); 2095 auto min_vector = 2096 ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); 2097 auto arg_vector = 2098 ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); 2099 auto max_scalar = ConstantR0<float>(&builder, 3.0f); 2100 auto max_vector = 2101 ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0}); 2102 // Perform clamp with broadcasted scalar and vector. 2103 Add(Add(Clamp(min_vector, arg_vector, max_scalar), 2104 Clamp(min_scalar, arg_vector, max_vector)), 2105 Add(Clamp(min_vector, arg_vector, max_vector), 2106 Clamp(min_scalar, arg_vector, max_scalar))); 2107 2108 ComputeAndCompareR1<float>(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {}, 2109 error_spec_); 2110 } 2111 2112 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) { 2113 XlaBuilder builder(TestName()); 2114 auto min_vector = ConstantR1<int32>(&builder, {1, -6, 1, 2, 0, -5}); 2115 auto arg_vector = ConstantR1<int32>(&builder, {2, 10, -5, 1, 4, 10}); 2116 auto max_vector = ConstantR1<int32>(&builder, {3, 0, 25, 5, 123, -1}); 2117 Clamp(min_vector, arg_vector, max_vector); 2118 2119 ComputeAndCompareR1<int32>(&builder, {2, 0, 1, 2, 4, -1}, {}); 2120 } 2121 2122 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) { 2123 XlaBuilder builder(TestName()); 2124 auto min_scalar = ConstantR0<int32>(&builder, 0); 2125 auto min_vector = ConstantR1<int32>(&builder, {1, -6, 1, 2, 0}); 2126 auto arg_vector = ConstantR1<int32>(&builder, {2, 10, -5, 1, 4}); 2127 auto max_scalar = ConstantR0<int32>(&builder, 3); 2128 auto max_vector = ConstantR1<int32>(&builder, {3, 1, 25, 5, 123}); 2129 // Perform clamp with broadcasted scalar and vector. 2130 Add(Add(Clamp(min_vector, arg_vector, max_scalar), 2131 Clamp(min_scalar, arg_vector, max_vector)), 2132 Add(Clamp(min_vector, arg_vector, max_vector), 2133 Clamp(min_scalar, arg_vector, max_scalar))); 2134 2135 ComputeAndCompareR1<int32>(&builder, {8, 8, 2, 6, 14}, {}); 2136 } 2137 2138 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) { 2139 XlaBuilder builder(TestName()); 2140 auto min_vector = ConstantR1<uint32>(&builder, {1, 2, 1, 2, 0, ~0u - 4}); 2141 auto arg_vector = ConstantR1<uint32>(&builder, {2, 10, 5, 1, 4, 10}); 2142 auto max_vector = ConstantR1<uint32>(&builder, {3, 5, 25, 5, 123, ~0u}); 2143 Clamp(min_vector, arg_vector, max_vector); 2144 2145 ComputeAndCompareR1<uint32>(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {}); 2146 } 2147 2148 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { 2149 XlaBuilder builder(TestName()); 2150 auto min_scalar = ConstantR0<uint32>(&builder, 0); 2151 auto min_vector = ConstantR1<uint32>(&builder, {1, 0, 1, 2, 0}); 2152 auto arg_vector = ConstantR1<uint32>(&builder, {2, 10, 0, 1, 4}); 2153 auto max_scalar = ConstantR0<uint32>(&builder, 3); 2154 auto max_vector = ConstantR1<uint32>(&builder, {3, 1, 25, 5, 123}); 2155 // Perform clamp with broadcasted scalar and vector. 2156 Add(Add(Clamp(min_vector, arg_vector, max_scalar), 2157 Clamp(min_scalar, arg_vector, max_vector)), 2158 Add(Clamp(min_vector, arg_vector, max_vector), 2159 Clamp(min_scalar, arg_vector, max_scalar))); 2160 2161 ComputeAndCompareR1<uint32>(&builder, {8, 8, 2, 6, 14}, {}); 2162 } 2163 2164 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { 2165 XlaBuilder builder(TestName()); 2166 2167 Literal param0_literal = 2168 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f}); 2169 std::unique_ptr<GlobalData> param0_data = 2170 client_->TransferToServer(param0_literal).ConsumeValueOrDie(); 2171 2172 Literal param1_literal = 2173 LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f}); 2174 std::unique_ptr<GlobalData> param1_data = 2175 client_->TransferToServer(param1_literal).ConsumeValueOrDie(); 2176 2177 auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); 2178 auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); 2179 Add(p0, p1); 2180 2181 ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f}, 2182 {param0_data.get(), param1_data.get()}, 2183 error_spec_); 2184 } 2185 2186 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { 2187 XlaBuilder builder(TestName()); 2188 2189 Literal param0_literal = 2190 LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0)); 2191 std::unique_ptr<GlobalData> param0_data = 2192 client_->TransferToServer(param0_literal).ConsumeValueOrDie(); 2193 2194 Literal param1_literal = 2195 LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0)); 2196 std::unique_ptr<GlobalData> param1_data = 2197 client_->TransferToServer(param1_literal).ConsumeValueOrDie(); 2198 2199 auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); 2200 auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); 2201 Add(p0, p1); 2202 2203 Array3D<float> expected(0, 7, 0); 2204 ComputeAndCompareR3<float>( 2205 &builder, expected, {param0_data.get(), param1_data.get()}, error_spec_); 2206 } 2207 2208 XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { 2209 XlaBuilder builder(TestName()); 2210 2211 Literal param0_literal = 2212 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f}); 2213 std::unique_ptr<GlobalData> param0_data = 2214 client_->TransferToServer(param0_literal).ConsumeValueOrDie(); 2215 2216 auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f}); 2217 auto p = Parameter(&builder, 0, param0_literal.shape(), "param0"); 2218 Add(a, p); 2219 2220 ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f}, 2221 {param0_data.get()}, error_spec_); 2222 } 2223 2224 XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) { 2225 XlaBuilder builder(TestName()); 2226 auto a = ConstantR1<float>(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f}); 2227 Cos(a); 2228 2229 ComputeAndCompareR1<float>(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {}, 2230 error_spec_); 2231 } 2232 2233 XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) { 2234 XlaBuilder builder(TestName()); 2235 auto a = ConstantR1<float>(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f}); 2236 Sin(a); 2237 2238 ComputeAndCompareR1<float>(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {}, 2239 error_spec_); 2240 } 2241 2242 XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { 2243 XlaBuilder builder(TestName()); 2244 auto a = ConstantR1<float>(&builder, {0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f}); 2245 auto b = ConstantR1<float>(&builder, {6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f}); 2246 Atan2(a, b); 2247 2248 ComputeAndCompareR1<float>( 2249 &builder, 2250 {0.0f, 1.57079633f, 3.14159265f, -1.57079633f, 0.78539816f, -0.78539816f}, 2251 {}, error_spec_); 2252 } 2253 2254 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) { 2255 XlaBuilder builder(TestName()); 2256 auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f}); 2257 Tanh(a); 2258 2259 ComputeAndCompareR1<float>(&builder, {-0.986614f, 0.996260f, 0.978026}, {}, 2260 error_spec_); 2261 } 2262 2263 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { 2264 // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that 2265 // the input tensor is large enough to exercise the vectorized tanh 2266 // implementation on XLA CPU. 2267 XlaBuilder builder(TestName()); 2268 auto input_literal = LiteralUtil::CreateR1<float>( 2269 {1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16, 2270 -0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32, 2271 -1.45, -0.77, -1.09, 0.91, -1.03, -0.30, -1.11, -1.17, 1.50, -0.85, 2272 0.04, 1.02, 0.34, -0.61, 0.41, 0.07, -0.02, 1.42, -0.62, 0.81, 2273 0.08, 0.81, -0.30, 1.17, -0.65, -0.44, 0.92, 1.26, -1.29, 1.35, 2274 0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31, 2275 -0.79, 1.41, 1.21, 1.05}); 2276 TF_ASSERT_OK_AND_ASSIGN(auto input_data, 2277 client_->TransferToServer(input_literal)); 2278 2279 auto input = Parameter(&builder, 0, input_literal.shape(), "input"); 2280 Tanh(input); 2281 2282 ComputeAndCompareR1<float>( 2283 &builder, 2284 {0.77009583, -0.30665702, 0.69070244, 0.71401149, 0.84400684, 2285 -0.71985596, -0.45764771, 0.66664988, -0.58278900, 0.16050975, 2286 -0.06770509, 0.36843640, -0.38476998, 0.04018109, 0.87562293, 2287 0.84788644, 0.38603750, 0.57294142, -0.79140943, 0.31032649, 2288 -0.89590985, -0.64770776, -0.79625875, 0.72234446, -0.77389336, 2289 -0.28871772, -0.80428445, -0.82541436, 0.90456349, -0.68856895, 2290 0.03877772, 0.76877952, 0.32561871, -0.54546672, 0.39072621, 2291 0.07273290, -0.01924866, 0.88924897, -0.55283129, 0.67183107, 2292 0.08006320, 0.66944766, -0.29068485, 0.82573754, -0.57170743, 2293 -0.41581789, 0.72739530, 0.85025692, -0.85931867, 0.87357593, 2294 0.07782833, -0.84597743, -0.72748238, 0.45396307, 0.82449573, 2295 -0.42462519, -0.86363792, -0.89368379, -0.12621804, -0.86445558, 2296 -0.65565848, 0.88789743, 0.83566397, 0.78287679}, 2297 {input_data.get()}, 2298 // The error spec is unusually high here to account for the fact that we 2299 // use a rational interpolant to approximate tanh. 2300 ErrorSpec(0.004, 0.004)); 2301 } 2302 2303 XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { 2304 // The input tensor is large enough to exercise the vectorized exp 2305 // implementation on XLA CPU. 2306 XlaBuilder builder(TestName()); 2307 2308 // Just to help make sense of the scales here -- exp(89) saturates float32 and 2309 // exp(-10) is smaller than our error spec. 2310 Literal input_literal = LiteralUtil::CreateR1<float>( 2311 {1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31, 2312 -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5, 2313 -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4, 2314 -16.3, -15.2, -14.1, -13.0, -11.9, -10.8, -9.7, -8.6, -7.5, 2315 -6.4, -5.3, -4.2, -3.1, -2.0, -0.9, 0.2, 1.3, 2.4, 2316 3.5, 4.6, 5.7, 6.8, 7.9, 9.0, 10.1, 11.2, 12.3, 2317 13.4, 14.5, 15.6, 16.7, 17.8, 18.9, 20.0, 21.1, 22.2, 2318 23.3, 24.4, 25.5, 26.6, 27.7, 28.8, 29.9, 31.0, 32.1, 2319 68.4, 69.5, 70.6, 71.7, 72.8, 73.9, 75.0, 76.1, 77.2, 2320 78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3, 2321 86.4, 86.5, 87.6, 87.7, 87.8, 87.9}); 2322 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data, 2323 client_->TransferToServer(input_literal)); 2324 2325 auto input = Parameter(&builder, 0, input_literal.shape(), "input"); 2326 Exp(input); 2327 2328 std::vector<float> expected_result; 2329 int64 input_size = input_literal.shape().dimensions(0); 2330 expected_result.reserve(input_size); 2331 for (int64 i = 0; i < input_size; i++) { 2332 expected_result.push_back(std::exp(input_literal.Get<float>({i}))); 2333 } 2334 2335 ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()}, 2336 error_spec_); 2337 } 2338 2339 XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { 2340 // The input tensor is large enough to exercise the vectorized exp 2341 // implementation on XLA CPU. 2342 XlaBuilder builder(TestName()); 2343 2344 Literal input_literal = LiteralUtil::CreateR1<float>( 2345 {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, 2346 -167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9, 2347 198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04, 2348 1.74e+04, 1.89e+05, 1.9e+05, 1.93e+06, 1.98e+06, 1.65e+06, 1.97e+07, 2349 1.66e+07, 1e+07, 1.98e+08, 1.96e+08, 1.64e+09, 1.58e+09, 1.64e+09, 2350 1.44e+10, 1.5e+10, 1.99e+10, 1.17e+11, 1.08e+11, 1.08e+12, 1.38e+12, 2351 1.4e+12, 1.03e+13, 1.6e+13, 1.99e+13, 1.26e+14, 1.51e+14, 1.33e+15, 2352 1.41e+15, 1.63e+15, 1.39e+16, 1.21e+16, 1.27e+16, 1.28e+17, 1.62e+17, 2353 2e+18, 1.96e+18, 1.81e+18, 1.99e+19, 1.86e+19, 1.61e+19, 1.71e+20, 2354 1.47e+20, 1.83e+21, 1.33e+21, 1.3e+21, 1.35e+22, 1.84e+22, 1.02e+22, 2355 1.81e+23, 1.02e+23, 1.89e+24, 1.49e+24, 1.08e+24, 1.95e+25, 1.1e+25, 2356 1.62e+25, 1.2e+26, 1.41e+26, 1.93e+27, 1.66e+27, 1.62e+27, 1.05e+28, 2357 1.5e+28, 1.79e+28, 1.36e+29, 1.95e+29, 1.5e+30, 1.81e+30, 1.34e+30, 2358 1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33, 2359 1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35}); 2360 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data, 2361 client_->TransferToServer(input_literal)); 2362 2363 auto input = Parameter(&builder, 0, input_literal.shape(), "input"); 2364 Log(input); 2365 2366 std::vector<float> expected_result; 2367 int64 input_size = input_literal.shape().dimensions(0); 2368 expected_result.reserve(input_size); 2369 for (int64 i = 0; i < input_size; i++) { 2370 expected_result.push_back(std::log(input_literal.Get<float>({i}))); 2371 } 2372 2373 ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()}, 2374 error_spec_); 2375 } 2376 2377 XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) { 2378 XlaBuilder builder(TestName()); 2379 auto a = ConstantR1<uint32>( 2380 &builder, {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678}); 2381 Clz(a); 2382 2383 ComputeAndCompareR1<uint32>(&builder, {32, 31, 27, 15, 9, 3, 0}, {}); 2384 } 2385 2386 XLA_TEST_F(ArrayElementwiseOpTest, ClzS64s) { 2387 XlaBuilder builder(TestName()); 2388 auto a = 2389 ConstantR1<int64>(&builder, {0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1}); 2390 Clz(a); 2391 2392 ComputeAndCompareR1<int64>(&builder, {64, 63, 32, 1, 0}, {}); 2393 } 2394 2395 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { 2396 // a ------ (add) --------- (add) 2397 // / / 2398 // b -----/ / 2399 // c---------------------/ 2400 XlaBuilder builder(TestName()); 2401 2402 auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f}); 2403 auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); 2404 auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f}); 2405 2406 auto add = Add(a, b); 2407 Add(add, c); 2408 2409 ComputeAndCompareR1<float>(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {}, 2410 error_spec_); 2411 } 2412 2413 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) { 2414 // b ------ (add) --------- (add) 2415 // / / 2416 // c -----/ / 2417 // a---------------------/ 2418 XlaBuilder builder(TestName()); 2419 2420 auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f}); 2421 auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); 2422 auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f}); 2423 2424 auto add = Add(b, c); 2425 Add(a, add); 2426 2427 ComputeAndCompareR1<float>(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {}, 2428 error_spec_); 2429 } 2430 2431 XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) { 2432 // a ----- (neg) ----- (add) 2433 // / 2434 // b ----- (neg) ----/ 2435 XlaBuilder builder(TestName()); 2436 2437 auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f}); 2438 auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); 2439 2440 auto neg_a = Neg(a); 2441 auto neg_b = Neg(b); 2442 Add(neg_a, neg_b); 2443 2444 ComputeAndCompareR1<float>(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {}, 2445 error_spec_); 2446 } 2447 2448 XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { 2449 // a ------ (add) ------------\ 2450 // / \ 2451 // b -----/ (add) 2452 // / 2453 // c ------ (add) ------------/ 2454 // / 2455 // d -----/ 2456 XlaBuilder builder(TestName()); 2457 2458 auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f}); 2459 auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); 2460 auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f}); 2461 auto d = ConstantR1<float>(&builder, {-19.0f, 10.0f, -40.0f, 20.2f}); 2462 2463 auto add_ab = Add(a, b); 2464 auto add_cd = Add(c, d); 2465 Add(add_ab, add_cd); 2466 2467 ComputeAndCompareR1<float>(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {}, 2468 error_spec_); 2469 } 2470 2471 XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { 2472 XlaBuilder builder(TestName()); 2473 auto a = ConstantR2<float>(&builder, 2474 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); 2475 auto b = ConstantR2<float>(&builder, 2476 {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); 2477 Add(a, b); 2478 2479 Array2D<float> expected_array( 2480 {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); 2481 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2482 } 2483 2484 XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { 2485 // Add a scalar + matrix. 2486 XlaBuilder builder(TestName()); 2487 auto a = ConstantR2<float>(&builder, 2488 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); 2489 auto scalar = ConstantR0<float>(&builder, 3.0f); 2490 Add(scalar, a); 2491 2492 Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); 2493 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2494 } 2495 2496 XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) { 2497 // Add a matrix + scalar. 2498 XlaBuilder builder(TestName()); 2499 auto a = ConstantR2<float>(&builder, 2500 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); 2501 auto scalar = ConstantR0<float>(&builder, 3.0f); 2502 Add(a, scalar); 2503 2504 Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); 2505 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2506 } 2507 2508 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) { 2509 // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches 2510 // only dim 0 of the matrix. 2511 XlaBuilder builder(TestName()); 2512 auto v = ConstantR1<float>(&builder, {20.0f, 40.0f, 60.0f}); 2513 // clang-format off 2514 auto m = ConstantR2<float>(&builder, { 2515 {-2.5f, 3.14f, 1.0f}, 2516 {2.25f, -10.0f, 3.33f}}); 2517 // clang-format on 2518 Add(v, m, /*broadcast_dimensions=*/{1}); 2519 Array2D<float> expected_array( 2520 {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}}); 2521 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2522 } 2523 2524 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { 2525 // Test broadcasting in Eq comparison. 2526 XlaBuilder builder(TestName()); 2527 auto v = ConstantR1<int32>(&builder, {42, 73}); 2528 auto m = ConstantR2<int32>(&builder, {{42, 73}, {42, 52}}); 2529 2530 // This test exercises both possible broadcast dimensions for a vector/matrix 2531 // comparison. 2532 auto cmp_dim_0 = Eq(v, m, /*broadcast_dimensions=*/{1}); 2533 auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0}); 2534 Tuple(&builder, {cmp_dim_0, cmp_dim_1}); 2535 2536 auto expected = LiteralUtil::MakeTupleFromSlices( 2537 {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}), 2538 LiteralUtil::CreateR2<bool>({{true, false}, {false, false}})}); 2539 ComputeAndCompareTuple(&builder, expected, {}, error_spec_); 2540 } 2541 2542 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { 2543 // Test broadcasting in Ne comparison. 2544 XlaBuilder builder(TestName()); 2545 auto v = ConstantR1<int32>(&builder, {42, 73}); 2546 auto m = ConstantR2<int32>(&builder, {{42, 73}, {42, 52}}); 2547 Ne(v, m, /*broadcast_dimensions=*/{1}); 2548 2549 const string expected = R"(pred[2,2] { 2550 { 0, 0 }, 2551 { 0, 1 } 2552 })"; 2553 EXPECT_EQ(expected, ExecuteToString(&builder, {})); 2554 } 2555 2556 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { 2557 // Test broadcasting in Ge comparison. 2558 XlaBuilder builder(TestName()); 2559 auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4}); 2560 auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); 2561 Ge(v, m, /*broadcast_dimensions=*/{1}); 2562 2563 const string expected = R"(pred[2,4] { 2564 { 1, 1, 0, 0 }, 2565 { 0, 0, 0, 1 } 2566 })"; 2567 EXPECT_EQ(expected, ExecuteToString(&builder, {})); 2568 } 2569 2570 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { 2571 // Test broadcasting in Gt comparison. 2572 XlaBuilder builder(TestName()); 2573 auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4}); 2574 auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); 2575 Gt(v, m, /*broadcast_dimensions=*/{1}); 2576 2577 const string expected = R"(pred[2,4] { 2578 { 0, 1, 0, 0 }, 2579 { 0, 0, 0, 0 } 2580 })"; 2581 EXPECT_EQ(expected, ExecuteToString(&builder, {})); 2582 } 2583 2584 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { 2585 // Test broadcasting in Le comparison. 2586 XlaBuilder builder(TestName()); 2587 auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4}); 2588 auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); 2589 Le(v, m, /*broadcast_dimensions=*/{1}); 2590 2591 const string expected = R"(pred[2,4] { 2592 { 1, 0, 1, 1 }, 2593 { 1, 1, 1, 1 } 2594 })"; 2595 EXPECT_EQ(expected, ExecuteToString(&builder, {})); 2596 } 2597 2598 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { 2599 // Test broadcasting in Lt comparison. 2600 XlaBuilder builder(TestName()); 2601 auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4}); 2602 auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); 2603 Lt(v, m, /*broadcast_dimensions=*/{1}); 2604 2605 const string expected = R"(pred[2,4] { 2606 { 0, 0, 1, 1 }, 2607 { 1, 1, 1, 0 } 2608 })"; 2609 EXPECT_EQ(expected, ExecuteToString(&builder, {})); 2610 } 2611 2612 XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) { 2613 // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op 2614 // arguments is reversed. 2615 XlaBuilder builder(TestName()); 2616 auto m = 2617 ConstantR2<float>(&builder, {{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}}); 2618 auto v = ConstantR1<float>(&builder, {2.0f, 4.0f, 6.0f}); 2619 Mul(m, v, /*broadcast_dimensions=*/{1}); 2620 Array2D<float> expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}}); 2621 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2622 } 2623 2624 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) { 2625 // Tests broadcasting for arrays with degenerate (size == 1) dimensions. 2626 XlaBuilder builder(TestName()); 2627 // m's shape in XLA notation is {3, 2} 2628 // md's shape in XLA notation is {3, 1} 2629 // The result has shape {3, 2}, where md is broadcast over m 2630 auto m = ConstantR2<float>(&builder, 2631 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); 2632 auto md = ConstantR2<float>(&builder, {{10.0f, 20.0f, 30.0f}}); 2633 Add(m, md); 2634 Array2D<float> expected_array( 2635 {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}}); 2636 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2637 } 2638 2639 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) { 2640 // Tests broadcasting for arrays with degenerate (size == 1) dimensions. 2641 XlaBuilder builder(TestName()); 2642 // m's shape in XLA notation is {3, 2} 2643 // md's shape in XLA notation is {1, 2} 2644 // The result has shape {3, 2}, where md is broadcast over m 2645 auto m = ConstantR2<float>(&builder, 2646 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); 2647 auto md = ConstantR2<float>(&builder, {{10.0f}, {20.0f}}); 2648 Add(m, md); 2649 Array2D<float> expected_array( 2650 {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}}); 2651 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2652 } 2653 2654 XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) { 2655 // Tests broadcasting for two degenerate arrays. This kind of broadcasting 2656 // effectively creates an "outer product" operation. 2657 // This is taken from the Numpy docs example at: 2658 // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html 2659 XlaBuilder builder(TestName()); 2660 // a's shape in XLA notation is {1, 4} 2661 // b's shape in XLA notation is {3, 1} 2662 // The result has shape {3, 4}. 2663 auto a = ConstantR2<float>(&builder, {{0.0f}, {10.0f}, {20.0f}, {30.0f}}); 2664 auto b = ConstantR2<float>(&builder, {{1.0f, 2.0f, 3.0f}}); 2665 Add(a, b); 2666 Array2D<float> expected_array({{1.0f, 2.0f, 3.0f}, 2667 {11.0f, 12.0f, 13.0f}, 2668 {21.0f, 22.0f, 23.0f}, 2669 {31.0f, 32.0f, 33.0f}}); 2670 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2671 } 2672 2673 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) { 2674 // Add together a (2,2) array and a (2) array, using dimension 0 for 2675 // broadcasting (though there are two ways to broadcast these shapes). 2676 XlaBuilder builder(TestName()); 2677 auto v = ConstantR1<float>(&builder, {20.0f, 40.0f}); 2678 auto m = ConstantR2<float>(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}}); 2679 Add(v, m, /*broadcast_dimensions=*/{1}); 2680 Array2D<float> expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}}); 2681 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2682 } 2683 2684 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) { 2685 // Add together a (2,2) array and a (2) array, using dimension 1 for 2686 // broadcasting (though there are two ways to broadcast these shapes). 2687 XlaBuilder builder(TestName()); 2688 auto v = ConstantR1<float>(&builder, {20.0f, 40.0f}); 2689 auto m = ConstantR2<float>(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}}); 2690 Add(v, m, /*broadcast_dimensions=*/{0}); 2691 Array2D<float> expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}}); 2692 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2693 } 2694 2695 XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { 2696 // Binary add of two R3s together 2697 XlaBuilder builder(TestName()); 2698 Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, 2699 {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); 2700 auto a = ConstantR3FromArray3D<float>(&builder, a_3d); 2701 2702 Array3D<float> b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}}, 2703 {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}}); 2704 auto b = ConstantR3FromArray3D<float>(&builder, b_3d); 2705 Add(a, b); 2706 2707 Array3D<float> expected_3d( 2708 {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}}, 2709 {{21.0f, 24.0f}, {27.0f, 30.0f}, {33.0f, 36.0f}}}); 2710 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_); 2711 } 2712 2713 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { 2714 // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for 2715 // broadcasting (though there are two ways to broadcast these shapes). 2716 XlaBuilder builder(TestName()); 2717 // clang-format off 2718 Array3D<float> a_3d({ 2719 {{1.0f, 2.0f}, 2720 {3.0f, 4.0f}, 2721 {5.0f, 6.0f}}, 2722 {{7.0f, 8.0f}, 2723 {9.0f, 10.0f}, 2724 {11.0f, 12.0f}}, 2725 }); 2726 // clang-format on 2727 auto a = ConstantR3FromArray3D<float>(&builder, a_3d); 2728 auto v = ConstantR1<float>(&builder, {10.0f, 20.0f}); 2729 Add(a, v, /*broadcast_dimensions=*/{2}); 2730 2731 Array3D<float> expected_3d( 2732 {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}}, 2733 {{17.0f, 28.0f}, {19.0f, 30.0f}, {21.0f, 32.0f}}}); 2734 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_); 2735 } 2736 2737 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { 2738 // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for 2739 // broadcasting (though there are two ways to broadcast these shapes). 2740 XlaBuilder builder(TestName()); 2741 // clang-format off 2742 Array3D<float> a_3d({ 2743 {{1.0f, 2.0f}, 2744 {3.0f, 4.0f}, 2745 {5.0f, 6.0f}}, 2746 {{7.0f, 8.0f}, 2747 {9.0f, 10.0f}, 2748 {11.0f, 12.0f}}, 2749 }); 2750 // clang-format on 2751 auto a = ConstantR3FromArray3D<float>(&builder, a_3d); 2752 auto v = ConstantR1<float>(&builder, {10.0f, 20.0f}); 2753 Add(a, v, /*broadcast_dimensions=*/{0}); 2754 2755 // clang-format off 2756 Array3D<float> expected_3d({ 2757 {{11.0f, 12.0f}, 2758 {13.0f, 14.0f}, 2759 {15.0f, 16.0f}}, 2760 {{27.0f, 28.0f}, 2761 {29.0f, 30.0f}, 2762 {31.0f, 32.0f}}, 2763 }); 2764 // clang-format on 2765 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_); 2766 } 2767 2768 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { 2769 // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2} 2770 // for broadcasting. 2771 XlaBuilder builder(TestName()); 2772 // clang-format off 2773 Array3D<float> a_3d({ 2774 {{1.0f, 2.0f}, 2775 {3.0f, 4.0f}, 2776 {5.0f, 6.0f}}, 2777 {{7.0f, 8.0f}, 2778 {9.0f, 10.0f}, 2779 {11.0f, 12.0f}}, 2780 }); 2781 auto a = ConstantR3FromArray3D<float>(&builder, a_3d); 2782 auto m = ConstantR2<float>(&builder, { 2783 {10.0f, 20.0f, 30.0f}, 2784 {40.0f, 50.0f, 60.0f}, 2785 }); 2786 Add(a, m, /*broadcast_dimensions=*/{0, 1}); 2787 2788 Array3D<float> expected_3d({ 2789 {{11.0f, 12.0f}, 2790 {23.0f, 24.0f}, 2791 {35.0f, 36.0f}}, 2792 {{47.0f, 48.0f}, 2793 {59.0f, 60.0f}, 2794 {71.0f, 72.0f}}, 2795 }); 2796 // clang-format on 2797 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_); 2798 } 2799 2800 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { 2801 // Comparison between two 3D arrays of compatible shapes: 2802 // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs. 2803 XlaBuilder builder(TestName()); 2804 Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, 2805 {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); 2806 auto a = ConstantR3FromArray3D<float>(&builder, a_3d); 2807 2808 Array3D<float> b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}}); 2809 auto b = ConstantR3FromArray3D<float>(&builder, b_3d); 2810 2811 Gt(a, b); 2812 2813 Array3D<int> expected_3d( 2814 {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}}); 2815 const string expected = R"(pred[2,3,2] { 2816 { 2817 { 0, 1 }, 2818 { 0, 0 }, 2819 { 0, 0 } 2820 }, 2821 { 2822 { 0, 1 }, 2823 { 1, 0 }, 2824 { 0, 1 } 2825 } 2826 })"; 2827 EXPECT_EQ(expected, ExecuteToString(&builder, {})); 2828 } 2829 2830 XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) { 2831 XlaBuilder builder(TestName()); 2832 2833 std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5)); 2834 std::unique_ptr<Array4D<float>> operand_b_4d(new Array4D<float>(2, 3, 4, 5)); 2835 std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5)); 2836 float value = 0.0; 2837 for (int64 p = 0; p < 2; ++p) { 2838 for (int64 z = 0; z < 3; ++z) { 2839 for (int64 y = 0; y < 4; ++y) { 2840 for (int64 x = 0; x < 5; ++x) { 2841 (*operand_a_4d)(p, z, y, x) = value; 2842 (*operand_b_4d)(p, z, y, x) = 2.0 * value; 2843 (*expected_4d)(p, z, y, x) = 3.0 * value; 2844 value += 0.1; 2845 } 2846 } 2847 } 2848 } 2849 2850 auto a = ConstantR4FromArray4D<float>(&builder, *operand_a_4d); 2851 auto b = ConstantR4FromArray4D<float>(&builder, *operand_b_4d); 2852 Add(a, b); 2853 2854 ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_); 2855 } 2856 2857 XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { 2858 XlaBuilder builder(TestName()); 2859 2860 std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5)); 2861 std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5)); 2862 std::vector<float> operand_b_1d(3); 2863 std::iota(operand_b_1d.begin(), operand_b_1d.end(), 1.0); 2864 2865 float value = 0.0; 2866 for (int64 p = 0; p < 2; ++p) { 2867 for (int64 z = 0; z < 3; ++z) { 2868 for (int64 y = 0; y < 4; ++y) { 2869 for (int64 x = 0; x < 5; ++x) { 2870 (*operand_a_4d)(p, z, y, x) = value; 2871 (*expected_4d)(p, z, y, x) = value + operand_b_1d[z]; 2872 value += 0.1; 2873 } 2874 } 2875 } 2876 } 2877 2878 auto a = ConstantR4FromArray4D<float>(&builder, *operand_a_4d); 2879 auto b = ConstantR1<float>(&builder, operand_b_1d); 2880 Add(a, b, {1}); 2881 2882 ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_); 2883 } 2884 2885 XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { 2886 constexpr int d0 = 16; 2887 constexpr int d1 = 16; 2888 constexpr int d2 = 2; 2889 constexpr int d3 = 2; 2890 Array4D<float> r4(d0, d1, d2, d3); 2891 r4.Fill(1.0); 2892 std::vector<float> r1(d1); 2893 std::iota(r1.begin(), r1.end(), 1.0); 2894 2895 XlaBuilder builder(TestName()); 2896 Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout( 2897 r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); 2898 auto a = ConstantLiteral(&builder, a_literal); 2899 auto b = ConstantR1<float>(&builder, r1); 2900 Add(a, b, {1}); 2901 2902 for (int i0 = 0; i0 < d0; ++i0) { 2903 for (int i1 = 0; i1 < d1; ++i1) { 2904 for (int i2 = 0; i2 < d2; ++i2) { 2905 for (int i3 = 0; i3 < d3; ++i3) { 2906 r4(i0, i1, i2, i3) += r1[i1]; 2907 } 2908 } 2909 } 2910 } 2911 ComputeAndCompareR4<float>(&builder, r4, {}, error_spec_); 2912 } 2913 2914 // Show that we can't add two opaques. 2915 XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { 2916 XlaBuilder builder(TestName()); 2917 auto shape = ShapeUtil::MakeOpaqueShape(); 2918 auto x = Parameter(&builder, 0, shape, "x"); 2919 Add(x, x); 2920 auto computation_status = builder.Build(); 2921 ASSERT_FALSE(computation_status.ok()); 2922 EXPECT_THAT(computation_status.status().ToString(), 2923 ::testing::ContainsRegex( 2924 "Expected array argument for lhs of binary operation")); 2925 } 2926 2927 XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { 2928 XlaBuilder builder(TestName()); 2929 auto a = ConstantR2<float>(&builder, 2930 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); 2931 auto b = ConstantR2<float>(&builder, 2932 {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); 2933 Add(a, b, /*broadcast_dimensions=*/{0, 1}); 2934 2935 Array2D<float> expected_array( 2936 {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); 2937 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2938 } 2939 2940 XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) { 2941 XlaBuilder builder(TestName()); 2942 auto a = ConstantR2<float>(&builder, 2943 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); 2944 auto b = ConstantR2<float>(&builder, 2945 {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); 2946 Add(a, b, /*broadcast_dimensions=*/{1, 0}); 2947 2948 auto computation_status = builder.Build(); 2949 ASSERT_FALSE(computation_status.ok()); 2950 EXPECT_THAT(computation_status.status().error_message(), 2951 ::testing::ContainsRegex("must.*be the identity")); 2952 } 2953 2954 // Regression test for b/31927799. "slice - y" is fused and requires implicit 2955 // broadcast. 2956 XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { 2957 XlaBuilder builder(TestName()); 2958 auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3}); 2959 auto y_literal = LiteralUtil::CreateR1<float>({4, 5}); 2960 auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); 2961 auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); 2962 2963 auto x = Parameter(&builder, 0, x_literal.shape(), "x"); 2964 auto y = Parameter(&builder, 1, y_literal.shape(), "y"); 2965 auto slice = Slice(x, {1}, {2}, {1}); 2966 Sub(slice, y); 2967 2968 ComputeAndCompareR1<float>(&builder, {-2, -3}, {x_data.get(), y_data.get()}, 2969 error_spec_); 2970 } 2971 2972 INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount, 2973 ArrayElementwiseOpTestParamCount, 2974 ::testing::Values(127, 128, 129, 17 * 4096)); 2975 2976 } // namespace 2977 } // namespace xla 2978