1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/literal_comparison.h" 17 18 #include <unistd.h> 19 #include <cmath> 20 #include <vector> 21 22 #include "absl/base/casts.h" 23 #include "absl/strings/str_cat.h" 24 #include "absl/strings/str_format.h" 25 #include "tensorflow/compiler/xla/literal_util.h" 26 #include "tensorflow/compiler/xla/util.h" 27 #include "tensorflow/core/platform/env.h" 28 29 using absl::StrAppend; 30 using absl::StrAppendFormat; 31 using absl::StrCat; 32 33 namespace xla { 34 namespace literal_comparison { 35 namespace { 36 37 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be 38 // able to transparently access the raw 16-bit value contained within. 39 template <typename T> 40 T GetRawValue(T val) { 41 return val; 42 } 43 uint16 GetRawValue(Eigen::half val) { return val.x; } 44 45 // Helper function for comparing a floating point type, FloatT, bitwise equal 46 // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT 47 // -- on miscompare, a nice error message is given in the AssertionFailure. 48 template <typename FloatT, typename UnsignedT> 49 bool CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, 50 absl::Span<const int64> multi_index) { 51 auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs)); 52 auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs)); 53 return ulhs == urhs; 54 } 55 56 // Templated comparator that specializes for float equality comparison with the 57 // bitwise helper above (this is the un-specialized fallback, to just use the 58 // default gunit implementation). 59 template <typename NativeT> 60 bool CompareEqual(NativeT lhs, NativeT rhs, 61 absl::Span<const int64> multi_index) { 62 return lhs == rhs; 63 } 64 65 // Specializations for floating types that do bitwise comparisons when equality 66 // comparison is requested. 67 template <> 68 bool CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs, 69 absl::Span<const int64> multi_index) { 70 return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs, multi_index); 71 } 72 template <> 73 bool CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs, 74 absl::Span<const int64> multi_index) { 75 return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs, multi_index); 76 } 77 template <> 78 bool CompareEqual<float>(float lhs, float rhs, 79 absl::Span<const int64> multi_index) { 80 return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs, multi_index); 81 } 82 template <> 83 bool CompareEqual<double>(double lhs, double rhs, 84 absl::Span<const int64> multi_index) { 85 return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs, multi_index); 86 } 87 template <> 88 bool CompareEqual<complex64>(complex64 lhs, complex64 rhs, 89 absl::Span<const int64> multi_index) { 90 return CompareEqual<float>(lhs.real(), rhs.real(), multi_index) && 91 CompareEqual<float>(lhs.imag(), rhs.imag(), multi_index); 92 } 93 template <> 94 bool CompareEqual<complex128>(complex128 lhs, complex128 rhs, 95 absl::Span<const int64> multi_index) { 96 return CompareEqual<double>(lhs.real(), rhs.real(), multi_index) && 97 CompareEqual<double>(lhs.imag(), rhs.imag(), multi_index); 98 } 99 100 template <typename NativeT, typename UnsignedT> 101 Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs, 102 absl::Span<const int64> multi_index) { 103 auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs)); 104 auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs)); 105 auto lhs_double = static_cast<double>(lhs); 106 auto rhs_double = static_cast<double>(rhs); 107 return InvalidArgument( 108 "floating values are not bitwise-equal; and equality testing " 109 "was requested: %s=%g=%a vs %s=%g=%a at array index %s", 110 StrCat(absl::Hex(ulhs)), lhs_double, lhs_double, 111 StrCat(absl::Hex(urhs)), rhs_double, rhs_double, 112 LiteralUtil::MultiIndexAsString(multi_index)); 113 } 114 115 template <typename NativeT> 116 Status MakeErrorStatus(NativeT lhs, NativeT rhs, 117 absl::Span<const int64> multi_index) { 118 return InvalidArgument( 119 "first mismatch at array index %s:\n expected value: %s\n actual " 120 "value: %s", 121 LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs)); 122 } 123 124 template <> 125 Status MakeErrorStatus(bfloat16 lhs, bfloat16 rhs, 126 absl::Span<const int64> multi_index) { 127 return MakeBitwiseErrorStatus<bfloat16, uint16>(lhs, rhs, multi_index); 128 } 129 template <> 130 Status MakeErrorStatus(Eigen::half lhs, Eigen::half rhs, 131 absl::Span<const int64> multi_index) { 132 return MakeBitwiseErrorStatus<Eigen::half, uint16>(lhs, rhs, multi_index); 133 } 134 template <> 135 Status MakeErrorStatus(float lhs, float rhs, 136 absl::Span<const int64> multi_index) { 137 return MakeBitwiseErrorStatus<float, uint32>(lhs, rhs, multi_index); 138 } 139 template <> 140 Status MakeErrorStatus(double lhs, double rhs, 141 absl::Span<const int64> multi_index) { 142 return MakeBitwiseErrorStatus<double, uint64>(lhs, rhs, multi_index); 143 } 144 template <> 145 Status MakeErrorStatus(complex64 lhs, complex64 rhs, 146 absl::Span<const int64> multi_index) { 147 if (!CompareEqual<float>(lhs.real(), rhs.real(), multi_index)) { 148 return MakeErrorStatus(lhs.real(), rhs.real(), multi_index); 149 } 150 return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index); 151 } 152 template <> 153 Status MakeErrorStatus(complex128 lhs, complex128 rhs, 154 absl::Span<const int64> multi_index) { 155 if (!CompareEqual<double>(lhs.real(), rhs.real(), multi_index)) { 156 return MakeErrorStatus(lhs.real(), rhs.real(), multi_index); 157 } 158 return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index); 159 } 160 161 // A recursive function which iterates through every index of expected and 162 // actual literal and compares their values elementwise. Returns true if all 163 // elements are equal. 164 template <typename NativeT> 165 Status Equal(LiteralSlice expected, LiteralSlice actual, 166 absl::Span<int64> multi_index, int64 dimension) { 167 if (dimension == expected.shape().dimensions_size()) { 168 NativeT expected_value = expected.Get<NativeT>(multi_index); 169 NativeT actual_value = actual.Get<NativeT>(multi_index); 170 bool result = 171 CompareEqual<NativeT>(expected_value, actual_value, multi_index); 172 return result ? Status::OK() 173 : MakeErrorStatus<NativeT>(expected_value, actual_value, 174 multi_index); 175 } 176 177 Status result; 178 for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { 179 multi_index[dimension] = i; 180 TF_RETURN_IF_ERROR( 181 Equal<NativeT>(expected, actual, multi_index, dimension + 1)); 182 } 183 return result; 184 } 185 186 // Gets the total element count. For tuples, this is not the count of tuple 187 // elements, but the sum of elements of each tuple element. 188 int64 RecursiveElementCount(const Shape& shape) { 189 if (shape.IsTuple()) { 190 const int64 tuple_elements = ShapeUtil::TupleElementCount(shape); 191 int64 total = 0; 192 for (int64 i = 0; i < tuple_elements; ++i) { 193 total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); 194 } 195 return total; 196 } else if (shape.IsArray()) { 197 return ShapeUtil::ElementsIn(shape); 198 } else { 199 return 0; 200 } 201 } 202 203 // Returns whether the given value is infinity. 204 template <typename NativeT> 205 bool IsInf(NativeT val) { 206 return std::isinf(val); 207 } 208 209 template <> 210 bool IsInf<half>(half val) { 211 return std::isinf(static_cast<float>(val)); 212 } 213 214 // Returns whether the given value is nan. 215 template <typename NativeT> 216 float IsNan(NativeT value) { 217 return std::isnan(value); 218 } 219 220 template <> 221 float IsNan(half value) { 222 return IsNan<float>(static_cast<float>(value)); 223 } 224 225 // Converts the given floating-point value to a string. 226 template <typename NativeT> 227 string FpValueToString(NativeT value) { 228 return absl::StrFormat("%8.4g", static_cast<double>(value)); 229 } 230 231 template <> 232 string FpValueToString<complex64>(complex64 value) { 233 return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag()); 234 } 235 236 template <> 237 string FpValueToString<complex128>(complex128 value) { 238 return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag()); 239 } 240 241 // Returns the absolute value of the given floating point value. This function 242 // is used instead of std::abs directly in order to allow type-dependent 243 // implementations for NearComparator. 244 template <typename NativeT> 245 float FpAbsoluteValue(NativeT value) { 246 return std::abs(value); 247 } 248 249 template <> 250 float FpAbsoluteValue(bfloat16 value) { 251 return FpAbsoluteValue<float>(static_cast<float>(value)); 252 } 253 254 template <> 255 float FpAbsoluteValue(half value) { 256 return FpAbsoluteValue<float>(static_cast<float>(value)); 257 } 258 259 // Helper class for comparing floating-point literals within an error bound. 260 template <typename NativeT> 261 class NearComparator { 262 public: 263 // Compares the two array literals elementwise and returns a comparison 264 // result. The comparison is ok() if all actual and expected elements are 265 // within the given error bound. In case of error, the status contains a 266 // detailed message about the discrepancy. 267 static Status Compare(const LiteralSlice& expected, 268 const LiteralSlice& actual, ErrorSpec error, 269 bool detailed_message, 270 const MiscompareCallback& miscompare_callback) { 271 NearComparator<NativeT> comparator(expected, actual, error, 272 detailed_message, miscompare_callback); 273 return comparator.Run(); 274 } 275 276 private: 277 // Data structure encapsulating metadata about a single element mismatch. 278 struct Mismatch { 279 NativeT actual; 280 NativeT expected; 281 float rel_error; 282 float abs_error; 283 284 // The linear index of the failure within the shape. This linear index is 285 // from the 'actual' literal. 286 int64 linear_index; 287 288 bool operator<(const Mismatch& other) const { 289 return rel_error < other.rel_error; 290 } 291 292 string ToString(const Shape& shape) const { 293 return absl::StrFormat( 294 "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g", 295 FpValueToString(actual), FpValueToString(expected), 296 LiteralUtil::MultiIndexAsString( 297 IndexUtil::LinearIndexToMultidimensionalIndex(shape, 298 linear_index)), 299 rel_error, abs_error); 300 } 301 }; 302 303 NearComparator(const LiteralSlice& expected, const LiteralSlice& actual, 304 ErrorSpec error, bool detailed_message, 305 const MiscompareCallback& miscompare_callback) 306 : expected_(expected), 307 actual_(actual), 308 error_(error), 309 detailed_message_(detailed_message), 310 miscompare_callback_(miscompare_callback), 311 abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}), 312 abs_error_buckets_(kErrorBucketBounds.size(), 0), 313 rel_error_buckets_(kErrorBucketBounds.size(), 0) {} 314 315 // Runs the comparison between expected and actual literals. 316 Status Run() { 317 // If the shapes mismatch, we simply fail the expectation instead of 318 // printing out data, as it's a type error rather than a value error. 319 TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape())); 320 if (!expected_.shape().IsArray()) { 321 return InvalidArgument("Expected array shape; got %s.", 322 ShapeUtil::HumanString(expected_.shape())); 323 } 324 325 mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED)); 326 mismatches_.PopulateWithValue(false); 327 328 CompareLiterals(); 329 330 if (num_mismatches_ == 0) { 331 return Status::OK(); 332 } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) { 333 miscompare_callback_(expected_, actual_, mismatches_); 334 } 335 return InvalidArgument("%s", ErrorMessage()); 336 } 337 338 // Insert the given absolute value into the absolute value bucket vector. The 339 // bounds of the buckets are given by kAbsValueBucketBounds. 340 void UpdateAbsValueBucket(NativeT value, bool is_mismatch) { 341 // Adjust the bucket containing the absolute values of the 'actual' 342 // elements. 343 const float abs_value = FpAbsoluteValue(value); 344 for (int i = 0; i < abs_value_buckets_.size(); ++i) { 345 if (i == abs_value_buckets_.size() - 1 || 346 (abs_value >= kAbsValueBucketBounds[i] && 347 abs_value < kAbsValueBucketBounds[i + 1])) { 348 // The first value of the pair is the count of elements in the bucket, 349 // the second is the count of mismatches in the bucket. 350 abs_value_buckets_[i].first++; 351 if (is_mismatch) { 352 abs_value_buckets_[i].second++; 353 } 354 return; 355 } 356 } 357 } 358 359 // Insert the given error into the given error bucket vector. 360 void UpdateErrorBucket(float error, absl::Span<int64> error_buckets) { 361 CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size()); 362 for (int i = 0; i < error_buckets.size(); ++i) { 363 if (error >= kErrorBucketBounds[i]) { 364 error_buckets[i]++; 365 } 366 } 367 } 368 369 // Compares the two given elements from the expected and actual literals at 370 // the given literal_index and keeps track of various mismatch statistics. 371 template <typename T> 372 void CompareValues(T expected, T actual, int64 linear_index) { 373 float abs_error; 374 float rel_error; 375 if (CompareEqual<T>(expected, actual, {linear_index})) { 376 abs_error = 0; 377 rel_error = 0; 378 } else if (IsNan(expected) || IsNan(actual)) { 379 if ((!error_.relaxed_nans && IsNan(expected) != IsNan(actual)) || 380 (error_.relaxed_nans && !IsNan(expected) && IsNan(actual))) { 381 num_nan_mismatches_++; 382 // A nan mismatch is considered to have infinite error. rel_error is 383 // used for sorting a std::set of the top mismatchs, and a nan value 384 // here will result in undefined behavior because nan's do not satisfy 385 // the strict weak ordering requirement of std containers. 386 abs_error = std::numeric_limits<float>::infinity(); 387 rel_error = std::numeric_limits<float>::infinity(); 388 } else { 389 abs_error = 0; 390 rel_error = 0; 391 } 392 } else if (IsInf(actual) && !IsInf(expected) && error_.fewer_infs_ok) { 393 // `fewer_infs_ok` gives us the option of comparing as though `actual` 394 // were float_max/min rather than inf. 395 T actual_finite = actual > T{0} ? std::numeric_limits<T>::max() 396 : std::numeric_limits<T>::lowest(); 397 abs_error = FpAbsoluteValue(actual_finite - expected); 398 399 // Avoid division by 0 even though it's well-defined because ubsan can be 400 // configured to treat this as a fatal error. 401 if (expected != T{0}) { 402 rel_error = abs_error / FpAbsoluteValue(expected); 403 } else { 404 rel_error = std::numeric_limits<float>::infinity(); 405 } 406 } else if (IsInf(expected) || IsInf(actual)) { 407 // If either the expected or actual value is infinity but not both, 408 // then both absolute and relative error are regarded as inifity. 409 CHECK(!CompareEqual(expected, actual, {linear_index})); 410 abs_error = std::numeric_limits<float>::infinity(); 411 rel_error = std::numeric_limits<float>::infinity(); 412 } else { 413 abs_error = FpAbsoluteValue(actual - expected); 414 415 // Avoid division by 0 even though it's well-defined because ubsan can be 416 // configured to treat this as a fatal error. 417 if (expected != T{0}) { 418 rel_error = abs_error / FpAbsoluteValue(expected); 419 } else { 420 rel_error = std::numeric_limits<float>::infinity(); 421 } 422 } 423 const bool is_abs_mismatch = abs_error > error_.abs; 424 const bool is_rel_mismatch = rel_error > error_.rel; 425 const bool is_mismatch = is_abs_mismatch && is_rel_mismatch; 426 427 // Update the error of the relative bucket only if the *absolute* error 428 // bound is exceeded and vice versa. 429 if (is_abs_mismatch) { 430 num_abs_mismatches_++; 431 UpdateErrorBucket(rel_error, absl::MakeSpan(rel_error_buckets_)); 432 } 433 if (is_rel_mismatch) { 434 num_rel_mismatches_++; 435 UpdateErrorBucket(abs_error, absl::MakeSpan(abs_error_buckets_)); 436 } 437 438 UpdateAbsValueBucket(actual, is_mismatch); 439 440 if (!is_mismatch) { 441 return; 442 } 443 444 num_mismatches_++; 445 446 // Keep track of the kTopRelativeErrorCount relative error mismatches. 447 if (top_rel_mismatches_.size() < kTopRelativeErrorCount || 448 rel_error > top_rel_mismatches_.begin()->rel_error) { 449 Mismatch mismatch = {actual, expected, rel_error, abs_error, 450 linear_index}; 451 top_rel_mismatches_.insert(mismatch); 452 if (top_rel_mismatches_.size() > kTopRelativeErrorCount) { 453 top_rel_mismatches_.erase(top_rel_mismatches_.begin()); 454 } 455 } 456 457 mismatches_.data<bool>()[linear_index] = true; 458 } 459 460 // For complex types, we compare real and imaginary parts individually. 461 void CompareValues(complex64 expected, complex64 actual, int64 linear_index) { 462 bool mismatch = false; 463 CompareValues<float>(expected.real(), actual.real(), linear_index); 464 if (mismatches_.data<bool>()[linear_index] == true) { 465 mismatch = true; 466 // Delay the mismatch count increase for real part, instead increase 467 // mismatch by 1 for the entire complex number. 468 num_mismatches_--; 469 } 470 CompareValues<float>(expected.imag(), actual.imag(), linear_index); 471 if (mismatches_.data<bool>()[linear_index] == true) { 472 mismatch = true; 473 // Delay the mismatch count increase for imag part, instead increase 474 // mismatch by 1 for the entire complex number. 475 num_mismatches_--; 476 } 477 if (mismatch == true) { 478 num_mismatches_++; 479 } 480 mismatches_.data<bool>()[linear_index] = mismatch; 481 } 482 483 void CompareValues(complex128 expected, complex128 actual, 484 int64 linear_index) { 485 bool mismatch = false; 486 CompareValues<double>(expected.real(), actual.real(), linear_index); 487 if (mismatches_.data<bool>()[linear_index] == true) { 488 mismatch = true; 489 // Delay the mismatch count increase for real part, instead increase 490 // mismatch by 1 for the entire complex number. 491 num_mismatches_--; 492 } 493 CompareValues<double>(expected.imag(), actual.imag(), linear_index); 494 if (mismatches_.data<bool>()[linear_index] == true) { 495 mismatch = true; 496 // Delay the mismatch count increase for imag part, instead increase 497 // mismatch by 1 for the entire complex number. 498 num_mismatches_--; 499 } 500 if (mismatch == true) { 501 num_mismatches_++; 502 } 503 mismatches_.data<bool>()[linear_index] = mismatch; 504 } 505 506 // Compares the two literals elementwise. 507 void CompareLiterals() { 508 // Fast path optimization for the case were layouts match. 509 if (LayoutUtil::Equal(actual_.shape().layout(), 510 expected_.shape().layout())) { 511 absl::Span<const NativeT> expected_data = expected_.data<NativeT>(); 512 absl::Span<const NativeT> actual_data = actual_.data<NativeT>(); 513 const int64 len = expected_data.size(); 514 for (int64 i = 0; i < len; ++i) { 515 CompareValues(expected_data[i], actual_data[i], i); 516 } 517 return; 518 } 519 std::vector<int64> multi_index(actual_.shape().rank(), 0); 520 CompareLiteralsSlow(0, &multi_index); 521 } 522 523 // Slow path for CompareLiterals when 'actual' and 'expected' literals have 524 // different layouts. In this case, multidimensional indices are constructed 525 // and indexed for each element. 526 void CompareLiteralsSlow(int64 dimension, std::vector<int64>* multi_index) { 527 if (dimension == multi_index->size()) { 528 CompareValues(expected_.Get<NativeT>(*multi_index), 529 actual_.Get<NativeT>(*multi_index), 530 IndexUtil::MultidimensionalIndexToLinearIndex( 531 actual_.shape(), *multi_index)); 532 } else { 533 for (int64 i = 0; i < expected_.shape().dimensions(dimension); ++i) { 534 (*multi_index)[dimension] = i; 535 CompareLiteralsSlow(dimension + 1, multi_index); 536 } 537 } 538 } 539 540 // Returns an error message string with a detailed breakdown of the 541 // mismatches. Called after calling Run(). 542 string ErrorMessage() { 543 string out; 544 int64 element_count = ShapeUtil::ElementsIn(actual_.shape()); 545 546 auto percent_string = [](float a, float b) { 547 float pct = b == 0.0 ? 0.0 : 100.0 * a / b; 548 return absl::StrFormat("%0.4f%%", pct); 549 }; 550 551 StrAppendFormat( 552 &out, 553 "\nMismatch count %d (%s) in shape %s (%d elements), abs bound " 554 "%g, rel bound %g\n", 555 num_mismatches_, percent_string(num_mismatches_, element_count), 556 ShapeUtil::HumanString(actual_.shape()), 557 ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel); 558 if (num_nan_mismatches_ > 0) { 559 StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n"); 560 } 561 StrAppendFormat(&out, "Top relative error mismatches:\n"); 562 for (auto it = top_rel_mismatches_.rbegin(); 563 it != top_rel_mismatches_.rend(); ++it) { 564 StrAppend(&out, " ", it->ToString(actual_.shape()), "\n"); 565 } 566 567 if (!detailed_message_) { 568 return out; 569 } 570 571 StrAppend(&out, "Absolute magnitude breakdown of actual values:\n"); 572 CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size()); 573 for (int i = 0; i < abs_value_buckets_.size(); ++i) { 574 const int64 bucket_size = abs_value_buckets_[i].first; 575 const int64 bucket_mismatches = abs_value_buckets_[i].second; 576 string mismatch_str = 577 bucket_mismatches > 0 578 ? absl::StrFormat(", mismatches %d", bucket_mismatches) 579 : ""; 580 StrAppendFormat(&out, " %-6g <= x < %-6g : %7d (%9s)%s\n", 581 kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1], 582 bucket_size, percent_string(bucket_size, element_count), 583 mismatch_str); 584 } 585 586 auto print_accum_buckets = [&](const string& header, int64 total, 587 absl::Span<const int64> buckets) { 588 StrAppend(&out, header, ":\n"); 589 StrAppendFormat(&out, " < %-6g : %7d (%s)\n", kErrorBucketBounds[0], 590 total - buckets[0], 591 percent_string(total - buckets[0], total)); 592 CHECK_EQ(buckets.size(), kErrorBucketBounds.size()); 593 for (int i = 0; i < kErrorBucketBounds.size(); ++i) { 594 StrAppendFormat(&out, " >= %-6g : %7d (%s)\n", kErrorBucketBounds[i], 595 buckets[i], percent_string(buckets[i], total)); 596 } 597 }; 598 StrAppendFormat(&out, "Elements exceeding abs error bound %g: %d (%s)\n", 599 error_.abs, num_abs_mismatches_, 600 percent_string(num_abs_mismatches_, element_count)); 601 print_accum_buckets( 602 "Relative error breakdown of elements exceeding abs error bound", 603 num_abs_mismatches_, rel_error_buckets_); 604 StrAppendFormat(&out, "Elements exceeding rel error bound %g: %d (%s)\n", 605 error_.rel, num_rel_mismatches_, 606 percent_string(num_rel_mismatches_, element_count)); 607 print_accum_buckets( 608 "Absolute error breakdown of elements exceeding rel error bound", 609 num_rel_mismatches_, abs_error_buckets_); 610 return out; 611 } 612 613 // 'actual' and 'expected' literals being compared. 614 LiteralSlice expected_; 615 LiteralSlice actual_; 616 617 // The error bounds of the comparison. 618 ErrorSpec error_; 619 620 // Whether to include detailed breakdown of mismatches in the error message. 621 bool detailed_message_; 622 623 // Callback to invoke on miscompare. 624 MiscompareCallback miscompare_callback_; 625 626 // Number of element element mismatches encountered so far. 627 int64 num_mismatches_ = 0; 628 629 // Number of elements with a nan mismatch. 630 int64 num_nan_mismatches_ = 0; 631 632 // Number of elements which exceed the absolute/relative error bound. 633 int64 num_abs_mismatches_ = 0; 634 int64 num_rel_mismatches_ = 0; 635 636 // A Literal containing which elements did not match in the expected and 637 // actual literals. mismatches_ contains PREDs and is of the same sizes as 638 // the comparison literals. 639 Literal mismatches_; 640 641 // The number of mismatches to report in the output, sorted by relative error 642 // magnitude. 643 static constexpr int64 kTopRelativeErrorCount = 5; 644 645 // The set of mismatches with the largest relative error. The size of this set 646 // is bounded by kTopRelativeErrorCount. 647 std::multiset<Mismatch> top_rel_mismatches_; 648 649 // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the 650 // bounds of these buckets. abs_value_buckets_ contains a pair for each 651 // bucket: the element count and failure count. 652 static constexpr std::array<float, 7> kAbsValueBucketBounds = { 653 0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits<float>::infinity()}; 654 std::vector<std::pair<int64, int64>> abs_value_buckets_; 655 656 // Buckets for relative and absolute errors. The relative error buckets only 657 // contains those elements which exceed the *absolute* error bound, and vice 658 // versa. This makes it easy to see the effect of adjusting the relative (or 659 // absolute) error bound on the success of the comparison. kErrorBucketBounds 660 // are the lower bounds of the buckets in both vectors. The error buckets are 661 // a cumulative distribution so an error value may appear in more than one 662 // bucket. For example an error value of 0.003 may appear in the buckets 663 // bounded by 0.01, 0.1, and 1.0. 664 static constexpr std::array<float, 5> kErrorBucketBounds = {0.0001, 0.001, 665 0.01, 0.1, 1}; 666 std::vector<int64> abs_error_buckets_; 667 std::vector<int64> rel_error_buckets_; 668 }; 669 670 template <typename NativeT> 671 constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds; 672 template <typename NativeT> 673 constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds; 674 675 Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { 676 TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); 677 std::vector<int64> multi_index(expected.shape().dimensions_size(), 0); 678 auto index = absl::MakeSpan(multi_index); 679 Status result; 680 switch (expected.shape().element_type()) { 681 case PRED: 682 result = Equal<bool>(expected, actual, index, 0); 683 break; 684 case U8: 685 result = Equal<uint8>(expected, actual, index, 0); 686 break; 687 case S32: 688 result = Equal<int32>(expected, actual, index, 0); 689 break; 690 case S64: 691 result = Equal<int64>(expected, actual, index, 0); 692 break; 693 case U32: 694 result = Equal<uint32>(expected, actual, index, 0); 695 break; 696 case U64: 697 result = Equal<uint64>(expected, actual, index, 0); 698 break; 699 case BF16: 700 result = Equal<bfloat16>(expected, actual, index, 0); 701 break; 702 case F16: 703 result = Equal<half>(expected, actual, index, 0); 704 break; 705 case F32: 706 result = Equal<float>(expected, actual, index, 0); 707 break; 708 case F64: 709 result = Equal<double>(expected, actual, index, 0); 710 break; 711 case C64: 712 result = Equal<complex64>(expected, actual, index, 0); 713 break; 714 case C128: 715 result = Equal<complex128>(expected, actual, index, 0); 716 break; 717 case TUPLE: { 718 for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { 719 result.Update(EqualHelper(LiteralSlice(expected, {i}), 720 LiteralSlice(actual, {i}))); 721 } 722 break; 723 } 724 case TOKEN: 725 // Tokens have no on-device representation and are trivially equal. 726 return Status::OK(); 727 default: 728 LOG(FATAL) << "Unsupported primitive type: " 729 << PrimitiveType_Name(expected.shape().element_type()); 730 } 731 732 return result; 733 } 734 735 // Helper function for comparing two literals for nearness. Handles tuple-shapes 736 // via recursion. shape_index is the ShapeIndex of expected (or actual) 737 // currently being compared. 738 Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, 739 const ErrorSpec& error, absl::optional<bool> detailed_message, 740 const MiscompareCallback& miscompare_callback, 741 const ShapeIndex& shape_index) { 742 TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); 743 744 if (expected.shape().IsTuple()) { 745 Status return_status; 746 for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { 747 const auto expected_element = LiteralSlice(expected, {i}); 748 const auto actual_element = LiteralSlice(actual, {i}); 749 ShapeIndex element_index = shape_index; 750 element_index.push_back(i); 751 Status element_result = 752 NearHelper(expected_element, actual_element, error, detailed_message, 753 miscompare_callback, element_index); 754 if (!element_result.ok()) { 755 element_result = InvalidArgument("Array at shape index %s, %s", 756 element_index.ToString(), 757 element_result.error_message()); 758 if (return_status.ok()) { 759 return_status = element_result; 760 } else { 761 return_status = 762 AppendStatus(return_status, element_result.error_message()); 763 } 764 } 765 } 766 if (!return_status.ok() && shape_index.empty()) { 767 // Emit a top-level error message containing the top-level shape in case 768 // of mismatch. 769 int64 total_elements = RecursiveElementCount(actual.shape()); 770 return_status = 771 InvalidArgument("\nMismatches in shape %s (%d elements):\n%s", 772 ShapeUtil::HumanString(actual.shape()), 773 total_elements, return_status.error_message()); 774 } 775 return return_status; 776 } 777 778 if (ShapeUtil::ElementIsFloating(expected.shape()) || 779 ShapeUtil::ElementIsComplex(expected.shape())) { 780 bool use_detailed_message = detailed_message.value_or( 781 ShapeUtil::ElementsIn(expected.shape()) >= 64); 782 switch (expected.shape().element_type()) { 783 case BF16: 784 return NearComparator<bfloat16>::Compare( 785 expected, actual, error, use_detailed_message, miscompare_callback); 786 break; 787 case F16: 788 return NearComparator<half>::Compare( 789 expected, actual, error, use_detailed_message, miscompare_callback); 790 break; 791 case F32: 792 return NearComparator<float>::Compare( 793 expected, actual, error, use_detailed_message, miscompare_callback); 794 break; 795 case F64: 796 return NearComparator<double>::Compare( 797 expected, actual, error, use_detailed_message, miscompare_callback); 798 break; 799 case C64: 800 return NearComparator<complex64>::Compare( 801 expected, actual, error, use_detailed_message, miscompare_callback); 802 break; 803 case C128: 804 return NearComparator<complex128>::Compare( 805 expected, actual, error, use_detailed_message, miscompare_callback); 806 break; 807 default: 808 LOG(FATAL) << "Unsupported primitive type in near comparator: " 809 << PrimitiveType_Name(expected.shape().element_type()) 810 << ". Must be floating-point type."; 811 } 812 } 813 814 // Non-floating point, non-tuple literal. 815 return EqualHelper(expected, actual); 816 } 817 818 } // namespace 819 820 Status EqualShapes(const Shape& expected, const Shape& actual) { 821 if (expected.element_type() != actual.element_type()) { 822 return InvalidArgument("element type mismatch, want: %s got %s", 823 ShapeUtil::HumanString(expected), 824 ShapeUtil::HumanString(actual)); 825 } 826 if (expected.IsTuple()) { 827 if (ShapeUtil::TupleElementCount(expected) != 828 ShapeUtil::TupleElementCount(actual)) { 829 return InvalidArgument( 830 "want tuple element count: %d got tuple element count: %d", 831 ShapeUtil::TupleElementCount(expected), 832 ShapeUtil::TupleElementCount(actual)); 833 } 834 for (int i = 0; i < expected.tuple_shapes_size(); ++i) { 835 Status result = 836 EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); 837 if (!result.ok()) { 838 return AppendStatus(result, StrCat("mismatch in tuple index", i)); 839 } 840 } 841 } else if (expected.IsArray()) { 842 if (expected.rank() != actual.rank()) { 843 return InvalidArgument("want rank of %s got rank of %s", 844 ShapeUtil::HumanString(expected), 845 ShapeUtil::HumanString(actual)); 846 } 847 if (expected.element_type() != actual.element_type()) { 848 return InvalidArgument("mismatch in primitive type %s vs %s", 849 PrimitiveType_Name(expected.element_type()), 850 PrimitiveType_Name(actual.element_type())); 851 } 852 if (expected.dimensions_size() != actual.dimensions_size()) { 853 return InvalidArgument("want dimensions_size %d got dimensions_size %d", 854 expected.dimensions_size(), 855 actual.dimensions_size()); 856 } 857 for (int i = 0; i < expected.dimensions_size(); ++i) { 858 if (expected.dimensions(i) != actual.dimensions(i)) { 859 return InvalidArgument( 860 "mismatch in dimension #%d expected: %s actual: %s", i, 861 ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual)); 862 } 863 } 864 } 865 // Non-array, non-tuple shapes are trivially equivalent. 866 return Status::OK(); 867 } 868 869 namespace { 870 871 // If result is an error, extend the error message with the expected and actual 872 // literals. 873 Status EmitLiteralsInErrorMessage(const Status& result, 874 const LiteralSlice& expected, 875 const LiteralSlice& actual) { 876 if (result.ok()) { 877 return result; 878 } 879 return InvalidArgument("%s\n\nExpected literal:\n%s\n\nActual literal:\n%s", 880 result.error_message(), ToStringTruncated(expected), 881 ToStringTruncated(actual)); 882 } 883 884 } // namespace 885 886 Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { 887 VLOG(1) << "expected:"; 888 XLA_VLOG_LINES(1, expected.ToString()); 889 VLOG(1) << "actual:"; 890 XLA_VLOG_LINES(1, actual.ToString()); 891 Status result = EqualHelper(expected, actual); 892 return EmitLiteralsInErrorMessage(result, expected, actual); 893 } 894 895 Status Near(const LiteralSlice& expected, const LiteralSlice& actual, 896 const ErrorSpec& error, absl::optional<bool> detailed_message, 897 const MiscompareCallback& miscompare_callback) { 898 VLOG(1) << "Expected literal:"; 899 XLA_VLOG_LINES(1, expected.ToString()); 900 VLOG(1) << "Actual literal:"; 901 XLA_VLOG_LINES(1, actual.ToString()); 902 Status result = 903 NearHelper(expected, actual, error, detailed_message, miscompare_callback, 904 /*shape_index=*/{}); 905 return EmitLiteralsInErrorMessage(result, expected, actual); 906 } 907 908 string ToStringTruncated(const LiteralSlice& literal) { 909 return RecursiveElementCount(literal.shape()) < 1000 910 ? literal.ToString() 911 : "[TRUNCATED, Literal with more than 1000 values]"; 912 } 913 914 } // namespace literal_comparison 915 } // namespace xla 916