Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     17 
     18 #include <unistd.h>
     19 #include <cmath>
     20 #include <vector>
     21 
     22 #include "tensorflow/compiler/xla/index_util.h"
     23 #include "tensorflow/compiler/xla/layout_util.h"
     24 #include "tensorflow/compiler/xla/literal_util.h"
     25 #include "tensorflow/compiler/xla/ptr_util.h"
     26 #include "tensorflow/compiler/xla/shape_util.h"
     27 #include "tensorflow/compiler/xla/test.h"
     28 #include "tensorflow/compiler/xla/types.h"
     29 #include "tensorflow/core/lib/core/casts.h"
     30 #include "tensorflow/core/lib/io/path.h"
     31 #include "tensorflow/core/lib/strings/str_util.h"
     32 #include "tensorflow/core/lib/strings/strcat.h"
     33 #include "tensorflow/core/lib/strings/stringprintf.h"
     34 #include "tensorflow/core/platform/env.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 #include "tensorflow/core/platform/protobuf.h"
     37 #include "tensorflow/core/platform/test.h"
     38 #include "tensorflow/core/platform/types.h"
     39 
     40 namespace xla {
     41 
     42 /* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes(
     43     const Shape& expected, const Shape& actual) {
     44   if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) {
     45     return ::testing::AssertionFailure()
     46            << "tupleness-mismatch! want: " << ShapeUtil::HumanString(expected)
     47            << " got: " << ShapeUtil::HumanString(actual);
     48   }
     49   if (ShapeUtil::IsTuple(expected)) {
     50     if (ShapeUtil::TupleElementCount(expected) !=
     51         ShapeUtil::TupleElementCount(actual)) {
     52       return ::testing::AssertionFailure()
     53              << "want tuple element count: "
     54              << ShapeUtil::TupleElementCount(expected)
     55              << " got tuple element count: "
     56              << ShapeUtil::TupleElementCount(actual);
     57     }
     58     for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
     59       ::testing::AssertionResult result =
     60           EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i))
     61           << "mismatch in tuple index " << i;
     62       if (!result) {
     63         return result;
     64       }
     65     }
     66   } else {
     67     if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
     68       return ::testing::AssertionFailure()
     69              << "want rank of: " << ShapeUtil::HumanString(expected)
     70              << " got rank of: " << ShapeUtil::HumanString(actual);
     71     }
     72     if (expected.element_type() != actual.element_type()) {
     73       return ::testing::AssertionFailure()
     74              << PrimitiveType_Name(expected.element_type()) << " vs "
     75              << PrimitiveType_Name(actual.element_type());
     76     }
     77     if (expected.dimensions_size() != actual.dimensions_size()) {
     78       return ::testing::AssertionFailure()
     79              << "want dimensions_size " << expected.dimensions_size()
     80              << " got dimensions_size " << actual.dimensions_size();
     81     }
     82     for (int i = 0; i < expected.dimensions_size(); ++i) {
     83       if (expected.dimensions(i) != actual.dimensions(i)) {
     84         return ::testing::AssertionFailure()
     85                << "mismatch in dimension #" << i
     86                << " expected: " << ShapeUtil::HumanString(expected)
     87                << " actual: " << ShapeUtil::HumanString(actual);
     88       }
     89     }
     90   }
     91   return ::testing::AssertionSuccess();
     92 }
     93 
     94 /* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected,
     95                                                      const Shape& actual) {
     96   ASSERT_TRUE(EqualShapes(expected, actual));
     97 }
     98 
     99 /* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts(
    100     const Shape& expected, const Shape& actual) {
    101   ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString());
    102 }
    103 
    104 namespace {
    105 
    106 // Return a literal with all arrays of type FromNativeT converted to type
    107 // ToNativeT in the given literal.
    108 template <typename FromNativeT, typename ToNativeT>
    109 std::unique_ptr<Literal> ConvertType(const Literal& literal) {
    110   // First construct shape of the result.
    111   Shape result_shape(literal.shape());
    112   ShapeUtil::ForEachMutableSubshape(
    113       &result_shape, [](Shape* subshape, const ShapeIndex&) {
    114         if (subshape->element_type() ==
    115             primitive_util::NativeToPrimitiveType<FromNativeT>()) {
    116           subshape->set_element_type(
    117               primitive_util::NativeToPrimitiveType<ToNativeT>());
    118         }
    119       });
    120   auto result = MakeUnique<Literal>(result_shape);
    121 
    122   // Then copy over the data from 'literal' converting FromNativeT values to
    123   // ToNativeT values as necessary.
    124   ShapeUtil::ForEachSubshape(
    125       literal.shape(),
    126       [&](const Shape& subshape, const ShapeIndex& shape_index) {
    127         if (ShapeUtil::IsArray(subshape)) {
    128           if (subshape.element_type() ==
    129               primitive_util::NativeToPrimitiveType<FromNativeT>()) {
    130             auto src = literal.data<FromNativeT>(shape_index);
    131             auto dest = result->data<ToNativeT>(shape_index);
    132             for (int64 i = 0; i < src.size(); ++i) {
    133               dest[i] = static_cast<ToNativeT>(src[i]);
    134             }
    135           } else {
    136             TF_CHECK_OK(result->CopyFrom(literal,
    137                                          /*dest_shape_index=*/shape_index,
    138                                          /*src_shape_index=*/shape_index));
    139           }
    140         }
    141       });
    142   return result;
    143 }
    144 
    145 }  // namespace
    146 
    147 /* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertBF16ToF32(
    148     const Literal& literal) {
    149   return ConvertType<bfloat16, float>(literal);
    150 }
    151 
    152 /* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertF32ToBF16(
    153     const Literal& literal) {
    154   return ConvertType<float, bfloat16>(literal);
    155 }
    156 
    157 namespace {
    158 
    159 string Hostname() {
    160   char hostname[1024];
    161   gethostname(hostname, sizeof hostname);
    162   hostname[sizeof hostname - 1] = 0;
    163   return string(hostname);
    164 }
    165 
    166 // Helper function for comparing a floating point type, FloatT, bitwise equal
    167 // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
    168 // -- on miscompare, a nice error message is given in the AssertionFailure.
    169 template <typename FloatT, typename UnsignedT>
    170 ::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
    171   auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
    172   auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
    173   auto lhs_double = static_cast<double>(lhs);
    174   auto rhs_double = static_cast<double>(rhs);
    175   if (ulhs != urhs) {
    176     return ::testing::AssertionFailure() << tensorflow::strings::Printf(
    177                "floating values are not bitwise-equal; and equality testing "
    178                "was requested: %s=%g=%a vs %s=%g=%a",
    179                tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs))
    180                    .c_str(),
    181                lhs_double, lhs_double,
    182                tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs))
    183                    .c_str(),
    184                rhs_double, rhs_double);
    185   }
    186   return ::testing::AssertionSuccess();
    187 }
    188 
    189 // Templated comparator that specializes for float equality comparison with the
    190 // bitwise helper above (this is the un-specialized fallback, to just use the
    191 // default gunit implementation).
    192 template <typename NativeT>
    193 ::testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) {
    194   if (lhs == rhs) {
    195     return ::testing::AssertionSuccess();
    196   }
    197   ::testing::Message msg;
    198   msg << "Expected equality of these values:";
    199   msg << "\n  " << lhs;
    200   msg << "\n  " << rhs;
    201 
    202   return ::testing::AssertionFailure() << msg;
    203 }
    204 
    205 // Specializations for floating types that do bitwise comparisons when equality
    206 // comparison is requested.
    207 template <>
    208 ::testing::AssertionResult CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
    209   return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
    210 }
    211 template <>
    212 ::testing::AssertionResult CompareEqual<float>(float lhs, float rhs) {
    213   return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
    214 }
    215 template <>
    216 ::testing::AssertionResult CompareEqual<double>(double lhs, double rhs) {
    217   return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs);
    218 }
    219 template <>
    220 ::testing::AssertionResult CompareEqual<complex64>(complex64 lhs,
    221                                                    complex64 rhs) {
    222   auto res = CompareEqual<float>(lhs.real(), rhs.real());
    223   if (!res) {
    224     return res;
    225   }
    226   return CompareEqual<float>(lhs.imag(), rhs.imag());
    227 }
    228 
    229 // A recursive function which iterates through every index of expected and
    230 // actual literal and compares their values elementwise. Returns true if all
    231 // elements are equal.
    232 template <typename NativeT>
    233 bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
    234                          tensorflow::gtl::MutableArraySlice<int64> multi_index,
    235                          int64 dimension) {
    236   if (dimension == expected.shape().dimensions_size()) {
    237     NativeT expected_value = expected.Get<NativeT>(multi_index);
    238     NativeT actual_value = actual.Get<NativeT>(multi_index);
    239     ::testing::AssertionResult result =
    240         CompareEqual<NativeT>(expected_value, actual_value);
    241     return result;  // Defines implicit coersion to bool.
    242   }
    243 
    244   bool all_match = true;
    245   for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
    246     multi_index[dimension] = i;
    247     all_match = all_match && ExpectLiteralsEqual<NativeT>(
    248                                  expected, actual, multi_index, dimension + 1);
    249   }
    250   return all_match;
    251 }
    252 
    253 }  // namespace
    254 
    255 /* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected,
    256                                                const Literal& actual,
    257                                                const string& message) {
    258   EXPECT_TRUE(Equal(expected, actual))
    259       << "expected:\n"
    260       << expected.ToString() << "\n\tvs actual:\n"
    261       << actual.ToString()
    262       << (message.empty()
    263               ? ""
    264               : tensorflow::strings::StrCat("\nmessage: ", message));
    265 }
    266 
    267 /* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected,
    268                                                   const Literal& actual) {
    269   EXPECT_FALSE(Equal(expected, actual));
    270 }
    271 
    272 /* static */ ::testing::AssertionResult LiteralTestUtil::Equal(
    273     const Literal& expected, const Literal& actual) {
    274   VLOG(1) << "expected:";
    275   XLA_VLOG_LINES(1, expected.ToString());
    276   VLOG(1) << "actual:";
    277   XLA_VLOG_LINES(1, actual.ToString());
    278 
    279   AssertEqualShapes(expected.shape(), actual.shape());
    280   std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
    281   bool match = false;
    282   switch (expected.shape().element_type()) {
    283     case PRED:
    284       match = ExpectLiteralsEqual<bool>(expected, actual, &multi_index, 0);
    285       break;
    286     case U8:
    287       match = ExpectLiteralsEqual<uint8>(expected, actual, &multi_index, 0);
    288       break;
    289     case S32:
    290       match = ExpectLiteralsEqual<int32>(expected, actual, &multi_index, 0);
    291       break;
    292     case S64:
    293       match = ExpectLiteralsEqual<int64>(expected, actual, &multi_index, 0);
    294       break;
    295     case U32:
    296       match = ExpectLiteralsEqual<uint32>(expected, actual, &multi_index, 0);
    297       break;
    298     case U64:
    299       match = ExpectLiteralsEqual<uint64>(expected, actual, &multi_index, 0);
    300       break;
    301     case BF16:
    302       match = ExpectLiteralsEqual<bfloat16>(expected, actual, &multi_index, 0);
    303       break;
    304     case F16:
    305       match = ExpectLiteralsEqual<half>(expected, actual, &multi_index, 0);
    306       break;
    307     case F32:
    308       match = ExpectLiteralsEqual<float>(expected, actual, &multi_index, 0);
    309       break;
    310     case F64:
    311       match = ExpectLiteralsEqual<double>(expected, actual, &multi_index, 0);
    312       break;
    313     case C64:
    314       match = ExpectLiteralsEqual<complex64>(expected, actual, &multi_index, 0);
    315       break;
    316     case TUPLE: {
    317       bool tuple_match = true;
    318       for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
    319         SCOPED_TRACE(tensorflow::strings::StrCat(
    320             "Tuple index ", i, " in ",
    321             ShapeUtil::HumanString(expected.shape())));
    322 
    323         // Create LiteralViews of the expected and actual elements.
    324         auto result = Equal(LiteralView::Create(expected, {i}),
    325                             LiteralView::Create(actual, {i}));
    326         tuple_match = tuple_match ? !!result : false;
    327       }
    328       match = tuple_match;
    329       break;
    330     }
    331     default:
    332       LOG(FATAL)
    333           << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
    334           << PrimitiveType_Name(expected.shape().element_type());
    335   }
    336   ::testing::AssertionResult result = ::testing::AssertionSuccess();
    337   if (!match) {
    338     result = ::testing::AssertionFailure()
    339              << "expected: " << expected.ToString()
    340              << "\nactual:   " << actual.ToString();
    341     VLOG(1) << result.message();
    342   }
    343   return result;
    344 }
    345 
    346 namespace {
    347 
    348 // Helper class for comparing floating-point literals within an error bound.
    349 class NearComparator {
    350  public:
    351   explicit NearComparator(ErrorSpec error) : error_(error) {}
    352 
    353   // Compares the two literals elementwise. EXPECTs each pair of elements to be
    354   // within the error bound. Emits useful log messages and dumps literals to
    355   // temporary files on failure. Returns true if  literals match.
    356   bool ExpectNear(const Literal& expected, const Literal& actual) {
    357     VLOG(1) << "expected:";
    358     XLA_VLOG_LINES(1, TruncateHugeLiteral(expected));
    359     VLOG(1) << "actual:";
    360     XLA_VLOG_LINES(1, TruncateHugeLiteral(actual));
    361 
    362     // If the shapes mismatch, we simply fail the expectation instead of
    363     // printing out data, as it's a type error rather than a value error.
    364     ::testing::AssertionResult equal_shapes =
    365         LiteralTestUtil::EqualShapes(expected.shape(), actual.shape());
    366     if (!equal_shapes) {
    367       EXPECT_TRUE(equal_shapes);
    368       return false;
    369     }
    370 
    371     // Set up members used during the comparison.
    372     num_miscompares_ = 0;
    373     abs_diff_sum_ = 0.0;
    374     abs_expected_sum_ = 0.0;
    375     abs_diff_miscompare_sum_ = 0.0;
    376     abs_expected_miscompare_sum_ = 0.0;
    377     max_rel_err_ = 0.0;
    378     max_abs_err_ = 0.0;
    379     first_linear_index_ = -1;
    380     last_linear_index_ = -1;
    381     max_rel_linear_index_ = -1;
    382     max_abs_linear_index_ = -1;
    383     miscompares_ = Literal(ShapeUtil::ChangeElementType(actual.shape(), PRED));
    384     miscompares_.PopulateWithValue(false);
    385     multi_index_.resize(expected.shape().dimensions_size(), 0);
    386 
    387     switch (expected.shape().element_type()) {
    388       case BF16:
    389         ExpectLiteralsNear<bfloat16>(expected, actual, 0);
    390         break;
    391       case F16:
    392         ExpectLiteralsNear<half>(expected, actual, 0);
    393         break;
    394       case F32:
    395         ExpectLiteralsNear<float>(expected, actual, 0);
    396         break;
    397       case F64:
    398         ExpectLiteralsNear<double>(expected, actual, 0);
    399         break;
    400       case C64:
    401         ExpectLiteralsNear<complex64>(expected, actual, 0);
    402         break;
    403       default:
    404         LOG(FATAL) << "Unsupported primitive type in near comparator: "
    405                    << PrimitiveType_Name(expected.shape().element_type())
    406                    << ". Must be floating-point type.";
    407     }
    408 
    409     if (num_miscompares_ > 0) {
    410       if (!VLOG_IS_ON(1)) {
    411         LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape())
    412                   << " " << TruncateHugeLiteral(expected);
    413         LOG(INFO) << "actual:   " << ShapeUtil::HumanString(actual.shape())
    414                   << " " << TruncateHugeLiteral(actual);
    415         LOG(INFO) << "Dumping literals to temp files...";
    416         WriteLiteralToTempFile(expected, "expected");
    417         WriteLiteralToTempFile(actual, "actual");
    418         WriteLiteralToTempFile(miscompares_, "miscompares");
    419       }
    420       EXPECT_TRUE(num_miscompares_ == 0)
    421           << "\nmax relative mismatch at index "
    422           << LiteralTestUtil::MultiIndexAsString(
    423                  IndexUtil::LinearIndexToMultidimensionalIndex(
    424                      actual.shape(), max_rel_linear_index_))
    425           << "\nmaximum relative error " << max_rel_err_
    426           << "\nmax absolute mismatch at index "
    427           << LiteralTestUtil::MultiIndexAsString(
    428                  IndexUtil::LinearIndexToMultidimensionalIndex(
    429                      actual.shape(), max_abs_linear_index_))
    430           << "\nmaximum absolute error " << max_abs_err_
    431           << "\nfirst mismatch at index "
    432           << LiteralTestUtil::MultiIndexAsString(
    433                  IndexUtil::LinearIndexToMultidimensionalIndex(
    434                      actual.shape(), first_linear_index_))
    435           << "\nlast mismatch at index "
    436           << LiteralTestUtil::MultiIndexAsString(
    437                  IndexUtil::LinearIndexToMultidimensionalIndex(
    438                      actual.shape(), last_linear_index_))
    439           << "\ntotal absolute error " << abs_diff_sum_
    440           << "\ntotal absolute error of miscompares "
    441           << abs_diff_miscompare_sum_ << "\ntotal relative error "
    442           << (abs_diff_sum_ / abs_expected_sum_)
    443           << "\ntotal relative error of miscompares "
    444           << (abs_diff_miscompare_sum_ / abs_expected_miscompare_sum_)
    445           << "\nfailure count " << num_miscompares_;
    446     }
    447     return num_miscompares_ == 0;
    448   }
    449 
    450  private:
    451   template <typename NativeT>
    452   bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) {
    453     if (relaxed_nans) {
    454       return !std::isnan(expected) && std::isnan(actual);
    455     } else {
    456       return std::isnan(expected) != std::isnan(actual);
    457     }
    458   }
    459 
    460   template <typename NativeT>
    461   void ExpectNear(NativeT expected, NativeT actual,
    462                   const ::testing::Message& message) {
    463     EXPECT_NEAR(expected, actual, error_.abs)
    464         << "expected:\n  " << expected << "\n\tvs actual:\n  " << actual << "\n"
    465         << message;
    466   }
    467 
    468   // EXPECTs that the two given scalar values are within the error bound. Keeps
    469   // track of how many mismatches have occurred to keep the size of the output
    470   // manageable.
    471   template <typename NativeT>
    472   bool ExpectValuesNear(NativeT expected, NativeT actual) {
    473     if (expected == actual) {
    474       return true;
    475     }
    476 
    477     const float abs_diff = std::abs(actual - expected);
    478     const float rel_err = abs_diff / std::abs(expected);
    479     const bool nan_mismatch =
    480         NanMismatch<NativeT>(expected, actual, error_.relaxed_nans);
    481     const bool mismatch =
    482         (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel));
    483     return !mismatch;
    484   }
    485 
    486   // Assumes that expected vs actual fail ExpectValuesNear.
    487   template <typename NativeT>
    488   void UpdateAndLogMiscompares(const NativeT expected, const NativeT actual,
    489                                const Shape& shape, const int64 linear_index) {
    490     const float abs_diff = std::abs(actual - expected);
    491     const float rel_err = abs_diff / std::abs(expected);
    492     abs_diff_sum_ += abs_diff;
    493     abs_expected_sum_ += std::abs(expected);
    494     if (rel_err > max_rel_err_ || std::isnan(rel_err)) {
    495       max_rel_err_ = rel_err;
    496       max_rel_linear_index_ = linear_index;
    497     }
    498     if (abs_diff > max_abs_err_ || std::isnan(abs_diff)) {
    499       max_abs_err_ = abs_diff;
    500       max_abs_linear_index_ = linear_index;
    501     }
    502     if (VLOG_IS_ON(10)) {
    503       VLOG(10) << tensorflow::strings::Printf(
    504           "index %s abs_diff %f rel_err %f",
    505           LiteralTestUtil::MultiIndexAsString(
    506               IndexUtil::LinearIndexToMultidimensionalIndex(shape,
    507                                                             linear_index))
    508               .c_str(),
    509           abs_diff, rel_err);
    510     }
    511     abs_diff_miscompare_sum_ += abs_diff;
    512     abs_expected_miscompare_sum_ += std::abs(expected);
    513     const int64 kMaxFailures = 2;
    514     if (num_miscompares_ < kMaxFailures) {
    515       const auto multi_index =
    516           IndexUtil::LinearIndexToMultidimensionalIndex(shape, linear_index);
    517       ::testing::Message msg;
    518       msg << "mismatch at index "
    519           << LiteralTestUtil::MultiIndexAsString(multi_index) << " abs diff "
    520           << abs_diff << " rel err " << rel_err << " failure #"
    521           << num_miscompares_;
    522       ExpectNear<NativeT>(expected, actual, msg);
    523     } else if (num_miscompares_ == kMaxFailures) {
    524       LOG(ERROR) << "reached max 'loud' failure count; silently proceeding...";
    525     }
    526     if (num_miscompares_ == 0) {
    527       first_linear_index_ = linear_index;
    528     }
    529     num_miscompares_++;
    530     last_linear_index_ = linear_index;
    531     miscompares_.data<bool>()[linear_index] = true;
    532   }
    533 
    534   // Recursive function which compares the two given literals elementwise.
    535   template <typename NativeT>
    536   void ExpectLiteralsNear(const Literal& expected, const Literal& actual,
    537                           int64 dimension) {
    538     // Fast path optimization for the case were layouts match.
    539     if (LayoutUtil::Equal(actual.shape().layout(), expected.shape().layout())) {
    540       tensorflow::gtl::ArraySlice<const NativeT> expected_data =
    541           expected.data<NativeT>();
    542       tensorflow::gtl::ArraySlice<const NativeT> actual_data =
    543           actual.data<NativeT>();
    544       const int64 len = expected_data.size();
    545       for (int64 i = 0; i < len; ++i) {
    546         const bool near = ExpectValuesNear(expected_data[i], actual_data[i]);
    547         if (!near) {
    548           UpdateAndLogMiscompares<NativeT>(expected_data[i], actual_data[i],
    549                                            actual.shape(), i);
    550         }
    551       }
    552       return;
    553     }
    554 
    555     if (dimension == expected.shape().dimensions_size()) {
    556       bool near = ExpectValuesNear(expected.Get<NativeT>(multi_index_),
    557                                    actual.Get<NativeT>(multi_index_));
    558       if (!near) {
    559         UpdateAndLogMiscompares<NativeT>(
    560             expected.Get<NativeT>(multi_index_),
    561             actual.Get<NativeT>(multi_index_), actual.shape(),
    562             IndexUtil::MultidimensionalIndexToLinearIndex(actual.shape(),
    563                                                           multi_index_));
    564       }
    565     } else {
    566       for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
    567         multi_index_[dimension] = i;
    568         ExpectLiteralsNear<NativeT>(expected, actual, dimension + 1);
    569       }
    570     }
    571   }
    572 
    573   // Writes the given literal to a file in the test temporary directory.
    574   void WriteLiteralToTempFile(const Literal& literal, const string& name) {
    575     int64 now_usec = tensorflow::Env::Default()->NowMicros();
    576     string filename = tensorflow::io::JoinPath(
    577         tensorflow::testing::TmpDir(),
    578         tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(),
    579                                     now_usec, name.c_str()));
    580     TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(),
    581                                              filename, literal.ToProto()));
    582     LOG(ERROR) << "wrote to " << name << " file: " << filename;
    583   }
    584 
    585   // Gets the total element count.  For tuples, this is not the count of tuple
    586   // elements, but the sum of elements of each tuple element.
    587   int64 RecursiveElementCount(const Shape& shape) {
    588     if (ShapeUtil::IsTuple(shape)) {
    589       const int64 tuple_elements = ShapeUtil::TupleElementCount(shape);
    590       int64 total = 0;
    591       for (int64 i = 0; i < tuple_elements; ++i) {
    592         total +=
    593             RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i));
    594       }
    595       return total;
    596     } else {
    597       return ShapeUtil::ElementsIn(shape);
    598     }
    599   }
    600 
    601   // Calling ToString on a literal with over 100 million elements takes around
    602   // 3 minutes.  The utility of printing a literal with >1000 elements is
    603   // questionable, especially when writing the Literal proto to disk is orders
    604   // of magnitude faster.
    605   string TruncateHugeLiteral(const Literal& literal) {
    606     return RecursiveElementCount(literal.shape()) < 1000
    607                ? literal.ToString()
    608                : "[TRUNCATED, Literal with more than 1000 values]";
    609   }
    610 
    611   ErrorSpec error_;
    612 
    613   // Number of element miscomparisons encountered so far.
    614   int64 num_miscompares_;
    615 
    616   // A Literal containing which elements did not match in the expected and
    617   // actual literals. miscompares_ contains PREDs and is of the same sizes as
    618   // the comparison literals.
    619   Literal miscompares_;
    620 
    621   // A multidimensional index used when performing the recursive comparison.
    622   std::vector<int64> multi_index_;
    623 
    624   // Aggregated Statistics on input.
    625   double abs_diff_sum_;
    626   double abs_expected_sum_;
    627   double abs_diff_miscompare_sum_;
    628   double abs_expected_miscompare_sum_;
    629   float max_rel_err_;
    630   float max_abs_err_;
    631   int64 first_linear_index_;
    632   int64 last_linear_index_;
    633   int64 max_rel_linear_index_;
    634   int64 max_abs_linear_index_;
    635 };
    636 
    637 template <>
    638 bool NearComparator::NanMismatch<complex64>(complex64 expected,
    639                                             complex64 actual,
    640                                             bool relaxed_nans) {
    641   return NanMismatch(expected.real(), actual.real(), relaxed_nans) ||
    642          NanMismatch(expected.imag(), actual.imag(), relaxed_nans);
    643 }
    644 
    645 template <>
    646 void NearComparator::ExpectNear<complex64>(complex64 expected, complex64 actual,
    647                                            const ::testing::Message& message) {
    648   EXPECT_NEAR(expected.real(), actual.real(), error_.abs)
    649       << "expected:\n  " << expected << "\n\tvs actual:\n  " << actual << "\n"
    650       << message;
    651   EXPECT_NEAR(expected.imag(), actual.imag(), error_.abs)
    652       << "expected:\n  " << expected << "\n\tvs actual:\n  " << actual << "\n"
    653       << message;
    654 }
    655 
    656 template <>
    657 bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
    658                                                 bfloat16 actual) {
    659   return ExpectValuesNear(static_cast<float>(expected),
    660                           static_cast<float>(actual));
    661 }
    662 
    663 template <>
    664 bool NearComparator::ExpectValuesNear<half>(half expected, half actual) {
    665   return ExpectValuesNear(static_cast<float>(std::move(expected)),
    666                           static_cast<float>(std::move(actual)));
    667 }
    668 
    669 template <>
    670 void NearComparator::UpdateAndLogMiscompares<bfloat16>(
    671     const bfloat16 expected, const bfloat16 actual, const Shape& shape,
    672     const int64 linear_index) {
    673   UpdateAndLogMiscompares(static_cast<float>(expected),
    674                           static_cast<float>(actual), shape, linear_index);
    675 }
    676 
    677 template <>
    678 void NearComparator::UpdateAndLogMiscompares<half>(half expected, half actual,
    679                                                    const Shape& shape,
    680                                                    const int64 linear_index) {
    681   UpdateAndLogMiscompares(static_cast<float>(std::move(expected)),
    682                           static_cast<float>(std::move(actual)), shape,
    683                           linear_index);
    684 }
    685 
    686 }  // namespace
    687 
    688 /* static */ ::testing::AssertionResult LiteralTestUtil::Near(
    689     const Literal& expected, const Literal& actual, const ErrorSpec& error) {
    690   ::testing::AssertionResult err =
    691       EqualShapes(expected.shape(), actual.shape());
    692   if (!err) {
    693     return err;
    694   }
    695 
    696   if (ShapeUtil::IsTuple(expected.shape())) {
    697     for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
    698       SCOPED_TRACE(tensorflow::strings::StrCat(
    699           "Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape())));
    700       const auto expected_element = LiteralView::Create(expected, {i});
    701       const auto actual_element = LiteralView::Create(actual, {i});
    702 
    703       ::testing::AssertionResult res =
    704           Near(expected_element, actual_element, error);
    705       if (err && !res) {
    706         err = res;
    707       }
    708     }
    709     return err;
    710   }
    711 
    712   if (ShapeUtil::ElementIsFloating(expected.shape()) ||
    713       ShapeUtil::ElementIsComplex(expected.shape())) {
    714     NearComparator comparator(error);
    715     return comparator.ExpectNear(expected, actual)
    716                ? ::testing::AssertionSuccess()
    717                : ::testing::AssertionFailure() << "values were not near";
    718   }
    719 
    720   return Equal(expected, actual);
    721 }
    722 
    723 /* static */ void LiteralTestUtil::ExpectNear(const Literal& expected,
    724                                               const Literal& actual,
    725                                               const ErrorSpec& error,
    726                                               const string& message) {
    727   EXPECT_TRUE(Near(expected, actual, error))
    728       << (message.empty()
    729               ? ""
    730               : tensorflow::strings::StrCat("\nmessage: ", message));
    731 }
    732 
    733 /*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
    734     const Literal& expected, const Literal& actual,
    735     const tensorflow::gtl::optional<ErrorSpec>& error) {
    736   if (error.has_value()) {
    737     VLOG(1) << "Expects near";
    738     return Near(expected, actual, *error);
    739   }
    740   VLOG(1) << "Expects equal";
    741   return Equal(expected, actual);
    742 }
    743 
    744 /*static*/ void LiteralTestUtil::ExpectNearOrEqual(
    745     const Literal& expected, const Literal& actual,
    746     const tensorflow::gtl::optional<ErrorSpec>& error) {
    747   EXPECT_TRUE(NearOrEqual(expected, actual, error));
    748 }
    749 
    750 /* static */ string LiteralTestUtil::MultiIndexAsString(
    751     tensorflow::gtl::ArraySlice<int64> multi_index) {
    752   return tensorflow::strings::StrCat(
    753       "{", tensorflow::str_util::Join(multi_index, ","), "}");
    754 }
    755 
    756 /* static */ std::unique_ptr<Literal> LiteralTestUtil::Reshape(
    757     tensorflow::gtl::ArraySlice<int64> new_dimensions,
    758     tensorflow::gtl::ArraySlice<int64> minor_to_major, const Literal& literal) {
    759   int64 new_num_elements = 1;
    760   for (int64 i = 0; i < new_dimensions.size(); ++i) {
    761     new_num_elements *= new_dimensions[i];
    762   }
    763   CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
    764   CHECK_EQ(new_dimensions.size(), minor_to_major.size());
    765 
    766   auto new_literal = MakeUnique<Literal>(
    767       ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
    768 
    769   // Create a new shape with the given minor-to-major layout. This shape is used
    770   // solely for converting linear address to multi-dimensional addresses when
    771   // writing elements to the new literal.
    772   Shape shape_with_layout = new_literal->shape();
    773   *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
    774 
    775   // Copy data into new literal, element-by-element.
    776   for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
    777     std::vector<int64> from_multi_index =
    778         IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
    779     std::vector<int64> to_multi_index =
    780         IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
    781     switch (literal.shape().element_type()) {
    782       case PRED:
    783         new_literal->Set<bool>(to_multi_index,
    784                                literal.Get<bool>(from_multi_index));
    785         break;
    786       case U8:
    787         new_literal->Set<uint8>(to_multi_index,
    788                                 literal.Get<uint8>(from_multi_index));
    789         break;
    790       case U32:
    791         new_literal->Set<uint32>(to_multi_index,
    792                                  literal.Get<uint32>(from_multi_index));
    793         break;
    794       case S32:
    795         new_literal->Set<int32>(to_multi_index,
    796                                 literal.Get<int32>(from_multi_index));
    797         break;
    798       case U64:
    799         new_literal->Set<uint64>(to_multi_index,
    800                                  literal.Get<uint64>(from_multi_index));
    801         break;
    802       case S64:
    803         new_literal->Set<int64>(to_multi_index,
    804                                 literal.Get<int64>(from_multi_index));
    805         break;
    806       case F32:
    807         new_literal->Set<float>(to_multi_index,
    808                                 literal.Get<float>(from_multi_index));
    809         break;
    810       case F64:
    811         new_literal->Set<double>(to_multi_index,
    812                                  literal.Get<double>(from_multi_index));
    813         break;
    814       case C64:
    815         new_literal->Set<complex64>(to_multi_index,
    816                                     literal.Get<complex64>(from_multi_index));
    817         break;
    818       default:
    819         LOG(FATAL) << "Unhandled primitive element type: "
    820                    << PrimitiveType_Name(literal.shape().element_type());
    821     }
    822   }
    823 
    824   return new_literal;
    825 }
    826 
    827 }  // namespace xla
    828