Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
     17 #define TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
     18 
     19 #include <initializer_list>
     20 #include <memory>
     21 #include <random>
     22 #include <string>
     23 
     24 #include "absl/types/optional.h"
     25 #include "absl/types/span.h"
     26 #include "tensorflow/compiler/xla/array2d.h"
     27 #include "tensorflow/compiler/xla/array3d.h"
     28 #include "tensorflow/compiler/xla/array4d.h"
     29 #include "tensorflow/compiler/xla/error_spec.h"
     30 #include "tensorflow/compiler/xla/literal.h"
     31 #include "tensorflow/compiler/xla/literal_util.h"
     32 #include "tensorflow/compiler/xla/test.h"
     33 #include "tensorflow/compiler/xla/test_helpers.h"
     34 #include "tensorflow/compiler/xla/types.h"
     35 #include "tensorflow/compiler/xla/xla_data.pb.h"
     36 #include "tensorflow/core/lib/core/errors.h"
     37 #include "tensorflow/core/platform/macros.h"
     38 #include "tensorflow/core/platform/test.h"
     39 #include "tensorflow/core/platform/types.h"
     40 
     41 namespace xla {
     42 
     43 // Utility class for making expectations/assertions related to XLA literals.
     44 class LiteralTestUtil {
     45  public:
     46   // Asserts that the given shapes have the same rank, dimension sizes, and
     47   // primitive types.
     48   static ::testing::AssertionResult EqualShapes(
     49       const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT;
     50 
     51   // Asserts that the provided shapes are equal as defined in AssertEqualShapes
     52   // and that they have the same layout.
     53   static ::testing::AssertionResult EqualShapesAndLayouts(
     54       const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT;
     55 
     56   static ::testing::AssertionResult Equal(const LiteralSlice& expected,
     57                                           const LiteralSlice& actual)
     58       TF_MUST_USE_RESULT;
     59 
     60   // Asserts the given literal are (bitwise) equal to given expected values.
     61   template <typename NativeT>
     62   static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual);
     63 
     64   template <typename NativeT>
     65   static void ExpectR1Equal(absl::Span<const NativeT> expected,
     66                             const LiteralSlice& actual);
     67   template <typename NativeT>
     68   static void ExpectR2Equal(
     69       std::initializer_list<std::initializer_list<NativeT>> expected,
     70       const LiteralSlice& actual);
     71 
     72   template <typename NativeT>
     73   static void ExpectR3Equal(
     74       std::initializer_list<
     75           std::initializer_list<std::initializer_list<NativeT>>>
     76           expected,
     77       const LiteralSlice& actual);
     78 
     79   // Asserts the given literal are (bitwise) equal to given array.
     80   template <typename NativeT>
     81   static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected,
     82                                    const LiteralSlice& actual);
     83   template <typename NativeT>
     84   static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected,
     85                                    const LiteralSlice& actual);
     86   template <typename NativeT>
     87   static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
     88                                    const LiteralSlice& actual);
     89 
     90   // Decorates literal_comparison::Near() with an AssertionResult return type.
     91   //
     92   // See comment on literal_comparison::Near().
     93   static ::testing::AssertionResult Near(
     94       const LiteralSlice& expected, const LiteralSlice& actual,
     95       const ErrorSpec& error_spec,
     96       absl::optional<bool> detailed_message = absl::nullopt) TF_MUST_USE_RESULT;
     97 
     98   // Asserts the given literal are within the given error bound of the given
     99   // expected values. Only supported for floating point values.
    100   template <typename NativeT>
    101   static void ExpectR0Near(NativeT expected, const LiteralSlice& actual,
    102                            const ErrorSpec& error);
    103 
    104   template <typename NativeT>
    105   static void ExpectR1Near(absl::Span<const NativeT> expected,
    106                            const LiteralSlice& actual, const ErrorSpec& error);
    107 
    108   template <typename NativeT>
    109   static void ExpectR2Near(
    110       std::initializer_list<std::initializer_list<NativeT>> expected,
    111       const LiteralSlice& actual, const ErrorSpec& error);
    112 
    113   template <typename NativeT>
    114   static void ExpectR3Near(
    115       std::initializer_list<
    116           std::initializer_list<std::initializer_list<NativeT>>>
    117           expected,
    118       const LiteralSlice& actual, const ErrorSpec& error);
    119 
    120   template <typename NativeT>
    121   static void ExpectR4Near(
    122       std::initializer_list<std::initializer_list<
    123           std::initializer_list<std::initializer_list<NativeT>>>>
    124           expected,
    125       const LiteralSlice& actual, const ErrorSpec& error);
    126 
    127   // Asserts the given literal are within the given error bound to the given
    128   // array. Only supported for floating point values.
    129   template <typename NativeT>
    130   static void ExpectR2NearArray2D(const Array2D<NativeT>& expected,
    131                                   const LiteralSlice& actual,
    132                                   const ErrorSpec& error);
    133 
    134   template <typename NativeT>
    135   static void ExpectR3NearArray3D(const Array3D<NativeT>& expected,
    136                                   const LiteralSlice& actual,
    137                                   const ErrorSpec& error);
    138 
    139   template <typename NativeT>
    140   static void ExpectR4NearArray4D(const Array4D<NativeT>& expected,
    141                                   const LiteralSlice& actual,
    142                                   const ErrorSpec& error);
    143 
    144   // If the error spec is given, returns whether the expected and the actual are
    145   // within the error bound; otherwise, returns whether they are equal. Tuples
    146   // will be compared recursively.
    147   static ::testing::AssertionResult NearOrEqual(
    148       const LiteralSlice& expected, const LiteralSlice& actual,
    149       const absl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT;
    150 
    151  private:
    152   TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil);
    153 };
    154 
    155 template <typename NativeT>
    156 /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
    157                                                  const LiteralSlice& actual) {
    158   EXPECT_TRUE(Equal(LiteralUtil::CreateR0<NativeT>(expected), actual));
    159 }
    160 
    161 template <typename NativeT>
    162 /* static */ void LiteralTestUtil::ExpectR1Equal(
    163     absl::Span<const NativeT> expected, const LiteralSlice& actual) {
    164   EXPECT_TRUE(Equal(LiteralUtil::CreateR1<NativeT>(expected), actual));
    165 }
    166 
    167 template <typename NativeT>
    168 /* static */ void LiteralTestUtil::ExpectR2Equal(
    169     std::initializer_list<std::initializer_list<NativeT>> expected,
    170     const LiteralSlice& actual) {
    171   EXPECT_TRUE(Equal(LiteralUtil::CreateR2<NativeT>(expected), actual));
    172 }
    173 
    174 template <typename NativeT>
    175 /* static */ void LiteralTestUtil::ExpectR3Equal(
    176     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
    177         expected,
    178     const LiteralSlice& actual) {
    179   EXPECT_TRUE(Equal(LiteralUtil::CreateR3<NativeT>(expected), actual));
    180 }
    181 
    182 template <typename NativeT>
    183 /* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
    184     const Array2D<NativeT>& expected, const LiteralSlice& actual) {
    185   EXPECT_TRUE(Equal(LiteralUtil::CreateR2FromArray2D(expected), actual));
    186 }
    187 
    188 template <typename NativeT>
    189 /* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
    190     const Array3D<NativeT>& expected, const LiteralSlice& actual) {
    191   EXPECT_TRUE(Equal(LiteralUtil::CreateR3FromArray3D(expected), actual));
    192 }
    193 
    194 template <typename NativeT>
    195 /* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
    196     const Array4D<NativeT>& expected, const LiteralSlice& actual) {
    197   EXPECT_TRUE(Equal(LiteralUtil::CreateR4FromArray4D(expected), actual));
    198 }
    199 
    200 template <typename NativeT>
    201 /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
    202                                                 const LiteralSlice& actual,
    203                                                 const ErrorSpec& error) {
    204   EXPECT_TRUE(Near(LiteralUtil::CreateR0<NativeT>(expected), actual, error));
    205 }
    206 
    207 template <typename NativeT>
    208 /* static */ void LiteralTestUtil::ExpectR1Near(
    209     absl::Span<const NativeT> expected, const LiteralSlice& actual,
    210     const ErrorSpec& error) {
    211   EXPECT_TRUE(Near(LiteralUtil::CreateR1<NativeT>(expected), actual, error));
    212 }
    213 
    214 template <typename NativeT>
    215 /* static */ void LiteralTestUtil::ExpectR2Near(
    216     std::initializer_list<std::initializer_list<NativeT>> expected,
    217     const LiteralSlice& actual, const ErrorSpec& error) {
    218   EXPECT_TRUE(Near(LiteralUtil::CreateR2<NativeT>(expected), actual, error));
    219 }
    220 
    221 template <typename NativeT>
    222 /* static */ void LiteralTestUtil::ExpectR3Near(
    223     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
    224         expected,
    225     const LiteralSlice& actual, const ErrorSpec& error) {
    226   EXPECT_TRUE(Near(LiteralUtil::CreateR3<NativeT>(expected), actual, error));
    227 }
    228 
    229 template <typename NativeT>
    230 /* static */ void LiteralTestUtil::ExpectR4Near(
    231     std::initializer_list<std::initializer_list<
    232         std::initializer_list<std::initializer_list<NativeT>>>>
    233         expected,
    234     const LiteralSlice& actual, const ErrorSpec& error) {
    235   EXPECT_TRUE(Near(LiteralUtil::CreateR4<NativeT>(expected), actual, error));
    236 }
    237 
    238 template <typename NativeT>
    239 /* static */ void LiteralTestUtil::ExpectR2NearArray2D(
    240     const Array2D<NativeT>& expected, const LiteralSlice& actual,
    241     const ErrorSpec& error) {
    242   EXPECT_TRUE(Near(LiteralUtil::CreateR2FromArray2D(expected), actual, error));
    243 }
    244 
    245 template <typename NativeT>
    246 /* static */ void LiteralTestUtil::ExpectR3NearArray3D(
    247     const Array3D<NativeT>& expected, const LiteralSlice& actual,
    248     const ErrorSpec& error) {
    249   EXPECT_TRUE(Near(LiteralUtil::CreateR3FromArray3D(expected), actual, error));
    250 }
    251 
    252 template <typename NativeT>
    253 /* static */ void LiteralTestUtil::ExpectR4NearArray4D(
    254     const Array4D<NativeT>& expected, const LiteralSlice& actual,
    255     const ErrorSpec& error) {
    256   EXPECT_TRUE(Near(LiteralUtil::CreateR4FromArray4D(expected), actual, error));
    257 }
    258 
    259 }  // namespace xla
    260 
    261 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
    262