1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 17 18 #include <unistd.h> 19 #include <cmath> 20 #include <vector> 21 22 #include "tensorflow/compiler/xla/index_util.h" 23 #include "tensorflow/compiler/xla/layout_util.h" 24 #include "tensorflow/compiler/xla/literal_util.h" 25 #include "tensorflow/compiler/xla/ptr_util.h" 26 #include "tensorflow/compiler/xla/shape_util.h" 27 #include "tensorflow/compiler/xla/test.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/core/lib/core/casts.h" 30 #include "tensorflow/core/lib/io/path.h" 31 #include "tensorflow/core/lib/strings/str_util.h" 32 #include "tensorflow/core/lib/strings/strcat.h" 33 #include "tensorflow/core/lib/strings/stringprintf.h" 34 #include "tensorflow/core/platform/env.h" 35 #include "tensorflow/core/platform/logging.h" 36 #include "tensorflow/core/platform/protobuf.h" 37 #include "tensorflow/core/platform/test.h" 38 #include "tensorflow/core/platform/types.h" 39 40 namespace xla { 41 42 /* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes( 43 const Shape& expected, const Shape& actual) { 44 if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { 45 return ::testing::AssertionFailure() 46 << "tupleness-mismatch! want: " << ShapeUtil::HumanString(expected) 47 << " got: " << ShapeUtil::HumanString(actual); 48 } 49 if (ShapeUtil::IsTuple(expected)) { 50 if (ShapeUtil::TupleElementCount(expected) != 51 ShapeUtil::TupleElementCount(actual)) { 52 return ::testing::AssertionFailure() 53 << "want tuple element count: " 54 << ShapeUtil::TupleElementCount(expected) 55 << " got tuple element count: " 56 << ShapeUtil::TupleElementCount(actual); 57 } 58 for (int i = 0; i < expected.tuple_shapes_size(); ++i) { 59 ::testing::AssertionResult result = 60 EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)) 61 << "mismatch in tuple index " << i; 62 if (!result) { 63 return result; 64 } 65 } 66 } else { 67 if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { 68 return ::testing::AssertionFailure() 69 << "want rank of: " << ShapeUtil::HumanString(expected) 70 << " got rank of: " << ShapeUtil::HumanString(actual); 71 } 72 if (expected.element_type() != actual.element_type()) { 73 return ::testing::AssertionFailure() 74 << PrimitiveType_Name(expected.element_type()) << " vs " 75 << PrimitiveType_Name(actual.element_type()); 76 } 77 if (expected.dimensions_size() != actual.dimensions_size()) { 78 return ::testing::AssertionFailure() 79 << "want dimensions_size " << expected.dimensions_size() 80 << " got dimensions_size " << actual.dimensions_size(); 81 } 82 for (int i = 0; i < expected.dimensions_size(); ++i) { 83 if (expected.dimensions(i) != actual.dimensions(i)) { 84 return ::testing::AssertionFailure() 85 << "mismatch in dimension #" << i 86 << " expected: " << ShapeUtil::HumanString(expected) 87 << " actual: " << ShapeUtil::HumanString(actual); 88 } 89 } 90 } 91 return ::testing::AssertionSuccess(); 92 } 93 94 /* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, 95 const Shape& actual) { 96 ASSERT_TRUE(EqualShapes(expected, actual)); 97 } 98 99 /* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts( 100 const Shape& expected, const Shape& actual) { 101 ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString()); 102 } 103 104 namespace { 105 106 // Return a literal with all arrays of type FromNativeT converted to type 107 // ToNativeT in the given literal. 108 template <typename FromNativeT, typename ToNativeT> 109 std::unique_ptr<Literal> ConvertType(const Literal& literal) { 110 // First construct shape of the result. 111 Shape result_shape(literal.shape()); 112 ShapeUtil::ForEachMutableSubshape( 113 &result_shape, [](Shape* subshape, const ShapeIndex&) { 114 if (subshape->element_type() == 115 primitive_util::NativeToPrimitiveType<FromNativeT>()) { 116 subshape->set_element_type( 117 primitive_util::NativeToPrimitiveType<ToNativeT>()); 118 } 119 }); 120 auto result = MakeUnique<Literal>(result_shape); 121 122 // Then copy over the data from 'literal' converting FromNativeT values to 123 // ToNativeT values as necessary. 124 ShapeUtil::ForEachSubshape( 125 literal.shape(), 126 [&](const Shape& subshape, const ShapeIndex& shape_index) { 127 if (ShapeUtil::IsArray(subshape)) { 128 if (subshape.element_type() == 129 primitive_util::NativeToPrimitiveType<FromNativeT>()) { 130 auto src = literal.data<FromNativeT>(shape_index); 131 auto dest = result->data<ToNativeT>(shape_index); 132 for (int64 i = 0; i < src.size(); ++i) { 133 dest[i] = static_cast<ToNativeT>(src[i]); 134 } 135 } else { 136 TF_CHECK_OK(result->CopyFrom(literal, 137 /*dest_shape_index=*/shape_index, 138 /*src_shape_index=*/shape_index)); 139 } 140 } 141 }); 142 return result; 143 } 144 145 } // namespace 146 147 /* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertBF16ToF32( 148 const Literal& literal) { 149 return ConvertType<bfloat16, float>(literal); 150 } 151 152 /* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertF32ToBF16( 153 const Literal& literal) { 154 return ConvertType<float, bfloat16>(literal); 155 } 156 157 namespace { 158 159 string Hostname() { 160 char hostname[1024]; 161 gethostname(hostname, sizeof hostname); 162 hostname[sizeof hostname - 1] = 0; 163 return string(hostname); 164 } 165 166 // Helper function for comparing a floating point type, FloatT, bitwise equal 167 // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT 168 // -- on miscompare, a nice error message is given in the AssertionFailure. 169 template <typename FloatT, typename UnsignedT> 170 ::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { 171 auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs); 172 auto urhs = tensorflow::bit_cast<UnsignedT>(rhs); 173 auto lhs_double = static_cast<double>(lhs); 174 auto rhs_double = static_cast<double>(rhs); 175 if (ulhs != urhs) { 176 return ::testing::AssertionFailure() << tensorflow::strings::Printf( 177 "floating values are not bitwise-equal; and equality testing " 178 "was requested: %s=%g=%a vs %s=%g=%a", 179 tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs)) 180 .c_str(), 181 lhs_double, lhs_double, 182 tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs)) 183 .c_str(), 184 rhs_double, rhs_double); 185 } 186 return ::testing::AssertionSuccess(); 187 } 188 189 // Templated comparator that specializes for float equality comparison with the 190 // bitwise helper above (this is the un-specialized fallback, to just use the 191 // default gunit implementation). 192 template <typename NativeT> 193 ::testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) { 194 if (lhs == rhs) { 195 return ::testing::AssertionSuccess(); 196 } 197 ::testing::Message msg; 198 msg << "Expected equality of these values:"; 199 msg << "\n " << lhs; 200 msg << "\n " << rhs; 201 202 return ::testing::AssertionFailure() << msg; 203 } 204 205 // Specializations for floating types that do bitwise comparisons when equality 206 // comparison is requested. 207 template <> 208 ::testing::AssertionResult CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) { 209 return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs); 210 } 211 template <> 212 ::testing::AssertionResult CompareEqual<float>(float lhs, float rhs) { 213 return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs); 214 } 215 template <> 216 ::testing::AssertionResult CompareEqual<double>(double lhs, double rhs) { 217 return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs); 218 } 219 template <> 220 ::testing::AssertionResult CompareEqual<complex64>(complex64 lhs, 221 complex64 rhs) { 222 auto res = CompareEqual<float>(lhs.real(), rhs.real()); 223 if (!res) { 224 return res; 225 } 226 return CompareEqual<float>(lhs.imag(), rhs.imag()); 227 } 228 229 // A recursive function which iterates through every index of expected and 230 // actual literal and compares their values elementwise. Returns true if all 231 // elements are equal. 232 template <typename NativeT> 233 bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, 234 tensorflow::gtl::MutableArraySlice<int64> multi_index, 235 int64 dimension) { 236 if (dimension == expected.shape().dimensions_size()) { 237 NativeT expected_value = expected.Get<NativeT>(multi_index); 238 NativeT actual_value = actual.Get<NativeT>(multi_index); 239 ::testing::AssertionResult result = 240 CompareEqual<NativeT>(expected_value, actual_value); 241 return result; // Defines implicit coersion to bool. 242 } 243 244 bool all_match = true; 245 for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { 246 multi_index[dimension] = i; 247 all_match = all_match && ExpectLiteralsEqual<NativeT>( 248 expected, actual, multi_index, dimension + 1); 249 } 250 return all_match; 251 } 252 253 } // namespace 254 255 /* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected, 256 const Literal& actual, 257 const string& message) { 258 EXPECT_TRUE(Equal(expected, actual)) 259 << "expected:\n" 260 << expected.ToString() << "\n\tvs actual:\n" 261 << actual.ToString() 262 << (message.empty() 263 ? "" 264 : tensorflow::strings::StrCat("\nmessage: ", message)); 265 } 266 267 /* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected, 268 const Literal& actual) { 269 EXPECT_FALSE(Equal(expected, actual)); 270 } 271 272 /* static */ ::testing::AssertionResult LiteralTestUtil::Equal( 273 const Literal& expected, const Literal& actual) { 274 VLOG(1) << "expected:"; 275 XLA_VLOG_LINES(1, expected.ToString()); 276 VLOG(1) << "actual:"; 277 XLA_VLOG_LINES(1, actual.ToString()); 278 279 AssertEqualShapes(expected.shape(), actual.shape()); 280 std::vector<int64> multi_index(expected.shape().dimensions_size(), 0); 281 bool match = false; 282 switch (expected.shape().element_type()) { 283 case PRED: 284 match = ExpectLiteralsEqual<bool>(expected, actual, &multi_index, 0); 285 break; 286 case U8: 287 match = ExpectLiteralsEqual<uint8>(expected, actual, &multi_index, 0); 288 break; 289 case S32: 290 match = ExpectLiteralsEqual<int32>(expected, actual, &multi_index, 0); 291 break; 292 case S64: 293 match = ExpectLiteralsEqual<int64>(expected, actual, &multi_index, 0); 294 break; 295 case U32: 296 match = ExpectLiteralsEqual<uint32>(expected, actual, &multi_index, 0); 297 break; 298 case U64: 299 match = ExpectLiteralsEqual<uint64>(expected, actual, &multi_index, 0); 300 break; 301 case BF16: 302 match = ExpectLiteralsEqual<bfloat16>(expected, actual, &multi_index, 0); 303 break; 304 case F16: 305 match = ExpectLiteralsEqual<half>(expected, actual, &multi_index, 0); 306 break; 307 case F32: 308 match = ExpectLiteralsEqual<float>(expected, actual, &multi_index, 0); 309 break; 310 case F64: 311 match = ExpectLiteralsEqual<double>(expected, actual, &multi_index, 0); 312 break; 313 case C64: 314 match = ExpectLiteralsEqual<complex64>(expected, actual, &multi_index, 0); 315 break; 316 case TUPLE: { 317 bool tuple_match = true; 318 for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { 319 SCOPED_TRACE(tensorflow::strings::StrCat( 320 "Tuple index ", i, " in ", 321 ShapeUtil::HumanString(expected.shape()))); 322 323 // Create LiteralViews of the expected and actual elements. 324 auto result = Equal(LiteralView::Create(expected, {i}), 325 LiteralView::Create(actual, {i})); 326 tuple_match = tuple_match ? !!result : false; 327 } 328 match = tuple_match; 329 break; 330 } 331 default: 332 LOG(FATAL) 333 << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " 334 << PrimitiveType_Name(expected.shape().element_type()); 335 } 336 ::testing::AssertionResult result = ::testing::AssertionSuccess(); 337 if (!match) { 338 result = ::testing::AssertionFailure() 339 << "expected: " << expected.ToString() 340 << "\nactual: " << actual.ToString(); 341 VLOG(1) << result.message(); 342 } 343 return result; 344 } 345 346 namespace { 347 348 // Helper class for comparing floating-point literals within an error bound. 349 class NearComparator { 350 public: 351 explicit NearComparator(ErrorSpec error) : error_(error) {} 352 353 // Compares the two literals elementwise. EXPECTs each pair of elements to be 354 // within the error bound. Emits useful log messages and dumps literals to 355 // temporary files on failure. Returns true if literals match. 356 bool ExpectNear(const Literal& expected, const Literal& actual) { 357 VLOG(1) << "expected:"; 358 XLA_VLOG_LINES(1, TruncateHugeLiteral(expected)); 359 VLOG(1) << "actual:"; 360 XLA_VLOG_LINES(1, TruncateHugeLiteral(actual)); 361 362 // If the shapes mismatch, we simply fail the expectation instead of 363 // printing out data, as it's a type error rather than a value error. 364 ::testing::AssertionResult equal_shapes = 365 LiteralTestUtil::EqualShapes(expected.shape(), actual.shape()); 366 if (!equal_shapes) { 367 EXPECT_TRUE(equal_shapes); 368 return false; 369 } 370 371 // Set up members used during the comparison. 372 num_miscompares_ = 0; 373 abs_diff_sum_ = 0.0; 374 abs_expected_sum_ = 0.0; 375 abs_diff_miscompare_sum_ = 0.0; 376 abs_expected_miscompare_sum_ = 0.0; 377 max_rel_err_ = 0.0; 378 max_abs_err_ = 0.0; 379 first_linear_index_ = -1; 380 last_linear_index_ = -1; 381 max_rel_linear_index_ = -1; 382 max_abs_linear_index_ = -1; 383 miscompares_ = Literal(ShapeUtil::ChangeElementType(actual.shape(), PRED)); 384 miscompares_.PopulateWithValue(false); 385 multi_index_.resize(expected.shape().dimensions_size(), 0); 386 387 switch (expected.shape().element_type()) { 388 case BF16: 389 ExpectLiteralsNear<bfloat16>(expected, actual, 0); 390 break; 391 case F16: 392 ExpectLiteralsNear<half>(expected, actual, 0); 393 break; 394 case F32: 395 ExpectLiteralsNear<float>(expected, actual, 0); 396 break; 397 case F64: 398 ExpectLiteralsNear<double>(expected, actual, 0); 399 break; 400 case C64: 401 ExpectLiteralsNear<complex64>(expected, actual, 0); 402 break; 403 default: 404 LOG(FATAL) << "Unsupported primitive type in near comparator: " 405 << PrimitiveType_Name(expected.shape().element_type()) 406 << ". Must be floating-point type."; 407 } 408 409 if (num_miscompares_ > 0) { 410 if (!VLOG_IS_ON(1)) { 411 LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) 412 << " " << TruncateHugeLiteral(expected); 413 LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) 414 << " " << TruncateHugeLiteral(actual); 415 LOG(INFO) << "Dumping literals to temp files..."; 416 WriteLiteralToTempFile(expected, "expected"); 417 WriteLiteralToTempFile(actual, "actual"); 418 WriteLiteralToTempFile(miscompares_, "miscompares"); 419 } 420 EXPECT_TRUE(num_miscompares_ == 0) 421 << "\nmax relative mismatch at index " 422 << LiteralTestUtil::MultiIndexAsString( 423 IndexUtil::LinearIndexToMultidimensionalIndex( 424 actual.shape(), max_rel_linear_index_)) 425 << "\nmaximum relative error " << max_rel_err_ 426 << "\nmax absolute mismatch at index " 427 << LiteralTestUtil::MultiIndexAsString( 428 IndexUtil::LinearIndexToMultidimensionalIndex( 429 actual.shape(), max_abs_linear_index_)) 430 << "\nmaximum absolute error " << max_abs_err_ 431 << "\nfirst mismatch at index " 432 << LiteralTestUtil::MultiIndexAsString( 433 IndexUtil::LinearIndexToMultidimensionalIndex( 434 actual.shape(), first_linear_index_)) 435 << "\nlast mismatch at index " 436 << LiteralTestUtil::MultiIndexAsString( 437 IndexUtil::LinearIndexToMultidimensionalIndex( 438 actual.shape(), last_linear_index_)) 439 << "\ntotal absolute error " << abs_diff_sum_ 440 << "\ntotal absolute error of miscompares " 441 << abs_diff_miscompare_sum_ << "\ntotal relative error " 442 << (abs_diff_sum_ / abs_expected_sum_) 443 << "\ntotal relative error of miscompares " 444 << (abs_diff_miscompare_sum_ / abs_expected_miscompare_sum_) 445 << "\nfailure count " << num_miscompares_; 446 } 447 return num_miscompares_ == 0; 448 } 449 450 private: 451 template <typename NativeT> 452 bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { 453 if (relaxed_nans) { 454 return !std::isnan(expected) && std::isnan(actual); 455 } else { 456 return std::isnan(expected) != std::isnan(actual); 457 } 458 } 459 460 template <typename NativeT> 461 void ExpectNear(NativeT expected, NativeT actual, 462 const ::testing::Message& message) { 463 EXPECT_NEAR(expected, actual, error_.abs) 464 << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" 465 << message; 466 } 467 468 // EXPECTs that the two given scalar values are within the error bound. Keeps 469 // track of how many mismatches have occurred to keep the size of the output 470 // manageable. 471 template <typename NativeT> 472 bool ExpectValuesNear(NativeT expected, NativeT actual) { 473 if (expected == actual) { 474 return true; 475 } 476 477 const float abs_diff = std::abs(actual - expected); 478 const float rel_err = abs_diff / std::abs(expected); 479 const bool nan_mismatch = 480 NanMismatch<NativeT>(expected, actual, error_.relaxed_nans); 481 const bool mismatch = 482 (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel)); 483 return !mismatch; 484 } 485 486 // Assumes that expected vs actual fail ExpectValuesNear. 487 template <typename NativeT> 488 void UpdateAndLogMiscompares(const NativeT expected, const NativeT actual, 489 const Shape& shape, const int64 linear_index) { 490 const float abs_diff = std::abs(actual - expected); 491 const float rel_err = abs_diff / std::abs(expected); 492 abs_diff_sum_ += abs_diff; 493 abs_expected_sum_ += std::abs(expected); 494 if (rel_err > max_rel_err_ || std::isnan(rel_err)) { 495 max_rel_err_ = rel_err; 496 max_rel_linear_index_ = linear_index; 497 } 498 if (abs_diff > max_abs_err_ || std::isnan(abs_diff)) { 499 max_abs_err_ = abs_diff; 500 max_abs_linear_index_ = linear_index; 501 } 502 if (VLOG_IS_ON(10)) { 503 VLOG(10) << tensorflow::strings::Printf( 504 "index %s abs_diff %f rel_err %f", 505 LiteralTestUtil::MultiIndexAsString( 506 IndexUtil::LinearIndexToMultidimensionalIndex(shape, 507 linear_index)) 508 .c_str(), 509 abs_diff, rel_err); 510 } 511 abs_diff_miscompare_sum_ += abs_diff; 512 abs_expected_miscompare_sum_ += std::abs(expected); 513 const int64 kMaxFailures = 2; 514 if (num_miscompares_ < kMaxFailures) { 515 const auto multi_index = 516 IndexUtil::LinearIndexToMultidimensionalIndex(shape, linear_index); 517 ::testing::Message msg; 518 msg << "mismatch at index " 519 << LiteralTestUtil::MultiIndexAsString(multi_index) << " abs diff " 520 << abs_diff << " rel err " << rel_err << " failure #" 521 << num_miscompares_; 522 ExpectNear<NativeT>(expected, actual, msg); 523 } else if (num_miscompares_ == kMaxFailures) { 524 LOG(ERROR) << "reached max 'loud' failure count; silently proceeding..."; 525 } 526 if (num_miscompares_ == 0) { 527 first_linear_index_ = linear_index; 528 } 529 num_miscompares_++; 530 last_linear_index_ = linear_index; 531 miscompares_.data<bool>()[linear_index] = true; 532 } 533 534 // Recursive function which compares the two given literals elementwise. 535 template <typename NativeT> 536 void ExpectLiteralsNear(const Literal& expected, const Literal& actual, 537 int64 dimension) { 538 // Fast path optimization for the case were layouts match. 539 if (LayoutUtil::Equal(actual.shape().layout(), expected.shape().layout())) { 540 tensorflow::gtl::ArraySlice<const NativeT> expected_data = 541 expected.data<NativeT>(); 542 tensorflow::gtl::ArraySlice<const NativeT> actual_data = 543 actual.data<NativeT>(); 544 const int64 len = expected_data.size(); 545 for (int64 i = 0; i < len; ++i) { 546 const bool near = ExpectValuesNear(expected_data[i], actual_data[i]); 547 if (!near) { 548 UpdateAndLogMiscompares<NativeT>(expected_data[i], actual_data[i], 549 actual.shape(), i); 550 } 551 } 552 return; 553 } 554 555 if (dimension == expected.shape().dimensions_size()) { 556 bool near = ExpectValuesNear(expected.Get<NativeT>(multi_index_), 557 actual.Get<NativeT>(multi_index_)); 558 if (!near) { 559 UpdateAndLogMiscompares<NativeT>( 560 expected.Get<NativeT>(multi_index_), 561 actual.Get<NativeT>(multi_index_), actual.shape(), 562 IndexUtil::MultidimensionalIndexToLinearIndex(actual.shape(), 563 multi_index_)); 564 } 565 } else { 566 for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { 567 multi_index_[dimension] = i; 568 ExpectLiteralsNear<NativeT>(expected, actual, dimension + 1); 569 } 570 } 571 } 572 573 // Writes the given literal to a file in the test temporary directory. 574 void WriteLiteralToTempFile(const Literal& literal, const string& name) { 575 int64 now_usec = tensorflow::Env::Default()->NowMicros(); 576 string filename = tensorflow::io::JoinPath( 577 tensorflow::testing::TmpDir(), 578 tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(), 579 now_usec, name.c_str())); 580 TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), 581 filename, literal.ToProto())); 582 LOG(ERROR) << "wrote to " << name << " file: " << filename; 583 } 584 585 // Gets the total element count. For tuples, this is not the count of tuple 586 // elements, but the sum of elements of each tuple element. 587 int64 RecursiveElementCount(const Shape& shape) { 588 if (ShapeUtil::IsTuple(shape)) { 589 const int64 tuple_elements = ShapeUtil::TupleElementCount(shape); 590 int64 total = 0; 591 for (int64 i = 0; i < tuple_elements; ++i) { 592 total += 593 RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); 594 } 595 return total; 596 } else { 597 return ShapeUtil::ElementsIn(shape); 598 } 599 } 600 601 // Calling ToString on a literal with over 100 million elements takes around 602 // 3 minutes. The utility of printing a literal with >1000 elements is 603 // questionable, especially when writing the Literal proto to disk is orders 604 // of magnitude faster. 605 string TruncateHugeLiteral(const Literal& literal) { 606 return RecursiveElementCount(literal.shape()) < 1000 607 ? literal.ToString() 608 : "[TRUNCATED, Literal with more than 1000 values]"; 609 } 610 611 ErrorSpec error_; 612 613 // Number of element miscomparisons encountered so far. 614 int64 num_miscompares_; 615 616 // A Literal containing which elements did not match in the expected and 617 // actual literals. miscompares_ contains PREDs and is of the same sizes as 618 // the comparison literals. 619 Literal miscompares_; 620 621 // A multidimensional index used when performing the recursive comparison. 622 std::vector<int64> multi_index_; 623 624 // Aggregated Statistics on input. 625 double abs_diff_sum_; 626 double abs_expected_sum_; 627 double abs_diff_miscompare_sum_; 628 double abs_expected_miscompare_sum_; 629 float max_rel_err_; 630 float max_abs_err_; 631 int64 first_linear_index_; 632 int64 last_linear_index_; 633 int64 max_rel_linear_index_; 634 int64 max_abs_linear_index_; 635 }; 636 637 template <> 638 bool NearComparator::NanMismatch<complex64>(complex64 expected, 639 complex64 actual, 640 bool relaxed_nans) { 641 return NanMismatch(expected.real(), actual.real(), relaxed_nans) || 642 NanMismatch(expected.imag(), actual.imag(), relaxed_nans); 643 } 644 645 template <> 646 void NearComparator::ExpectNear<complex64>(complex64 expected, complex64 actual, 647 const ::testing::Message& message) { 648 EXPECT_NEAR(expected.real(), actual.real(), error_.abs) 649 << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" 650 << message; 651 EXPECT_NEAR(expected.imag(), actual.imag(), error_.abs) 652 << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" 653 << message; 654 } 655 656 template <> 657 bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected, 658 bfloat16 actual) { 659 return ExpectValuesNear(static_cast<float>(expected), 660 static_cast<float>(actual)); 661 } 662 663 template <> 664 bool NearComparator::ExpectValuesNear<half>(half expected, half actual) { 665 return ExpectValuesNear(static_cast<float>(std::move(expected)), 666 static_cast<float>(std::move(actual))); 667 } 668 669 template <> 670 void NearComparator::UpdateAndLogMiscompares<bfloat16>( 671 const bfloat16 expected, const bfloat16 actual, const Shape& shape, 672 const int64 linear_index) { 673 UpdateAndLogMiscompares(static_cast<float>(expected), 674 static_cast<float>(actual), shape, linear_index); 675 } 676 677 template <> 678 void NearComparator::UpdateAndLogMiscompares<half>(half expected, half actual, 679 const Shape& shape, 680 const int64 linear_index) { 681 UpdateAndLogMiscompares(static_cast<float>(std::move(expected)), 682 static_cast<float>(std::move(actual)), shape, 683 linear_index); 684 } 685 686 } // namespace 687 688 /* static */ ::testing::AssertionResult LiteralTestUtil::Near( 689 const Literal& expected, const Literal& actual, const ErrorSpec& error) { 690 ::testing::AssertionResult err = 691 EqualShapes(expected.shape(), actual.shape()); 692 if (!err) { 693 return err; 694 } 695 696 if (ShapeUtil::IsTuple(expected.shape())) { 697 for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { 698 SCOPED_TRACE(tensorflow::strings::StrCat( 699 "Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape()))); 700 const auto expected_element = LiteralView::Create(expected, {i}); 701 const auto actual_element = LiteralView::Create(actual, {i}); 702 703 ::testing::AssertionResult res = 704 Near(expected_element, actual_element, error); 705 if (err && !res) { 706 err = res; 707 } 708 } 709 return err; 710 } 711 712 if (ShapeUtil::ElementIsFloating(expected.shape()) || 713 ShapeUtil::ElementIsComplex(expected.shape())) { 714 NearComparator comparator(error); 715 return comparator.ExpectNear(expected, actual) 716 ? ::testing::AssertionSuccess() 717 : ::testing::AssertionFailure() << "values were not near"; 718 } 719 720 return Equal(expected, actual); 721 } 722 723 /* static */ void LiteralTestUtil::ExpectNear(const Literal& expected, 724 const Literal& actual, 725 const ErrorSpec& error, 726 const string& message) { 727 EXPECT_TRUE(Near(expected, actual, error)) 728 << (message.empty() 729 ? "" 730 : tensorflow::strings::StrCat("\nmessage: ", message)); 731 } 732 733 /*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( 734 const Literal& expected, const Literal& actual, 735 const tensorflow::gtl::optional<ErrorSpec>& error) { 736 if (error.has_value()) { 737 VLOG(1) << "Expects near"; 738 return Near(expected, actual, *error); 739 } 740 VLOG(1) << "Expects equal"; 741 return Equal(expected, actual); 742 } 743 744 /*static*/ void LiteralTestUtil::ExpectNearOrEqual( 745 const Literal& expected, const Literal& actual, 746 const tensorflow::gtl::optional<ErrorSpec>& error) { 747 EXPECT_TRUE(NearOrEqual(expected, actual, error)); 748 } 749 750 /* static */ string LiteralTestUtil::MultiIndexAsString( 751 tensorflow::gtl::ArraySlice<int64> multi_index) { 752 return tensorflow::strings::StrCat( 753 "{", tensorflow::str_util::Join(multi_index, ","), "}"); 754 } 755 756 /* static */ std::unique_ptr<Literal> LiteralTestUtil::Reshape( 757 tensorflow::gtl::ArraySlice<int64> new_dimensions, 758 tensorflow::gtl::ArraySlice<int64> minor_to_major, const Literal& literal) { 759 int64 new_num_elements = 1; 760 for (int64 i = 0; i < new_dimensions.size(); ++i) { 761 new_num_elements *= new_dimensions[i]; 762 } 763 CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); 764 CHECK_EQ(new_dimensions.size(), minor_to_major.size()); 765 766 auto new_literal = MakeUnique<Literal>( 767 ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); 768 769 // Create a new shape with the given minor-to-major layout. This shape is used 770 // solely for converting linear address to multi-dimensional addresses when 771 // writing elements to the new literal. 772 Shape shape_with_layout = new_literal->shape(); 773 *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); 774 775 // Copy data into new literal, element-by-element. 776 for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { 777 std::vector<int64> from_multi_index = 778 IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); 779 std::vector<int64> to_multi_index = 780 IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); 781 switch (literal.shape().element_type()) { 782 case PRED: 783 new_literal->Set<bool>(to_multi_index, 784 literal.Get<bool>(from_multi_index)); 785 break; 786 case U8: 787 new_literal->Set<uint8>(to_multi_index, 788 literal.Get<uint8>(from_multi_index)); 789 break; 790 case U32: 791 new_literal->Set<uint32>(to_multi_index, 792 literal.Get<uint32>(from_multi_index)); 793 break; 794 case S32: 795 new_literal->Set<int32>(to_multi_index, 796 literal.Get<int32>(from_multi_index)); 797 break; 798 case U64: 799 new_literal->Set<uint64>(to_multi_index, 800 literal.Get<uint64>(from_multi_index)); 801 break; 802 case S64: 803 new_literal->Set<int64>(to_multi_index, 804 literal.Get<int64>(from_multi_index)); 805 break; 806 case F32: 807 new_literal->Set<float>(to_multi_index, 808 literal.Get<float>(from_multi_index)); 809 break; 810 case F64: 811 new_literal->Set<double>(to_multi_index, 812 literal.Get<double>(from_multi_index)); 813 break; 814 case C64: 815 new_literal->Set<complex64>(to_multi_index, 816 literal.Get<complex64>(from_multi_index)); 817 break; 818 default: 819 LOG(FATAL) << "Unhandled primitive element type: " 820 << PrimitiveType_Name(literal.shape().element_type()); 821 } 822 } 823 824 return new_literal; 825 } 826 827 } // namespace xla 828