Home | History | Annotate | Download | only in tests
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <cmath>
     17 #include "absl/base/casts.h"
     18 #include "tensorflow/compiler/xla/client/lib/constants.h"
     19 #include "tensorflow/compiler/xla/client/lib/math.h"
     20 #include "tensorflow/compiler/xla/client/xla_builder.h"
     21 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     22 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     23 #include "tensorflow/compiler/xla/tests/test_macros.h"
     24 
     25 namespace xla {
     26 namespace {
     27 
     28 using Eigen::half;
     29 
     30 template <typename T, size_t N>
     31 T EvaluatePolynomial(T x, const std::array<T, N>& coeffs) {
     32   T result = 0;
     33   for (T c : coeffs) {
     34     result = result * x + c;
     35   }
     36   return result;
     37 }
     38 
     39 // There's no std::erfinv, so we have to implement it ourselves.  This follows
     40 // Wichura 1998, https://www.jstor.org/stable/2347330 which, notably, is a
     41 // different implementation from that in math.cc.
     42 float HostErfInv(float x) {
     43   std::array<double, 8> kPolyA = {
     44       8.8709406962545514830200e2, 1.1819493347062294404278e4,
     45       2.3782041382114385731252e4, 1.6235862515167575384252e4,
     46       4.8548868893843886794648e3, 6.9706266534389598238465e2,
     47       4.7072688112383978012285e1, 1.1975323115670912564578e0,
     48   };
     49   std::array<double, 8> kPolyB = {
     50       5.2264952788528545610e3, 2.8729085735721942674e4, 3.9307895800092710610e4,
     51       2.1213794301586595867e4, 5.3941960214247511077e3, 6.8718700749205790830e2,
     52       4.2313330701600911252e1, 1.0000000000000000000e0,
     53   };
     54   std::array<double, 8> kPolyC = {
     55       7.74545014278341407640e-4, 2.27238449892691845833e-2,
     56       2.41780725177450611770e-1, 1.27045825245236838258e0,
     57       3.64784832476320460504e0,  5.76949722146069140550e0,
     58       4.63033784615654529590e0,  1.42343711074968357734e0,
     59   };
     60   std::array<double, 8> kPolyD = {
     61       1.4859850019840355905497876e-9, 7.7441459065157709165577218e-4,
     62       2.1494160384252876777097297e-2, 2.0945065210512749128288442e-1,
     63       9.7547832001787427186894837e-1, 2.3707661626024532365971225e0,
     64       2.9036514445419946173133295e0,  1.4142135623730950488016887e0,
     65   };
     66   std::array<double, 8> kPolyE = {
     67       2.01033439929228813265e-7, 2.71155556874348757815e-5,
     68       1.24266094738807843860e-3, 2.65321895265761230930e-2,
     69       2.96560571828504891230e-1, 1.78482653991729133580e0,
     70       5.46378491116411436990e0,  6.65790464350110377720e0,
     71   };
     72   std::array<double, 8> kPolyF = {
     73       2.891024605872965461538222e-15, 2.010321207683943062279931e-7,
     74       2.611088405080593625138020e-5,  1.112800997078859844711555e-3,
     75       2.103693768272068968719679e-2,  1.936480946950659106176712e-1,
     76       8.482908416595164588112026e-1,  1.414213562373095048801689e0,
     77   };
     78 
     79   if (std::abs(x) > 1 || std::isnan(x)) {
     80     return std::numeric_limits<float>::quiet_NaN();
     81   }
     82   if (std::abs(x) == 1) {
     83     return std::copysign(std::numeric_limits<float>::infinity(), x);
     84   }
     85 
     86   float unsigned_result = [&] {
     87     float y = std::abs(x);
     88     if (y <= 0.85) {
     89       double r = 0.180625 - 0.25 * y * y;
     90       return (y * EvaluatePolynomial(r, kPolyA)) /
     91              EvaluatePolynomial(r, kPolyB);
     92     } else {
     93       double r = std::sqrt(std::log(2.0) - std::log1p(-y));
     94       if (r <= 5.0) {
     95         r -= 1.6;
     96         return EvaluatePolynomial(r, kPolyC) / EvaluatePolynomial(r, kPolyD);
     97       } else {
     98         r -= 5;
     99         return EvaluatePolynomial(r, kPolyE) / EvaluatePolynomial(r, kPolyF);
    100       }
    101     }
    102   }();
    103   return std::copysign(unsigned_result, x);
    104 }
    105 
    106 // Digamma implementation using a polynomial from Cephes.  Notably this is a
    107 // different implementation from the one in math.cc.
    108 float HostDigamma(float x) {
    109   // Euler-Mascheroni constant
    110   float kGamma = 0.57721566490153286061;
    111   float kPi = M_PI;
    112 
    113   std::array<float, 4> kPoly = {
    114       -4.16666666666666666667E-3,
    115       3.96825396825396825397E-3,
    116       -8.33333333333333333333E-3,
    117       8.33333333333333333333E-2,
    118   };
    119 
    120   float reflection = 0;
    121   if (x <= 0) {
    122     float floor = std::floor(x);
    123     if (x == floor) {
    124       return std::numeric_limits<float>::quiet_NaN();
    125     }
    126     // Compute reflection term, pi * cot(pi * x).
    127     reflection = x - floor;
    128     if (reflection == 0.5) {
    129       reflection = 0;
    130     } else {
    131       if (reflection > 0.5) {
    132         reflection = x - (floor + 1.0f);
    133       }
    134       reflection = kPi / std::tan(kPi * reflection);
    135     }
    136     x = 1 - x;
    137   }
    138 
    139   float result = 0;
    140   if (x <= 10 && x == std::floor(x)) {
    141     // Special case for integers <= 10.
    142     for (int i = 1; i < x; ++i) {
    143       result += 1.0f / i;
    144     }
    145     result -= kGamma;
    146   } else {
    147     float w = 0;
    148     for (; x < 10; ++x) {
    149       w += 1.0f / x;
    150     }
    151     if (x < 1e8) {
    152       float z = 1.0f / (x * x);
    153       result = z * EvaluatePolynomial(z, kPoly);
    154     }
    155     result = std::log(x) - 0.5f / x - result - w;
    156   }
    157 
    158   // Compute the final, reflected value.
    159   return result - reflection;
    160 }
    161 
    162 // For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be
    163 // guaranteed that we're printing the full number.
    164 //
    165 // (The general formula is, given a floating-point number with S significand
    166 // bits, the number of decimal digits needed to print it to full precision is
    167 //
    168 //   ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103).
    169 //
    170 // See https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf.)
    171 string StringifyNum(float x) {
    172   return absl::StrFormat("%0.9g (0x%08x)", x, absl::bit_cast<uint32>(x));
    173 }
    174 
    175 string StringifyNum(half x) {
    176   return absl::StrFormat("%0.5g (0x%04x)", static_cast<float>(x),
    177                          absl::bit_cast<uint16>(x));
    178 }
    179 
    180 string StringifyNum(bfloat16 x) {
    181   return absl::StrFormat("%0.4g (0x%04x)", static_cast<float>(x),
    182                          absl::bit_cast<uint16>(x));
    183 }
    184 
    185 // Test parameter is a tuple containing
    186 //   - primitive type under test,
    187 //   - (begin, end) range under test, as zero-extended int64s bitcast to the
    188 //     primtive type under test.
    189 class ExhaustiveOpTest
    190     : public ClientLibraryTestBase,
    191       public ::testing::WithParamInterface<
    192           std::tuple<PrimitiveType, std::pair<int64, int64>>> {
    193  public:
    194   ExhaustiveOpTest()
    195       : ty_(std::get<0>(GetParam())), platform_(client_->platform()->Name()) {}
    196 
    197   void Run(std::function<XlaOp(XlaOp)> enqueue_op,
    198            float (*evaluate_op)(float)) {
    199     SetFastMathDisabled(true);
    200 
    201     // Run all HLO passes.  In particular, constant folding is disabled by
    202     // default for tests, but we need to run it in order to tickle some bugs.
    203     mutable_debug_options()->clear_xla_disable_hlo_passes();
    204 
    205     PrimitiveType ty;
    206     std::tie(ty, std::ignore) = GetParam();
    207 
    208     switch (ty) {
    209       case F32:
    210         SetDefaultErrSpec(0.0001, 0.0001);
    211         RunImpl<float, uint32>(enqueue_op, evaluate_op);
    212         break;
    213       case F16:
    214         SetDefaultErrSpec(0.001, 0.001);
    215         RunImpl<half, uint16>(enqueue_op, evaluate_op);
    216         break;
    217       case BF16:
    218         SetDefaultErrSpec(0.001, 0.01);
    219         RunImpl<bfloat16, uint16>(enqueue_op, evaluate_op);
    220         break;
    221       default:
    222         LOG(FATAL) << "Unhandled type.";
    223     }
    224   }
    225 
    226   void SetDefaultErrSpec(float abs_err, float rel_err) {
    227     if (!abs_err_.has_value()) {
    228       abs_err_ = abs_err;
    229     }
    230     if (!rel_err_.has_value()) {
    231       rel_err_ = rel_err;
    232     }
    233   }
    234 
    235   template <typename T, typename IntegralT>
    236   void RunImpl(std::function<XlaOp(XlaOp)> enqueue_op,
    237                float (*evaluate_op)(float)) {
    238     static_assert(
    239         sizeof(T) == sizeof(IntegralT),
    240         "IntegralT must be an unsigned integer type of the same width as T.");
    241 
    242     PrimitiveType ty;
    243     std::pair<int64, int64> test_range;
    244     std::tie(ty, test_range) = GetParam();
    245     int64 begin, end;
    246     std::tie(begin, end) = test_range;
    247 
    248     if (begin >= known_incorrect_begin_ && end <= known_incorrect_end_) {
    249       LOG(INFO) << absl::StreamFormat(
    250           "Skipping this shard, as the range under test, [%d, %d), falls "
    251           "entirely within the known-incorrect range [%d, %d).",
    252           begin, end, known_incorrect_begin_, known_incorrect_end_);
    253       return;
    254     }
    255 
    256     LOG(INFO) << "Checking range [" << begin << ", " << end << ")";
    257 
    258     int64 input_size = end - begin;
    259     Literal input_literal = LiteralUtil::CreateFromDimensions(ty, {input_size});
    260     absl::Span<T> input_arr = input_literal.data<T>();
    261     for (int64 i = 0; i < input_size; i++) {
    262       IntegralT input_val = i + begin;
    263       // If the operation is known to be buggy on a specific input clamp that
    264       // input to 0 under the assumption that the op is at least correct on 0.
    265       if (input_val >= known_incorrect_begin_ &&
    266           input_val < known_incorrect_end_) {
    267         input_arr[i] = T{0};
    268       } else {
    269         input_arr[i] = absl::bit_cast<T>(input_val);
    270       }
    271     }
    272 
    273     TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
    274                             BuildAndRunComputation(enqueue_op, input_literal));
    275     ExpectNear<T>(input_literal, result_literal, evaluate_op);
    276   }
    277 
    278   StatusOr<Literal> BuildAndRunComputation(
    279       const std::function<XlaOp(XlaOp)>& enqueue_op,
    280       const Literal& input_literal) {
    281     XlaBuilder builder(TestName());
    282     auto input = Parameter(&builder, 0, input_literal.shape(), "input");
    283     enqueue_op(input);
    284     TF_ASSIGN_OR_RETURN(XlaComputation comp, builder.Build());
    285 
    286     // Build and run the computation using the LocalClient API, rather than the
    287     // plain Client API, which is used by ClientLibraryTestBase.  This is
    288     // because the plain Client API results does more memcpys to/from Literals,
    289     // and that's slow given that we're touching a lot of data here.
    290     //
    291     // Copy debug options from ClientLibraryTestBase.  In particular, we're
    292     // interested in disabling constant folding.
    293     ExecutableBuildOptions build_opts;
    294     *build_opts.mutable_debug_options() = *mutable_debug_options();
    295     TF_ASSIGN_OR_RETURN(
    296         auto executable,
    297         client_->Compile(comp, {&input_literal.shape()}, build_opts));
    298 
    299     TF_ASSIGN_OR_RETURN(
    300         ScopedShapedBuffer input_data,
    301         client_->LiteralToShapedBuffer(input_literal, /*device_ordinal=*/0));
    302 
    303     ExecutableRunOptions run_opts;
    304     run_opts.set_allocator(client_->backend().memory_allocator());
    305     run_opts.set_intra_op_thread_pool(
    306         client_->backend().eigen_intra_op_thread_pool_device());
    307     TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
    308                         executable->Run({&input_data}, run_opts));
    309 
    310     TF_ASSIGN_OR_RETURN(Literal result_literal,
    311                         client_->ShapedBufferToLiteral(result));
    312     return std::move(result_literal);
    313   }
    314 
    315   template <typename T>
    316   bool IsClose(T expected, T actual) {
    317     float expected_f32 = static_cast<float>(expected);
    318     float actual_f32 = static_cast<float>(actual);
    319     float abs_err = std::abs(expected_f32 - actual_f32);
    320     float rel_err = abs_err / std::abs(expected_f32);
    321     if (strict_signed_zeros_ && actual == T{0} && expected == T{0}) {
    322       // Check sign of zero.
    323       return std::signbit(actual_f32) == std::signbit(expected_f32);
    324     }
    325     return abs_err < *abs_err_ || rel_err < *rel_err_ ||
    326            (std::isnan(expected_f32) && std::isnan(actual_f32)) ||
    327            (std::isinf(expected_f32) && std::isinf(actual_f32) &&
    328             (expected_f32 > 0) == (actual_f32 > 0));
    329   }
    330 
    331   template <typename T>
    332   void ExpectNear(const Literal& input_literal, const Literal& result_literal,
    333                   float (*evaluate_op)(float)) {
    334     // We essentially reimplement LiteralTestUtil::Near here because
    335     //  a) this streamlined implementation is much faster, and
    336     //  b) we can print out better error messages (namely, we can print out
    337     //     which floating-point value input failed, while LiteralTestUtil::Near
    338     //     can only print out the input index that failed).
    339     //  c) we need special handling of certain inputs.  For example, we say that
    340     //     a denormal input has multiple correct outputs (namely, f(x) and f(0))
    341     //     and just needs to be close to one of them.
    342     absl::Span<const T> input_arr = input_literal.data<T>();
    343     absl::Span<const T> result_arr = result_literal.data<T>();
    344     ASSERT_EQ(result_arr.size(), input_arr.size());
    345     int64 mismatches = 0;
    346     // Hoisting these out of the loop is a nice speedup on shards that have many
    347     // denormals.
    348     const T expected_at_pos_zero = static_cast<T>(evaluate_op(0));
    349     const T expected_at_neg_zero = static_cast<T>(evaluate_op(-0.0));
    350     for (int64 i = 0; i < input_arr.size(); ++i) {
    351       T input = input_arr[i];
    352       float input_f32 = static_cast<float>(input);
    353       T actual = result_arr[i];
    354       T expected = static_cast<T>(evaluate_op(input_f32));
    355 
    356       if (IsClose(expected, actual)) {
    357         continue;
    358       }
    359 
    360       // Easy case: If `input` is not denormal and !IsClose(expected, actual),
    361       // print an error.
    362       //
    363       // (This doesn't correctly detect f16 and bfloat16 denormals!  This seems
    364       // to be OK for now, but at some point we may need to implement fpclassify
    365       // for half and bfloat.)
    366       if (std::fpclassify(input_f32) != FP_SUBNORMAL) {
    367         PrintMismatch(&mismatches, [&] {
    368           return absl::StrFormat("Mismatch on %s. Expected %s, but got %s.",
    369                                  StringifyNum(input), StringifyNum(expected),
    370                                  StringifyNum(actual));
    371         });
    372         continue;
    373       }
    374 
    375       // Otherwise, `input` is denormal.  For denormal inputs, we accept answers
    376       // that are close to any of:
    377       //
    378       //   - evaluate_op(input)
    379       //   - evaluate_op(+/-0), where the sign of 0 equal to the sign of
    380       //     `input`,
    381       //   - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of
    382       //     0 is the opposite of `input`.
    383       T sign_preserving_ftz_expected =
    384           std::signbit(input_f32) ? expected_at_neg_zero : expected_at_pos_zero;
    385       T sign_nonpreserving_ftz_expected =
    386           std::signbit(input_f32) ? expected_at_pos_zero : expected_at_neg_zero;
    387       if (IsClose(sign_preserving_ftz_expected, actual) ||
    388           (relaxed_denormal_signs_ &&
    389            IsClose(sign_nonpreserving_ftz_expected, actual))) {
    390         continue;
    391       }
    392 
    393       if (relaxed_denormal_signs_) {
    394         PrintMismatch(&mismatches, [&] {
    395           return absl::StrFormat(
    396               "Mismatch on denormal value %s.  Expected one of:\n"
    397               "  %10s (evaluated at full-precision value)\n"
    398               "  %10s (evaluated after flushing to sign-preserving zero)\n"
    399               "  %10s (evaluated after flushing to non-sign-preserving "
    400               "zero)\n"
    401               "but got %s.",
    402               StringifyNum(input), StringifyNum(expected),
    403               StringifyNum(sign_preserving_ftz_expected),
    404               StringifyNum(sign_nonpreserving_ftz_expected),
    405               StringifyNum(actual));
    406         });
    407       } else {
    408         PrintMismatch(&mismatches, [&] {
    409           return absl::StrFormat(
    410               "Mismatch on denormal value %s.  Expected one of:\n"
    411               "  %10s (evaluated at full-precision value)\n"
    412               "  %10s (evaluated after flushing to sign-preserving zero)\n"
    413               "but got %s.",
    414               StringifyNum(input), StringifyNum(expected),
    415               StringifyNum(sign_preserving_ftz_expected), StringifyNum(actual));
    416         });
    417       }
    418     }
    419     EXPECT_EQ(mismatches, 0);
    420   }
    421 
    422   template <typename ErrorGenerator>
    423   void PrintMismatch(int64* mismatches, const ErrorGenerator& err_generator) {
    424     // We send a few mismatches to gunit so they show up nicely in test logs.
    425     // Then we send more to LOG(ERROR).  The remainder we squelch unless we're
    426     // at vlog level 2.
    427     constexpr int64 kMaxMismatchesLoggedToGunit = 10;
    428     constexpr int64 kMaxMismatchesLoggedToErr = 1000;
    429 
    430     (*mismatches)++;
    431     if (*mismatches < kMaxMismatchesLoggedToGunit) {
    432       FAIL() << err_generator();
    433     } else if (*mismatches < kMaxMismatchesLoggedToErr || VLOG_IS_ON(2)) {
    434       LOG(ERROR) << err_generator();
    435     } else if (*mismatches == kMaxMismatchesLoggedToErr) {
    436       LOG(ERROR) << "Not printing any more mismatches; pass "
    437                     "--vmodule=exhaustive_f32__op_test=2 to see "
    438                     "all of them.";
    439     }
    440   }
    441 
    442   // The following members are set during construction so testcases can read
    443   // these values and use them e.g. to influence the values given to the mutable
    444   // members below.
    445 
    446   // The primitive type under test.
    447   const PrimitiveType ty_;
    448 
    449   // The platform under test.
    450   const string platform_;
    451 
    452   // Tests can set the following variables for control over execution.  This is
    453   // safe because each XLA_TEST_P instantiates a new instance of this class.
    454 
    455   // Testing will ignore the given range (encoded as bitwise representations of
    456   // the type under test zero-extended to int64).
    457   int64 known_incorrect_begin_ = 0;
    458   int64 known_incorrect_end_ = 0;
    459 
    460   // If unset, reasonable defaults will be used depending on the type under
    461   // test.
    462   absl::optional<float> abs_err_;
    463   absl::optional<float> rel_err_;
    464 
    465   // If true, will consider -0 not near to +0 and vice versa.  Note that
    466   // +epsilon may still be considered close to -0, depending on the error spec;
    467   // this only covers the case when both `expected` and `actual` are equal to 0.
    468   bool strict_signed_zeros_ = false;
    469 
    470   // If true, allows denormals to be flushed to non-sign-preserving 0.
    471   //
    472   // For example, normally we'd expect sqrt(-denormal) to be either nan (sqrt of
    473   // a negative number) or -inf (flush the denormal to sign-perserving zero,
    474   // then sqrt(-0)).  But with this as true, we'll also accept 0 (sqrt(0)).
    475   //
    476   // XLA:GPU preserves denormal signs, but other backends don't.
    477   bool relaxed_denormal_signs_ = platform_ != "CUDA";
    478 };
    479 
    480 XLA_TEST_P(ExhaustiveOpTest, Log) {
    481   if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
    482     abs_err_ = 0.001;
    483     rel_err_ = 0.001;
    484   }
    485 
    486   Run(Log, std::log);
    487 }
    488 
    489 XLA_TEST_P(ExhaustiveOpTest, Log1p) {
    490   if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
    491     abs_err_ = 0.001;
    492     rel_err_ = 0.001;
    493   }
    494 
    495   Run(Log1p, std::log1p);
    496 }
    497 
    498 XLA_TEST_P(ExhaustiveOpTest, Exp) {
    499   if (platform_ == "Host" && ty_ == F32) {
    500     // TODO(b/73142289): The vectorized Exp implementation gives results outside
    501     // our error spec in this range.
    502     known_incorrect_begin_ = 1107296256 + 11583654;
    503     known_incorrect_end_ = 1107296256 + 11629080;
    504   } else if (platform_ == "Host" && ty_ == BF16) {
    505     // TODO(jlebar): Is this a rounding error?  Why doesn't it occur on XLA:GPU?
    506     //
    507     // Mismatch on 88.5 (0x42b1).
    508     //   Expected 2.72491739e+38 (0x7f4d), but got inf (0x7f80).
    509     known_incorrect_begin_ = 0x42b1;
    510     known_incorrect_end_ = 0x42b2;
    511   }
    512 
    513   Run(Exp, std::exp);
    514 }
    515 
    516 XLA_TEST_P(ExhaustiveOpTest, Expm1) {
    517   // Expm1 has the same erroneous behavior on CPU as Exp.
    518   if (platform_ == "Host" && ty_ == F32) {
    519     // TODO(b/73142289): The vectorized Exp implementation gives results outside
    520     // our error spec in this range.
    521     known_incorrect_begin_ = 1107296256 + 11583654;
    522     known_incorrect_end_ = 1107296256 + 11629080;
    523   } else if (platform_ == "Host" && ty_ == BF16) {
    524     // TODO(jlebar): Is this a rounding error?  Why doesn't it occur on XLA:GPU?
    525     //
    526     // Mismatch on 88.5 (0x42b1).
    527     //   Expected 2.72491739e+38 (0x7f4d), but got inf (0x7f80).
    528     known_incorrect_begin_ = 0x42b1;
    529     known_incorrect_end_ = 0x42b2;
    530   }
    531 
    532   Run(Expm1, std::expm1);
    533 }
    534 
    535 // It feels a little overkill to exhaustively test sqrt and pow(x, 0.5), but
    536 // this *did* find a bug, namely that some backends were assuming sqrt(x) ==
    537 // pow(x, 0.5), but this is not true for x == -inf.
    538 XLA_TEST_P(ExhaustiveOpTest, PowOneHalf) {
    539   Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); },
    540       +[](float x) { return std::pow(x, 0.5f); });
    541 }
    542 
    543 XLA_TEST_P(ExhaustiveOpTest, Rsqrt) {
    544   Run(
    545       Rsqrt, +[](float x) { return 1 / std::sqrt(x); });
    546 }
    547 
    548 XLA_TEST_P(ExhaustiveOpTest, Sqrt) {
    549   if (platform_ == "Host" || platform_ == "CUDA") {
    550     strict_signed_zeros_ = true;
    551   }
    552 
    553   Run(Sqrt, std::sqrt);
    554 }
    555 
    556 // TODO(jlebar): Add remaining trig functions.  Don't forget Atan2!
    557 // TODO(jlebar): Test trig functions over complex inputs.
    558 XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); }
    559 
    560 XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); }
    561 XLA_TEST_P(ExhaustiveOpTest, Erfc) { Run(Erfc, std::erfc); }
    562 XLA_TEST_P(ExhaustiveOpTest, ErfInv) { Run(ErfInv, HostErfInv); }
    563 XLA_TEST_P(ExhaustiveOpTest, Digamma) {
    564   if (platform_ != "Host" && platform_ != "CUDA") {
    565     // TODO(b/123956399): This is a fairly high error, significantly higher than
    566     // we see on CPU/GPU.
    567     rel_err_ = 0.01;
    568     abs_err_ = 0.01;
    569   }
    570 
    571   if (platform_ == "CUDA") {
    572     // On GPU we get a wrong answer for the denormal inputs +/-2.93873588e-39
    573     // (0x00200000 and 0x80200000).  These should return -/+inf (at least
    574     // according to our reference implementation!) but XLA:GPU returns
    575     // -/+3.40282326e+38 (0xff7ffffe and 0x7f7ffffe).
    576     //
    577     // I deem this an acceptable result, as XLA:GPU flushes denormals, and as
    578     // the results we get here are very close to MAX_FLOAT.  We just hardcode
    579     // these results, as this is better than ignoring these inputs altogether.
    580     auto host_digamma_with_gpu_ftz_errors = +[](float x) {
    581       if (absl::bit_cast<uint32>(x) == 0x00200000 ||
    582           absl::bit_cast<uint32>(x) == 0x80200000) {
    583         return std::copysign(std::numeric_limits<float>::max(), -x);
    584       }
    585       return HostDigamma(x);
    586     };
    587     Run(Digamma, host_digamma_with_gpu_ftz_errors);
    588   } else {
    589     Run(Digamma, HostDigamma);
    590   }
    591 }
    592 XLA_TEST_P(ExhaustiveOpTest, Lgamma) {
    593   // Our implementation gets within 0.0001 rel error except for ~20 denormal
    594   // inputs on GPU.  Anyway 0.001 rel error should be good enough for lgamma.
    595   if (platform_ == "CUDA" && (ty_ == F32 || ty_ == F16)) {
    596     rel_err_ = 0.001;
    597   }
    598   if (platform_ != "Host" && platform_ != "CUDA") {
    599     // TODO(b/123956399): This is a fairly high error, significantly higher than
    600     // we see on CPU/GPU.
    601     rel_err_ = 0.01;
    602     abs_err_ = 0.01;
    603 
    604     // Overflows for to inf for input 4.08500343e+36 (0x7c44af8e).
    605     if (ty_ == F32) {
    606       known_incorrect_begin_ = 0x7c44af8e;
    607       known_incorrect_end_ = 0x7c44af8e + 1;
    608     }
    609   }
    610   Run(Lgamma, std::lgamma);
    611 }
    612 
    613 XLA_TEST_P(ExhaustiveOpTest, Round) { Run(Round, std::round); }
    614 
    615 std::vector<std::pair<int64, int64>> CreateExhaustiveF32Ranges() {
    616   // We break up the 2^32-element space into small'ish chunks to keep peak
    617   // memory usage low.
    618   std::vector<std::pair<int64, int64>> result;
    619   const int64 step = 1 << 25;
    620   for (int64 i = 0; i < (1l << 32); i += step) {
    621     result.push_back({i, i + step});
    622   }
    623   return result;
    624 }
    625 
    626 INSTANTIATE_TEST_SUITE_P(
    627     F32, ExhaustiveOpTest,
    628     ::testing::Combine(::testing::Values(F32),
    629                        ::testing::ValuesIn(CreateExhaustiveF32Ranges())));
    630 
    631 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
    632 INSTANTIATE_TEST_SUITE_P(
    633     F16, ExhaustiveOpTest,
    634     ::testing::Combine(::testing::Values(F16),
    635                        ::testing::Values(std::make_pair(0, 1 << 16))));
    636 #endif
    637 
    638 #if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
    639 INSTANTIATE_TEST_SUITE_P(
    640     BF16, ExhaustiveOpTest,
    641     ::testing::Combine(::testing::Values(BF16),
    642                        ::testing::Values(std::make_pair(0, 1 << 16))));
    643 #endif
    644 
    645 }  // namespace
    646 }  // namespace xla
    647