1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <cmath> 17 #include <limits> 18 #include <memory> 19 #include <numeric> 20 #include <vector> 21 22 #include "tensorflow/compiler/xla/array2d.h" 23 #include "tensorflow/compiler/xla/array3d.h" 24 #include "tensorflow/compiler/xla/array4d.h" 25 #include "tensorflow/compiler/xla/client/computation_builder.h" 26 #include "tensorflow/compiler/xla/client/global_data.h" 27 #include "tensorflow/compiler/xla/client/local_client.h" 28 #include "tensorflow/compiler/xla/layout_util.h" 29 #include "tensorflow/compiler/xla/literal_util.h" 30 #include "tensorflow/compiler/xla/statusor.h" 31 #include "tensorflow/compiler/xla/test.h" 32 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 33 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 34 #include "tensorflow/compiler/xla/tests/test_macros.h" 35 #include "tensorflow/compiler/xla/types.h" 36 #include "tensorflow/compiler/xla/xla_data.pb.h" 37 #include "tensorflow/core/lib/core/casts.h" 38 #include "tensorflow/core/platform/types.h" 39 40 namespace xla { 41 namespace { 42 43 class ArrayElementwiseOpTest : public ClientLibraryTestBase { 44 public: 45 ErrorSpec error_spec_{0.0001, 0.0001}; 46 }; 47 48 class ArrayElementwiseOpTestParamCount 49 : public ArrayElementwiseOpTest, 50 public ::testing::WithParamInterface<int> {}; 51 52 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) { 53 ComputationBuilder builder(client_, TestName()); 54 auto a = builder.ConstantR1<float>({}); 55 auto result = builder.Neg(a); 56 57 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 58 } 59 60 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) { 61 ComputationBuilder builder(client_, TestName()); 62 auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); 63 auto result = builder.Neg(a); 64 65 ComputeAndCompareR1<float>(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {}, 66 error_spec_); 67 } 68 69 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { 70 ComputationBuilder builder(client_, TestName()); 71 auto a = builder.ConstantR1<int32>({-1, 0, 1, 324, 72 std::numeric_limits<int32>::min(), 73 std::numeric_limits<int32>::max()}); 74 auto result = builder.Neg(a); 75 76 // -min == min for int32 due to an overflow. In C++ it is undefined behavior 77 // to do this calculation. For XLA we have not specified that, so it 78 // ought to work. 79 ComputeAndCompareR1<int32>(&builder, 80 {1, 0, -1, -324, std::numeric_limits<int32>::min(), 81 -std::numeric_limits<int32>::max()}, 82 {}); 83 } 84 85 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) { 86 ComputationBuilder builder(client_, TestName()); 87 auto a = builder.ConstantR1<complex64>({}); 88 auto result = builder.Neg(a); 89 90 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_); 91 } 92 93 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { 94 ComputationBuilder builder(client_, TestName()); 95 auto a = builder.ConstantR1<complex64>( 96 {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}}); 97 auto result = builder.Neg(a); 98 99 ComputeAndCompareR1<complex64>( 100 &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}}, 101 {}, error_spec_); 102 } 103 104 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { 105 ComputationBuilder builder(client_, TestName()); 106 auto a = builder.ConstantR1<float>({}); 107 auto result = builder.IsFinite(a); 108 109 ComputeAndCompareR1<bool>(&builder, {}, {}); 110 } 111 112 // A non-canonical quiet NaN value. 113 static const float kNonCanonicalNaN = tensorflow::bit_cast<float>(0x7FD01234); 114 115 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) { 116 ComputationBuilder builder(client_, TestName()); 117 auto result = builder.IsFinite(builder.ConstantR0<float>(NAN)); 118 ComputeAndCompareR0<bool>(&builder, false, {}); 119 120 EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); 121 auto result_non_canonical = 122 builder.IsFinite(builder.ConstantR0<float>(kNonCanonicalNaN)); 123 ComputeAndCompareR0<bool>(&builder, false, {}); 124 125 const float inf = std::numeric_limits<float>::infinity(); 126 auto result_inf = builder.IsFinite(builder.ConstantR0<float>(inf)); 127 ComputeAndCompareR0<bool>(&builder, false, {}); 128 129 auto result_neg_inf = builder.IsFinite(builder.ConstantR0<float>(-inf)); 130 ComputeAndCompareR0<bool>(&builder, false, {}); 131 132 auto result_zero = builder.IsFinite(builder.ConstantR0<float>(0.0f)); 133 ComputeAndCompareR0<bool>(&builder, true, {}); 134 } 135 136 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) { 137 ComputationBuilder builder(client_, TestName()); 138 const float inf = std::numeric_limits<float>::infinity(); 139 EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); 140 auto a = builder.ConstantR1<float>( 141 {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}}); 142 auto result = builder.IsFinite(a); 143 144 ComputeAndCompareR1<bool>(&builder, {false, true, false, true, false, false}, 145 {}); 146 } 147 148 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { 149 ComputationBuilder builder(client_, TestName()); 150 auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); 151 auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); 152 auto add = builder.Add(a, b); 153 154 ComputeAndCompareR1<float>(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {}, 155 error_spec_); 156 } 157 158 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) { 159 ComputationBuilder builder(client_, TestName()); 160 auto a = builder.ConstantR1<float>({}); 161 auto b = builder.ConstantR1<float>({}); 162 auto add = builder.Add(a, b); 163 164 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 165 } 166 167 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { 168 ComputationBuilder builder(client_, TestName()); 169 auto a = builder.ConstantR1<complex64>( 170 {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}}); 171 auto b = builder.ConstantR1<complex64>( 172 {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}}); 173 auto add = builder.Add(a, b); 174 175 ComputeAndCompareR1<complex64>( 176 &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {}, 177 error_spec_); 178 } 179 180 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) { 181 ComputationBuilder builder(client_, TestName()); 182 auto a = builder.ConstantR1<complex64>({}); 183 auto b = builder.ConstantR1<complex64>({}); 184 auto add = builder.Add(a, b); 185 186 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_); 187 } 188 189 TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { 190 const int count = GetParam(); 191 ComputationBuilder builder(client_, TestName()); 192 std::vector<float> a_values; 193 std::vector<float> b_values; 194 for (int i = 0; i < count; ++i) { 195 a_values.push_back(i / static_cast<float>(count)); 196 b_values.push_back(2 * i / static_cast<float>(count + 2)); 197 } 198 199 std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({a_values}); 200 std::unique_ptr<GlobalData> a_data = 201 client_->TransferToServer(*a_literal).ConsumeValueOrDie(); 202 auto a_constant = builder.ConstantR1<float>(a_values); 203 auto a_param = builder.Parameter(0, a_literal->shape(), "a_param"); 204 205 std::unique_ptr<Literal> b_literal = Literal::CreateR1<float>({b_values}); 206 std::unique_ptr<GlobalData> b_data = 207 client_->TransferToServer(*b_literal).ConsumeValueOrDie(); 208 auto b_constant = builder.Parameter(1, a_literal->shape(), "b_param"); 209 auto b_param = builder.ConstantR1<float>(b_values); 210 211 auto sum1 = builder.Add(a_constant, b_constant); 212 auto sum2 = builder.Add(a_constant, b_param); 213 auto sum3 = builder.Add(a_param, b_constant); 214 auto sum4 = builder.Add(a_param, b_param); 215 216 auto sum = builder.Add(sum1, sum2); 217 sum = builder.Add(sum, sum3); 218 sum = builder.Add(sum, sum4); 219 220 std::vector<float> expected; 221 for (int64 i = 0; i < count; ++i) { 222 expected.push_back(4 * (a_values[i] + b_values[i])); 223 } 224 225 ComputeAndCompareR1<float>(&builder, expected, {a_data.get(), b_data.get()}, 226 error_spec_); 227 } 228 229 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { 230 ComputationBuilder builder(client_, TestName()); 231 auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); 232 auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); 233 auto add = builder.Sub(a, b); 234 235 ComputeAndCompareR1<float>(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f}, 236 {}, error_spec_); 237 } 238 239 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) { 240 ComputationBuilder builder(client_, TestName()); 241 auto a = builder.ConstantR1<float>({}); 242 auto b = builder.ConstantR1<float>({}); 243 auto add = builder.Sub(a, b); 244 245 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 246 } 247 248 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) { 249 ComputationBuilder builder(client_, TestName()); 250 auto a = builder.ConstantR1<int32>({-1, 0, 2, 1000000000}); 251 auto b = builder.ConstantR1<int32>({-1, 2, 1, -1}); 252 auto add = builder.Sub(a, b); 253 254 ComputeAndCompareR1<int32>(&builder, {0, -2, 1, 1000000001}, {}); 255 } 256 257 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) { 258 ComputationBuilder builder(client_, TestName()); 259 auto a = builder.ConstantR1<int32>({}); 260 auto b = builder.ConstantR1<int32>({}); 261 auto add = builder.Sub(a, b); 262 263 ComputeAndCompareR1<int32>(&builder, {}, {}); 264 } 265 266 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { 267 ComputationBuilder builder(client_, TestName()); 268 auto a = builder.ConstantR1<complex64>( 269 {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}}); 270 auto b = builder.ConstantR1<complex64>( 271 {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}}); 272 auto add = builder.Sub(a, b); 273 274 ComputeAndCompareR1<complex64>( 275 &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {}, 276 error_spec_); 277 } 278 279 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) { 280 ComputationBuilder builder(client_, TestName()); 281 auto a = builder.ConstantR1<complex64>({}); 282 auto b = builder.ConstantR1<complex64>({}); 283 auto add = builder.Sub(a, b); 284 285 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_); 286 } 287 288 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { 289 ComputationBuilder builder(client_, TestName()); 290 auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); 291 auto b = builder.ConstantR1<float>({10.0f, 5.1f, 1.0f, 10.0f, -6.0f}); 292 auto add = builder.Div(a, b); 293 294 ComputeAndCompareR1<float>(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {}, 295 error_spec_); 296 } 297 298 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { 299 ComputationBuilder builder(client_, TestName()); 300 auto a = builder.ConstantR1<float>({}); 301 auto b = builder.ConstantR1<float>({}); 302 auto add = builder.Div(a, b); 303 304 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 305 } 306 307 XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { 308 // clang-format off 309 // Some interesting values to test. 310 std::vector<int32> vals = { 311 INT32_MIN, INT32_MIN + 1, INT32_MIN + 2, -0x40000000, -0x3fffffff, 312 -271181, -1309, -17, -10, -5, -3, -2, -1, 0, 1, 2, 3, 5, 10, 17, 26, 101, 313 7919, 0x40000000, INT32_MAX - 2, INT32_MAX - 1, INT32_MAX}; 314 // clang-format on 315 316 std::vector<int32> dividends, divisors, quotients, remainders; 317 for (int32 divisor : vals) { 318 if (divisor != 0) { 319 for (int32 dividend : vals) { 320 // Avoid integer overflow. 321 if (dividend != INT32_MIN || divisor != -1) { 322 dividends.push_back(dividend); 323 divisors.push_back(divisor); 324 quotients.push_back(dividend / divisor); 325 remainders.push_back(dividend % divisor); 326 } 327 } 328 } 329 } 330 331 { 332 ComputationBuilder builder(client_, TestName()); 333 ComputationDataHandle dividend; 334 ComputationDataHandle divisor; 335 auto dividend_data = 336 CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, ÷nd); 337 auto divisor_data = 338 CreateR1Parameter<int32>(divisors, 1, "divisor", &builder, &divisor); 339 builder.Div(dividend, divisor); 340 341 ComputeAndCompareR1<int32>(&builder, quotients, 342 {dividend_data.get(), divisor_data.get()}); 343 } 344 345 // Test with a compile-time constant divisor. 346 { 347 ComputationBuilder builder(client_, TestName()); 348 ComputationDataHandle dividend; 349 auto dividend_data = 350 CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, ÷nd); 351 builder.Div(dividend, builder.ConstantR1<int32>(divisors)); 352 353 ComputeAndCompareR1<int32>(&builder, quotients, {dividend_data.get()}); 354 } 355 356 { 357 ComputationBuilder builder(client_, TestName()); 358 ComputationDataHandle dividend; 359 ComputationDataHandle divisor; 360 auto dividend_data = 361 CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, ÷nd); 362 auto divisor_data = 363 CreateR1Parameter<int32>(divisors, 1, "divisor", &builder, &divisor); 364 builder.Rem(dividend, divisor); 365 366 ComputeAndCompareR1<int32>(&builder, remainders, 367 {dividend_data.get(), divisor_data.get()}); 368 } 369 370 // Test with a compile-time constant divisor. 371 { 372 ComputationBuilder builder(client_, TestName()); 373 ComputationDataHandle dividend; 374 auto dividend_data = 375 CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, ÷nd); 376 builder.Rem(dividend, builder.ConstantR1<int32>(divisors)); 377 378 ComputeAndCompareR1<int32>(&builder, remainders, {dividend_data.get()}); 379 } 380 } 381 382 XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { 383 // clang-format off 384 // Some interesting values to test. 385 std::vector<uint32> vals = { 386 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0xABCDEF12, 0xCAFEBEEF, 0x80000000, 387 0x80000001, UINT32_MAX - 2, UINT32_MAX - 1, UINT32_MAX}; 388 // clang-format on 389 390 std::vector<uint32> dividends, divisors, quotients, remainders; 391 for (uint32 divisor : vals) { 392 if (divisor != 0) { 393 for (uint32 dividend : vals) { 394 dividends.push_back(dividend); 395 divisors.push_back(divisor); 396 quotients.push_back(dividend / divisor); 397 remainders.push_back(dividend % divisor); 398 } 399 } 400 } 401 402 { 403 ComputationBuilder builder(client_, TestName()); 404 ComputationDataHandle dividend; 405 ComputationDataHandle divisor; 406 auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend", 407 &builder, ÷nd); 408 auto divisor_data = 409 CreateR1Parameter<uint32>(divisors, 1, "divisor", &builder, &divisor); 410 builder.Div(dividend, divisor); 411 412 ComputeAndCompareR1<uint32>(&builder, quotients, 413 {dividend_data.get(), divisor_data.get()}); 414 } 415 416 { 417 ComputationBuilder builder(client_, TestName()); 418 ComputationDataHandle dividend; 419 auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend", 420 &builder, ÷nd); 421 builder.Div(dividend, builder.ConstantR1<uint32>(divisors)); 422 423 ComputeAndCompareR1<uint32>(&builder, quotients, {dividend_data.get()}); 424 } 425 426 { 427 ComputationBuilder builder(client_, TestName()); 428 ComputationDataHandle dividend; 429 ComputationDataHandle divisor; 430 auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend", 431 &builder, ÷nd); 432 auto divisor_data = 433 CreateR1Parameter<uint32>(divisors, 1, "divisor", &builder, &divisor); 434 builder.Rem(dividend, divisor); 435 436 ComputeAndCompareR1<uint32>(&builder, remainders, 437 {dividend_data.get(), divisor_data.get()}); 438 } 439 440 { 441 ComputationBuilder builder(client_, TestName()); 442 ComputationDataHandle dividend; 443 auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend", 444 &builder, ÷nd); 445 builder.Rem(dividend, builder.ConstantR1<uint32>(divisors)); 446 447 ComputeAndCompareR1<uint32>(&builder, remainders, {dividend_data.get()}); 448 } 449 } 450 451 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { 452 ComputationBuilder builder(client_, TestName()); 453 auto a = builder.ConstantR1<complex64>( 454 {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}}); 455 auto b = builder.ConstantR1<complex64>( 456 {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}}); 457 auto div = builder.Div(a, b); 458 459 ComputeAndCompareR1<complex64>( 460 &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_); 461 } 462 463 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) { 464 ComputationBuilder builder(client_, TestName()); 465 auto a = builder.ConstantR1<complex64>({}); 466 auto b = builder.ConstantR1<complex64>({}); 467 auto div = builder.Div(a, b); 468 469 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_); 470 } 471 472 XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { 473 ComputationBuilder builder(client_, TestName()); 474 auto a = builder.ConstantR1<float>( 475 {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f}); 476 auto b = builder.ConstantR1<float>( 477 {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f}); 478 auto add = builder.Rem(a, b); 479 480 ComputeAndCompareR1<float>( 481 &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {}, 482 error_spec_); 483 } 484 485 XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) { 486 ComputationBuilder builder(client_, TestName()); 487 auto a = builder.ConstantR1<float>({}); 488 auto b = builder.ConstantR1<float>({}); 489 auto add = builder.Rem(a, b); 490 491 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 492 } 493 494 XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { 495 ComputationBuilder builder(client_, TestName()); 496 auto a = builder.ConstantR1<double>( 497 {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0}); 498 auto b = builder.ConstantR1<double>( 499 {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0}); 500 auto add = builder.Rem(a, b); 501 502 ComputeAndCompareR1<double>( 503 &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {}, 504 error_spec_); 505 } 506 507 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) { 508 ComputationBuilder builder(client_, TestName()); 509 auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); 510 auto b = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); 511 auto add = builder.Mul(a, b); 512 513 ComputeAndCompareR1<float>(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f}, 514 {}, error_spec_); 515 } 516 517 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) { 518 ComputationBuilder builder(client_, TestName()); 519 auto a = builder.ConstantR1<float>({}); 520 auto b = builder.ConstantR1<float>({}); 521 auto add = builder.Mul(a, b); 522 523 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 524 } 525 526 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) { 527 std::vector<int32> data = {0, 528 1, 529 -1, 530 1234, 531 0x1a243514, 532 std::numeric_limits<int32>::max(), 533 std::numeric_limits<int32>::min()}; 534 // Form the test data set using all products of 'data' with itself. 535 std::vector<int32> a_data, b_data, expected; 536 for (int32 a : data) { 537 for (int32 b : data) { 538 a_data.push_back(a); 539 b_data.push_back(b); 540 expected.push_back(static_cast<uint32>(a) * static_cast<uint32>(b)); 541 } 542 } 543 544 ComputationBuilder builder(client_, TestName()); 545 auto a = builder.ConstantR1<int32>(a_data); 546 auto b = builder.ConstantR1<int32>(b_data); 547 auto add = builder.Mul(a, b); 548 549 ComputeAndCompareR1<int32>(&builder, expected, {}); 550 } 551 552 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) { 553 ComputationBuilder builder(client_, TestName()); 554 auto a = builder.ConstantR1<int32>({}); 555 auto b = builder.ConstantR1<int32>({}); 556 auto add = builder.Mul(a, b); 557 558 ComputeAndCompareR1<int32>(&builder, {}, {}); 559 } 560 561 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { 562 std::vector<uint32> data = {0, 1, 0xDEADBEEF, 1234, 563 0x1a243514, 0xFFFFFFFF, 0x80808080}; 564 565 // Form the test data set using all products of 'data' with itself. 566 std::vector<uint32> a_data, b_data, expected; 567 for (uint32 a : data) { 568 for (uint32 b : data) { 569 a_data.push_back(a); 570 b_data.push_back(b); 571 expected.push_back(a * b); 572 } 573 } 574 575 ComputationBuilder builder(client_, TestName()); 576 auto a = builder.ConstantR1<uint32>(a_data); 577 auto b = builder.ConstantR1<uint32>(b_data); 578 auto add = builder.Mul(a, b); 579 580 ComputeAndCompareR1<uint32>(&builder, expected, {}); 581 } 582 583 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { 584 ComputationBuilder builder(client_, TestName()); 585 auto a = builder.ConstantR1<complex64>( 586 {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}}); 587 auto b = builder.ConstantR1<complex64>( 588 {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}}); 589 auto add = builder.Mul(a, b); 590 591 ComputeAndCompareR1<complex64>( 592 &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {}, 593 error_spec_); 594 } 595 596 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) { 597 ComputationBuilder builder(client_, TestName()); 598 auto a = builder.ConstantR1<complex64>({}); 599 auto b = builder.ConstantR1<complex64>({}); 600 auto add = builder.Mul(a, b); 601 602 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_); 603 } 604 605 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { 606 ComputationBuilder builder(client_, TestName()); 607 auto a = builder.ConstantR1<bool>({false, false, true, true}); 608 auto b = builder.ConstantR1<bool>({false, true, false, true}); 609 auto out = builder.And(a, b); 610 611 ComputeAndCompareR1<bool>(&builder, {false, false, false, true}, {}); 612 } 613 614 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) { 615 ComputationBuilder builder(client_, TestName()); 616 auto a = builder.ConstantR2<bool>({{false, false}, {true, true}}); 617 auto b = builder.ConstantR2<bool>({{false, true}, {false, true}}); 618 auto out = builder.And(a, b); 619 620 Array2D<bool> expected_array({{false, false}, {false, true}}); 621 ComputeAndCompareR2<bool>(&builder, expected_array, {}); 622 } 623 624 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) { 625 ComputationBuilder builder(client_, TestName()); 626 auto a = builder.ConstantR1<bool>({}); 627 auto b = builder.ConstantR1<bool>({}); 628 auto out = builder.And(a, b); 629 630 ComputeAndCompareR1<bool>(&builder, {}, {}); 631 } 632 633 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) { 634 ComputationBuilder builder(client_, TestName()); 635 auto a = builder.ConstantR1<int32>({0, -1, -8}); 636 auto b = builder.ConstantR1<int32>({5, -7, 12}); 637 auto out = builder.And(a, b); 638 639 ComputeAndCompareR1<int32>(&builder, {0, -7, 8}, {}); 640 } 641 642 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) { 643 ComputationBuilder builder(client_, TestName()); 644 auto a = builder.ConstantR2<int32>({{0, -5}, {-1, 5}}); 645 auto b = builder.ConstantR2<int32>({{1, -6}, {4, 5}}); 646 auto out = builder.And(a, b); 647 648 Array2D<int32> expected_array({{0, -6}, {4, 5}}); 649 ComputeAndCompareR2<int32>(&builder, expected_array, {}); 650 } 651 652 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) { 653 ComputationBuilder builder(client_, TestName()); 654 auto a = builder.ConstantR1<int32>({}); 655 auto b = builder.ConstantR1<int32>({}); 656 auto out = builder.And(a, b); 657 658 ComputeAndCompareR1<int32>(&builder, {}, {}); 659 } 660 661 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) { 662 ComputationBuilder builder(client_, TestName()); 663 auto a = builder.ConstantR1<int32>({0, 1, 8}); 664 auto b = builder.ConstantR1<int32>({5, 7, 12}); 665 auto out = builder.And(a, b); 666 667 ComputeAndCompareR1<int32>(&builder, {0, 1, 8}, {}); 668 } 669 670 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) { 671 ComputationBuilder builder(client_, TestName()); 672 auto a = builder.ConstantR2<uint32>({{0, 1}, {3, 8}}); 673 auto b = builder.ConstantR2<uint32>({{1, 0}, {7, 6}}); 674 auto out = builder.And(a, b); 675 676 Array2D<uint32> expected_array({{0, 0}, {3, 0}}); 677 ComputeAndCompareR2<uint32>(&builder, expected_array, {}); 678 } 679 680 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) { 681 ComputationBuilder builder(client_, TestName()); 682 auto a = builder.ConstantR1<uint32>({}); 683 auto b = builder.ConstantR1<uint32>({}); 684 auto out = builder.And(a, b); 685 686 ComputeAndCompareR1<uint32>(&builder, {}, {}); 687 } 688 689 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) { 690 ComputationBuilder builder(client_, TestName()); 691 auto a = builder.ConstantR1<bool>({false, false, true, true}); 692 auto b = builder.ConstantR1<bool>({false, true, false, true}); 693 auto out = builder.Or(a, b); 694 695 ComputeAndCompareR1<bool>(&builder, {false, true, true, true}, {}); 696 } 697 698 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) { 699 ComputationBuilder builder(client_, TestName()); 700 auto a = builder.ConstantR2<bool>({{false, false}, {true, true}}); 701 auto b = builder.ConstantR2<bool>({{false, true}, {false, true}}); 702 auto out = builder.Or(a, b); 703 704 Array2D<bool> expected_array({{false, true}, {true, true}}); 705 ComputeAndCompareR2<bool>(&builder, expected_array, {}); 706 } 707 708 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) { 709 ComputationBuilder builder(client_, TestName()); 710 auto a = builder.ConstantR1<bool>({}); 711 auto b = builder.ConstantR1<bool>({}); 712 auto out = builder.Or(a, b); 713 714 ComputeAndCompareR1<bool>(&builder, {}, {}); 715 } 716 717 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) { 718 ComputationBuilder builder(client_, TestName()); 719 auto a = builder.ConstantR1<int32>({0, -1, 8}); 720 auto b = builder.ConstantR1<int32>({5, -7, 4}); 721 auto out = builder.Or(a, b); 722 723 ComputeAndCompareR1<int32>(&builder, {5, -1, 12}, {}); 724 } 725 726 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) { 727 ComputationBuilder builder(client_, TestName()); 728 auto a = builder.ConstantR2<int32>({{0, -1}, {8, 8}}); 729 auto b = builder.ConstantR2<int32>({{5, -7}, {4, 1}}); 730 auto out = builder.Or(a, b); 731 732 Array2D<int32> expected_array({{5, -1}, {12, 9}}); 733 ComputeAndCompareR2<int32>(&builder, expected_array, {}); 734 } 735 736 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) { 737 ComputationBuilder builder(client_, TestName()); 738 auto a = builder.ConstantR1<int32>({}); 739 auto b = builder.ConstantR1<int32>({}); 740 auto out = builder.Or(a, b); 741 742 ComputeAndCompareR1<int32>(&builder, {}, {}); 743 } 744 745 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) { 746 ComputationBuilder builder(client_, TestName()); 747 auto a = builder.ConstantR1<uint32>({0, 1, 8}); 748 auto b = builder.ConstantR1<uint32>({5, 7, 4}); 749 auto out = builder.Or(a, b); 750 751 ComputeAndCompareR1<uint32>(&builder, {5, 7, 12}, {}); 752 } 753 754 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) { 755 ComputationBuilder builder(client_, TestName()); 756 auto a = builder.ConstantR2<uint32>({{0, 1}, {8, 8}}); 757 auto b = builder.ConstantR2<uint32>({{5, 7}, {4, 1}}); 758 auto out = builder.Or(a, b); 759 760 Array2D<uint32> expected_array({{5, 7}, {12, 9}}); 761 ComputeAndCompareR2<uint32>(&builder, expected_array, {}); 762 } 763 764 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) { 765 ComputationBuilder builder(client_, TestName()); 766 auto a = builder.ConstantR1<uint32>({}); 767 auto b = builder.ConstantR1<uint32>({}); 768 auto out = builder.Or(a, b); 769 770 ComputeAndCompareR1<uint32>(&builder, {}, {}); 771 } 772 773 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) { 774 ComputationBuilder builder(client_, TestName()); 775 auto a = builder.ConstantR1<bool>({false, true, true, false}); 776 auto out = builder.Not(a); 777 778 ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {}); 779 } 780 781 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) { 782 ComputationBuilder builder(client_, TestName()); 783 auto a = builder.ConstantR2<bool>({{false, true}, {true, false}}); 784 auto out = builder.Not(a); 785 786 Array2D<bool> expected_array({{true, false}, {false, true}}); 787 ComputeAndCompareR2<bool>(&builder, expected_array, {}); 788 } 789 790 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) { 791 ComputationBuilder builder(client_, TestName()); 792 auto a = builder.ConstantR1<bool>({}); 793 auto out = builder.Not(a); 794 795 ComputeAndCompareR1<bool>(&builder, {}, {}); 796 } 797 798 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) { 799 ComputationBuilder builder(client_, TestName()); 800 auto a = builder.ConstantR1<int32>({-1, 0, 1}); 801 auto out = builder.Not(a); 802 803 ComputeAndCompareR1<int32>(&builder, {0, -1, -2}, {}); 804 } 805 806 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) { 807 ComputationBuilder builder(client_, TestName()); 808 auto a = builder.ConstantR2<int32>({{-1, 0}, {1, 8}}); 809 auto out = builder.Not(a); 810 811 Array2D<int32> expected_array({{0, -1}, {-2, -9}}); 812 ComputeAndCompareR2<int32>(&builder, expected_array, {}); 813 } 814 815 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) { 816 ComputationBuilder builder(client_, TestName()); 817 auto a = builder.ConstantR1<int32>({}); 818 auto out = builder.Not(a); 819 820 ComputeAndCompareR1<int32>(&builder, {}, {}); 821 } 822 823 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) { 824 ComputationBuilder builder(client_, TestName()); 825 auto a = builder.ConstantR1<uint32>({0, 4294967295}); 826 auto out = builder.Not(a); 827 828 ComputeAndCompareR1<uint32>(&builder, {4294967295, 0}, {}); 829 } 830 831 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) { 832 ComputationBuilder builder(client_, TestName()); 833 auto a = builder.ConstantR2<uint32>({{0, 4294967295}, {1, 4294967294}}); 834 auto out = builder.Not(a); 835 836 Array2D<uint32> expected_array({{4294967295, 0}, {4294967294, 1}}); 837 ComputeAndCompareR2<uint32>(&builder, expected_array, {}); 838 } 839 840 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { 841 ComputationBuilder builder(client_, TestName()); 842 auto a = builder.ConstantR1<uint32>({}); 843 auto out = builder.Not(a); 844 845 ComputeAndCompareR1<uint32>(&builder, {}, {}); 846 } 847 848 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { 849 ComputationBuilder builder(client_, TestName()); 850 auto a = 851 builder.ConstantR1<int32>({static_cast<int32>(0x12345678), 852 static_cast<int32>(0xF0001000), 1, 3, 77}); 853 auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 15}); 854 auto out = builder.ShiftLeft(a, b); 855 856 ComputeAndCompareR1<int32>( 857 &builder, 858 {static_cast<int32>(0x23456780), 0x00100000, 0x4, 0x180, 2523136}, {}); 859 } 860 861 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { 862 ComputationBuilder builder(client_, TestName()); 863 auto a = 864 builder.ConstantR1<int32>({static_cast<int32>(0x92345678), 865 static_cast<int32>(0x10001000), 1, 3, 77}); 866 auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 2}); 867 auto out = builder.ShiftRightArithmetic(a, b); 868 869 ComputeAndCompareR1<int32>(&builder, 870 {static_cast<int32>(0xF9234567), 871 static_cast<int32>(0x00100010), 0, 0, 19}, 872 {}); 873 } 874 875 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { 876 ComputationBuilder builder(client_, TestName()); 877 auto a = 878 builder.ConstantR1<int32>({static_cast<int32>(0x92345678), 879 static_cast<int32>(0x10001000), 1, 3, 77}); 880 auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 5}); 881 auto out = builder.ShiftRightLogical(a, b); 882 883 ComputeAndCompareR1<int32>(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {}); 884 } 885 886 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { 887 ComputationBuilder builder(client_, TestName()); 888 auto a = builder.ConstantR1<uint32>({0x12345678, 0xF0001000, 1, 3, 77}); 889 auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 15}); 890 auto out = builder.ShiftLeft(a, b); 891 892 ComputeAndCompareR1<uint32>( 893 &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136}, {}); 894 } 895 896 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { 897 ComputationBuilder builder(client_, TestName()); 898 auto a = builder.ConstantR1<uint32>({0x92345678, 0x10001000, 1, 3, 77}); 899 auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 2}); 900 auto out = builder.ShiftRightArithmetic(a, b); 901 902 ComputeAndCompareR1<uint32>(&builder, {0xF9234567, 0x00100010, 0, 0, 19}, {}); 903 } 904 905 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { 906 ComputationBuilder builder(client_, TestName()); 907 auto a = builder.ConstantR1<uint32>({0x92345678, 0x10001000, 1, 3, 77}); 908 auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 5}); 909 auto out = builder.ShiftRightLogical(a, b); 910 911 ComputeAndCompareR1<uint32>(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {}); 912 } 913 914 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { 915 SetFastMathDisabled(true); 916 ComputationBuilder builder(client_, TestName()); 917 auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); 918 auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 2.25f, 10.0f, NAN}); 919 auto compare = builder.Eq(lhs, rhs); 920 921 ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {}); 922 } 923 924 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { 925 ComputationBuilder builder(client_, TestName()); 926 auto lhs = builder.ConstantR1<float>({}); 927 auto rhs = builder.ConstantR1<float>({}); 928 auto compare = builder.Eq(lhs, rhs); 929 930 ComputeAndCompareR1<bool>(&builder, {}, {}); 931 } 932 933 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { 934 SetFastMathDisabled(true); 935 ComputationBuilder builder(client_, TestName()); 936 auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); 937 auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN}); 938 auto compare = builder.Ge(lhs, rhs); 939 940 ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {}); 941 } 942 943 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { 944 SetFastMathDisabled(true); 945 ComputationBuilder builder(client_, TestName()); 946 auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); 947 auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN}); 948 auto compare = builder.Gt(lhs, rhs); 949 950 ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {}); 951 } 952 953 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) { 954 SetFastMathDisabled(true); 955 ComputationBuilder builder(client_, TestName()); 956 auto lhs = builder.ConstantR1<float>({-2.5f, 5.0f, 2.25f, NAN, 6.0f}); 957 auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN}); 958 auto compare = builder.Le(lhs, rhs); 959 960 ComputeAndCompareR1<bool>(&builder, {true, true, false, false, false}, {}); 961 } 962 963 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) { 964 SetFastMathDisabled(true); 965 ComputationBuilder builder(client_, TestName()); 966 auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); 967 auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN}); 968 auto compare = builder.Lt(lhs, rhs); 969 970 ComputeAndCompareR1<bool>(&builder, {true, false, false, false, false}, {}); 971 } 972 973 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { 974 const int32 min = std::numeric_limits<int32>::min(); 975 const int32 max = std::numeric_limits<int32>::max(); 976 ComputationBuilder builder(client_, TestName()); 977 auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max}); 978 auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max}); 979 auto compare = builder.Eq(lhs, rhs); 980 981 ComputeAndCompareR1<bool>( 982 &builder, {true, false, false, false, true, false, false, false, true}, 983 {}); 984 } 985 986 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { 987 ComputationBuilder builder(client_, TestName()); 988 auto lhs = builder.ConstantR1<int32>({}); 989 auto rhs = builder.ConstantR1<int32>({}); 990 auto compare = builder.Eq(lhs, rhs); 991 992 ComputeAndCompareR1<bool>(&builder, {}, {}); 993 } 994 995 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { 996 SetFastMathDisabled(true); 997 ComputationBuilder builder(client_, TestName()); 998 auto lhs = builder.ConstantR1<complex64>({{-2.5f, 10.0f}, 999 {1.0f, 25.5f}, 1000 {2.25f, -3.0f}, 1001 {NAN, 0.0f}, 1002 {1.0f, 6.0f}}); 1003 auto rhs = builder.ConstantR1<complex64>({{0.0f, 10.0f}, 1004 {1.0f, 5.0f}, 1005 {2.25f, -3.0f}, 1006 {10.0f, 0.0f}, 1007 {1.0f, NAN}}); 1008 auto compare = builder.Eq(lhs, rhs); 1009 1010 ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {}); 1011 } 1012 1013 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) { 1014 ComputationBuilder builder(client_, TestName()); 1015 auto lhs = builder.ConstantR1<complex64>({}); 1016 auto rhs = builder.ConstantR1<complex64>({}); 1017 auto compare = builder.Eq(lhs, rhs); 1018 1019 ComputeAndCompareR1<bool>(&builder, {}, {}); 1020 } 1021 1022 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { 1023 // Disable fast-math because we're operating on NaNs. 1024 SetFastMathDisabled(true); 1025 1026 ComputationBuilder builder(client_, TestName()); 1027 auto lhs = builder.ConstantR1<complex64>({{-2.5f, 10.0f}, 1028 {1.0f, 25.5f}, 1029 {2.25f, -3.0f}, 1030 {NAN, 0.0f}, 1031 {1.0f, 6.0f}}); 1032 auto rhs = builder.ConstantR1<complex64>({{0.0f, 10.0f}, 1033 {1.0f, 5.0f}, 1034 {2.25f, -3.0f}, 1035 {10.0f, 0.0f}, 1036 {1.0f, NAN}}); 1037 auto compare = builder.Ne(lhs, rhs); 1038 1039 ComputeAndCompareR1<bool>(&builder, {true, true, false, true, true}, {}); 1040 } 1041 1042 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { 1043 // Disable fast-math because we're operating on NaNs. 1044 SetFastMathDisabled(true); 1045 1046 ComputationBuilder builder(client_, TestName()); 1047 auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); 1048 auto rhs = builder.ConstantR1<float>({10.0f, 25.5f, 1.0f, 10.0f, NAN}); 1049 auto compare = builder.Ne(lhs, rhs); 1050 1051 ComputeAndCompareR1<bool>(&builder, {true, false, true, true, true}, {}); 1052 } 1053 1054 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { 1055 const int32 min = std::numeric_limits<int32>::min(); 1056 const int32 max = std::numeric_limits<int32>::max(); 1057 ComputationBuilder builder(client_, TestName()); 1058 auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max}); 1059 auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max}); 1060 auto compare = builder.Ne(lhs, rhs); 1061 1062 ComputeAndCompareR1<bool>( 1063 &builder, {false, true, true, true, false, true, true, true, false}, {}); 1064 } 1065 1066 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) { 1067 const int32 min = std::numeric_limits<int32>::min(); 1068 const int32 max = std::numeric_limits<int32>::max(); 1069 ComputationBuilder builder(client_, TestName()); 1070 auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max}); 1071 auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max}); 1072 auto compare = builder.Ge(lhs, rhs); 1073 1074 ComputeAndCompareR1<bool>( 1075 &builder, {true, false, false, true, true, false, true, true, true}, {}); 1076 } 1077 1078 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) { 1079 const int32 min = std::numeric_limits<int32>::min(); 1080 const int32 max = std::numeric_limits<int32>::max(); 1081 ComputationBuilder builder(client_, TestName()); 1082 auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max}); 1083 auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max}); 1084 auto compare = builder.Gt(lhs, rhs); 1085 1086 ComputeAndCompareR1<bool>( 1087 &builder, {false, false, false, true, false, false, true, true, false}, 1088 {}); 1089 } 1090 1091 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) { 1092 const int32 min = std::numeric_limits<int32>::min(); 1093 const int32 max = std::numeric_limits<int32>::max(); 1094 ComputationBuilder builder(client_, TestName()); 1095 auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max}); 1096 auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max}); 1097 auto compare = builder.Le(lhs, rhs); 1098 1099 ComputeAndCompareR1<bool>( 1100 &builder, {true, true, true, false, true, true, false, false, true}, {}); 1101 } 1102 1103 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { 1104 const int32 min = std::numeric_limits<int32>::min(); 1105 const int32 max = std::numeric_limits<int32>::max(); 1106 ComputationBuilder builder(client_, TestName()); 1107 auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max}); 1108 auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max}); 1109 auto compare = builder.Lt(lhs, rhs); 1110 1111 ComputeAndCompareR1<bool>( 1112 &builder, {false, true, true, false, false, true, false, false, false}, 1113 {}); 1114 } 1115 1116 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { 1117 const uint32 max = std::numeric_limits<uint32>::max(); 1118 ComputationBuilder builder(client_, TestName()); 1119 auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); 1120 auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max}); 1121 auto compare = builder.Eq(lhs, rhs); 1122 1123 ComputeAndCompareR1<bool>( 1124 &builder, {true, false, false, false, true, false, false, false, true}, 1125 {}); 1126 } 1127 1128 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { 1129 const uint32 max = std::numeric_limits<uint32>::max(); 1130 ComputationBuilder builder(client_, TestName()); 1131 auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); 1132 auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max}); 1133 auto compare = builder.Ne(lhs, rhs); 1134 1135 ComputeAndCompareR1<bool>( 1136 &builder, {false, true, true, true, false, true, true, true, false}, {}); 1137 } 1138 1139 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { 1140 const uint32 max = std::numeric_limits<uint32>::max(); 1141 ComputationBuilder builder(client_, TestName()); 1142 auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); 1143 auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max}); 1144 auto compare = builder.Ge(lhs, rhs); 1145 1146 ComputeAndCompareR1<bool>( 1147 &builder, {true, false, false, true, true, false, true, true, true}, {}); 1148 } 1149 1150 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { 1151 const uint32 max = std::numeric_limits<uint32>::max(); 1152 ComputationBuilder builder(client_, TestName()); 1153 auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); 1154 auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max}); 1155 auto compare = builder.Gt(lhs, rhs); 1156 1157 ComputeAndCompareR1<bool>( 1158 &builder, {false, false, false, true, false, false, true, true, false}, 1159 {}); 1160 } 1161 1162 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { 1163 const uint32 max = std::numeric_limits<uint32>::max(); 1164 ComputationBuilder builder(client_, TestName()); 1165 auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); 1166 auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max}); 1167 auto compare = builder.Le(lhs, rhs); 1168 1169 ComputeAndCompareR1<bool>( 1170 &builder, {true, true, true, false, true, true, false, false, true}, {}); 1171 } 1172 1173 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { 1174 const uint32 max = std::numeric_limits<uint32>::max(); 1175 ComputationBuilder builder(client_, TestName()); 1176 auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); 1177 auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max}); 1178 auto compare = builder.Lt(lhs, rhs); 1179 1180 ComputeAndCompareR1<bool>( 1181 &builder, {false, true, true, false, false, true, false, false, false}, 1182 {}); 1183 } 1184 1185 XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { 1186 SetFastMathDisabled(true); 1187 ComputationBuilder builder(client_, TestName()); 1188 auto lhs = 1189 builder.ConstantR1<float>({4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f}); 1190 auto rhs = 1191 builder.ConstantR1<float>({2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f}); 1192 auto minimum = builder.Pow(lhs, rhs); 1193 1194 ComputeAndCompareR1<float>( 1195 &builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_); 1196 } 1197 1198 XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { 1199 SetFastMathDisabled(true); 1200 ComputationBuilder builder(client_, TestName()); 1201 auto lhs = builder.ConstantR1<float>({-2.0f, -0.6f, -0.6f, 0.0f}); 1202 auto rhs = builder.ConstantR1<float>({0.5f, 0.6f, -0.6f, -0.6f}); 1203 auto minimum = builder.Pow(lhs, rhs); 1204 1205 ComputeAndCompareR1<float>(&builder, {NAN, NAN, NAN, INFINITY}, {}, 1206 error_spec_); 1207 } 1208 1209 XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { 1210 ComputationBuilder builder(client_, TestName()); 1211 auto lhs = builder.ConstantR1<float>({}); 1212 auto rhs = builder.ConstantR1<float>({}); 1213 auto minimum = builder.Pow(lhs, rhs); 1214 1215 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 1216 } 1217 1218 // Some Pow cases that can be implemented more efficiently. 1219 XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { 1220 ComputationBuilder b(client_, TestName()); 1221 1222 std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f}; 1223 std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; 1224 1225 std::unique_ptr<Literal> param_literal = Literal::CreateR1<float>(values); 1226 std::unique_ptr<GlobalData> param_data = 1227 client_->TransferToServer(*param_literal).ConsumeValueOrDie(); 1228 1229 auto sum = b.ConstantR0<float>(0.0f); 1230 auto param = b.Parameter(0, param_literal->shape(), "param"); 1231 for (float exponent : exponents) { 1232 sum = b.Add(sum, b.Pow(param, b.ConstantR0<float>(exponent))); 1233 } 1234 1235 std::vector<float> expected; 1236 for (auto value : values) { 1237 float sum = 0.0f; 1238 for (float exponent : exponents) { 1239 sum += std::pow(value, exponent); 1240 } 1241 expected.push_back(sum); 1242 } 1243 1244 ComputeAndCompareR1<float>(&b, expected, {param_data.get()}, error_spec_); 1245 } 1246 1247 XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { 1248 ComputationBuilder b(client_, TestName()); 1249 1250 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; 1251 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; 1252 1253 std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0); 1254 std::unique_ptr<GlobalData> data0 = 1255 client_->TransferToServer(*literal0).ConsumeValueOrDie(); 1256 std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1); 1257 std::unique_ptr<GlobalData> data1 = 1258 client_->TransferToServer(*literal1).ConsumeValueOrDie(); 1259 auto param0 = b.Parameter(0, literal0->shape(), "param0"); 1260 auto param1 = b.Parameter(1, literal1->shape(), "param1"); 1261 b.Pow(b.Exp(param0), param1); 1262 1263 std::vector<float> expected(values0.size()); 1264 for (int64 i = 0; i < values0.size(); ++i) { 1265 expected[i] = std::pow(std::exp(values0[i]), values1[i]); 1266 } 1267 1268 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()}, 1269 error_spec_); 1270 } 1271 1272 XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { 1273 ComputationBuilder b(client_, TestName()); 1274 1275 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f}; 1276 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; 1277 1278 std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0); 1279 std::unique_ptr<GlobalData> data0 = 1280 client_->TransferToServer(*literal0).ConsumeValueOrDie(); 1281 std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1); 1282 std::unique_ptr<GlobalData> data1 = 1283 client_->TransferToServer(*literal1).ConsumeValueOrDie(); 1284 auto param0 = b.Parameter(0, literal0->shape(), "param0"); 1285 auto param1 = b.Parameter(1, literal1->shape(), "param1"); 1286 b.Log(b.Pow(param0, param1)); 1287 1288 std::vector<float> expected(values0.size()); 1289 for (int64 i = 0; i < values0.size(); ++i) { 1290 expected[i] = std::log(std::pow(values0[i], values1[i])); 1291 } 1292 1293 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()}, 1294 error_spec_); 1295 } 1296 1297 XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { 1298 ComputationBuilder b(client_, TestName()); 1299 1300 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; 1301 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; 1302 1303 std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0); 1304 std::unique_ptr<GlobalData> data0 = 1305 client_->TransferToServer(*literal0).ConsumeValueOrDie(); 1306 std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1); 1307 std::unique_ptr<GlobalData> data1 = 1308 client_->TransferToServer(*literal1).ConsumeValueOrDie(); 1309 auto param0 = b.Parameter(0, literal0->shape(), "param0"); 1310 auto param1 = b.Parameter(1, literal1->shape(), "param1"); 1311 b.Mul(b.Exp(param0), b.Exp(param1)); 1312 1313 std::vector<float> expected(values0.size()); 1314 for (int64 i = 0; i < values0.size(); ++i) { 1315 expected[i] = std::exp(values0[i]) * std::exp(values1[i]); 1316 } 1317 1318 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()}, 1319 error_spec_); 1320 } 1321 1322 XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { 1323 ComputationBuilder b(client_, TestName()); 1324 1325 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; 1326 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; 1327 1328 std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0); 1329 std::unique_ptr<GlobalData> data0 = 1330 client_->TransferToServer(*literal0).ConsumeValueOrDie(); 1331 std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1); 1332 std::unique_ptr<GlobalData> data1 = 1333 client_->TransferToServer(*literal1).ConsumeValueOrDie(); 1334 auto param0 = b.Parameter(0, literal0->shape(), "param0"); 1335 auto param1 = b.Parameter(1, literal1->shape(), "param1"); 1336 b.Div(param0, b.Exp(param1)); 1337 1338 std::vector<float> expected(values0.size()); 1339 for (int64 i = 0; i < values0.size(); ++i) { 1340 expected[i] = values0[i] / std::exp(values1[i]); 1341 } 1342 1343 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()}, 1344 error_spec_); 1345 } 1346 1347 XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { 1348 ComputationBuilder b(client_, TestName()); 1349 1350 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; 1351 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; 1352 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; 1353 1354 std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0); 1355 std::unique_ptr<GlobalData> data0 = 1356 client_->TransferToServer(*literal0).ConsumeValueOrDie(); 1357 1358 std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1); 1359 std::unique_ptr<GlobalData> data1 = 1360 client_->TransferToServer(*literal1).ConsumeValueOrDie(); 1361 1362 std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2); 1363 std::unique_ptr<GlobalData> data2 = 1364 client_->TransferToServer(*literal2).ConsumeValueOrDie(); 1365 auto param0 = b.Parameter(0, literal0->shape(), "param0"); 1366 auto param1 = b.Parameter(1, literal1->shape(), "param1"); 1367 auto param2 = b.Parameter(2, literal2->shape(), "param2"); 1368 b.Div(b.Div(param0, param1), param2); 1369 1370 std::vector<float> expected(values0.size()); 1371 for (int64 i = 0; i < values0.size(); ++i) { 1372 expected[i] = (values0[i] / values1[i]) / values2[i]; 1373 } 1374 1375 ComputeAndCompareR1<float>( 1376 &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); 1377 } 1378 1379 XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { 1380 ComputationBuilder b(client_, TestName()); 1381 1382 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; 1383 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; 1384 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; 1385 1386 std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0); 1387 std::unique_ptr<GlobalData> data0 = 1388 client_->TransferToServer(*literal0).ConsumeValueOrDie(); 1389 1390 std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1); 1391 std::unique_ptr<GlobalData> data1 = 1392 client_->TransferToServer(*literal1).ConsumeValueOrDie(); 1393 1394 std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2); 1395 std::unique_ptr<GlobalData> data2 = 1396 client_->TransferToServer(*literal2).ConsumeValueOrDie(); 1397 1398 auto param0 = b.Parameter(0, literal0->shape(), "param0"); 1399 auto param1 = b.Parameter(1, literal1->shape(), "param1"); 1400 auto param2 = b.Parameter(2, literal2->shape(), "param2"); 1401 b.Div(param0, b.Div(param1, param2)); 1402 1403 std::vector<float> expected(values0.size()); 1404 for (int64 i = 0; i < values0.size(); ++i) { 1405 expected[i] = values0[i] / (values1[i] / values2[i]); 1406 } 1407 1408 ComputeAndCompareR1<float>( 1409 &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); 1410 } 1411 1412 XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { 1413 ComputationBuilder b(client_, TestName()); 1414 1415 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; 1416 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f}; 1417 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f}; 1418 1419 std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0); 1420 std::unique_ptr<GlobalData> data0 = 1421 client_->TransferToServer(*literal0).ConsumeValueOrDie(); 1422 1423 std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1); 1424 std::unique_ptr<GlobalData> data1 = 1425 client_->TransferToServer(*literal1).ConsumeValueOrDie(); 1426 1427 std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2); 1428 std::unique_ptr<GlobalData> data2 = 1429 client_->TransferToServer(*literal2).ConsumeValueOrDie(); 1430 1431 auto param0 = b.Parameter(0, literal0->shape(), "param0"); 1432 auto param1 = b.Parameter(1, literal1->shape(), "param1"); 1433 auto param2 = b.Parameter(2, literal2->shape(), "param2"); 1434 b.Div(param0, b.Pow(param1, param2)); 1435 1436 std::vector<float> expected(values0.size()); 1437 for (int64 i = 0; i < values0.size(); ++i) { 1438 expected[i] = values0[i] / std::pow(values1[i], values2[i]); 1439 } 1440 1441 ComputeAndCompareR1<float>( 1442 &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); 1443 } 1444 1445 XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { 1446 ComputationBuilder b(client_, TestName()); 1447 1448 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; 1449 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; 1450 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; 1451 std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f}; 1452 1453 std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0); 1454 std::unique_ptr<GlobalData> data0 = 1455 client_->TransferToServer(*literal0).ConsumeValueOrDie(); 1456 1457 std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1); 1458 std::unique_ptr<GlobalData> data1 = 1459 client_->TransferToServer(*literal1).ConsumeValueOrDie(); 1460 1461 std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2); 1462 std::unique_ptr<GlobalData> data2 = 1463 client_->TransferToServer(*literal2).ConsumeValueOrDie(); 1464 1465 std::unique_ptr<Literal> literal3 = Literal::CreateR1<float>(values3); 1466 std::unique_ptr<GlobalData> data3 = 1467 client_->TransferToServer(*literal3).ConsumeValueOrDie(); 1468 1469 auto param0 = b.Parameter(0, literal0->shape(), "param0"); 1470 auto param1 = b.Parameter(1, literal1->shape(), "param1"); 1471 auto param2 = b.Parameter(2, literal2->shape(), "param2"); 1472 auto param3 = b.Parameter(3, literal3->shape(), "param2"); 1473 b.Div(b.Div(param0, param1), b.Div(param2, param3)); 1474 1475 std::vector<float> expected(values0.size()); 1476 for (int64 i = 0; i < values0.size(); ++i) { 1477 expected[i] = (values0[i] / values1[i]) / (values2[i] / values3[i]); 1478 } 1479 1480 ComputeAndCompareR1<float>( 1481 &b, expected, {data0.get(), data1.get(), data2.get(), data3.get()}, 1482 error_spec_); 1483 } 1484 1485 TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { 1486 const int count = GetParam(); 1487 ComputationBuilder builder(client_, TestName()); 1488 std::vector<float> values; 1489 values.reserve(count); 1490 for (int i = 0; i < count; ++i) { 1491 values.push_back(i / static_cast<float>(count)); 1492 } 1493 auto x = builder.ConstantR1<float>(values); 1494 auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f)); 1495 1496 std::vector<float> expected; 1497 expected.reserve(values.size()); 1498 for (float value : values) { 1499 expected.push_back(value * value); 1500 } 1501 1502 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); 1503 } 1504 1505 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) { 1506 ComputationBuilder builder(client_, TestName()); 1507 Array4D<float> values(2, 2, 2, 2); 1508 1509 std::vector<float> values_vector; 1510 std::vector<float> expected_vector; 1511 for (int i = 0; i < values.num_elements(); ++i) { 1512 values_vector.push_back(static_cast<float>(i) / values.num_elements()); 1513 expected_vector.push_back(values_vector.back() * values_vector.back()); 1514 } 1515 values.SetValues(values_vector); 1516 1517 Array4D<float> expected(2, 2, 2, 2, expected_vector); 1518 1519 auto x = builder.ConstantR4FromArray4D<float>(values); 1520 auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f)); 1521 1522 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 1523 } 1524 1525 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) { 1526 ComputationBuilder builder(client_, TestName()); 1527 Array4D<float> values(2, 2, 0, 2); 1528 Array4D<float> expected(2, 2, 0, 2); 1529 1530 auto x = builder.ConstantR4FromArray4D<float>(values); 1531 auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f)); 1532 1533 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 1534 } 1535 1536 // GPU backend emits nvvm intrinsic for fmin and fmax, whose semantics is NOT 1537 // such 1538 // * fmin(NaN, x) = x 1539 // * fmax(NaN, x) = x 1540 // so we only test NAN on CPU. 1541 // 1542 // TODO(b/28180546): Make this compile in a way that is consistent 1543 // among backends. 1544 XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) { 1545 ComputationBuilder builder(client_, TestName()); 1546 #if !defined(XLA_TEST_BACKEND_CPU) 1547 auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f}); 1548 auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f}); 1549 #else 1550 SetFastMathDisabled(true); 1551 auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f, NAN, 6.0f}); 1552 auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f, 10.0f, NAN}); 1553 #endif 1554 auto minimum = builder.Min(lhs, rhs); 1555 1556 ComputeAndCompareR1<float>(&builder, 1557 #if !defined(XLA_TEST_BACKEND_CPU) 1558 {1.0f, -5.0f, 1.0f}, 1559 #else 1560 {1.0f, -5.0f, 1.0f, 10.0f, 6.0f}, 1561 #endif 1562 {}, error_spec_); 1563 } 1564 1565 XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) { 1566 ComputationBuilder builder(client_, TestName()); 1567 auto lhs = builder.ConstantR1<float>({}); 1568 auto rhs = builder.ConstantR1<float>({}); 1569 auto minimum = builder.Min(lhs, rhs); 1570 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 1571 } 1572 1573 // TODO(b/28180546): Make this compile in a way that is consistent 1574 // among backends. See comment on MinF32s test above. 1575 XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) { 1576 ComputationBuilder builder(client_, TestName()); 1577 #if !defined(XLA_TEST_BACKEND_CPU) 1578 auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25}); 1579 auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0}); 1580 #else 1581 SetFastMathDisabled(true); 1582 auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25, NAN, 6.0}); 1583 auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0, 10.0, NAN}); 1584 #endif 1585 auto minimum = builder.Min(lhs, rhs); 1586 1587 ComputeAndCompareR1<double>(&builder, 1588 #if !defined(XLA_TEST_BACKEND_CPU) 1589 {1.0, -5.0, 1.0}, 1590 #else 1591 {1.0, -5.0, 1.0, 10.0, 6.0}, 1592 #endif 1593 {}, error_spec_); 1594 } 1595 1596 // TODO(b/28180546): Make this compile in a way that is consistent 1597 // among backends. See comment on MinF32s test above. 1598 XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) { 1599 ComputationBuilder builder(client_, TestName()); 1600 #if !defined(XLA_TEST_BACKEND_CPU) 1601 auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f}); 1602 auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f}); 1603 #else 1604 SetFastMathDisabled(true); 1605 auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f, NAN, 6.0f}); 1606 auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f, 10.0f, NAN}); 1607 #endif 1608 auto maximum = builder.Max(lhs, rhs); 1609 1610 ComputeAndCompareR1<float>(&builder, 1611 #if !defined(XLA_TEST_BACKEND_CPU) 1612 {2.0f, 1.0f, 2.25f}, 1613 #else 1614 {2.0f, 1.0f, 2.25f, 10.0f, 6.0f}, 1615 #endif 1616 {}, error_spec_); 1617 } 1618 1619 XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) { 1620 ComputationBuilder builder(client_, TestName()); 1621 auto lhs = builder.ConstantR1<float>({}); 1622 auto rhs = builder.ConstantR1<float>({}); 1623 auto minimum = builder.Max(lhs, rhs); 1624 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 1625 } 1626 1627 // TODO(b/28180546): Make this compile in a way that is consistent 1628 // among backends. See comment on MinF32s test above. 1629 XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) { 1630 ComputationBuilder builder(client_, TestName()); 1631 #if !defined(XLA_TEST_BACKEND_CPU) 1632 auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25}); 1633 auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0}); 1634 #else 1635 SetFastMathDisabled(true); 1636 auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25, NAN, 6.0}); 1637 auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0, 10.0, NAN}); 1638 #endif 1639 auto maximum = builder.Max(lhs, rhs); 1640 1641 ComputeAndCompareR1<double>(&builder, 1642 #if !defined(XLA_TEST_BACKEND_CPU) 1643 {2.0, 1.0, 2.25}, 1644 #else 1645 {2.0, 1.0, 2.25, 10.0, 6.0}, 1646 #endif 1647 {}, error_spec_); 1648 } 1649 1650 XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) { 1651 const int32 min = std::numeric_limits<int32>::min(); 1652 const int32 max = std::numeric_limits<int32>::max(); 1653 ComputationBuilder builder(client_, TestName()); 1654 auto x = builder.ConstantR1<int32>( 1655 {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); 1656 auto y = builder.ConstantR1<int32>( 1657 {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); 1658 builder.Max(x, y); 1659 1660 std::vector<int32> expected = {min, max, 0, -1, 0, 0, 0, 1661 1, 1, 10, max, max, max}; 1662 ComputeAndCompareR1<int32>(&builder, expected, {}); 1663 } 1664 1665 XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) { 1666 const int32 min = std::numeric_limits<int32>::min(); 1667 const int32 max = std::numeric_limits<int32>::max(); 1668 ComputationBuilder builder(client_, TestName()); 1669 auto x = builder.ConstantR1<int32>( 1670 {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); 1671 auto y = builder.ConstantR1<int32>( 1672 {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); 1673 builder.Min(x, y); 1674 1675 std::vector<int32> expected = {min, min, min, -10, -1, -1, 0, 1676 0, 0, 1, 0, max, min}; 1677 ComputeAndCompareR1<int32>(&builder, expected, {}); 1678 } 1679 1680 XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) { 1681 const uint32 max = std::numeric_limits<uint32>::max(); 1682 ComputationBuilder builder(client_, TestName()); 1683 auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max}); 1684 auto y = builder.ConstantR1<uint32>({0, 1, 0, 1, 10, 0, 234234, max}); 1685 builder.Max(x, y); 1686 1687 std::vector<uint32> expected = {0, 1, 1, 1, 10, max, max, max}; 1688 ComputeAndCompareR1<uint32>(&builder, expected, {}); 1689 } 1690 1691 XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) { 1692 const uint32 max = std::numeric_limits<uint32>::max(); 1693 ComputationBuilder builder(client_, TestName()); 1694 auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max}); 1695 auto y = builder.ConstantR1<uint32>({0, 1, 0, 1, 10, 0, 234234, max}); 1696 builder.Min(x, y); 1697 1698 std::vector<uint32> expected = {0, 0, 0, 1, 1, 0, 234234, max}; 1699 ComputeAndCompareR1<uint32>(&builder, expected, {}); 1700 } 1701 1702 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { 1703 ComputationBuilder builder(client_, TestName()); 1704 auto x = builder.ConstantR1<float>( 1705 {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); 1706 auto y = builder.ConstantR1<float>( 1707 {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); 1708 builder.Max(x, y); 1709 1710 std::vector<float> expected = {-0.0, 1.0, 2.0, 3.0, 4.0, 1711 5.0, 6.0, 7.0, 8.0, 9.0}; 1712 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); 1713 } 1714 1715 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { 1716 ComputationBuilder builder(client_, TestName()); 1717 auto u = builder.ConstantR1<float>({3.5}); 1718 auto v = builder.ConstantR1<float>({}); 1719 builder.Max(u, v); 1720 1721 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); 1722 } 1723 1724 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { 1725 for (int broadcast_dim : {0, 1}) { 1726 ComputationBuilder builder(client_, TestName()); 1727 auto u = builder.ConstantR1<float>({3.5}); 1728 auto v = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2)); 1729 builder.Max(u, v, /*broadcast_dimensions=*/{broadcast_dim}); 1730 1731 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_); 1732 } 1733 } 1734 1735 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { 1736 ComputationBuilder builder(client_, TestName()); 1737 auto v = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f}); 1738 auto m = 1739 builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); 1740 builder.Max(v, m, /*broadcast_dimensions=*/{1}); 1741 1742 Array2D<float> expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}}); 1743 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); 1744 } 1745 1746 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { 1747 ComputationBuilder builder(client_, TestName()); 1748 auto v = builder.ConstantR1<float>({}); 1749 auto m = builder.ConstantR2<float>({{}, {}}); 1750 builder.Max(v, m, /*broadcast_dimensions=*/{1}); 1751 1752 Array2D<float> expected({{}, {}}); 1753 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); 1754 } 1755 1756 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { 1757 ComputationBuilder builder(client_, TestName()); 1758 auto scalar = builder.ConstantR0<int32>(2); 1759 Array3D<int32> a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}}); 1760 auto array = builder.ConstantR3FromArray3D<int32>(a_3d); 1761 builder.Max(array, scalar, /*broadcast_dimensions=*/{}); 1762 1763 Array3D<int32> expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}}); 1764 ComputeAndCompareR3<int32>(&builder, expected, {}); 1765 } 1766 1767 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { 1768 ComputationBuilder builder(client_, TestName()); 1769 auto scalar = builder.ConstantR0<int32>(2); 1770 Array3D<int32> a_3d(2, 0, 3); 1771 auto array = builder.ConstantR3FromArray3D<int32>(a_3d); 1772 builder.Max(array, scalar, /*broadcast_dimensions=*/{}); 1773 1774 Array3D<int32> expected(2, 0, 3); 1775 ComputeAndCompareR3<int32>(&builder, expected, {}); 1776 } 1777 1778 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { 1779 ComputationBuilder builder(client_, TestName()); 1780 auto m = 1781 builder.ConstantR2<float>({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}}); 1782 auto v = builder.ConstantR1<float>({-10.2f, 16.4f}); 1783 builder.Min(m, v, /*broadcast_dimensions=*/{0}); 1784 1785 Array2D<float> expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}}); 1786 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); 1787 } 1788 1789 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { 1790 ComputationBuilder builder(client_, TestName()); 1791 auto m = builder.ConstantR2<float>({{}, {}}); 1792 auto v = builder.ConstantR1<float>({-10.2f, 16.4f}); 1793 builder.Min(m, v, /*broadcast_dimensions=*/{0}); 1794 1795 Array2D<float> expected({{}, {}}); 1796 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); 1797 } 1798 1799 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { 1800 ComputationBuilder builder(client_, TestName()); 1801 auto array2d = 1802 builder.ConstantR2<float>({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); 1803 auto array4d = builder.ConstantR4FromArray4D<float>( 1804 {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}}, 1805 {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}}); 1806 builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); 1807 1808 Array4D<float> expected( 1809 {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}}, 1810 {{{-12.2f, 64.29f, 6.1f}}, {{-0.01f, 32.2f, 2.5f}}}}); 1811 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 1812 } 1813 1814 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { 1815 ComputationBuilder builder(client_, TestName()); 1816 auto array2d = 1817 builder.ConstantR2<float>({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); 1818 Array4D<float> arg(2, 2, 0, 3); 1819 auto array4d = builder.ConstantR4FromArray4D<float>(arg); 1820 builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); 1821 1822 Array4D<float> expected(2, 2, 0, 3); 1823 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 1824 } 1825 1826 XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { 1827 ComputationBuilder builder(client_, TestName()); 1828 auto x = builder.ConstantR1<int32>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); 1829 auto y = builder.ConstantR1<int32>({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); 1830 builder.Min(x, y); 1831 1832 std::vector<int32> expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0}; 1833 ComputeAndCompareR1<int32>(&builder, expected, {}); 1834 } 1835 1836 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { 1837 ComputationBuilder builder(client_, TestName()); 1838 auto x = builder.ConstantR1<int32>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); 1839 auto y = builder.ConstantR1<int32>({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); 1840 builder.Max(x, y); 1841 1842 std::vector<int32> expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9}; 1843 ComputeAndCompareR1<int32>(&builder, expected, {}); 1844 } 1845 1846 XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) { 1847 ComputationBuilder builder(client_, TestName()); 1848 auto a = builder.ConstantR1<int32>({-3, 26, 2, -1, 1}); 1849 auto b = builder.ConstantR1<int32>({10, 5, 1, 10, -10}); 1850 auto add = builder.Rem(a, b); 1851 1852 ComputeAndCompareR1<int32>(&builder, {-3, 1, 0, -1, 1}, {}); 1853 } 1854 1855 XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { 1856 ComputationBuilder builder(client_, TestName()); 1857 auto minimum = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); 1858 auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); 1859 auto maximum = builder.ConstantR1<float>({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); 1860 auto clamp = builder.Clamp(minimum, argument, maximum); 1861 1862 ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {}, 1863 error_spec_); 1864 } 1865 1866 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { 1867 ComputationBuilder builder(client_, TestName()); 1868 auto minimum = builder.ConstantR0<float>(0.0f); 1869 auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); 1870 auto maximum = builder.ConstantR0<float>(5.0f); 1871 auto clamp = builder.Clamp(minimum, argument, maximum); 1872 1873 ComputeAndCompareR1<float>(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {}, 1874 error_spec_); 1875 } 1876 1877 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { 1878 ComputationBuilder builder(client_, TestName()); 1879 auto min_scalar = builder.ConstantR0<float>(0.0f); 1880 auto min_vector = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); 1881 auto arg_vector = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); 1882 auto max_scalar = builder.ConstantR0<float>(3.0f); 1883 auto max_vector = builder.ConstantR1<float>({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); 1884 // Perform clamp with broadcasted scalar and vector. 1885 auto clamp = builder.Add( 1886 builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), 1887 builder.Clamp(min_scalar, arg_vector, max_vector)), 1888 builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), 1889 builder.Clamp(min_scalar, arg_vector, max_scalar))); 1890 1891 ComputeAndCompareR1<float>(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {}, 1892 error_spec_); 1893 } 1894 1895 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) { 1896 ComputationBuilder builder(client_, TestName()); 1897 auto min_vector = builder.ConstantR1<int32>({1, -6, 1, 2, 0, -5}); 1898 auto arg_vector = builder.ConstantR1<int32>({2, 10, -5, 1, 4, 10}); 1899 auto max_vector = builder.ConstantR1<int32>({3, 0, 25, 5, 123, -1}); 1900 auto clamp = builder.Clamp(min_vector, arg_vector, max_vector); 1901 1902 ComputeAndCompareR1<int32>(&builder, {2, 0, 1, 2, 4, -1}, {}); 1903 } 1904 1905 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) { 1906 ComputationBuilder builder(client_, TestName()); 1907 auto min_scalar = builder.ConstantR0<int32>(0); 1908 auto min_vector = builder.ConstantR1<int32>({1, -6, 1, 2, 0}); 1909 auto arg_vector = builder.ConstantR1<int32>({2, 10, -5, 1, 4}); 1910 auto max_scalar = builder.ConstantR0<int32>(3); 1911 auto max_vector = builder.ConstantR1<int32>({3, 1, 25, 5, 123}); 1912 // Perform clamp with broadcasted scalar and vector. 1913 auto clamp = builder.Add( 1914 builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), 1915 builder.Clamp(min_scalar, arg_vector, max_vector)), 1916 builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), 1917 builder.Clamp(min_scalar, arg_vector, max_scalar))); 1918 1919 ComputeAndCompareR1<int32>(&builder, {8, 8, 2, 6, 14}, {}); 1920 } 1921 1922 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) { 1923 ComputationBuilder builder(client_, TestName()); 1924 auto min_vector = builder.ConstantR1<uint32>({1, 2, 1, 2, 0, ~0u - 4}); 1925 auto arg_vector = builder.ConstantR1<uint32>({2, 10, 5, 1, 4, 10}); 1926 auto max_vector = builder.ConstantR1<uint32>({3, 5, 25, 5, 123, ~0u}); 1927 auto clamp = builder.Clamp(min_vector, arg_vector, max_vector); 1928 1929 ComputeAndCompareR1<uint32>(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {}); 1930 } 1931 1932 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { 1933 ComputationBuilder builder(client_, TestName()); 1934 auto min_scalar = builder.ConstantR0<uint32>(0); 1935 auto min_vector = builder.ConstantR1<uint32>({1, 0, 1, 2, 0}); 1936 auto arg_vector = builder.ConstantR1<uint32>({2, 10, 0, 1, 4}); 1937 auto max_scalar = builder.ConstantR0<uint32>(3); 1938 auto max_vector = builder.ConstantR1<uint32>({3, 1, 25, 5, 123}); 1939 // Perform clamp with broadcasted scalar and vector. 1940 auto clamp = builder.Add( 1941 builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), 1942 builder.Clamp(min_scalar, arg_vector, max_vector)), 1943 builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), 1944 builder.Clamp(min_scalar, arg_vector, max_scalar))); 1945 1946 ComputeAndCompareR1<uint32>(&builder, {8, 8, 2, 6, 14}, {}); 1947 } 1948 1949 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { 1950 ComputationBuilder builder(client_, TestName()); 1951 1952 std::unique_ptr<Literal> param0_literal = 1953 Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f}); 1954 std::unique_ptr<GlobalData> param0_data = 1955 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 1956 1957 std::unique_ptr<Literal> param1_literal = 1958 Literal::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f}); 1959 std::unique_ptr<GlobalData> param1_data = 1960 client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); 1961 1962 auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); 1963 auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); 1964 auto add = builder.Add(p0, p1); 1965 1966 ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f}, 1967 {param0_data.get(), param1_data.get()}, 1968 error_spec_); 1969 } 1970 1971 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { 1972 ComputationBuilder builder(client_, TestName()); 1973 1974 std::unique_ptr<Literal> param0_literal = 1975 Literal::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0)); 1976 std::unique_ptr<GlobalData> param0_data = 1977 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 1978 1979 std::unique_ptr<Literal> param1_literal = 1980 Literal::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0)); 1981 std::unique_ptr<GlobalData> param1_data = 1982 client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); 1983 1984 auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); 1985 auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); 1986 auto add = builder.Add(p0, p1); 1987 1988 Array3D<float> expected(0, 7, 0); 1989 ComputeAndCompareR3<float>( 1990 &builder, expected, {param0_data.get(), param1_data.get()}, error_spec_); 1991 } 1992 1993 XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { 1994 ComputationBuilder builder(client_, TestName()); 1995 1996 std::unique_ptr<Literal> param0_literal = 1997 Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f}); 1998 std::unique_ptr<GlobalData> param0_data = 1999 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 2000 2001 auto a = builder.ConstantR1<float>({1.1f, 2.2f, 3.3f, 4.4f}); 2002 auto p = builder.Parameter(0, param0_literal->shape(), "param0"); 2003 auto add = builder.Add(a, p); 2004 2005 ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f}, 2006 {param0_data.get()}, error_spec_); 2007 } 2008 2009 XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) { 2010 ComputationBuilder builder(client_, TestName()); 2011 auto a = builder.ConstantR1<float>({3.14159f, 0.0f, 1.570796f, -0.78539f}); 2012 auto result = builder.Cos(a); 2013 2014 ComputeAndCompareR1<float>(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {}, 2015 error_spec_); 2016 } 2017 2018 XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) { 2019 ComputationBuilder builder(client_, TestName()); 2020 auto a = builder.ConstantR1<float>({3.14159f, 0.0f, 1.570796f, -0.78539f}); 2021 auto result = builder.Sin(a); 2022 2023 ComputeAndCompareR1<float>(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {}, 2024 error_spec_); 2025 } 2026 2027 XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { 2028 ComputationBuilder builder(client_, TestName()); 2029 auto a = builder.ConstantR1<float>({0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f}); 2030 auto b = builder.ConstantR1<float>({6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f}); 2031 auto atan = builder.Atan2(a, b); 2032 2033 ComputeAndCompareR1<float>( 2034 &builder, 2035 {0.0f, 1.57079633f, 3.14159265f, -1.57079633f, 0.78539816f, -0.78539816f}, 2036 {}, error_spec_); 2037 } 2038 2039 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) { 2040 ComputationBuilder builder(client_, TestName()); 2041 auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f}); 2042 auto result = builder.Tanh(a); 2043 2044 ComputeAndCompareR1<float>(&builder, {-0.986614f, 0.996260f, 0.978026}, {}, 2045 error_spec_); 2046 } 2047 2048 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { 2049 // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that 2050 // the input tensor is large enough to exercise the vectorized tanh 2051 // implementation on XLA CPU. 2052 ComputationBuilder builder(client_, TestName()); 2053 auto input_literal = Literal::CreateR1<float>( 2054 {1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16, 2055 -0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32, 2056 -1.45, -0.77, -1.09, 0.91, -1.03, -0.30, -1.11, -1.17, 1.50, -0.85, 2057 0.04, 1.02, 0.34, -0.61, 0.41, 0.07, -0.02, 1.42, -0.62, 0.81, 2058 0.08, 0.81, -0.30, 1.17, -0.65, -0.44, 0.92, 1.26, -1.29, 1.35, 2059 0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31, 2060 -0.79, 1.41, 1.21, 1.05}); 2061 TF_ASSERT_OK_AND_ASSIGN(auto input_data, 2062 client_->TransferToServer(*input_literal)); 2063 2064 auto input = builder.Parameter(0, input_literal->shape(), "input"); 2065 builder.Tanh(input); 2066 2067 ComputeAndCompareR1<float>( 2068 &builder, 2069 {0.77009583, -0.30665702, 0.69070244, 0.71401149, 0.84400684, 2070 -0.71985596, -0.45764771, 0.66664988, -0.58278900, 0.16050975, 2071 -0.06770509, 0.36843640, -0.38476998, 0.04018109, 0.87562293, 2072 0.84788644, 0.38603750, 0.57294142, -0.79140943, 0.31032649, 2073 -0.89590985, -0.64770776, -0.79625875, 0.72234446, -0.77389336, 2074 -0.28871772, -0.80428445, -0.82541436, 0.90456349, -0.68856895, 2075 0.03877772, 0.76877952, 0.32561871, -0.54546672, 0.39072621, 2076 0.07273290, -0.01924866, 0.88924897, -0.55283129, 0.67183107, 2077 0.08006320, 0.66944766, -0.29068485, 0.82573754, -0.57170743, 2078 -0.41581789, 0.72739530, 0.85025692, -0.85931867, 0.87357593, 2079 0.07782833, -0.84597743, -0.72748238, 0.45396307, 0.82449573, 2080 -0.42462519, -0.86363792, -0.89368379, -0.12621804, -0.86445558, 2081 -0.65565848, 0.88789743, 0.83566397, 0.78287679}, 2082 {input_data.get()}, 2083 // The error spec is unusually high here to account for the fact that we 2084 // use a rational interpolant to approximate tanh. 2085 ErrorSpec(0.004, 0.004)); 2086 } 2087 2088 XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { 2089 // The input tensor is large enough to exercise the vectorized exp 2090 // implementation on XLA CPU. 2091 ComputationBuilder builder(client_, TestName()); 2092 2093 // Just to help make sense of the scales here -- exp(89) saturates float32 and 2094 // exp(-10) is smaller than our error spec. 2095 std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>( 2096 {1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31, 2097 -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5, 2098 -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4, 2099 -16.3, -15.2, -14.1, -13.0, -11.9, -10.8, -9.7, -8.6, -7.5, 2100 -6.4, -5.3, -4.2, -3.1, -2.0, -0.9, 0.2, 1.3, 2.4, 2101 3.5, 4.6, 5.7, 6.8, 7.9, 9.0, 10.1, 11.2, 12.3, 2102 13.4, 14.5, 15.6, 16.7, 17.8, 18.9, 20.0, 21.1, 22.2, 2103 23.3, 24.4, 25.5, 26.6, 27.7, 28.8, 29.9, 31.0, 32.1, 2104 68.4, 69.5, 70.6, 71.7, 72.8, 73.9, 75.0, 76.1, 77.2, 2105 78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3, 2106 86.4, 86.5, 87.6, 87.7, 87.8, 87.9}); 2107 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data, 2108 client_->TransferToServer(*input_literal)); 2109 2110 auto input = builder.Parameter(0, input_literal->shape(), "input"); 2111 builder.Exp(input); 2112 2113 std::vector<float> expected_result; 2114 int64 input_size = input_literal->shape().dimensions(0); 2115 expected_result.reserve(input_size); 2116 for (int64 i = 0; i < input_size; i++) { 2117 expected_result.push_back(std::exp(input_literal->Get<float>({i}))); 2118 } 2119 2120 ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()}, 2121 error_spec_); 2122 } 2123 2124 XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { 2125 // The input tensor is large enough to exercise the vectorized exp 2126 // implementation on XLA CPU. 2127 ComputationBuilder builder(client_, TestName()); 2128 2129 std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>( 2130 {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, 2131 -167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9, 2132 198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04, 2133 1.74e+04, 1.89e+05, 1.9e+05, 1.93e+06, 1.98e+06, 1.65e+06, 1.97e+07, 2134 1.66e+07, 1e+07, 1.98e+08, 1.96e+08, 1.64e+09, 1.58e+09, 1.64e+09, 2135 1.44e+10, 1.5e+10, 1.99e+10, 1.17e+11, 1.08e+11, 1.08e+12, 1.38e+12, 2136 1.4e+12, 1.03e+13, 1.6e+13, 1.99e+13, 1.26e+14, 1.51e+14, 1.33e+15, 2137 1.41e+15, 1.63e+15, 1.39e+16, 1.21e+16, 1.27e+16, 1.28e+17, 1.62e+17, 2138 2e+18, 1.96e+18, 1.81e+18, 1.99e+19, 1.86e+19, 1.61e+19, 1.71e+20, 2139 1.47e+20, 1.83e+21, 1.33e+21, 1.3e+21, 1.35e+22, 1.84e+22, 1.02e+22, 2140 1.81e+23, 1.02e+23, 1.89e+24, 1.49e+24, 1.08e+24, 1.95e+25, 1.1e+25, 2141 1.62e+25, 1.2e+26, 1.41e+26, 1.93e+27, 1.66e+27, 1.62e+27, 1.05e+28, 2142 1.5e+28, 1.79e+28, 1.36e+29, 1.95e+29, 1.5e+30, 1.81e+30, 1.34e+30, 2143 1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33, 2144 1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35}); 2145 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data, 2146 client_->TransferToServer(*input_literal)); 2147 2148 auto input = builder.Parameter(0, input_literal->shape(), "input"); 2149 builder.Log(input); 2150 2151 std::vector<float> expected_result; 2152 int64 input_size = input_literal->shape().dimensions(0); 2153 expected_result.reserve(input_size); 2154 for (int64 i = 0; i < input_size; i++) { 2155 expected_result.push_back(std::log(input_literal->Get<float>({i}))); 2156 } 2157 2158 ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()}, 2159 error_spec_); 2160 } 2161 2162 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { 2163 // a ------ (add) --------- (add) 2164 // / / 2165 // b -----/ / 2166 // c---------------------/ 2167 ComputationBuilder builder(client_, TestName()); 2168 2169 auto a = builder.ConstantR1<float>({1.1f, 2.2f, 3.3f, 4.4f}); 2170 auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f}); 2171 auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f}); 2172 2173 auto add = builder.Add(a, b); 2174 auto add2 = builder.Add(add, c); 2175 2176 ComputeAndCompareR1<float>(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {}, 2177 error_spec_); 2178 } 2179 2180 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) { 2181 // b ------ (add) --------- (add) 2182 // / / 2183 // c -----/ / 2184 // a---------------------/ 2185 ComputationBuilder builder(client_, TestName()); 2186 2187 auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f}); 2188 auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f}); 2189 auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f}); 2190 2191 auto add = builder.Add(b, c); 2192 auto add2 = builder.Add(a, add); 2193 2194 ComputeAndCompareR1<float>(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {}, 2195 error_spec_); 2196 } 2197 2198 XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) { 2199 // a ----- (neg) ----- (add) 2200 // / 2201 // b ----- (neg) ----/ 2202 ComputationBuilder builder(client_, TestName()); 2203 2204 auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f}); 2205 auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f}); 2206 2207 auto neg_a = builder.Neg(a); 2208 auto neg_b = builder.Neg(b); 2209 auto result = builder.Add(neg_a, neg_b); 2210 2211 ComputeAndCompareR1<float>(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {}, 2212 error_spec_); 2213 } 2214 2215 XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { 2216 // a ------ (add) ------------\ 2217 // / \ 2218 // b -----/ (add) 2219 // / 2220 // c ------ (add) ------------/ 2221 // / 2222 // d -----/ 2223 ComputationBuilder builder(client_, TestName()); 2224 2225 auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f}); 2226 auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f}); 2227 auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f}); 2228 auto d = builder.ConstantR1<float>({-19.0f, 10.0f, -40.0f, 20.2f}); 2229 2230 auto add_ab = builder.Add(a, b); 2231 auto add_cd = builder.Add(c, d); 2232 auto add_all = builder.Add(add_ab, add_cd); 2233 2234 ComputeAndCompareR1<float>(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {}, 2235 error_spec_); 2236 } 2237 2238 XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { 2239 ComputationBuilder builder(client_, TestName()); 2240 auto a = 2241 builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); 2242 auto b = 2243 builder.ConstantR2<float>({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); 2244 auto add = builder.Add(a, b); 2245 2246 Array2D<float> expected_array( 2247 {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); 2248 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2249 } 2250 2251 XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { 2252 // Add a scalar + matrix. 2253 ComputationBuilder builder(client_, TestName()); 2254 auto a = 2255 builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); 2256 auto scalar = builder.ConstantR0<float>(3.0f); 2257 auto add = builder.Add(scalar, a); 2258 2259 Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); 2260 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2261 } 2262 2263 XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) { 2264 // Add a matrix + scalar. 2265 ComputationBuilder builder(client_, TestName()); 2266 auto a = 2267 builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); 2268 auto scalar = builder.ConstantR0<float>(3.0f); 2269 auto add = builder.Add(a, scalar); 2270 2271 Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); 2272 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2273 } 2274 2275 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) { 2276 // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches 2277 // only dim 0 of the matrix. 2278 ComputationBuilder builder(client_, TestName()); 2279 auto v = builder.ConstantR1<float>({20.0f, 40.0f, 60.0f}); 2280 // clang-format off 2281 auto m = builder.ConstantR2<float>({ 2282 {-2.5f, 3.14f, 1.0f}, 2283 {2.25f, -10.0f, 3.33f}}); 2284 // clang-format on 2285 auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1}); 2286 Array2D<float> expected_array( 2287 {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}}); 2288 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2289 } 2290 2291 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { 2292 // Test broadcasting in Eq comparison. 2293 ComputationBuilder builder(client_, TestName()); 2294 auto v = builder.ConstantR1<int32>({42, 73}); 2295 auto m = builder.ConstantR2<int32>({{42, 73}, {42, 52}}); 2296 2297 // This test exercises both possible broadcast dimensions for a vector/matrix 2298 // comparison. 2299 auto cmp_dim_0 = builder.Eq(v, m, /*broadcast_dimensions=*/{1}); 2300 auto cmp_dim_1 = builder.Eq(v, m, /*broadcast_dimensions=*/{0}); 2301 auto result = builder.Tuple({cmp_dim_0, cmp_dim_1}); 2302 2303 auto expected = Literal::MakeTuple( 2304 {Literal::CreateR2<bool>({{true, true}, {true, false}}).get(), 2305 Literal::CreateR2<bool>({{true, false}, {false, false}}).get()}); 2306 ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); 2307 } 2308 2309 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { 2310 // Test broadcasting in Ne comparison. 2311 ComputationBuilder builder(client_, TestName()); 2312 auto v = builder.ConstantR1<int32>({42, 73}); 2313 auto m = builder.ConstantR2<int32>({{42, 73}, {42, 52}}); 2314 auto cmp = builder.Ne(v, m, /*broadcast_dimensions=*/{1}); 2315 2316 const string expected = R"(pred[2,2] { 2317 { 00 }, 2318 { 01 } 2319 })"; 2320 EXPECT_EQ(expected, ExecuteToString(&builder, {})); 2321 } 2322 2323 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { 2324 // Test broadcasting in Ge comparison. 2325 ComputationBuilder builder(client_, TestName()); 2326 auto v = builder.ConstantR1<int32>({1, 2, 3, 4}); 2327 auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}}); 2328 auto cmp = builder.Ge(v, m, /*broadcast_dimensions=*/{1}); 2329 2330 const string expected = R"(pred[2,4] { 2331 { 1100 }, 2332 { 0001 } 2333 })"; 2334 EXPECT_EQ(expected, ExecuteToString(&builder, {})); 2335 } 2336 2337 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { 2338 // Test broadcasting in Gt comparison. 2339 ComputationBuilder builder(client_, TestName()); 2340 auto v = builder.ConstantR1<int32>({1, 2, 3, 4}); 2341 auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}}); 2342 auto cmp = builder.Gt(v, m, /*broadcast_dimensions=*/{1}); 2343 2344 const string expected = R"(pred[2,4] { 2345 { 0100 }, 2346 { 0000 } 2347 })"; 2348 EXPECT_EQ(expected, ExecuteToString(&builder, {})); 2349 } 2350 2351 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { 2352 // Test broadcasting in Le comparison. 2353 ComputationBuilder builder(client_, TestName()); 2354 auto v = builder.ConstantR1<int32>({1, 2, 3, 4}); 2355 auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}}); 2356 auto cmp = builder.Le(v, m, /*broadcast_dimensions=*/{1}); 2357 2358 const string expected = R"(pred[2,4] { 2359 { 1011 }, 2360 { 1111 } 2361 })"; 2362 EXPECT_EQ(expected, ExecuteToString(&builder, {})); 2363 } 2364 2365 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { 2366 // Test broadcasting in Lt comparison. 2367 ComputationBuilder builder(client_, TestName()); 2368 auto v = builder.ConstantR1<int32>({1, 2, 3, 4}); 2369 auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}}); 2370 auto cmp = builder.Lt(v, m, /*broadcast_dimensions=*/{1}); 2371 2372 const string expected = R"(pred[2,4] { 2373 { 0011 }, 2374 { 1110 } 2375 })"; 2376 EXPECT_EQ(expected, ExecuteToString(&builder, {})); 2377 } 2378 2379 XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) { 2380 // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op 2381 // arguments is reversed. 2382 ComputationBuilder builder(client_, TestName()); 2383 auto m = builder.ConstantR2<float>({{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}}); 2384 auto v = builder.ConstantR1<float>({2.0f, 4.0f, 6.0f}); 2385 auto add = builder.Mul(m, v, /*broadcast_dimensions=*/{1}); 2386 Array2D<float> expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}}); 2387 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2388 } 2389 2390 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) { 2391 // Tests broadcasting for arrays with degenerate (size == 1) dimensions. 2392 ComputationBuilder builder(client_, TestName()); 2393 // m's shape in XLA notation is {3, 2} 2394 // md's shape in XLA notation is {3, 1} 2395 // The result has shape {3, 2}, where md is broadcast over m 2396 auto m = 2397 builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); 2398 auto md = builder.ConstantR2<float>({{10.0f, 20.0f, 30.0f}}); 2399 auto add = builder.Add(m, md); 2400 Array2D<float> expected_array( 2401 {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}}); 2402 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2403 } 2404 2405 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) { 2406 // Tests broadcasting for arrays with degenerate (size == 1) dimensions. 2407 ComputationBuilder builder(client_, TestName()); 2408 // m's shape in XLA notation is {3, 2} 2409 // md's shape in XLA notation is {1, 2} 2410 // The result has shape {3, 2}, where md is broadcast over m 2411 auto m = 2412 builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); 2413 auto md = builder.ConstantR2<float>({{10.0f}, {20.0f}}); 2414 auto add = builder.Add(m, md); 2415 Array2D<float> expected_array( 2416 {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}}); 2417 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2418 } 2419 2420 XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) { 2421 // Tests broadcasting for two degenerate arrays. This kind of broadcasting 2422 // effectively creates an "outer product" operation. 2423 // This is taken from the Numpy docs example at: 2424 // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html 2425 ComputationBuilder builder(client_, TestName()); 2426 // a's shape in XLA notation is {1, 4} 2427 // b's shape in XLA notation is {3, 1} 2428 // The result has shape {3, 4}. 2429 auto a = builder.ConstantR2<float>({{0.0f}, {10.0f}, {20.0f}, {30.0f}}); 2430 auto b = builder.ConstantR2<float>({{1.0f, 2.0f, 3.0f}}); 2431 auto add = builder.Add(a, b); 2432 Array2D<float> expected_array({{1.0f, 2.0f, 3.0f}, 2433 {11.0f, 12.0f, 13.0f}, 2434 {21.0f, 22.0f, 23.0f}, 2435 {31.0f, 32.0f, 33.0f}}); 2436 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2437 } 2438 2439 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) { 2440 // Add together a (2,2) array and a (2) array, using dimension 0 for 2441 // broadcasting (though there are two ways to broadcast these shapes). 2442 ComputationBuilder builder(client_, TestName()); 2443 auto v = builder.ConstantR1<float>({20.0f, 40.0f}); 2444 auto m = builder.ConstantR2<float>({{10.0f, 50.0f}, {77.0f, 88.0f}}); 2445 auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1}); 2446 Array2D<float> expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}}); 2447 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2448 } 2449 2450 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) { 2451 // Add together a (2,2) array and a (2) array, using dimension 1 for 2452 // broadcasting (though there are two ways to broadcast these shapes). 2453 ComputationBuilder builder(client_, TestName()); 2454 auto v = builder.ConstantR1<float>({20.0f, 40.0f}); 2455 auto m = builder.ConstantR2<float>({{10.0f, 50.0f}, {77.0f, 88.0f}}); 2456 auto add = builder.Add(v, m, /*broadcast_dimensions=*/{0}); 2457 Array2D<float> expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}}); 2458 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2459 } 2460 2461 XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { 2462 // Binary add of two R3s together 2463 ComputationBuilder builder(client_, TestName()); 2464 Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, 2465 {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); 2466 auto a = builder.ConstantR3FromArray3D<float>(a_3d); 2467 2468 Array3D<float> b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}}, 2469 {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}}); 2470 auto b = builder.ConstantR3FromArray3D<float>(b_3d); 2471 auto add = builder.Add(a, b); 2472 2473 Array3D<float> expected_3d( 2474 {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}}, 2475 {{21.0f, 24.0f}, {27.0f, 30.0f}, {33.0f, 36.0f}}}); 2476 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_); 2477 } 2478 2479 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { 2480 // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for 2481 // broadcasting (though there are two ways to broadcast these shapes). 2482 ComputationBuilder builder(client_, TestName()); 2483 // clang-format off 2484 Array3D<float> a_3d({ 2485 {{1.0f, 2.0f}, 2486 {3.0f, 4.0f}, 2487 {5.0f, 6.0f}}, 2488 {{7.0f, 8.0f}, 2489 {9.0f, 10.0f}, 2490 {11.0f, 12.0f}}, 2491 }); 2492 // clang-format on 2493 auto a = builder.ConstantR3FromArray3D<float>(a_3d); 2494 auto v = builder.ConstantR1<float>({10.0f, 20.0f}); 2495 auto add = builder.Add(a, v, /*broadcast_dimensions=*/{2}); 2496 2497 Array3D<float> expected_3d( 2498 {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}}, 2499 {{17.0f, 28.0f}, {19.0f, 30.0f}, {21.0f, 32.0f}}}); 2500 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_); 2501 } 2502 2503 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { 2504 // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for 2505 // broadcasting (though there are two ways to broadcast these shapes). 2506 ComputationBuilder builder(client_, TestName()); 2507 // clang-format off 2508 Array3D<float> a_3d({ 2509 {{1.0f, 2.0f}, 2510 {3.0f, 4.0f}, 2511 {5.0f, 6.0f}}, 2512 {{7.0f, 8.0f}, 2513 {9.0f, 10.0f}, 2514 {11.0f, 12.0f}}, 2515 }); 2516 // clang-format on 2517 auto a = builder.ConstantR3FromArray3D<float>(a_3d); 2518 auto v = builder.ConstantR1<float>({10.0f, 20.0f}); 2519 auto add = builder.Add(a, v, /*broadcast_dimensions=*/{0}); 2520 2521 // clang-format off 2522 Array3D<float> expected_3d({ 2523 {{11.0f, 12.0f}, 2524 {13.0f, 14.0f}, 2525 {15.0f, 16.0f}}, 2526 {{27.0f, 28.0f}, 2527 {29.0f, 30.0f}, 2528 {31.0f, 32.0f}}, 2529 }); 2530 // clang-format on 2531 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_); 2532 } 2533 2534 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { 2535 // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2} 2536 // for broadcasting. 2537 ComputationBuilder builder(client_, TestName()); 2538 // clang-format off 2539 Array3D<float> a_3d({ 2540 {{1.0f, 2.0f}, 2541 {3.0f, 4.0f}, 2542 {5.0f, 6.0f}}, 2543 {{7.0f, 8.0f}, 2544 {9.0f, 10.0f}, 2545 {11.0f, 12.0f}}, 2546 }); 2547 auto a = builder.ConstantR3FromArray3D<float>(a_3d); 2548 auto m = builder.ConstantR2<float>({ 2549 {10.0f, 20.0f, 30.0f}, 2550 {40.0f, 50.0f, 60.0f}, 2551 }); 2552 auto add = builder.Add(a, m, /*broadcast_dimensions=*/{0, 1}); 2553 2554 Array3D<float> expected_3d({ 2555 {{11.0f, 12.0f}, 2556 {23.0f, 24.0f}, 2557 {35.0f, 36.0f}}, 2558 {{47.0f, 48.0f}, 2559 {59.0f, 60.0f}, 2560 {71.0f, 72.0f}}, 2561 }); 2562 // clang-format on 2563 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_); 2564 } 2565 2566 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { 2567 // Comparison between two 3D arrays of compatible shapes: 2568 // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs. 2569 ComputationBuilder builder(client_, TestName()); 2570 Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, 2571 {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); 2572 auto a = builder.ConstantR3FromArray3D<float>(a_3d); 2573 2574 Array3D<float> b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}}); 2575 auto b = builder.ConstantR3FromArray3D<float>(b_3d); 2576 2577 auto compare = builder.Gt(a, b); 2578 2579 Array3D<int> expected_3d( 2580 {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}}); 2581 const string expected = R"(pred[2,3,2] { 2582 { { 01 }, 2583 { 00 }, 2584 { 00 } }, 2585 { { 01 }, 2586 { 10 }, 2587 { 01 } } 2588 })"; 2589 EXPECT_EQ(expected, ExecuteToString(&builder, {})); 2590 } 2591 2592 XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) { 2593 ComputationBuilder builder(client_, TestName()); 2594 2595 std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5)); 2596 std::unique_ptr<Array4D<float>> operand_b_4d(new Array4D<float>(2, 3, 4, 5)); 2597 std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5)); 2598 float value = 0.0; 2599 for (int64 p = 0; p < 2; ++p) { 2600 for (int64 z = 0; z < 3; ++z) { 2601 for (int64 y = 0; y < 4; ++y) { 2602 for (int64 x = 0; x < 5; ++x) { 2603 (*operand_a_4d)(p, z, y, x) = value; 2604 (*operand_b_4d)(p, z, y, x) = 2.0 * value; 2605 (*expected_4d)(p, z, y, x) = 3.0 * value; 2606 value += 0.1; 2607 } 2608 } 2609 } 2610 } 2611 2612 auto a = builder.ConstantR4FromArray4D<float>(*operand_a_4d); 2613 auto b = builder.ConstantR4FromArray4D<float>(*operand_b_4d); 2614 auto add = builder.Add(a, b); 2615 2616 ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_); 2617 } 2618 2619 XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { 2620 ComputationBuilder builder(client_, TestName()); 2621 2622 std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5)); 2623 std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5)); 2624 std::vector<float> operand_b_1d(3); 2625 std::iota(operand_b_1d.begin(), operand_b_1d.end(), 1.0); 2626 2627 float value = 0.0; 2628 for (int64 p = 0; p < 2; ++p) { 2629 for (int64 z = 0; z < 3; ++z) { 2630 for (int64 y = 0; y < 4; ++y) { 2631 for (int64 x = 0; x < 5; ++x) { 2632 (*operand_a_4d)(p, z, y, x) = value; 2633 (*expected_4d)(p, z, y, x) = value + operand_b_1d[z]; 2634 value += 0.1; 2635 } 2636 } 2637 } 2638 } 2639 2640 auto a = builder.ConstantR4FromArray4D<float>(*operand_a_4d); 2641 auto b = builder.ConstantR1<float>(operand_b_1d); 2642 auto add = builder.Add(a, b, {1}); 2643 2644 ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_); 2645 } 2646 2647 XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { 2648 constexpr int d0 = 16; 2649 constexpr int d1 = 16; 2650 constexpr int d2 = 2; 2651 constexpr int d3 = 2; 2652 Array4D<float> r4(d0, d1, d2, d3); 2653 r4.Fill(1.0); 2654 std::vector<float> r1(d1); 2655 std::iota(r1.begin(), r1.end(), 1.0); 2656 2657 ComputationBuilder builder(client_, TestName()); 2658 std::unique_ptr<Literal> a_literal = Literal::CreateR4FromArray4DWithLayout( 2659 r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); 2660 auto a = builder.ConstantLiteral(*a_literal); 2661 auto b = builder.ConstantR1<float>(r1); 2662 builder.Add(a, b, {1}); 2663 2664 for (int i0 = 0; i0 < d0; ++i0) { 2665 for (int i1 = 0; i1 < d1; ++i1) { 2666 for (int i2 = 0; i2 < d2; ++i2) { 2667 for (int i3 = 0; i3 < d3; ++i3) { 2668 r4(i0, i1, i2, i3) += r1[i1]; 2669 } 2670 } 2671 } 2672 } 2673 ComputeAndCompareR4<float>(&builder, r4, {}, error_spec_); 2674 } 2675 2676 // Show that we can't add two opaques. 2677 XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { 2678 ComputationBuilder builder(client_, TestName()); 2679 auto shape = ShapeUtil::MakeOpaqueShape(); 2680 auto x = builder.Parameter(0, shape, "x"); 2681 auto concatenated = builder.Add(x, x); 2682 StatusOr<Computation> computation_status = builder.Build(); 2683 ASSERT_FALSE(computation_status.ok()); 2684 EXPECT_THAT(computation_status.status().ToString(), 2685 ::testing::ContainsRegex( 2686 "Expected non-opaque argument for lhs of binary operation")); 2687 } 2688 2689 XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { 2690 ComputationBuilder builder(client_, TestName()); 2691 auto a = 2692 builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); 2693 auto b = 2694 builder.ConstantR2<float>({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); 2695 auto add = builder.Add(a, b, /*broadcast_dimensions=*/{0, 1}); 2696 2697 Array2D<float> expected_array( 2698 {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); 2699 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); 2700 } 2701 2702 XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) { 2703 ComputationBuilder builder(client_, TestName()); 2704 auto a = 2705 builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); 2706 auto b = 2707 builder.ConstantR2<float>({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); 2708 auto add = builder.Add(a, b, /*broadcast_dimensions=*/{1, 0}); 2709 2710 StatusOr<Computation> computation_status = builder.Build(); 2711 ASSERT_FALSE(computation_status.ok()); 2712 EXPECT_THAT(computation_status.status().error_message(), 2713 ::testing::ContainsRegex("must.*be the identity")); 2714 } 2715 2716 // Regression test for b/31927799. "slice - y" is fused and requires implicit 2717 // broadcast. 2718 XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { 2719 ComputationBuilder builder(client_, TestName()); 2720 auto x_literal = Literal::CreateR1<float>({1, 2, 3}); 2721 auto y_literal = Literal::CreateR1<float>({4, 5}); 2722 auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); 2723 auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); 2724 2725 auto x = builder.Parameter(0, x_literal->shape(), "x"); 2726 auto y = builder.Parameter(1, y_literal->shape(), "y"); 2727 auto slice = builder.Slice(x, {1}, {2}, {1}); 2728 builder.Sub(slice, y); 2729 2730 ComputeAndCompareR1<float>(&builder, {-2, -3}, {x_data.get(), y_data.get()}, 2731 error_spec_); 2732 } 2733 2734 INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount, 2735 ArrayElementwiseOpTestParamCount, 2736 ::testing::Values(127, 128, 129, 17 * 4096)); 2737 2738 } // namespace 2739 } // namespace xla 2740