1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <cmath> 17 #include "absl/base/casts.h" 18 #include "tensorflow/compiler/xla/client/lib/constants.h" 19 #include "tensorflow/compiler/xla/client/lib/math.h" 20 #include "tensorflow/compiler/xla/client/xla_builder.h" 21 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 22 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 23 #include "tensorflow/compiler/xla/tests/test_macros.h" 24 25 namespace xla { 26 namespace { 27 28 using Eigen::half; 29 30 template <typename T, size_t N> 31 T EvaluatePolynomial(T x, const std::array<T, N>& coeffs) { 32 T result = 0; 33 for (T c : coeffs) { 34 result = result * x + c; 35 } 36 return result; 37 } 38 39 // There's no std::erfinv, so we have to implement it ourselves. This follows 40 // Wichura 1998, https://www.jstor.org/stable/2347330 which, notably, is a 41 // different implementation from that in math.cc. 42 float HostErfInv(float x) { 43 std::array<double, 8> kPolyA = { 44 8.8709406962545514830200e2, 1.1819493347062294404278e4, 45 2.3782041382114385731252e4, 1.6235862515167575384252e4, 46 4.8548868893843886794648e3, 6.9706266534389598238465e2, 47 4.7072688112383978012285e1, 1.1975323115670912564578e0, 48 }; 49 std::array<double, 8> kPolyB = { 50 5.2264952788528545610e3, 2.8729085735721942674e4, 3.9307895800092710610e4, 51 2.1213794301586595867e4, 5.3941960214247511077e3, 6.8718700749205790830e2, 52 4.2313330701600911252e1, 1.0000000000000000000e0, 53 }; 54 std::array<double, 8> kPolyC = { 55 7.74545014278341407640e-4, 2.27238449892691845833e-2, 56 2.41780725177450611770e-1, 1.27045825245236838258e0, 57 3.64784832476320460504e0, 5.76949722146069140550e0, 58 4.63033784615654529590e0, 1.42343711074968357734e0, 59 }; 60 std::array<double, 8> kPolyD = { 61 1.4859850019840355905497876e-9, 7.7441459065157709165577218e-4, 62 2.1494160384252876777097297e-2, 2.0945065210512749128288442e-1, 63 9.7547832001787427186894837e-1, 2.3707661626024532365971225e0, 64 2.9036514445419946173133295e0, 1.4142135623730950488016887e0, 65 }; 66 std::array<double, 8> kPolyE = { 67 2.01033439929228813265e-7, 2.71155556874348757815e-5, 68 1.24266094738807843860e-3, 2.65321895265761230930e-2, 69 2.96560571828504891230e-1, 1.78482653991729133580e0, 70 5.46378491116411436990e0, 6.65790464350110377720e0, 71 }; 72 std::array<double, 8> kPolyF = { 73 2.891024605872965461538222e-15, 2.010321207683943062279931e-7, 74 2.611088405080593625138020e-5, 1.112800997078859844711555e-3, 75 2.103693768272068968719679e-2, 1.936480946950659106176712e-1, 76 8.482908416595164588112026e-1, 1.414213562373095048801689e0, 77 }; 78 79 if (std::abs(x) > 1 || std::isnan(x)) { 80 return std::numeric_limits<float>::quiet_NaN(); 81 } 82 if (std::abs(x) == 1) { 83 return std::copysign(std::numeric_limits<float>::infinity(), x); 84 } 85 86 float unsigned_result = [&] { 87 float y = std::abs(x); 88 if (y <= 0.85) { 89 double r = 0.180625 - 0.25 * y * y; 90 return (y * EvaluatePolynomial(r, kPolyA)) / 91 EvaluatePolynomial(r, kPolyB); 92 } else { 93 double r = std::sqrt(std::log(2.0) - std::log1p(-y)); 94 if (r <= 5.0) { 95 r -= 1.6; 96 return EvaluatePolynomial(r, kPolyC) / EvaluatePolynomial(r, kPolyD); 97 } else { 98 r -= 5; 99 return EvaluatePolynomial(r, kPolyE) / EvaluatePolynomial(r, kPolyF); 100 } 101 } 102 }(); 103 return std::copysign(unsigned_result, x); 104 } 105 106 // Digamma implementation using a polynomial from Cephes. Notably this is a 107 // different implementation from the one in math.cc. 108 float HostDigamma(float x) { 109 // Euler-Mascheroni constant 110 float kGamma = 0.57721566490153286061; 111 float kPi = M_PI; 112 113 std::array<float, 4> kPoly = { 114 -4.16666666666666666667E-3, 115 3.96825396825396825397E-3, 116 -8.33333333333333333333E-3, 117 8.33333333333333333333E-2, 118 }; 119 120 float reflection = 0; 121 if (x <= 0) { 122 float floor = std::floor(x); 123 if (x == floor) { 124 return std::numeric_limits<float>::quiet_NaN(); 125 } 126 // Compute reflection term, pi * cot(pi * x). 127 reflection = x - floor; 128 if (reflection == 0.5) { 129 reflection = 0; 130 } else { 131 if (reflection > 0.5) { 132 reflection = x - (floor + 1.0f); 133 } 134 reflection = kPi / std::tan(kPi * reflection); 135 } 136 x = 1 - x; 137 } 138 139 float result = 0; 140 if (x <= 10 && x == std::floor(x)) { 141 // Special case for integers <= 10. 142 for (int i = 1; i < x; ++i) { 143 result += 1.0f / i; 144 } 145 result -= kGamma; 146 } else { 147 float w = 0; 148 for (; x < 10; ++x) { 149 w += 1.0f / x; 150 } 151 if (x < 1e8) { 152 float z = 1.0f / (x * x); 153 result = z * EvaluatePolynomial(z, kPoly); 154 } 155 result = std::log(x) - 0.5f / x - result - w; 156 } 157 158 // Compute the final, reflected value. 159 return result - reflection; 160 } 161 162 // For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be 163 // guaranteed that we're printing the full number. 164 // 165 // (The general formula is, given a floating-point number with S significand 166 // bits, the number of decimal digits needed to print it to full precision is 167 // 168 // ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103). 169 // 170 // See https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf.) 171 string StringifyNum(float x) { 172 return absl::StrFormat("%0.9g (0x%08x)", x, absl::bit_cast<uint32>(x)); 173 } 174 175 string StringifyNum(half x) { 176 return absl::StrFormat("%0.5g (0x%04x)", static_cast<float>(x), 177 absl::bit_cast<uint16>(x)); 178 } 179 180 string StringifyNum(bfloat16 x) { 181 return absl::StrFormat("%0.4g (0x%04x)", static_cast<float>(x), 182 absl::bit_cast<uint16>(x)); 183 } 184 185 // Test parameter is a tuple containing 186 // - primitive type under test, 187 // - (begin, end) range under test, as zero-extended int64s bitcast to the 188 // primtive type under test. 189 class ExhaustiveOpTest 190 : public ClientLibraryTestBase, 191 public ::testing::WithParamInterface< 192 std::tuple<PrimitiveType, std::pair<int64, int64>>> { 193 public: 194 ExhaustiveOpTest() 195 : ty_(std::get<0>(GetParam())), platform_(client_->platform()->Name()) {} 196 197 void Run(std::function<XlaOp(XlaOp)> enqueue_op, 198 float (*evaluate_op)(float)) { 199 SetFastMathDisabled(true); 200 201 // Run all HLO passes. In particular, constant folding is disabled by 202 // default for tests, but we need to run it in order to tickle some bugs. 203 mutable_debug_options()->clear_xla_disable_hlo_passes(); 204 205 PrimitiveType ty; 206 std::tie(ty, std::ignore) = GetParam(); 207 208 switch (ty) { 209 case F32: 210 SetDefaultErrSpec(0.0001, 0.0001); 211 RunImpl<float, uint32>(enqueue_op, evaluate_op); 212 break; 213 case F16: 214 SetDefaultErrSpec(0.001, 0.001); 215 RunImpl<half, uint16>(enqueue_op, evaluate_op); 216 break; 217 case BF16: 218 SetDefaultErrSpec(0.001, 0.01); 219 RunImpl<bfloat16, uint16>(enqueue_op, evaluate_op); 220 break; 221 default: 222 LOG(FATAL) << "Unhandled type."; 223 } 224 } 225 226 void SetDefaultErrSpec(float abs_err, float rel_err) { 227 if (!abs_err_.has_value()) { 228 abs_err_ = abs_err; 229 } 230 if (!rel_err_.has_value()) { 231 rel_err_ = rel_err; 232 } 233 } 234 235 template <typename T, typename IntegralT> 236 void RunImpl(std::function<XlaOp(XlaOp)> enqueue_op, 237 float (*evaluate_op)(float)) { 238 static_assert( 239 sizeof(T) == sizeof(IntegralT), 240 "IntegralT must be an unsigned integer type of the same width as T."); 241 242 PrimitiveType ty; 243 std::pair<int64, int64> test_range; 244 std::tie(ty, test_range) = GetParam(); 245 int64 begin, end; 246 std::tie(begin, end) = test_range; 247 248 if (begin >= known_incorrect_begin_ && end <= known_incorrect_end_) { 249 LOG(INFO) << absl::StreamFormat( 250 "Skipping this shard, as the range under test, [%d, %d), falls " 251 "entirely within the known-incorrect range [%d, %d).", 252 begin, end, known_incorrect_begin_, known_incorrect_end_); 253 return; 254 } 255 256 LOG(INFO) << "Checking range [" << begin << ", " << end << ")"; 257 258 int64 input_size = end - begin; 259 Literal input_literal = LiteralUtil::CreateFromDimensions(ty, {input_size}); 260 absl::Span<T> input_arr = input_literal.data<T>(); 261 for (int64 i = 0; i < input_size; i++) { 262 IntegralT input_val = i + begin; 263 // If the operation is known to be buggy on a specific input clamp that 264 // input to 0 under the assumption that the op is at least correct on 0. 265 if (input_val >= known_incorrect_begin_ && 266 input_val < known_incorrect_end_) { 267 input_arr[i] = T{0}; 268 } else { 269 input_arr[i] = absl::bit_cast<T>(input_val); 270 } 271 } 272 273 TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, 274 BuildAndRunComputation(enqueue_op, input_literal)); 275 ExpectNear<T>(input_literal, result_literal, evaluate_op); 276 } 277 278 StatusOr<Literal> BuildAndRunComputation( 279 const std::function<XlaOp(XlaOp)>& enqueue_op, 280 const Literal& input_literal) { 281 XlaBuilder builder(TestName()); 282 auto input = Parameter(&builder, 0, input_literal.shape(), "input"); 283 enqueue_op(input); 284 TF_ASSIGN_OR_RETURN(XlaComputation comp, builder.Build()); 285 286 // Build and run the computation using the LocalClient API, rather than the 287 // plain Client API, which is used by ClientLibraryTestBase. This is 288 // because the plain Client API results does more memcpys to/from Literals, 289 // and that's slow given that we're touching a lot of data here. 290 // 291 // Copy debug options from ClientLibraryTestBase. In particular, we're 292 // interested in disabling constant folding. 293 ExecutableBuildOptions build_opts; 294 *build_opts.mutable_debug_options() = *mutable_debug_options(); 295 TF_ASSIGN_OR_RETURN( 296 auto executable, 297 client_->Compile(comp, {&input_literal.shape()}, build_opts)); 298 299 TF_ASSIGN_OR_RETURN( 300 ScopedShapedBuffer input_data, 301 client_->LiteralToShapedBuffer(input_literal, /*device_ordinal=*/0)); 302 303 ExecutableRunOptions run_opts; 304 run_opts.set_allocator(client_->backend().memory_allocator()); 305 run_opts.set_intra_op_thread_pool( 306 client_->backend().eigen_intra_op_thread_pool_device()); 307 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, 308 executable->Run({&input_data}, run_opts)); 309 310 TF_ASSIGN_OR_RETURN(Literal result_literal, 311 client_->ShapedBufferToLiteral(result)); 312 return std::move(result_literal); 313 } 314 315 template <typename T> 316 bool IsClose(T expected, T actual) { 317 float expected_f32 = static_cast<float>(expected); 318 float actual_f32 = static_cast<float>(actual); 319 float abs_err = std::abs(expected_f32 - actual_f32); 320 float rel_err = abs_err / std::abs(expected_f32); 321 if (strict_signed_zeros_ && actual == T{0} && expected == T{0}) { 322 // Check sign of zero. 323 return std::signbit(actual_f32) == std::signbit(expected_f32); 324 } 325 return abs_err < *abs_err_ || rel_err < *rel_err_ || 326 (std::isnan(expected_f32) && std::isnan(actual_f32)) || 327 (std::isinf(expected_f32) && std::isinf(actual_f32) && 328 (expected_f32 > 0) == (actual_f32 > 0)); 329 } 330 331 template <typename T> 332 void ExpectNear(const Literal& input_literal, const Literal& result_literal, 333 float (*evaluate_op)(float)) { 334 // We essentially reimplement LiteralTestUtil::Near here because 335 // a) this streamlined implementation is much faster, and 336 // b) we can print out better error messages (namely, we can print out 337 // which floating-point value input failed, while LiteralTestUtil::Near 338 // can only print out the input index that failed). 339 // c) we need special handling of certain inputs. For example, we say that 340 // a denormal input has multiple correct outputs (namely, f(x) and f(0)) 341 // and just needs to be close to one of them. 342 absl::Span<const T> input_arr = input_literal.data<T>(); 343 absl::Span<const T> result_arr = result_literal.data<T>(); 344 ASSERT_EQ(result_arr.size(), input_arr.size()); 345 int64 mismatches = 0; 346 // Hoisting these out of the loop is a nice speedup on shards that have many 347 // denormals. 348 const T expected_at_pos_zero = static_cast<T>(evaluate_op(0)); 349 const T expected_at_neg_zero = static_cast<T>(evaluate_op(-0.0)); 350 for (int64 i = 0; i < input_arr.size(); ++i) { 351 T input = input_arr[i]; 352 float input_f32 = static_cast<float>(input); 353 T actual = result_arr[i]; 354 T expected = static_cast<T>(evaluate_op(input_f32)); 355 356 if (IsClose(expected, actual)) { 357 continue; 358 } 359 360 // Easy case: If `input` is not denormal and !IsClose(expected, actual), 361 // print an error. 362 // 363 // (This doesn't correctly detect f16 and bfloat16 denormals! This seems 364 // to be OK for now, but at some point we may need to implement fpclassify 365 // for half and bfloat.) 366 if (std::fpclassify(input_f32) != FP_SUBNORMAL) { 367 PrintMismatch(&mismatches, [&] { 368 return absl::StrFormat("Mismatch on %s. Expected %s, but got %s.", 369 StringifyNum(input), StringifyNum(expected), 370 StringifyNum(actual)); 371 }); 372 continue; 373 } 374 375 // Otherwise, `input` is denormal. For denormal inputs, we accept answers 376 // that are close to any of: 377 // 378 // - evaluate_op(input) 379 // - evaluate_op(+/-0), where the sign of 0 equal to the sign of 380 // `input`, 381 // - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of 382 // 0 is the opposite of `input`. 383 T sign_preserving_ftz_expected = 384 std::signbit(input_f32) ? expected_at_neg_zero : expected_at_pos_zero; 385 T sign_nonpreserving_ftz_expected = 386 std::signbit(input_f32) ? expected_at_pos_zero : expected_at_neg_zero; 387 if (IsClose(sign_preserving_ftz_expected, actual) || 388 (relaxed_denormal_signs_ && 389 IsClose(sign_nonpreserving_ftz_expected, actual))) { 390 continue; 391 } 392 393 if (relaxed_denormal_signs_) { 394 PrintMismatch(&mismatches, [&] { 395 return absl::StrFormat( 396 "Mismatch on denormal value %s. Expected one of:\n" 397 " %10s (evaluated at full-precision value)\n" 398 " %10s (evaluated after flushing to sign-preserving zero)\n" 399 " %10s (evaluated after flushing to non-sign-preserving " 400 "zero)\n" 401 "but got %s.", 402 StringifyNum(input), StringifyNum(expected), 403 StringifyNum(sign_preserving_ftz_expected), 404 StringifyNum(sign_nonpreserving_ftz_expected), 405 StringifyNum(actual)); 406 }); 407 } else { 408 PrintMismatch(&mismatches, [&] { 409 return absl::StrFormat( 410 "Mismatch on denormal value %s. Expected one of:\n" 411 " %10s (evaluated at full-precision value)\n" 412 " %10s (evaluated after flushing to sign-preserving zero)\n" 413 "but got %s.", 414 StringifyNum(input), StringifyNum(expected), 415 StringifyNum(sign_preserving_ftz_expected), StringifyNum(actual)); 416 }); 417 } 418 } 419 EXPECT_EQ(mismatches, 0); 420 } 421 422 template <typename ErrorGenerator> 423 void PrintMismatch(int64* mismatches, const ErrorGenerator& err_generator) { 424 // We send a few mismatches to gunit so they show up nicely in test logs. 425 // Then we send more to LOG(ERROR). The remainder we squelch unless we're 426 // at vlog level 2. 427 constexpr int64 kMaxMismatchesLoggedToGunit = 10; 428 constexpr int64 kMaxMismatchesLoggedToErr = 1000; 429 430 (*mismatches)++; 431 if (*mismatches < kMaxMismatchesLoggedToGunit) { 432 FAIL() << err_generator(); 433 } else if (*mismatches < kMaxMismatchesLoggedToErr || VLOG_IS_ON(2)) { 434 LOG(ERROR) << err_generator(); 435 } else if (*mismatches == kMaxMismatchesLoggedToErr) { 436 LOG(ERROR) << "Not printing any more mismatches; pass " 437 "--vmodule=exhaustive_f32__op_test=2 to see " 438 "all of them."; 439 } 440 } 441 442 // The following members are set during construction so testcases can read 443 // these values and use them e.g. to influence the values given to the mutable 444 // members below. 445 446 // The primitive type under test. 447 const PrimitiveType ty_; 448 449 // The platform under test. 450 const string platform_; 451 452 // Tests can set the following variables for control over execution. This is 453 // safe because each XLA_TEST_P instantiates a new instance of this class. 454 455 // Testing will ignore the given range (encoded as bitwise representations of 456 // the type under test zero-extended to int64). 457 int64 known_incorrect_begin_ = 0; 458 int64 known_incorrect_end_ = 0; 459 460 // If unset, reasonable defaults will be used depending on the type under 461 // test. 462 absl::optional<float> abs_err_; 463 absl::optional<float> rel_err_; 464 465 // If true, will consider -0 not near to +0 and vice versa. Note that 466 // +epsilon may still be considered close to -0, depending on the error spec; 467 // this only covers the case when both `expected` and `actual` are equal to 0. 468 bool strict_signed_zeros_ = false; 469 470 // If true, allows denormals to be flushed to non-sign-preserving 0. 471 // 472 // For example, normally we'd expect sqrt(-denormal) to be either nan (sqrt of 473 // a negative number) or -inf (flush the denormal to sign-perserving zero, 474 // then sqrt(-0)). But with this as true, we'll also accept 0 (sqrt(0)). 475 // 476 // XLA:GPU preserves denormal signs, but other backends don't. 477 bool relaxed_denormal_signs_ = platform_ != "CUDA"; 478 }; 479 480 XLA_TEST_P(ExhaustiveOpTest, Log) { 481 if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) { 482 abs_err_ = 0.001; 483 rel_err_ = 0.001; 484 } 485 486 Run(Log, std::log); 487 } 488 489 XLA_TEST_P(ExhaustiveOpTest, Log1p) { 490 if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) { 491 abs_err_ = 0.001; 492 rel_err_ = 0.001; 493 } 494 495 Run(Log1p, std::log1p); 496 } 497 498 XLA_TEST_P(ExhaustiveOpTest, Exp) { 499 if (platform_ == "Host" && ty_ == F32) { 500 // TODO(b/73142289): The vectorized Exp implementation gives results outside 501 // our error spec in this range. 502 known_incorrect_begin_ = 1107296256 + 11583654; 503 known_incorrect_end_ = 1107296256 + 11629080; 504 } else if (platform_ == "Host" && ty_ == BF16) { 505 // TODO(jlebar): Is this a rounding error? Why doesn't it occur on XLA:GPU? 506 // 507 // Mismatch on 88.5 (0x42b1). 508 // Expected 2.72491739e+38 (0x7f4d), but got inf (0x7f80). 509 known_incorrect_begin_ = 0x42b1; 510 known_incorrect_end_ = 0x42b2; 511 } 512 513 Run(Exp, std::exp); 514 } 515 516 XLA_TEST_P(ExhaustiveOpTest, Expm1) { 517 // Expm1 has the same erroneous behavior on CPU as Exp. 518 if (platform_ == "Host" && ty_ == F32) { 519 // TODO(b/73142289): The vectorized Exp implementation gives results outside 520 // our error spec in this range. 521 known_incorrect_begin_ = 1107296256 + 11583654; 522 known_incorrect_end_ = 1107296256 + 11629080; 523 } else if (platform_ == "Host" && ty_ == BF16) { 524 // TODO(jlebar): Is this a rounding error? Why doesn't it occur on XLA:GPU? 525 // 526 // Mismatch on 88.5 (0x42b1). 527 // Expected 2.72491739e+38 (0x7f4d), but got inf (0x7f80). 528 known_incorrect_begin_ = 0x42b1; 529 known_incorrect_end_ = 0x42b2; 530 } 531 532 Run(Expm1, std::expm1); 533 } 534 535 // It feels a little overkill to exhaustively test sqrt and pow(x, 0.5), but 536 // this *did* find a bug, namely that some backends were assuming sqrt(x) == 537 // pow(x, 0.5), but this is not true for x == -inf. 538 XLA_TEST_P(ExhaustiveOpTest, PowOneHalf) { 539 Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }, 540 +[](float x) { return std::pow(x, 0.5f); }); 541 } 542 543 XLA_TEST_P(ExhaustiveOpTest, Rsqrt) { 544 Run( 545 Rsqrt, +[](float x) { return 1 / std::sqrt(x); }); 546 } 547 548 XLA_TEST_P(ExhaustiveOpTest, Sqrt) { 549 if (platform_ == "Host" || platform_ == "CUDA") { 550 strict_signed_zeros_ = true; 551 } 552 553 Run(Sqrt, std::sqrt); 554 } 555 556 // TODO(jlebar): Add remaining trig functions. Don't forget Atan2! 557 // TODO(jlebar): Test trig functions over complex inputs. 558 XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); } 559 560 XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); } 561 XLA_TEST_P(ExhaustiveOpTest, Erfc) { Run(Erfc, std::erfc); } 562 XLA_TEST_P(ExhaustiveOpTest, ErfInv) { Run(ErfInv, HostErfInv); } 563 XLA_TEST_P(ExhaustiveOpTest, Digamma) { 564 if (platform_ != "Host" && platform_ != "CUDA") { 565 // TODO(b/123956399): This is a fairly high error, significantly higher than 566 // we see on CPU/GPU. 567 rel_err_ = 0.01; 568 abs_err_ = 0.01; 569 } 570 571 if (platform_ == "CUDA") { 572 // On GPU we get a wrong answer for the denormal inputs +/-2.93873588e-39 573 // (0x00200000 and 0x80200000). These should return -/+inf (at least 574 // according to our reference implementation!) but XLA:GPU returns 575 // -/+3.40282326e+38 (0xff7ffffe and 0x7f7ffffe). 576 // 577 // I deem this an acceptable result, as XLA:GPU flushes denormals, and as 578 // the results we get here are very close to MAX_FLOAT. We just hardcode 579 // these results, as this is better than ignoring these inputs altogether. 580 auto host_digamma_with_gpu_ftz_errors = +[](float x) { 581 if (absl::bit_cast<uint32>(x) == 0x00200000 || 582 absl::bit_cast<uint32>(x) == 0x80200000) { 583 return std::copysign(std::numeric_limits<float>::max(), -x); 584 } 585 return HostDigamma(x); 586 }; 587 Run(Digamma, host_digamma_with_gpu_ftz_errors); 588 } else { 589 Run(Digamma, HostDigamma); 590 } 591 } 592 XLA_TEST_P(ExhaustiveOpTest, Lgamma) { 593 // Our implementation gets within 0.0001 rel error except for ~20 denormal 594 // inputs on GPU. Anyway 0.001 rel error should be good enough for lgamma. 595 if (platform_ == "CUDA" && (ty_ == F32 || ty_ == F16)) { 596 rel_err_ = 0.001; 597 } 598 if (platform_ != "Host" && platform_ != "CUDA") { 599 // TODO(b/123956399): This is a fairly high error, significantly higher than 600 // we see on CPU/GPU. 601 rel_err_ = 0.01; 602 abs_err_ = 0.01; 603 604 // Overflows for to inf for input 4.08500343e+36 (0x7c44af8e). 605 if (ty_ == F32) { 606 known_incorrect_begin_ = 0x7c44af8e; 607 known_incorrect_end_ = 0x7c44af8e + 1; 608 } 609 } 610 Run(Lgamma, std::lgamma); 611 } 612 613 XLA_TEST_P(ExhaustiveOpTest, Round) { Run(Round, std::round); } 614 615 std::vector<std::pair<int64, int64>> CreateExhaustiveF32Ranges() { 616 // We break up the 2^32-element space into small'ish chunks to keep peak 617 // memory usage low. 618 std::vector<std::pair<int64, int64>> result; 619 const int64 step = 1 << 25; 620 for (int64 i = 0; i < (1l << 32); i += step) { 621 result.push_back({i, i + step}); 622 } 623 return result; 624 } 625 626 INSTANTIATE_TEST_SUITE_P( 627 F32, ExhaustiveOpTest, 628 ::testing::Combine(::testing::Values(F32), 629 ::testing::ValuesIn(CreateExhaustiveF32Ranges()))); 630 631 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) 632 INSTANTIATE_TEST_SUITE_P( 633 F16, ExhaustiveOpTest, 634 ::testing::Combine(::testing::Values(F16), 635 ::testing::Values(std::make_pair(0, 1 << 16)))); 636 #endif 637 638 #if defined(XLA_BACKEND_SUPPORTS_BFLOAT16) 639 INSTANTIATE_TEST_SUITE_P( 640 BF16, ExhaustiveOpTest, 641 ::testing::Combine(::testing::Values(BF16), 642 ::testing::Values(std::make_pair(0, 1 << 16)))); 643 #endif 644 645 } // namespace 646 } // namespace xla 647