1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_ 17 #define TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_ 18 19 #include <initializer_list> 20 #include <memory> 21 #include <random> 22 #include <string> 23 24 #include "tensorflow/compiler/xla/array2d.h" 25 #include "tensorflow/compiler/xla/array3d.h" 26 #include "tensorflow/compiler/xla/array4d.h" 27 #include "tensorflow/compiler/xla/literal_util.h" 28 #include "tensorflow/compiler/xla/test.h" 29 #include "tensorflow/compiler/xla/test_helpers.h" 30 #include "tensorflow/compiler/xla/types.h" 31 #include "tensorflow/compiler/xla/xla_data.pb.h" 32 #include "tensorflow/core/lib/core/errors.h" 33 #include "tensorflow/core/lib/gtl/array_slice.h" 34 #include "tensorflow/core/lib/gtl/optional.h" 35 #include "tensorflow/core/platform/macros.h" 36 #include "tensorflow/core/platform/test.h" 37 #include "tensorflow/core/platform/types.h" 38 39 namespace xla { 40 41 // Structure describing permissible absolute and relative error bounds. 42 struct ErrorSpec { 43 explicit ErrorSpec(float aabs, float arel = 0, bool relaxed_nans = false) 44 : abs(aabs), rel(arel), relaxed_nans(relaxed_nans) {} 45 46 float abs; // Absolute error bound. 47 float rel; // Relative error bound. 48 49 // If relaxed_nans is true then any result is valid if we are expecting NaNs. 50 // In effect, this allows the tested operation to produce incorrect results 51 // for inputs outside its mathematical domain. 52 bool relaxed_nans; 53 }; 54 55 // Utility class for making expectations/assertions related to XLA literals. 56 class LiteralTestUtil { 57 public: 58 // Asserts that the given shapes have the same rank, dimension sizes, and 59 // primitive types. 60 static ::testing::AssertionResult EqualShapes(const Shape& expected, 61 const Shape& actual); 62 static void AssertEqualShapes(const Shape& expected, const Shape& actual); 63 64 // Asserts that the provided shapes are equal as defined in AssertEqualShapes 65 // and that they have the same layout. 66 static void AssertEqualShapesAndLayouts(const Shape& expected, 67 const Shape& actual); 68 69 // If the given literal's data type is bfloat16, converts it to a float 70 // literal; otherwise, returns a copy of it. If the literal is a tuple, 71 // recursively converts its elements. 72 static std::unique_ptr<Literal> ConvertBF16ToF32(const Literal& bf16_literal); 73 74 // If the given literal's data type is float, converts it to a bfloat16 75 // literal; otherwise, returns a copy of it. If the literal is a tuple, 76 // recursively converts its elements. 77 static std::unique_ptr<Literal> ConvertF32ToBF16(const Literal& f32_literal); 78 79 // Asserts that the expected and actual literals are (bitwise) equal for all 80 // elements in the literal. Also, asserts that the rank, dimensions sizes, and 81 // primitive type are equal. 82 static ::testing::AssertionResult Equal( 83 const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT; 84 85 // Expects that expected and actual are Equal. 86 static void ExpectEqual(const Literal& expected, const Literal& actual, 87 const string& message = ""); 88 89 // Expects that expected and actual are Not Equal. 90 static void ExpectNotEqual(const Literal& expected, const Literal& actual); 91 92 // Asserts the given literal are (bitwise) equal to given expected values. 93 template <typename NativeT> 94 static void ExpectR0Equal(NativeT expected, const Literal& actual); 95 template <typename NativeT> 96 static void ExpectR1Equal(tensorflow::gtl::ArraySlice<NativeT> expected, 97 const Literal& actual); 98 template <typename NativeT> 99 static void ExpectR2Equal( 100 std::initializer_list<std::initializer_list<NativeT>> expected, 101 const Literal& actual); 102 template <typename NativeT> 103 static void ExpectR3Equal( 104 std::initializer_list< 105 std::initializer_list<std::initializer_list<NativeT>>> 106 expected, 107 const Literal& actual); 108 109 // Asserts the given literal are (bitwise) equal to given array. 110 template <typename NativeT> 111 static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected, 112 const Literal& actual); 113 template <typename NativeT> 114 static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected, 115 const Literal& actual); 116 template <typename NativeT> 117 static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected, 118 const Literal& actual); 119 120 // Asserts that the expected and actual literals are within the given error 121 // bound for all elements. Also, asserts that the rank, dimensions sizes, and 122 // bounds are equivalent. 123 // 124 // Tuples are matched recursively. When comparing tensors of 125 // non-floating-point type, checks for exact equality, ignoring the ErroSpec. 126 // 127 // If the shape of the literals is neither a complex/floating-point tensor nor 128 // a tuple which contains a complex/floating-point tensor, Near() is 129 // equivalent to Equal(). We don't raise an error in this case, because we 130 // want to allow callers to call Near() even if they have no preconceptions 131 // about the shapes being compared. 132 static ::testing::AssertionResult Near( 133 const Literal& expected, const Literal& actual, 134 const ErrorSpec& error) TF_MUST_USE_RESULT; 135 136 // Expects expected and actual to be Near with the given error. 137 static void ExpectNear(const Literal& expected, const Literal& actual, 138 const ErrorSpec& error, const string& message = ""); 139 140 // Asserts the given literal are within the given error bound of the given 141 // expected values. Only supported for floating point values. 142 template <typename NativeT> 143 static void ExpectR0Near(NativeT expected, const Literal& actual, 144 const ErrorSpec& error); 145 template <typename NativeT> 146 static void ExpectR1Near(tensorflow::gtl::ArraySlice<NativeT> expected, 147 const Literal& actual, const ErrorSpec& error); 148 template <typename NativeT> 149 static void ExpectR2Near( 150 std::initializer_list<std::initializer_list<NativeT>> expected, 151 const Literal& actual, const ErrorSpec& error); 152 template <typename NativeT> 153 static void ExpectR3Near( 154 std::initializer_list< 155 std::initializer_list<std::initializer_list<NativeT>>> 156 expected, 157 const Literal& actual, const ErrorSpec& error); 158 template <typename NativeT> 159 static void ExpectR4Near( 160 std::initializer_list<std::initializer_list< 161 std::initializer_list<std::initializer_list<NativeT>>>> 162 expected, 163 const Literal& actual, const ErrorSpec& error); 164 165 // Asserts the given literal are within the given error bound to the given 166 // array. Only supported for floating point values. 167 template <typename NativeT> 168 static void ExpectR2NearArray2D(const Array2D<NativeT>& expected, 169 const Literal& actual, 170 const ErrorSpec& error); 171 template <typename NativeT> 172 static void ExpectR3NearArray3D(const Array3D<NativeT>& expected, 173 const Literal& actual, 174 const ErrorSpec& error); 175 template <typename NativeT> 176 static void ExpectR4NearArray4D(const Array4D<NativeT>& expected, 177 const Literal& actual, 178 const ErrorSpec& error); 179 180 // If the error spec is given, returns whether the expected and the actual are 181 // within the error bound; otherwise, returns whether they are equal. Tuples 182 // will be compared recursively. 183 static ::testing::AssertionResult NearOrEqual( 184 const Literal& expected, const Literal& actual, 185 const tensorflow::gtl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT; 186 187 // If the error spec is given, expects the expected and the actual to be near; 188 // otherwise, expects them to be equal. Tuples will be compared recursively. 189 static void ExpectNearOrEqual( 190 const Literal& expected, const Literal& actual, 191 const tensorflow::gtl::optional<ErrorSpec>& error); 192 193 // Returns a multi-dimensional index as a string. For example: '{7, 8}' will 194 // be returned for a 2-dimensional index with dimension 0 index equal to 7, 195 // dimension 1 equal to 8. 196 static string MultiIndexAsString( 197 tensorflow::gtl::ArraySlice<int64> multi_index); 198 199 // Creates a literal with a new shape with the given new dimensions using the 200 // data in the given input literal. For reshaping purposes the (flat) data 201 // buffer of the input literal is assumed to have the given minor_to_major 202 // layout order. 203 static std::unique_ptr<Literal> Reshape( 204 tensorflow::gtl::ArraySlice<int64> new_dimensions, 205 tensorflow::gtl::ArraySlice<int64> minor_to_major, 206 const Literal& literal); 207 208 // Creates a literal with the supplied shape, and uses the provided value 209 // generator to populate the literal's values. 210 // Returns the new literal object, or an error Status if failed. 211 template < 212 PrimitiveType type, 213 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type> 214 static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral( 215 const Shape& shape, 216 const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator); 217 218 // Creates a literal with the supplied shape, and initializes the literal 219 // values using a normal distribution with given mean and stddev standard 220 // deviation, and using the engine as entropy generator. 221 // Returns the new literal object, or an error Status if failed. 222 template < 223 PrimitiveType type, typename E, 224 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type> 225 static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral( 226 const Shape& shape, E* engine, T mean, T stddev); 227 228 // Creates a literal with the supplied shape, and initializes the literal 229 // values using a normal distribution with given mean and stddev standard 230 // deviation. 231 // Returns the new literal object, or an error Status if failed. 232 template < 233 PrimitiveType type, 234 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type> 235 static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral( 236 const Shape& shape, T mean, T stddev); 237 238 private: 239 TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil); 240 }; 241 242 template <typename NativeT> 243 /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, 244 const Literal& actual) { 245 ExpectEqual(*Literal::CreateR0<NativeT>(expected), actual); 246 } 247 248 template <typename NativeT> 249 /* static */ void LiteralTestUtil::ExpectR1Equal( 250 tensorflow::gtl::ArraySlice<NativeT> expected, const Literal& actual) { 251 ExpectEqual(*Literal::CreateR1<NativeT>(expected), actual); 252 } 253 254 template <typename NativeT> 255 /* static */ void LiteralTestUtil::ExpectR2Equal( 256 std::initializer_list<std::initializer_list<NativeT>> expected, 257 const Literal& actual) { 258 ExpectEqual(*Literal::CreateR2<NativeT>(expected), actual); 259 } 260 261 template <typename NativeT> 262 /* static */ void LiteralTestUtil::ExpectR3Equal( 263 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> 264 expected, 265 const Literal& actual) { 266 ExpectEqual(*Literal::CreateR3<NativeT>(expected), actual); 267 } 268 269 template <typename NativeT> 270 /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( 271 const Array2D<NativeT>& expected, const Literal& actual) { 272 ExpectEqual(*Literal::CreateR2FromArray2D(expected), actual); 273 } 274 275 template <typename NativeT> 276 /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( 277 const Array3D<NativeT>& expected, const Literal& actual) { 278 ExpectEqual(*Literal::CreateR3FromArray3D(expected), actual); 279 } 280 281 template <typename NativeT> 282 /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( 283 const Array4D<NativeT>& expected, const Literal& actual) { 284 ExpectEqual(*Literal::CreateR4FromArray4D(expected), actual); 285 } 286 287 template <typename NativeT> 288 /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, 289 const Literal& actual, 290 const ErrorSpec& error) { 291 ExpectNear(*Literal::CreateR0<NativeT>(expected), actual, error); 292 } 293 294 template <typename NativeT> 295 /* static */ void LiteralTestUtil::ExpectR1Near( 296 tensorflow::gtl::ArraySlice<NativeT> expected, const Literal& actual, 297 const ErrorSpec& error) { 298 ExpectNear(*Literal::CreateR1<NativeT>(expected), actual, error); 299 } 300 301 template <typename NativeT> 302 /* static */ void LiteralTestUtil::ExpectR2Near( 303 std::initializer_list<std::initializer_list<NativeT>> expected, 304 const Literal& actual, const ErrorSpec& error) { 305 ExpectNear(*Literal::CreateR2<NativeT>(expected), actual, error); 306 } 307 308 template <typename NativeT> 309 /* static */ void LiteralTestUtil::ExpectR3Near( 310 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> 311 expected, 312 const Literal& actual, const ErrorSpec& error) { 313 ExpectNear(*Literal::CreateR3<NativeT>(expected), actual, error); 314 } 315 316 template <typename NativeT> 317 /* static */ void LiteralTestUtil::ExpectR4Near( 318 std::initializer_list<std::initializer_list< 319 std::initializer_list<std::initializer_list<NativeT>>>> 320 expected, 321 const Literal& actual, const ErrorSpec& error) { 322 ExpectNear(*Literal::CreateR4<NativeT>(expected), actual, error); 323 } 324 325 template <typename NativeT> 326 /* static */ void LiteralTestUtil::ExpectR2NearArray2D( 327 const Array2D<NativeT>& expected, const Literal& actual, 328 const ErrorSpec& error) { 329 ExpectNear(*Literal::CreateR2FromArray2D(expected), actual, error); 330 } 331 332 template <typename NativeT> 333 /* static */ void LiteralTestUtil::ExpectR3NearArray3D( 334 const Array3D<NativeT>& expected, const Literal& actual, 335 const ErrorSpec& error) { 336 ExpectNear(*Literal::CreateR3FromArray3D(expected), actual, error); 337 } 338 339 template <typename NativeT> 340 /* static */ void LiteralTestUtil::ExpectR4NearArray4D( 341 const Array4D<NativeT>& expected, const Literal& actual, 342 const ErrorSpec& error) { 343 ExpectNear(*Literal::CreateR4FromArray4D(expected), actual, error); 344 } 345 346 template <PrimitiveType type, typename T> 347 /* static */ StatusOr<std::unique_ptr<Literal>> 348 LiteralTestUtil::CreateRandomLiteral( 349 const Shape& shape, 350 const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) { 351 using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type; 352 TF_RET_CHECK(shape.element_type() == type); 353 std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape); 354 TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>( 355 [&](tensorflow::gtl::ArraySlice<int64> indexes) { 356 return generator(indexes); 357 })); 358 return std::move(literal); 359 } 360 361 template <PrimitiveType type, typename E, typename T> 362 /* static */ StatusOr<std::unique_ptr<Literal>> 363 LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, 364 T stddev) { 365 using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type; 366 std::normal_distribution<NativeT> generator(mean, stddev); 367 return CreateRandomLiteral<type, NativeT>( 368 shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) { 369 return generator(*engine); 370 }); 371 } 372 373 template <PrimitiveType type, typename T> 374 /* static */ StatusOr<std::unique_ptr<Literal>> 375 LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { 376 std::minstd_rand0 engine; 377 return CreateRandomLiteral<type>(shape, &engine, mean, stddev); 378 } 379 380 } // namespace xla 381 382 #endif // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_ 383