Home | History | Annotate | Download | only in xla
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/literal_comparison.h"
     17 
     18 #include <unistd.h>
     19 #include <cmath>
     20 #include <vector>
     21 
     22 #include "absl/base/casts.h"
     23 #include "absl/strings/str_cat.h"
     24 #include "absl/strings/str_format.h"
     25 #include "tensorflow/compiler/xla/literal_util.h"
     26 #include "tensorflow/compiler/xla/util.h"
     27 #include "tensorflow/core/platform/env.h"
     28 
     29 using absl::StrAppend;
     30 using absl::StrAppendFormat;
     31 using absl::StrCat;
     32 
     33 namespace xla {
     34 namespace literal_comparison {
     35 namespace {
     36 
     37 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
     38 // able to transparently access the raw 16-bit value contained within.
     39 template <typename T>
     40 T GetRawValue(T val) {
     41   return val;
     42 }
     43 uint16 GetRawValue(Eigen::half val) { return val.x; }
     44 
     45 // Helper function for comparing a floating point type, FloatT, bitwise equal
     46 // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
     47 // -- on miscompare, a nice error message is given in the AssertionFailure.
     48 template <typename FloatT, typename UnsignedT>
     49 bool CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs,
     50                                absl::Span<const int64> multi_index) {
     51   auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
     52   auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
     53   return ulhs == urhs;
     54 }
     55 
     56 // Templated comparator that specializes for float equality comparison with the
     57 // bitwise helper above (this is the un-specialized fallback, to just use the
     58 // default gunit implementation).
     59 template <typename NativeT>
     60 bool CompareEqual(NativeT lhs, NativeT rhs,
     61                   absl::Span<const int64> multi_index) {
     62   return lhs == rhs;
     63 }
     64 
     65 // Specializations for floating types that do bitwise comparisons when equality
     66 // comparison is requested.
     67 template <>
     68 bool CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs,
     69                             absl::Span<const int64> multi_index) {
     70   return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs, multi_index);
     71 }
     72 template <>
     73 bool CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs,
     74                                absl::Span<const int64> multi_index) {
     75   return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs, multi_index);
     76 }
     77 template <>
     78 bool CompareEqual<float>(float lhs, float rhs,
     79                          absl::Span<const int64> multi_index) {
     80   return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs, multi_index);
     81 }
     82 template <>
     83 bool CompareEqual<double>(double lhs, double rhs,
     84                           absl::Span<const int64> multi_index) {
     85   return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs, multi_index);
     86 }
     87 template <>
     88 bool CompareEqual<complex64>(complex64 lhs, complex64 rhs,
     89                              absl::Span<const int64> multi_index) {
     90   return CompareEqual<float>(lhs.real(), rhs.real(), multi_index) &&
     91          CompareEqual<float>(lhs.imag(), rhs.imag(), multi_index);
     92 }
     93 template <>
     94 bool CompareEqual<complex128>(complex128 lhs, complex128 rhs,
     95                               absl::Span<const int64> multi_index) {
     96   return CompareEqual<double>(lhs.real(), rhs.real(), multi_index) &&
     97          CompareEqual<double>(lhs.imag(), rhs.imag(), multi_index);
     98 }
     99 
    100 template <typename NativeT, typename UnsignedT>
    101 Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs,
    102                               absl::Span<const int64> multi_index) {
    103   auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
    104   auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
    105   auto lhs_double = static_cast<double>(lhs);
    106   auto rhs_double = static_cast<double>(rhs);
    107     return InvalidArgument(
    108         "floating values are not bitwise-equal; and equality testing "
    109         "was requested: %s=%g=%a vs %s=%g=%a at array index %s",
    110         StrCat(absl::Hex(ulhs)), lhs_double, lhs_double,
    111         StrCat(absl::Hex(urhs)), rhs_double, rhs_double,
    112         LiteralUtil::MultiIndexAsString(multi_index));
    113 }
    114 
    115 template <typename NativeT>
    116 Status MakeErrorStatus(NativeT lhs, NativeT rhs,
    117                        absl::Span<const int64> multi_index) {
    118   return InvalidArgument(
    119       "first mismatch at array index %s:\n  expected value: %s\n  actual "
    120       "value:   %s",
    121       LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs));
    122 }
    123 
    124 template <>
    125 Status MakeErrorStatus(bfloat16 lhs, bfloat16 rhs,
    126                        absl::Span<const int64> multi_index) {
    127   return MakeBitwiseErrorStatus<bfloat16, uint16>(lhs, rhs, multi_index);
    128 }
    129 template <>
    130 Status MakeErrorStatus(Eigen::half lhs, Eigen::half rhs,
    131                        absl::Span<const int64> multi_index) {
    132   return MakeBitwiseErrorStatus<Eigen::half, uint16>(lhs, rhs, multi_index);
    133 }
    134 template <>
    135 Status MakeErrorStatus(float lhs, float rhs,
    136                        absl::Span<const int64> multi_index) {
    137   return MakeBitwiseErrorStatus<float, uint32>(lhs, rhs, multi_index);
    138 }
    139 template <>
    140 Status MakeErrorStatus(double lhs, double rhs,
    141                        absl::Span<const int64> multi_index) {
    142   return MakeBitwiseErrorStatus<double, uint64>(lhs, rhs, multi_index);
    143 }
    144 template <>
    145 Status MakeErrorStatus(complex64 lhs, complex64 rhs,
    146                        absl::Span<const int64> multi_index) {
    147   if (!CompareEqual<float>(lhs.real(), rhs.real(), multi_index)) {
    148     return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
    149   }
    150   return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
    151 }
    152 template <>
    153 Status MakeErrorStatus(complex128 lhs, complex128 rhs,
    154                        absl::Span<const int64> multi_index) {
    155   if (!CompareEqual<double>(lhs.real(), rhs.real(), multi_index)) {
    156     return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
    157   }
    158   return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
    159 }
    160 
    161 // A recursive function which iterates through every index of expected and
    162 // actual literal and compares their values elementwise. Returns true if all
    163 // elements are equal.
    164 template <typename NativeT>
    165 Status Equal(LiteralSlice expected, LiteralSlice actual,
    166              absl::Span<int64> multi_index, int64 dimension) {
    167   if (dimension == expected.shape().dimensions_size()) {
    168     NativeT expected_value = expected.Get<NativeT>(multi_index);
    169     NativeT actual_value = actual.Get<NativeT>(multi_index);
    170     bool result =
    171         CompareEqual<NativeT>(expected_value, actual_value, multi_index);
    172     return result ? Status::OK()
    173                   : MakeErrorStatus<NativeT>(expected_value, actual_value,
    174                                              multi_index);
    175   }
    176 
    177   Status result;
    178   for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
    179     multi_index[dimension] = i;
    180     TF_RETURN_IF_ERROR(
    181         Equal<NativeT>(expected, actual, multi_index, dimension + 1));
    182   }
    183   return result;
    184 }
    185 
    186 // Gets the total element count.  For tuples, this is not the count of tuple
    187 // elements, but the sum of elements of each tuple element.
    188 int64 RecursiveElementCount(const Shape& shape) {
    189   if (shape.IsTuple()) {
    190     const int64 tuple_elements = ShapeUtil::TupleElementCount(shape);
    191     int64 total = 0;
    192     for (int64 i = 0; i < tuple_elements; ++i) {
    193       total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i));
    194     }
    195     return total;
    196   } else if (shape.IsArray()) {
    197     return ShapeUtil::ElementsIn(shape);
    198   } else {
    199     return 0;
    200   }
    201 }
    202 
    203 // Returns whether the given value is infinity.
    204 template <typename NativeT>
    205 bool IsInf(NativeT val) {
    206   return std::isinf(val);
    207 }
    208 
    209 template <>
    210 bool IsInf<half>(half val) {
    211   return std::isinf(static_cast<float>(val));
    212 }
    213 
    214 // Returns whether the given value is nan.
    215 template <typename NativeT>
    216 float IsNan(NativeT value) {
    217   return std::isnan(value);
    218 }
    219 
    220 template <>
    221 float IsNan(half value) {
    222   return IsNan<float>(static_cast<float>(value));
    223 }
    224 
    225 // Converts the given floating-point value to a string.
    226 template <typename NativeT>
    227 string FpValueToString(NativeT value) {
    228   return absl::StrFormat("%8.4g", static_cast<double>(value));
    229 }
    230 
    231 template <>
    232 string FpValueToString<complex64>(complex64 value) {
    233   return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag());
    234 }
    235 
    236 template <>
    237 string FpValueToString<complex128>(complex128 value) {
    238   return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag());
    239 }
    240 
    241 // Returns the absolute value of the given floating point value. This function
    242 // is used instead of std::abs directly in order to allow type-dependent
    243 // implementations for NearComparator.
    244 template <typename NativeT>
    245 float FpAbsoluteValue(NativeT value) {
    246   return std::abs(value);
    247 }
    248 
    249 template <>
    250 float FpAbsoluteValue(bfloat16 value) {
    251   return FpAbsoluteValue<float>(static_cast<float>(value));
    252 }
    253 
    254 template <>
    255 float FpAbsoluteValue(half value) {
    256   return FpAbsoluteValue<float>(static_cast<float>(value));
    257 }
    258 
    259 // Helper class for comparing floating-point literals within an error bound.
    260 template <typename NativeT>
    261 class NearComparator {
    262  public:
    263   // Compares the two array literals elementwise and returns a comparison
    264   // result. The comparison is ok() if all actual and expected elements are
    265   // within the given error bound. In case of error, the status contains a
    266   // detailed message about the discrepancy.
    267   static Status Compare(const LiteralSlice& expected,
    268                         const LiteralSlice& actual, ErrorSpec error,
    269                         bool detailed_message,
    270                         const MiscompareCallback& miscompare_callback) {
    271     NearComparator<NativeT> comparator(expected, actual, error,
    272                                        detailed_message, miscompare_callback);
    273     return comparator.Run();
    274   }
    275 
    276  private:
    277   // Data structure encapsulating metadata about a single element mismatch.
    278   struct Mismatch {
    279     NativeT actual;
    280     NativeT expected;
    281     float rel_error;
    282     float abs_error;
    283 
    284     // The linear index of the failure within the shape. This linear index is
    285     // from the 'actual' literal.
    286     int64 linear_index;
    287 
    288     bool operator<(const Mismatch& other) const {
    289       return rel_error < other.rel_error;
    290     }
    291 
    292     string ToString(const Shape& shape) const {
    293       return absl::StrFormat(
    294           "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
    295           FpValueToString(actual), FpValueToString(expected),
    296           LiteralUtil::MultiIndexAsString(
    297               IndexUtil::LinearIndexToMultidimensionalIndex(shape,
    298                                                             linear_index)),
    299           rel_error, abs_error);
    300     }
    301   };
    302 
    303   NearComparator(const LiteralSlice& expected, const LiteralSlice& actual,
    304                  ErrorSpec error, bool detailed_message,
    305                  const MiscompareCallback& miscompare_callback)
    306       : expected_(expected),
    307         actual_(actual),
    308         error_(error),
    309         detailed_message_(detailed_message),
    310         miscompare_callback_(miscompare_callback),
    311         abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}),
    312         abs_error_buckets_(kErrorBucketBounds.size(), 0),
    313         rel_error_buckets_(kErrorBucketBounds.size(), 0) {}
    314 
    315   // Runs the comparison between expected and actual literals.
    316   Status Run() {
    317     // If the shapes mismatch, we simply fail the expectation instead of
    318     // printing out data, as it's a type error rather than a value error.
    319     TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape()));
    320     if (!expected_.shape().IsArray()) {
    321       return InvalidArgument("Expected array shape; got %s.",
    322                              ShapeUtil::HumanString(expected_.shape()));
    323     }
    324 
    325     mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED));
    326     mismatches_.PopulateWithValue(false);
    327 
    328     CompareLiterals();
    329 
    330     if (num_mismatches_ == 0) {
    331       return Status::OK();
    332     } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) {
    333       miscompare_callback_(expected_, actual_, mismatches_);
    334     }
    335     return InvalidArgument("%s", ErrorMessage());
    336   }
    337 
    338   // Insert the given absolute value into the absolute value bucket vector. The
    339   // bounds of the buckets are given by kAbsValueBucketBounds.
    340   void UpdateAbsValueBucket(NativeT value, bool is_mismatch) {
    341     // Adjust the bucket containing the absolute values of the 'actual'
    342     // elements.
    343     const float abs_value = FpAbsoluteValue(value);
    344     for (int i = 0; i < abs_value_buckets_.size(); ++i) {
    345       if (i == abs_value_buckets_.size() - 1 ||
    346           (abs_value >= kAbsValueBucketBounds[i] &&
    347            abs_value < kAbsValueBucketBounds[i + 1])) {
    348         // The first value of the pair is the count of elements in the bucket,
    349         // the second is the count of mismatches in the bucket.
    350         abs_value_buckets_[i].first++;
    351         if (is_mismatch) {
    352           abs_value_buckets_[i].second++;
    353         }
    354         return;
    355       }
    356     }
    357   }
    358 
    359   // Insert the given error into the given error bucket vector.
    360   void UpdateErrorBucket(float error, absl::Span<int64> error_buckets) {
    361     CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size());
    362     for (int i = 0; i < error_buckets.size(); ++i) {
    363       if (error >= kErrorBucketBounds[i]) {
    364         error_buckets[i]++;
    365       }
    366     }
    367   }
    368 
    369   // Compares the two given elements from the expected and actual literals at
    370   // the given literal_index and keeps track of various mismatch statistics.
    371   template <typename T>
    372   void CompareValues(T expected, T actual, int64 linear_index) {
    373     float abs_error;
    374     float rel_error;
    375     if (CompareEqual<T>(expected, actual, {linear_index})) {
    376       abs_error = 0;
    377       rel_error = 0;
    378     } else if (IsNan(expected) || IsNan(actual)) {
    379       if ((!error_.relaxed_nans && IsNan(expected) != IsNan(actual)) ||
    380           (error_.relaxed_nans && !IsNan(expected) && IsNan(actual))) {
    381         num_nan_mismatches_++;
    382         // A nan mismatch is considered to have infinite error. rel_error is
    383         // used for sorting a std::set of the top mismatchs, and a nan value
    384         // here will result in undefined behavior because nan's do not satisfy
    385         // the strict weak ordering requirement of std containers.
    386         abs_error = std::numeric_limits<float>::infinity();
    387         rel_error = std::numeric_limits<float>::infinity();
    388       } else {
    389         abs_error = 0;
    390         rel_error = 0;
    391       }
    392     } else if (IsInf(actual) && !IsInf(expected) && error_.fewer_infs_ok) {
    393       // `fewer_infs_ok` gives us the option of comparing as though `actual`
    394       // were float_max/min rather than inf.
    395       T actual_finite = actual > T{0} ? std::numeric_limits<T>::max()
    396                                       : std::numeric_limits<T>::lowest();
    397       abs_error = FpAbsoluteValue(actual_finite - expected);
    398 
    399       // Avoid division by 0 even though it's well-defined because ubsan can be
    400       // configured to treat this as a fatal error.
    401       if (expected != T{0}) {
    402         rel_error = abs_error / FpAbsoluteValue(expected);
    403       } else {
    404         rel_error = std::numeric_limits<float>::infinity();
    405       }
    406     } else if (IsInf(expected) || IsInf(actual)) {
    407       // If either the expected or actual value is infinity but not both,
    408       // then both absolute and relative error are regarded as inifity.
    409       CHECK(!CompareEqual(expected, actual, {linear_index}));
    410       abs_error = std::numeric_limits<float>::infinity();
    411       rel_error = std::numeric_limits<float>::infinity();
    412     } else {
    413       abs_error = FpAbsoluteValue(actual - expected);
    414 
    415       // Avoid division by 0 even though it's well-defined because ubsan can be
    416       // configured to treat this as a fatal error.
    417       if (expected != T{0}) {
    418         rel_error = abs_error / FpAbsoluteValue(expected);
    419       } else {
    420         rel_error = std::numeric_limits<float>::infinity();
    421       }
    422     }
    423     const bool is_abs_mismatch = abs_error > error_.abs;
    424     const bool is_rel_mismatch = rel_error > error_.rel;
    425     const bool is_mismatch = is_abs_mismatch && is_rel_mismatch;
    426 
    427     // Update the error of the relative bucket only if the *absolute* error
    428     // bound is exceeded and vice versa.
    429     if (is_abs_mismatch) {
    430       num_abs_mismatches_++;
    431       UpdateErrorBucket(rel_error, absl::MakeSpan(rel_error_buckets_));
    432     }
    433     if (is_rel_mismatch) {
    434       num_rel_mismatches_++;
    435       UpdateErrorBucket(abs_error, absl::MakeSpan(abs_error_buckets_));
    436     }
    437 
    438     UpdateAbsValueBucket(actual, is_mismatch);
    439 
    440     if (!is_mismatch) {
    441       return;
    442     }
    443 
    444     num_mismatches_++;
    445 
    446     // Keep track of the kTopRelativeErrorCount relative error mismatches.
    447     if (top_rel_mismatches_.size() < kTopRelativeErrorCount ||
    448         rel_error > top_rel_mismatches_.begin()->rel_error) {
    449       Mismatch mismatch = {actual, expected, rel_error, abs_error,
    450                            linear_index};
    451       top_rel_mismatches_.insert(mismatch);
    452       if (top_rel_mismatches_.size() > kTopRelativeErrorCount) {
    453         top_rel_mismatches_.erase(top_rel_mismatches_.begin());
    454       }
    455     }
    456 
    457     mismatches_.data<bool>()[linear_index] = true;
    458   }
    459 
    460   // For complex types, we compare real and imaginary parts individually.
    461   void CompareValues(complex64 expected, complex64 actual, int64 linear_index) {
    462     bool mismatch = false;
    463     CompareValues<float>(expected.real(), actual.real(), linear_index);
    464     if (mismatches_.data<bool>()[linear_index] == true) {
    465       mismatch = true;
    466       // Delay the mismatch count increase for real part, instead increase
    467       // mismatch by 1 for the entire complex number.
    468       num_mismatches_--;
    469     }
    470     CompareValues<float>(expected.imag(), actual.imag(), linear_index);
    471     if (mismatches_.data<bool>()[linear_index] == true) {
    472       mismatch = true;
    473       // Delay the mismatch count increase for imag part, instead increase
    474       // mismatch by 1 for the entire complex number.
    475       num_mismatches_--;
    476     }
    477     if (mismatch == true) {
    478       num_mismatches_++;
    479     }
    480     mismatches_.data<bool>()[linear_index] = mismatch;
    481   }
    482 
    483   void CompareValues(complex128 expected, complex128 actual,
    484                      int64 linear_index) {
    485     bool mismatch = false;
    486     CompareValues<double>(expected.real(), actual.real(), linear_index);
    487     if (mismatches_.data<bool>()[linear_index] == true) {
    488       mismatch = true;
    489       // Delay the mismatch count increase for real part, instead increase
    490       // mismatch by 1 for the entire complex number.
    491       num_mismatches_--;
    492     }
    493     CompareValues<double>(expected.imag(), actual.imag(), linear_index);
    494     if (mismatches_.data<bool>()[linear_index] == true) {
    495       mismatch = true;
    496       // Delay the mismatch count increase for imag part, instead increase
    497       // mismatch by 1 for the entire complex number.
    498       num_mismatches_--;
    499     }
    500     if (mismatch == true) {
    501       num_mismatches_++;
    502     }
    503     mismatches_.data<bool>()[linear_index] = mismatch;
    504   }
    505 
    506   // Compares the two literals elementwise.
    507   void CompareLiterals() {
    508     // Fast path optimization for the case were layouts match.
    509     if (LayoutUtil::Equal(actual_.shape().layout(),
    510                           expected_.shape().layout())) {
    511       absl::Span<const NativeT> expected_data = expected_.data<NativeT>();
    512       absl::Span<const NativeT> actual_data = actual_.data<NativeT>();
    513       const int64 len = expected_data.size();
    514       for (int64 i = 0; i < len; ++i) {
    515         CompareValues(expected_data[i], actual_data[i], i);
    516       }
    517       return;
    518     }
    519     std::vector<int64> multi_index(actual_.shape().rank(), 0);
    520     CompareLiteralsSlow(0, &multi_index);
    521   }
    522 
    523   // Slow path for CompareLiterals when 'actual' and 'expected' literals have
    524   // different layouts. In this case, multidimensional indices are constructed
    525   // and indexed for each element.
    526   void CompareLiteralsSlow(int64 dimension, std::vector<int64>* multi_index) {
    527     if (dimension == multi_index->size()) {
    528       CompareValues(expected_.Get<NativeT>(*multi_index),
    529                     actual_.Get<NativeT>(*multi_index),
    530                     IndexUtil::MultidimensionalIndexToLinearIndex(
    531                         actual_.shape(), *multi_index));
    532     } else {
    533       for (int64 i = 0; i < expected_.shape().dimensions(dimension); ++i) {
    534         (*multi_index)[dimension] = i;
    535         CompareLiteralsSlow(dimension + 1, multi_index);
    536       }
    537     }
    538   }
    539 
    540   // Returns an error message string with a detailed breakdown of the
    541   // mismatches. Called after calling Run().
    542   string ErrorMessage() {
    543     string out;
    544     int64 element_count = ShapeUtil::ElementsIn(actual_.shape());
    545 
    546     auto percent_string = [](float a, float b) {
    547       float pct = b == 0.0 ? 0.0 : 100.0 * a / b;
    548       return absl::StrFormat("%0.4f%%", pct);
    549     };
    550 
    551     StrAppendFormat(
    552         &out,
    553         "\nMismatch count %d (%s) in shape %s (%d elements), abs bound "
    554         "%g, rel bound %g\n",
    555         num_mismatches_, percent_string(num_mismatches_, element_count),
    556         ShapeUtil::HumanString(actual_.shape()),
    557         ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel);
    558     if (num_nan_mismatches_ > 0) {
    559       StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n");
    560     }
    561     StrAppendFormat(&out, "Top relative error mismatches:\n");
    562     for (auto it = top_rel_mismatches_.rbegin();
    563          it != top_rel_mismatches_.rend(); ++it) {
    564       StrAppend(&out, "  ", it->ToString(actual_.shape()), "\n");
    565     }
    566 
    567     if (!detailed_message_) {
    568       return out;
    569     }
    570 
    571     StrAppend(&out, "Absolute magnitude breakdown of actual values:\n");
    572     CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size());
    573     for (int i = 0; i < abs_value_buckets_.size(); ++i) {
    574       const int64 bucket_size = abs_value_buckets_[i].first;
    575       const int64 bucket_mismatches = abs_value_buckets_[i].second;
    576       string mismatch_str =
    577           bucket_mismatches > 0
    578               ? absl::StrFormat(", mismatches %d", bucket_mismatches)
    579               : "";
    580       StrAppendFormat(&out, "  %-6g <= x < %-6g : %7d (%9s)%s\n",
    581                       kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1],
    582                       bucket_size, percent_string(bucket_size, element_count),
    583                       mismatch_str);
    584     }
    585 
    586     auto print_accum_buckets = [&](const string& header, int64 total,
    587                                    absl::Span<const int64> buckets) {
    588       StrAppend(&out, header, ":\n");
    589       StrAppendFormat(&out, "  <  %-6g : %7d (%s)\n", kErrorBucketBounds[0],
    590                       total - buckets[0],
    591                       percent_string(total - buckets[0], total));
    592       CHECK_EQ(buckets.size(), kErrorBucketBounds.size());
    593       for (int i = 0; i < kErrorBucketBounds.size(); ++i) {
    594         StrAppendFormat(&out, "  >= %-6g : %7d (%s)\n", kErrorBucketBounds[i],
    595                         buckets[i], percent_string(buckets[i], total));
    596       }
    597     };
    598     StrAppendFormat(&out, "Elements exceeding abs error bound %g: %d (%s)\n",
    599                     error_.abs, num_abs_mismatches_,
    600                     percent_string(num_abs_mismatches_, element_count));
    601     print_accum_buckets(
    602         "Relative error breakdown of elements exceeding abs error bound",
    603         num_abs_mismatches_, rel_error_buckets_);
    604     StrAppendFormat(&out, "Elements exceeding rel error bound %g: %d (%s)\n",
    605                     error_.rel, num_rel_mismatches_,
    606                     percent_string(num_rel_mismatches_, element_count));
    607     print_accum_buckets(
    608         "Absolute error breakdown of elements exceeding rel error bound",
    609         num_rel_mismatches_, abs_error_buckets_);
    610     return out;
    611   }
    612 
    613   // 'actual' and 'expected' literals being compared.
    614   LiteralSlice expected_;
    615   LiteralSlice actual_;
    616 
    617   // The error bounds of the comparison.
    618   ErrorSpec error_;
    619 
    620   // Whether to include detailed breakdown of mismatches in the error message.
    621   bool detailed_message_;
    622 
    623   // Callback to invoke on miscompare.
    624   MiscompareCallback miscompare_callback_;
    625 
    626   // Number of element element mismatches encountered so far.
    627   int64 num_mismatches_ = 0;
    628 
    629   // Number of elements with a nan mismatch.
    630   int64 num_nan_mismatches_ = 0;
    631 
    632   // Number of elements which exceed the absolute/relative error bound.
    633   int64 num_abs_mismatches_ = 0;
    634   int64 num_rel_mismatches_ = 0;
    635 
    636   // A Literal containing which elements did not match in the expected and
    637   // actual literals. mismatches_ contains PREDs and is of the same sizes as
    638   // the comparison literals.
    639   Literal mismatches_;
    640 
    641   // The number of mismatches to report in the output, sorted by relative error
    642   // magnitude.
    643   static constexpr int64 kTopRelativeErrorCount = 5;
    644 
    645   // The set of mismatches with the largest relative error. The size of this set
    646   // is bounded by kTopRelativeErrorCount.
    647   std::multiset<Mismatch> top_rel_mismatches_;
    648 
    649   // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the
    650   // bounds of these buckets. abs_value_buckets_ contains a pair for each
    651   // bucket: the element count and failure count.
    652   static constexpr std::array<float, 7> kAbsValueBucketBounds = {
    653       0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits<float>::infinity()};
    654   std::vector<std::pair<int64, int64>> abs_value_buckets_;
    655 
    656   // Buckets for relative and absolute errors. The relative error buckets only
    657   // contains those elements which exceed the *absolute* error bound, and vice
    658   // versa. This makes it easy to see the effect of adjusting the relative (or
    659   // absolute) error bound on the success of the comparison. kErrorBucketBounds
    660   // are the lower bounds of the buckets in both vectors. The error buckets are
    661   // a cumulative distribution so an error value may appear in more than one
    662   // bucket. For example an error value of 0.003 may appear in the buckets
    663   // bounded by 0.01, 0.1, and 1.0.
    664   static constexpr std::array<float, 5> kErrorBucketBounds = {0.0001, 0.001,
    665                                                               0.01, 0.1, 1};
    666   std::vector<int64> abs_error_buckets_;
    667   std::vector<int64> rel_error_buckets_;
    668 };
    669 
    670 template <typename NativeT>
    671 constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds;
    672 template <typename NativeT>
    673 constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
    674 
    675 Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) {
    676   TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
    677   std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
    678   auto index = absl::MakeSpan(multi_index);
    679   Status result;
    680   switch (expected.shape().element_type()) {
    681     case PRED:
    682       result = Equal<bool>(expected, actual, index, 0);
    683       break;
    684     case U8:
    685       result = Equal<uint8>(expected, actual, index, 0);
    686       break;
    687     case S32:
    688       result = Equal<int32>(expected, actual, index, 0);
    689       break;
    690     case S64:
    691       result = Equal<int64>(expected, actual, index, 0);
    692       break;
    693     case U32:
    694       result = Equal<uint32>(expected, actual, index, 0);
    695       break;
    696     case U64:
    697       result = Equal<uint64>(expected, actual, index, 0);
    698       break;
    699     case BF16:
    700       result = Equal<bfloat16>(expected, actual, index, 0);
    701       break;
    702     case F16:
    703       result = Equal<half>(expected, actual, index, 0);
    704       break;
    705     case F32:
    706       result = Equal<float>(expected, actual, index, 0);
    707       break;
    708     case F64:
    709       result = Equal<double>(expected, actual, index, 0);
    710       break;
    711     case C64:
    712       result = Equal<complex64>(expected, actual, index, 0);
    713       break;
    714     case C128:
    715       result = Equal<complex128>(expected, actual, index, 0);
    716       break;
    717     case TUPLE: {
    718       for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
    719         result.Update(EqualHelper(LiteralSlice(expected, {i}),
    720                                   LiteralSlice(actual, {i})));
    721       }
    722       break;
    723     }
    724     case TOKEN:
    725       // Tokens have no on-device representation and are trivially equal.
    726       return Status::OK();
    727     default:
    728       LOG(FATAL) << "Unsupported primitive type: "
    729                  << PrimitiveType_Name(expected.shape().element_type());
    730   }
    731 
    732   return result;
    733 }
    734 
    735 // Helper function for comparing two literals for nearness. Handles tuple-shapes
    736 // via recursion. shape_index is the ShapeIndex of expected (or actual)
    737 // currently being compared.
    738 Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
    739                   const ErrorSpec& error, absl::optional<bool> detailed_message,
    740                   const MiscompareCallback& miscompare_callback,
    741                   const ShapeIndex& shape_index) {
    742   TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
    743 
    744   if (expected.shape().IsTuple()) {
    745     Status return_status;
    746     for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
    747       const auto expected_element = LiteralSlice(expected, {i});
    748       const auto actual_element = LiteralSlice(actual, {i});
    749       ShapeIndex element_index = shape_index;
    750       element_index.push_back(i);
    751       Status element_result =
    752           NearHelper(expected_element, actual_element, error, detailed_message,
    753                      miscompare_callback, element_index);
    754       if (!element_result.ok()) {
    755         element_result = InvalidArgument("Array at shape index %s, %s",
    756                                          element_index.ToString(),
    757                                          element_result.error_message());
    758         if (return_status.ok()) {
    759           return_status = element_result;
    760         } else {
    761           return_status =
    762               AppendStatus(return_status, element_result.error_message());
    763         }
    764       }
    765     }
    766     if (!return_status.ok() && shape_index.empty()) {
    767       // Emit a top-level error message containing the top-level shape in case
    768       // of mismatch.
    769       int64 total_elements = RecursiveElementCount(actual.shape());
    770       return_status =
    771           InvalidArgument("\nMismatches in shape %s (%d elements):\n%s",
    772                           ShapeUtil::HumanString(actual.shape()),
    773                           total_elements, return_status.error_message());
    774     }
    775     return return_status;
    776   }
    777 
    778   if (ShapeUtil::ElementIsFloating(expected.shape()) ||
    779       ShapeUtil::ElementIsComplex(expected.shape())) {
    780     bool use_detailed_message = detailed_message.value_or(
    781         ShapeUtil::ElementsIn(expected.shape()) >= 64);
    782     switch (expected.shape().element_type()) {
    783       case BF16:
    784         return NearComparator<bfloat16>::Compare(
    785             expected, actual, error, use_detailed_message, miscompare_callback);
    786         break;
    787       case F16:
    788         return NearComparator<half>::Compare(
    789             expected, actual, error, use_detailed_message, miscompare_callback);
    790         break;
    791       case F32:
    792         return NearComparator<float>::Compare(
    793             expected, actual, error, use_detailed_message, miscompare_callback);
    794         break;
    795       case F64:
    796         return NearComparator<double>::Compare(
    797             expected, actual, error, use_detailed_message, miscompare_callback);
    798         break;
    799       case C64:
    800         return NearComparator<complex64>::Compare(
    801             expected, actual, error, use_detailed_message, miscompare_callback);
    802         break;
    803       case C128:
    804         return NearComparator<complex128>::Compare(
    805             expected, actual, error, use_detailed_message, miscompare_callback);
    806         break;
    807       default:
    808         LOG(FATAL) << "Unsupported primitive type in near comparator: "
    809                    << PrimitiveType_Name(expected.shape().element_type())
    810                    << ". Must be floating-point type.";
    811     }
    812   }
    813 
    814   // Non-floating point, non-tuple literal.
    815   return EqualHelper(expected, actual);
    816 }
    817 
    818 }  // namespace
    819 
    820 Status EqualShapes(const Shape& expected, const Shape& actual) {
    821   if (expected.element_type() != actual.element_type()) {
    822     return InvalidArgument("element type mismatch, want: %s got %s",
    823                            ShapeUtil::HumanString(expected),
    824                            ShapeUtil::HumanString(actual));
    825   }
    826   if (expected.IsTuple()) {
    827     if (ShapeUtil::TupleElementCount(expected) !=
    828         ShapeUtil::TupleElementCount(actual)) {
    829       return InvalidArgument(
    830           "want tuple element count: %d got tuple element count: %d",
    831           ShapeUtil::TupleElementCount(expected),
    832           ShapeUtil::TupleElementCount(actual));
    833     }
    834     for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
    835       Status result =
    836           EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i));
    837       if (!result.ok()) {
    838         return AppendStatus(result, StrCat("mismatch in tuple index", i));
    839       }
    840     }
    841   } else if (expected.IsArray()) {
    842     if (expected.rank() != actual.rank()) {
    843       return InvalidArgument("want rank of %s got rank of %s",
    844                              ShapeUtil::HumanString(expected),
    845                              ShapeUtil::HumanString(actual));
    846     }
    847     if (expected.element_type() != actual.element_type()) {
    848       return InvalidArgument("mismatch in primitive type %s vs %s",
    849                              PrimitiveType_Name(expected.element_type()),
    850                              PrimitiveType_Name(actual.element_type()));
    851     }
    852     if (expected.dimensions_size() != actual.dimensions_size()) {
    853       return InvalidArgument("want dimensions_size %d got dimensions_size %d",
    854                              expected.dimensions_size(),
    855                              actual.dimensions_size());
    856     }
    857     for (int i = 0; i < expected.dimensions_size(); ++i) {
    858       if (expected.dimensions(i) != actual.dimensions(i)) {
    859         return InvalidArgument(
    860             "mismatch in dimension #%d expected: %s actual: %s", i,
    861             ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual));
    862       }
    863     }
    864   }
    865   // Non-array, non-tuple shapes are trivially equivalent.
    866   return Status::OK();
    867 }
    868 
    869 namespace {
    870 
    871 // If result is an error, extend the error message with the expected and actual
    872 // literals.
    873 Status EmitLiteralsInErrorMessage(const Status& result,
    874                                   const LiteralSlice& expected,
    875                                   const LiteralSlice& actual) {
    876   if (result.ok()) {
    877     return result;
    878   }
    879   return InvalidArgument("%s\n\nExpected literal:\n%s\n\nActual literal:\n%s",
    880                          result.error_message(), ToStringTruncated(expected),
    881                          ToStringTruncated(actual));
    882 }
    883 
    884 }  // namespace
    885 
    886 Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
    887   VLOG(1) << "expected:";
    888   XLA_VLOG_LINES(1, expected.ToString());
    889   VLOG(1) << "actual:";
    890   XLA_VLOG_LINES(1, actual.ToString());
    891   Status result = EqualHelper(expected, actual);
    892   return EmitLiteralsInErrorMessage(result, expected, actual);
    893 }
    894 
    895 Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
    896             const ErrorSpec& error, absl::optional<bool> detailed_message,
    897             const MiscompareCallback& miscompare_callback) {
    898   VLOG(1) << "Expected literal:";
    899   XLA_VLOG_LINES(1, expected.ToString());
    900   VLOG(1) << "Actual literal:";
    901   XLA_VLOG_LINES(1, actual.ToString());
    902   Status result =
    903       NearHelper(expected, actual, error, detailed_message, miscompare_callback,
    904                  /*shape_index=*/{});
    905   return EmitLiteralsInErrorMessage(result, expected, actual);
    906 }
    907 
    908 string ToStringTruncated(const LiteralSlice& literal) {
    909   return RecursiveElementCount(literal.shape()) < 1000
    910              ? literal.ToString()
    911              : "[TRUNCATED, Literal with more than 1000 values]";
    912 }
    913 
    914 }  // namespace literal_comparison
    915 }  // namespace xla
    916