1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Utilities for dealing with Literal protobufs. 17 18 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ 19 #define TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ 20 21 #include <functional> 22 #include <initializer_list> 23 #include <iterator> 24 #include <memory> 25 #include <ostream> 26 #include <string> 27 #include <type_traits> 28 #include <vector> 29 30 #include "absl/memory/memory.h" 31 #include "absl/strings/string_view.h" 32 #include "absl/types/span.h" 33 #include "tensorflow/compiler/xla/array2d.h" 34 #include "tensorflow/compiler/xla/array3d.h" 35 #include "tensorflow/compiler/xla/array4d.h" 36 #include "tensorflow/compiler/xla/index_util.h" 37 #include "tensorflow/compiler/xla/layout_util.h" 38 #include "tensorflow/compiler/xla/literal.h" 39 #include "tensorflow/compiler/xla/primitive_util.h" 40 #include "tensorflow/compiler/xla/shape_util.h" 41 #include "tensorflow/compiler/xla/sparse_index_array.h" 42 #include "tensorflow/compiler/xla/status_macros.h" 43 #include "tensorflow/compiler/xla/types.h" 44 #include "tensorflow/compiler/xla/util.h" 45 #include "tensorflow/compiler/xla/xla_data.pb.h" 46 #include "tensorflow/core/lib/core/bitmap.h" 47 #include "tensorflow/core/lib/core/status.h" 48 #include "tensorflow/core/platform/logging.h" 49 #include "tensorflow/core/platform/macros.h" 50 #include "tensorflow/core/platform/protobuf.h" 51 #include "tensorflow/core/platform/types.h" 52 53 namespace xla { 54 55 class LiteralUtil { 56 public: 57 LiteralUtil() = delete; 58 59 // Returns a literal scalar representing the first element. 60 static Literal GetFirstScalarLiteral(const LiteralSlice& literal); 61 62 // Creates a new literal of a given rank. To minimize ambiguity (for users 63 // and the compiler) these CreateR[0-2] methods should explicitly specify the 64 // native type. For example: 65 // 66 // CreateR1<float>({1.0, 42.0}); 67 // CreateR2<uint32>({{1, 2}, {3, 4}}); 68 // 69 // The variants not ending with WithLayout use the default XLA layout for the 70 // literal's linear representation in memory. 71 template <typename NativeT> 72 static Literal CreateR0(NativeT value); 73 template <typename NativeT> 74 static Literal CreateR1(absl::Span<const NativeT> values); 75 static Literal CreateR1(const tensorflow::core::Bitmap& values); 76 template <typename NativeT> 77 static Literal CreateR2( 78 std::initializer_list<std::initializer_list<NativeT>> values); 79 template <typename NativeT> 80 static Literal CreateR2WithLayout( 81 std::initializer_list<std::initializer_list<NativeT>> values, 82 const Layout& layout); 83 template <typename NativeT> 84 static Literal CreateR3(std::initializer_list< 85 std::initializer_list<std::initializer_list<NativeT>>> 86 values); 87 template <typename NativeT> 88 static Literal CreateR3WithLayout( 89 std::initializer_list< 90 std::initializer_list<std::initializer_list<NativeT>>> 91 values, 92 const Layout& layout); 93 template <typename NativeT> 94 static Literal CreateR4( 95 std::initializer_list<std::initializer_list< 96 std::initializer_list<std::initializer_list<NativeT>>>> 97 values); 98 template <typename NativeT> 99 static Literal CreateR4WithLayout( 100 std::initializer_list<std::initializer_list< 101 std::initializer_list<std::initializer_list<NativeT>>>> 102 values, 103 const Layout& layout); 104 105 // Creates a literal with a sparse layout and the given indices and values. 106 // The shape is initialized from the given dimensions. The minor dimension of 107 // the indices array must equal the rank of the shape (i.e. size of the 108 // dimensions array). The major dimension of the indices array must equal the 109 // number of elements in the values array. The maximum number of elements in 110 // the array is taken from the max_indices() value of the index array. 111 // 112 // XLA assumes that sparse literals are in sorted order for all operations. If 113 // the `sort` argument is true, then the indices and values will be sorted 114 // while copying them into the literal. If you have ensured that the indices 115 // and values are already sorted, then you may set the `sort` argument to 116 // false to skip the sorting step. 117 // 118 // For example: 119 // 120 // CreateSparse( 121 // {12, 12, 12}, 122 // SparseIndexArray(10, 3, 123 // Array2D{ 124 // {0, 1, 2}, 125 // {3, 4, 5}, 126 // {6, 7, 8}, 127 // {9, 10, 11}, 128 // }), 129 // {1.0, 2.0 3.0, 4.0}) 130 // 131 // This creates an array with shape F64[12,12,12]sparse{10}, that has the 132 // following non-zero values: 133 // 134 // [0, 1, 2]: 1.0 135 // [3, 4, 5]: 2.0 136 // [6, 7, 8]: 3.0 137 // [9, 10, 11]: 4.0 138 // 139 template <typename NativeT> 140 static Literal CreateSparse(absl::Span<const int64> dimensions, 141 SparseIndexArray indices, 142 absl::Span<const NativeT> values, 143 bool sort = true); 144 145 // Creates a scalar literal value zero of the given primitive type. 146 static Literal Zero(PrimitiveType primitive_type); 147 // Creates a scalar literal value one of the given primitive type. 148 static Literal One(PrimitiveType primitive_type); 149 // Creates a scalar literal value containing the minimum value of the given 150 // primitive type. For floating-point types, returns -inf. 151 static Literal MinValue(PrimitiveType primitive_type); 152 // Creates a scalar literal value containing the maximum value of the given 153 // primitive type. For floating-point types, returns inf. 154 static Literal MaxValue(PrimitiveType primitive_type); 155 // Creates a literal of the given shape where each element is `value`. 156 template <typename NativeT> 157 static Literal CreateFullWithDescendingLayout( 158 absl::Span<const int64> dimensions, NativeT value); 159 160 // Creates a new literal from an Array type. The variants not ending with 161 // WithLayout use the default XLA layout for the literal's linear 162 // representation in memory. 163 template <typename NativeT> 164 static Literal CreateFromArray(const Array<NativeT>& values); 165 template <typename NativeT> 166 static Literal CreateFromArrayWithLayout(const Array<NativeT>& values, 167 const Layout& layout); 168 template <typename NativeT> 169 static Literal CreateR2FromArray2D(const Array2D<NativeT>& values); 170 template <typename NativeT> 171 static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values, 172 const Layout& layout); 173 template <typename NativeT> 174 static Literal CreateR3FromArray3D(const Array3D<NativeT>& values); 175 template <typename NativeT> 176 static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values, 177 const Layout& layout); 178 template <typename NativeT> 179 static Literal CreateR4FromArray4D(const Array4D<NativeT>& values); 180 template <typename NativeT> 181 static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values, 182 const Layout& layout); 183 184 // Creates a new vector of U8s literal value from a string. 185 static Literal CreateR1U8(absl::string_view value); 186 187 // Creates a linspace-populated literal with the given number of rows and 188 // columns. 189 static Literal CreateR2F32Linspace(float from, float to, int64 rows, 190 int64 cols); 191 192 // Creates a literal that projects the (x, y) dimensions given in values into 193 // the z dimension given by "projection". 194 template <typename NativeT> 195 static Literal CreateR3Projected( 196 std::initializer_list<std::initializer_list<NativeT>> values, 197 int64 projection); 198 199 // Creates a literal that projects the (x, y) dimensions given in values into 200 // the z and p dimensions given. 201 template <typename NativeT> 202 static Literal CreateR4Projected( 203 std::initializer_list<std::initializer_list<NativeT>> values, 204 int64 projection_p, int64 projection_z); 205 206 // Returns an identity matrix (rank 2) with the given row and column count. 207 template <typename NativeT> 208 static Literal MakeIdentityR2(int64 size); 209 210 // Returns a tuple literal composed of given literals. Data is copied from the 211 // given elements into the returned literal. 212 static Literal MakeTuple(absl::Span<const Literal* const> elements); 213 214 static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements); 215 216 // As above, but intended to be invoked with move semantics; i.e. 217 // 218 // std::vector<Literal> elements = ...; 219 // auto result = LiteralUtil::MakeTupleOwned(std::move(elements)); 220 // 221 // This would have been declared as an overload, but there is ambiguity 222 // in invocation between the above signature and this one. 223 static Literal MakeTupleOwned(std::vector<Literal> elements); 224 225 // This overload lets you pass a braced list of Literals to 226 // MakeTupleOwned: 227 // 228 // LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...). 229 // 230 // Simply relying on the MakeTupleOwned(std::vector<Literal>) 231 // overload doesn't work because std::initializer_list's elements are always 232 // const. 233 // 234 // The arguments to this function must all be Literal. 235 template <typename... Ts> 236 static Literal MakeTupleOwned(Ts... elements) { 237 std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...}; 238 std::vector<Literal> v; 239 v.insert(v.begin(), std::make_move_iterator(arr.begin()), 240 std::make_move_iterator(arr.end())); 241 return MakeTupleOwned(std::move(v)); 242 } 243 244 // Create a constant token literal. Token types have no value. 245 static Literal CreateToken(); 246 247 // Creates a new Literal object with its values havings the primitive_type 248 // type, and with dimensions defined by the dimensions parameter. 249 // The content of the literal values is the default value of the primitive 250 // type of literal itself (0 for numeric types, and false for predicates). 251 static Literal CreateFromDimensions(PrimitiveType primitive_type, 252 absl::Span<const int64> dimensions); 253 254 // If the given literal's data type is bfloat16, converts it to a float 255 // literal; otherwise, returns a copy of it. If the literal is a tuple, 256 // recursively converts its elements. 257 static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal); 258 259 // If the given literal's data type is float, converts it to a bfloat16 260 // literal; otherwise, returns a copy of it. If the literal is a tuple, 261 // recursively converts its elements. 262 static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal); 263 264 // Creates a literal with a new shape with the given new dimensions using the 265 // data in the given input literal. For reshaping purposes the (flat) data 266 // buffer of the input literal is assumed to have the given minor_to_major 267 // layout order. 268 static Literal ReshapeSlice(absl::Span<const int64> new_dimensions, 269 absl::Span<const int64> minor_to_major, 270 const LiteralSlice& literal); 271 272 // Creates a literal with the supplied shape, and uses the provided value 273 // generator to populate the literal's values. 274 // Returns the new literal object, or an error Status if failed. 275 template < 276 PrimitiveType type, 277 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type> 278 static StatusOr<Literal> CreateRandomLiteral( 279 const Shape& shape, 280 const std::function<T(absl::Span<const int64>)>& generator); 281 282 // Creates a literal with the supplied shape, and initializes the literal 283 // values using a normal distribution with given mean and stddev standard 284 // deviation, and using the engine as entropy generator. 285 // Returns the new literal object, or an error Status if failed. 286 template < 287 PrimitiveType type, typename E, 288 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type> 289 static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine, 290 T mean, T stddev); 291 292 // Creates a literal with the supplied shape, and initializes the literal 293 // values using a normal distribution with given mean and stddev standard 294 // deviation. 295 // Returns the new literal object, or an error Status if failed. 296 template < 297 PrimitiveType type, 298 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type> 299 static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean, 300 T stddev); 301 302 // 303 // End of factory methods. 304 305 // Returns a multi-dimensional index as a string. For example: '{7, 8}' will 306 // be returned for a 2-dimensional index with dimension 0 index equal to 7, 307 // dimension 1 equal to 8. 308 static string MultiIndexAsString(absl::Span<const int64> multi_index); 309 }; 310 311 std::ostream& operator<<(std::ostream& out, const Literal& literal); 312 313 template <typename NativeT> 314 /* static */ Literal LiteralUtil::CreateR0(NativeT value) { 315 Literal literal(ShapeUtil::MakeShape( 316 primitive_util::NativeToPrimitiveType<NativeT>(), {})); 317 literal.Set({}, value); 318 return literal; 319 } 320 321 template <typename NativeT> 322 /* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) { 323 Literal literal( 324 ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(), 325 {static_cast<int64>(values.size())})); 326 literal.PopulateR1(values); 327 return literal; 328 } 329 330 template <typename NativeT> 331 /* static */ Literal LiteralUtil::CreateR2WithLayout( 332 std::initializer_list<std::initializer_list<NativeT>> values, 333 const Layout& layout) { 334 Literal literal(ShapeUtil::MakeShapeWithLayout( 335 primitive_util::NativeToPrimitiveType<NativeT>(), 336 {static_cast<int64>(values.size()), 337 static_cast<int64>(values.begin()->size())}, 338 AsInt64Slice(layout.minor_to_major()))); 339 literal.PopulateR2(values); 340 return literal; 341 } 342 343 template <typename NativeT> 344 /* static */ Literal LiteralUtil::CreateR2( 345 std::initializer_list<std::initializer_list<NativeT>> values) { 346 return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); 347 } 348 349 template <typename NativeT> 350 /* static */ Literal LiteralUtil::CreateR3WithLayout( 351 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> 352 values, 353 const Layout& layout) { 354 const int64 d0 = values.size(); 355 const int64 d1 = values.begin()->size(); 356 const int64 d2 = values.begin()->begin()->size(); 357 Array3D<NativeT> tmp(d0, d1, d2); 358 int64 i0 = 0; 359 for (auto d1_values : values) { 360 int64 i1 = 0; 361 for (auto d2_values : d1_values) { 362 int64 i2 = 0; 363 for (auto value : d2_values) { 364 tmp(i0, i1, i2) = value; 365 ++i2; 366 } 367 ++i1; 368 } 369 ++i0; 370 } 371 return CreateR3FromArray3DWithLayout(tmp, layout); 372 } 373 374 template <typename NativeT> 375 /* static */ Literal LiteralUtil::CreateR3( 376 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> 377 values) { 378 return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); 379 } 380 381 template <typename NativeT> 382 /* static */ Literal LiteralUtil::CreateR4WithLayout( 383 std::initializer_list<std::initializer_list< 384 std::initializer_list<std::initializer_list<NativeT>>>> 385 values, 386 const Layout& layout) { 387 const int64 d0 = values.size(); 388 const int64 d1 = values.begin()->size(); 389 const int64 d2 = values.begin()->begin()->size(); 390 const int64 d3 = values.begin()->begin()->begin()->size(); 391 Array4D<NativeT> tmp(d0, d1, d2, d3); 392 int64 i0 = 0; 393 for (auto d1_values : values) { 394 int64 i1 = 0; 395 for (auto d2_values : d1_values) { 396 int64 i2 = 0; 397 for (auto d3_values : d2_values) { 398 int64 i3 = 0; 399 for (auto value : d3_values) { 400 tmp(i0, i1, i2, i3) = value; 401 ++i3; 402 } 403 ++i2; 404 } 405 ++i1; 406 } 407 ++i0; 408 } 409 return CreateR4FromArray4DWithLayout(tmp, layout); 410 } 411 412 template <typename NativeT> 413 /* static */ Literal LiteralUtil::CreateSparse( 414 absl::Span<const int64> dimensions, SparseIndexArray indices, 415 absl::Span<const NativeT> values, bool sort) { 416 int64 num_elements = values.size(); 417 int64 rank = dimensions.size(); 418 CHECK_EQ(num_elements, indices.index_count()); 419 CHECK_EQ(rank, indices.rank()); 420 Literal literal(ShapeUtil::MakeShapeWithSparseLayout( 421 primitive_util::NativeToPrimitiveType<NativeT>(), dimensions, 422 indices.max_indices())); 423 literal.PopulateSparse(indices, values, sort); 424 return literal; 425 } 426 427 template <typename NativeT> 428 /* static */ Literal LiteralUtil::CreateR4( 429 std::initializer_list<std::initializer_list< 430 std::initializer_list<std::initializer_list<NativeT>>>> 431 values) { 432 return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); 433 } 434 435 template <typename NativeT> 436 /* static */ Literal LiteralUtil::CreateFromArrayWithLayout( 437 const Array<NativeT>& values, const Layout& layout) { 438 Literal literal(ShapeUtil::MakeShapeWithLayout( 439 primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(), 440 AsInt64Slice(layout.minor_to_major()))); 441 literal.PopulateFromArray(values); 442 return literal; 443 } 444 445 template <typename NativeT> 446 /* static */ Literal LiteralUtil::CreateFromArray( 447 const Array<NativeT>& values) { 448 return CreateFromArrayWithLayout( 449 values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); 450 } 451 452 template <typename NativeT> 453 /* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout( 454 const Array2D<NativeT>& values, const Layout& layout) { 455 return CreateFromArrayWithLayout(values, layout); 456 } 457 458 template <typename NativeT> 459 /* static */ Literal LiteralUtil::CreateR2FromArray2D( 460 const Array2D<NativeT>& values) { 461 return CreateFromArray(values); 462 } 463 464 template <typename NativeT> 465 /* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout( 466 const Array3D<NativeT>& values, const Layout& layout) { 467 return CreateFromArrayWithLayout(values, layout); 468 } 469 470 template <typename NativeT> 471 /* static */ Literal LiteralUtil::CreateR3FromArray3D( 472 const Array3D<NativeT>& values) { 473 return CreateFromArray(values); 474 } 475 476 template <typename NativeT> 477 /* static */ Literal LiteralUtil::CreateR3Projected( 478 std::initializer_list<std::initializer_list<NativeT>> values, 479 int64 projection) { 480 int64 dim0_size = projection; 481 int64 dim1_size = values.size(); 482 int64 dim2_size = values.begin()->size(); 483 484 Array3D<NativeT> array(dim0_size, dim1_size, dim2_size); 485 for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { 486 int64 dim1 = 0; 487 for (auto inner_list : values) { 488 int64 dim2 = 0; 489 for (auto value : inner_list) { 490 array(dim0, dim1, dim2) = value; 491 ++dim2; 492 } 493 CHECK_EQ(dim2_size, dim2); 494 ++dim1; 495 } 496 CHECK_EQ(dim1_size, dim1); 497 } 498 return CreateR3FromArray3D(array); 499 } 500 501 template <typename NativeT> 502 /* static */ Literal LiteralUtil::CreateR4Projected( 503 std::initializer_list<std::initializer_list<NativeT>> values, 504 int64 projection_p, int64 projection_z) { 505 int64 dim0_size = projection_p; 506 int64 dim1_size = projection_z; 507 int64 dim2_size = values.size(); 508 int64 dim3_size = values.begin()->size(); 509 510 Array4D<NativeT> array(dim0_size, dim1_size, dim2_size, dim3_size); 511 for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { 512 for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) { 513 int64 dim2 = 0; 514 for (auto inner_list : values) { 515 int64 dim3 = 0; 516 for (auto value : inner_list) { 517 array(dim0, dim1, dim2, dim3) = value; 518 ++dim3; 519 } 520 CHECK_EQ(dim3_size, dim3); 521 ++dim2; 522 } 523 CHECK_EQ(dim2_size, dim2); 524 } 525 } 526 return CreateR4FromArray4D(array); 527 } 528 529 template <typename NativeT> 530 /* static */ Literal LiteralUtil::CreateR4FromArray4D( 531 const Array4D<NativeT>& values) { 532 return CreateFromArray(values); 533 } 534 535 template <typename NativeT> 536 /* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout( 537 const Array4D<NativeT>& values, const Layout& layout) { 538 return CreateFromArrayWithLayout(values, layout); 539 } 540 541 // Returns an identity matrix (rank 2) with the given row and column count. 542 template <typename NativeT> 543 /* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) { 544 Array2D<NativeT> array(size, size, 0); 545 for (int64 i = 0; i < size; ++i) { 546 array(i, i) = 1; 547 } 548 return CreateR2FromArray2D(array); 549 } 550 551 template <typename NativeT> 552 /* static */ Literal LiteralUtil::CreateFullWithDescendingLayout( 553 absl::Span<const int64> dimensions, NativeT value) { 554 Literal literal(ShapeUtil::MakeShapeWithDescendingLayout( 555 primitive_util::NativeToPrimitiveType<NativeT>(), dimensions)); 556 literal.PopulateWithValue(value); 557 return literal; 558 } 559 560 template <PrimitiveType type, typename T> 561 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral( 562 const Shape& shape, 563 const std::function<T(absl::Span<const int64>)>& generator) { 564 using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type; 565 TF_RET_CHECK(shape.element_type() == type); 566 Literal literal(shape); 567 TF_RETURN_IF_ERROR(literal.Populate<NativeT>( 568 [&](absl::Span<const int64> indexes) { return generator(indexes); })); 569 return std::move(literal); 570 } 571 572 template <PrimitiveType type, typename E, typename T> 573 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral( 574 const Shape& shape, E* engine, T mean, T stddev) { 575 using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type; 576 std::normal_distribution<NativeT> generator(mean, stddev); 577 return CreateRandomLiteral<type, NativeT>( 578 shape, 579 [&](absl::Span<const int64> /*indexes*/) { return generator(*engine); }); 580 } 581 582 template <PrimitiveType type, typename T> 583 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral( 584 const Shape& shape, T mean, T stddev) { 585 std::minstd_rand0 engine; 586 return CreateRandomLiteral<type>(shape, &engine, mean, stddev); 587 } 588 589 } // namespace xla 590 591 #endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ 592