1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_ 17 #define TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_ 18 19 #include <initializer_list> 20 #include <memory> 21 #include <random> 22 #include <string> 23 24 #include "absl/types/optional.h" 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/array2d.h" 27 #include "tensorflow/compiler/xla/array3d.h" 28 #include "tensorflow/compiler/xla/array4d.h" 29 #include "tensorflow/compiler/xla/error_spec.h" 30 #include "tensorflow/compiler/xla/literal.h" 31 #include "tensorflow/compiler/xla/literal_util.h" 32 #include "tensorflow/compiler/xla/test.h" 33 #include "tensorflow/compiler/xla/test_helpers.h" 34 #include "tensorflow/compiler/xla/types.h" 35 #include "tensorflow/compiler/xla/xla_data.pb.h" 36 #include "tensorflow/core/lib/core/errors.h" 37 #include "tensorflow/core/platform/macros.h" 38 #include "tensorflow/core/platform/test.h" 39 #include "tensorflow/core/platform/types.h" 40 41 namespace xla { 42 43 // Utility class for making expectations/assertions related to XLA literals. 44 class LiteralTestUtil { 45 public: 46 // Asserts that the given shapes have the same rank, dimension sizes, and 47 // primitive types. 48 static ::testing::AssertionResult EqualShapes( 49 const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT; 50 51 // Asserts that the provided shapes are equal as defined in AssertEqualShapes 52 // and that they have the same layout. 53 static ::testing::AssertionResult EqualShapesAndLayouts( 54 const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT; 55 56 static ::testing::AssertionResult Equal(const LiteralSlice& expected, 57 const LiteralSlice& actual) 58 TF_MUST_USE_RESULT; 59 60 // Asserts the given literal are (bitwise) equal to given expected values. 61 template <typename NativeT> 62 static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual); 63 64 template <typename NativeT> 65 static void ExpectR1Equal(absl::Span<const NativeT> expected, 66 const LiteralSlice& actual); 67 template <typename NativeT> 68 static void ExpectR2Equal( 69 std::initializer_list<std::initializer_list<NativeT>> expected, 70 const LiteralSlice& actual); 71 72 template <typename NativeT> 73 static void ExpectR3Equal( 74 std::initializer_list< 75 std::initializer_list<std::initializer_list<NativeT>>> 76 expected, 77 const LiteralSlice& actual); 78 79 // Asserts the given literal are (bitwise) equal to given array. 80 template <typename NativeT> 81 static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected, 82 const LiteralSlice& actual); 83 template <typename NativeT> 84 static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected, 85 const LiteralSlice& actual); 86 template <typename NativeT> 87 static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected, 88 const LiteralSlice& actual); 89 90 // Decorates literal_comparison::Near() with an AssertionResult return type. 91 // 92 // See comment on literal_comparison::Near(). 93 static ::testing::AssertionResult Near( 94 const LiteralSlice& expected, const LiteralSlice& actual, 95 const ErrorSpec& error_spec, 96 absl::optional<bool> detailed_message = absl::nullopt) TF_MUST_USE_RESULT; 97 98 // Asserts the given literal are within the given error bound of the given 99 // expected values. Only supported for floating point values. 100 template <typename NativeT> 101 static void ExpectR0Near(NativeT expected, const LiteralSlice& actual, 102 const ErrorSpec& error); 103 104 template <typename NativeT> 105 static void ExpectR1Near(absl::Span<const NativeT> expected, 106 const LiteralSlice& actual, const ErrorSpec& error); 107 108 template <typename NativeT> 109 static void ExpectR2Near( 110 std::initializer_list<std::initializer_list<NativeT>> expected, 111 const LiteralSlice& actual, const ErrorSpec& error); 112 113 template <typename NativeT> 114 static void ExpectR3Near( 115 std::initializer_list< 116 std::initializer_list<std::initializer_list<NativeT>>> 117 expected, 118 const LiteralSlice& actual, const ErrorSpec& error); 119 120 template <typename NativeT> 121 static void ExpectR4Near( 122 std::initializer_list<std::initializer_list< 123 std::initializer_list<std::initializer_list<NativeT>>>> 124 expected, 125 const LiteralSlice& actual, const ErrorSpec& error); 126 127 // Asserts the given literal are within the given error bound to the given 128 // array. Only supported for floating point values. 129 template <typename NativeT> 130 static void ExpectR2NearArray2D(const Array2D<NativeT>& expected, 131 const LiteralSlice& actual, 132 const ErrorSpec& error); 133 134 template <typename NativeT> 135 static void ExpectR3NearArray3D(const Array3D<NativeT>& expected, 136 const LiteralSlice& actual, 137 const ErrorSpec& error); 138 139 template <typename NativeT> 140 static void ExpectR4NearArray4D(const Array4D<NativeT>& expected, 141 const LiteralSlice& actual, 142 const ErrorSpec& error); 143 144 // If the error spec is given, returns whether the expected and the actual are 145 // within the error bound; otherwise, returns whether they are equal. Tuples 146 // will be compared recursively. 147 static ::testing::AssertionResult NearOrEqual( 148 const LiteralSlice& expected, const LiteralSlice& actual, 149 const absl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT; 150 151 private: 152 TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil); 153 }; 154 155 template <typename NativeT> 156 /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, 157 const LiteralSlice& actual) { 158 EXPECT_TRUE(Equal(LiteralUtil::CreateR0<NativeT>(expected), actual)); 159 } 160 161 template <typename NativeT> 162 /* static */ void LiteralTestUtil::ExpectR1Equal( 163 absl::Span<const NativeT> expected, const LiteralSlice& actual) { 164 EXPECT_TRUE(Equal(LiteralUtil::CreateR1<NativeT>(expected), actual)); 165 } 166 167 template <typename NativeT> 168 /* static */ void LiteralTestUtil::ExpectR2Equal( 169 std::initializer_list<std::initializer_list<NativeT>> expected, 170 const LiteralSlice& actual) { 171 EXPECT_TRUE(Equal(LiteralUtil::CreateR2<NativeT>(expected), actual)); 172 } 173 174 template <typename NativeT> 175 /* static */ void LiteralTestUtil::ExpectR3Equal( 176 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> 177 expected, 178 const LiteralSlice& actual) { 179 EXPECT_TRUE(Equal(LiteralUtil::CreateR3<NativeT>(expected), actual)); 180 } 181 182 template <typename NativeT> 183 /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( 184 const Array2D<NativeT>& expected, const LiteralSlice& actual) { 185 EXPECT_TRUE(Equal(LiteralUtil::CreateR2FromArray2D(expected), actual)); 186 } 187 188 template <typename NativeT> 189 /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( 190 const Array3D<NativeT>& expected, const LiteralSlice& actual) { 191 EXPECT_TRUE(Equal(LiteralUtil::CreateR3FromArray3D(expected), actual)); 192 } 193 194 template <typename NativeT> 195 /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( 196 const Array4D<NativeT>& expected, const LiteralSlice& actual) { 197 EXPECT_TRUE(Equal(LiteralUtil::CreateR4FromArray4D(expected), actual)); 198 } 199 200 template <typename NativeT> 201 /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, 202 const LiteralSlice& actual, 203 const ErrorSpec& error) { 204 EXPECT_TRUE(Near(LiteralUtil::CreateR0<NativeT>(expected), actual, error)); 205 } 206 207 template <typename NativeT> 208 /* static */ void LiteralTestUtil::ExpectR1Near( 209 absl::Span<const NativeT> expected, const LiteralSlice& actual, 210 const ErrorSpec& error) { 211 EXPECT_TRUE(Near(LiteralUtil::CreateR1<NativeT>(expected), actual, error)); 212 } 213 214 template <typename NativeT> 215 /* static */ void LiteralTestUtil::ExpectR2Near( 216 std::initializer_list<std::initializer_list<NativeT>> expected, 217 const LiteralSlice& actual, const ErrorSpec& error) { 218 EXPECT_TRUE(Near(LiteralUtil::CreateR2<NativeT>(expected), actual, error)); 219 } 220 221 template <typename NativeT> 222 /* static */ void LiteralTestUtil::ExpectR3Near( 223 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> 224 expected, 225 const LiteralSlice& actual, const ErrorSpec& error) { 226 EXPECT_TRUE(Near(LiteralUtil::CreateR3<NativeT>(expected), actual, error)); 227 } 228 229 template <typename NativeT> 230 /* static */ void LiteralTestUtil::ExpectR4Near( 231 std::initializer_list<std::initializer_list< 232 std::initializer_list<std::initializer_list<NativeT>>>> 233 expected, 234 const LiteralSlice& actual, const ErrorSpec& error) { 235 EXPECT_TRUE(Near(LiteralUtil::CreateR4<NativeT>(expected), actual, error)); 236 } 237 238 template <typename NativeT> 239 /* static */ void LiteralTestUtil::ExpectR2NearArray2D( 240 const Array2D<NativeT>& expected, const LiteralSlice& actual, 241 const ErrorSpec& error) { 242 EXPECT_TRUE(Near(LiteralUtil::CreateR2FromArray2D(expected), actual, error)); 243 } 244 245 template <typename NativeT> 246 /* static */ void LiteralTestUtil::ExpectR3NearArray3D( 247 const Array3D<NativeT>& expected, const LiteralSlice& actual, 248 const ErrorSpec& error) { 249 EXPECT_TRUE(Near(LiteralUtil::CreateR3FromArray3D(expected), actual, error)); 250 } 251 252 template <typename NativeT> 253 /* static */ void LiteralTestUtil::ExpectR4NearArray4D( 254 const Array4D<NativeT>& expected, const LiteralSlice& actual, 255 const ErrorSpec& error) { 256 EXPECT_TRUE(Near(LiteralUtil::CreateR4FromArray4D(expected), actual, error)); 257 } 258 259 } // namespace xla 260 261 #endif // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_ 262