1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Utilities for dealing with Literal protobufs. 17 18 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ 19 #define TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ 20 21 #include <functional> 22 #include <initializer_list> 23 #include <iterator> 24 #include <memory> 25 #include <ostream> 26 #include <string> 27 #include <type_traits> 28 #include <vector> 29 30 #include "tensorflow/compiler/xla/array2d.h" 31 #include "tensorflow/compiler/xla/array3d.h" 32 #include "tensorflow/compiler/xla/array4d.h" 33 #include "tensorflow/compiler/xla/index_util.h" 34 #include "tensorflow/compiler/xla/layout_util.h" 35 #include "tensorflow/compiler/xla/primitive_util.h" 36 #include "tensorflow/compiler/xla/ptr_util.h" 37 #include "tensorflow/compiler/xla/shape_tree.h" 38 #include "tensorflow/compiler/xla/shape_util.h" 39 #include "tensorflow/compiler/xla/sparse_index_array.h" 40 #include "tensorflow/compiler/xla/status_macros.h" 41 #include "tensorflow/compiler/xla/types.h" 42 #include "tensorflow/compiler/xla/util.h" 43 #include "tensorflow/compiler/xla/xla_data.pb.h" 44 #include "tensorflow/core/lib/core/bitmap.h" 45 #include "tensorflow/core/lib/core/status.h" 46 #include "tensorflow/core/lib/core/stringpiece.h" 47 #include "tensorflow/core/lib/gtl/array_slice.h" 48 #include "tensorflow/core/platform/logging.h" 49 #include "tensorflow/core/platform/macros.h" 50 #include "tensorflow/core/platform/protobuf.h" 51 #include "tensorflow/core/platform/types.h" 52 53 namespace xla { 54 55 // Class representing literal values in XLA. 56 // 57 // TODO(b/67651157): The methods in this class should be reduced to a minimal 58 // set of methods which construct Literals and accessors methods. Other methods 59 // which perform computation on Literals (Reshape, Slice, etc) should be moved 60 // elsewhere, and perhaps combined with evaluator code which operates on 61 // Literals. 62 class Literal { 63 public: 64 Literal() : Literal(ShapeUtil::MakeNil()) {} 65 66 // Create a literal of the given shape. The literal is allocated sufficient 67 // memory to hold the shape. Memory is uninitialized. 68 explicit Literal(const Shape& shape); 69 virtual ~Literal(); 70 71 // Literals are moveable, but not copyable. To copy a literal use 72 // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies 73 // of literals which can be expensive. 74 Literal(const Literal& other) = delete; 75 Literal& operator=(const Literal& other) = delete; 76 Literal(Literal&& other); 77 Literal& operator=(Literal&& other); 78 79 // Literals are equal if they have compatible shapes and the same data 80 // values. Layout is not compared. 81 bool operator==(const Literal& other) const; 82 bool operator!=(const Literal& other) const { return !(*this == other); } 83 84 // Serialize to and from a proto. 85 static StatusOr<std::unique_ptr<Literal>> CreateFromProto( 86 const LiteralProto& proto); 87 LiteralProto ToProto() const; 88 89 // Return the shape of the literal. 90 const Shape& shape() const { return shape_; } 91 92 // TODO(b/67651157): Remove this accessor. Literal users should not be able to 93 // mutate the shape as this can produce malformed Literals. 94 Shape* mutable_shape_do_not_use() { return &shape_; } 95 96 // Returns a (Mutable)ArraySlice view of the array for this literal for the 97 // given NativeT (e.g., float). CHECKs if the subshape of the literal at the 98 // given ShapeIndex is not array. See primitive_util.h for the mapping from 99 // XLA type to native type. 100 template <typename NativeT> 101 tensorflow::gtl::ArraySlice<NativeT> data( 102 const ShapeIndex& shape_index = {}) const; 103 template <typename NativeT> 104 tensorflow::gtl::MutableArraySlice<NativeT> data( 105 const ShapeIndex& shape_index = {}); 106 107 // Returns a pointer to the sparse index array. Returns nullptr if the literal 108 // is not a sparse array. 109 const SparseIndexArray* sparse_indices( 110 const ShapeIndex& shape_index = {}) const; 111 SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); 112 113 // Returns a pointer to (or size of) the underlying buffer holding the array 114 // at the given shape index. CHECKs if the subshape of the literal at the 115 // given ShapeIndex is not array. 116 const void* untyped_data(const ShapeIndex& shape_index = {}) const; 117 void* untyped_data(const ShapeIndex& shape_index = {}); 118 int64 size_bytes(const ShapeIndex& shape_index = {}) const; 119 120 // Creates a new literal of a given rank. To minimize ambiguity (for users 121 // and the compiler) these CreateR[0-2] methods should explicitly specify the 122 // native type. For example: 123 // 124 // CreateR1<float>({1.0, 42.0}); 125 // CreateR2<uint32>({{1, 2}, {3, 4}}); 126 // 127 // The variants not ending with WithLayout use the default XLA layout for the 128 // literal's linear representation in memory. 129 template <typename NativeT> 130 static std::unique_ptr<Literal> CreateR0(NativeT value); 131 template <typename NativeT> 132 static std::unique_ptr<Literal> CreateR1( 133 tensorflow::gtl::ArraySlice<NativeT> values); 134 static std::unique_ptr<Literal> CreateR1( 135 const tensorflow::core::Bitmap& values); 136 template <typename NativeT> 137 static std::unique_ptr<Literal> CreateR2( 138 std::initializer_list<std::initializer_list<NativeT>> values); 139 template <typename NativeT> 140 static std::unique_ptr<Literal> CreateR2WithLayout( 141 std::initializer_list<std::initializer_list<NativeT>> values, 142 const Layout& layout); 143 template <typename NativeT> 144 static std::unique_ptr<Literal> CreateR3( 145 std::initializer_list< 146 std::initializer_list<std::initializer_list<NativeT>>> 147 values); 148 template <typename NativeT> 149 static std::unique_ptr<Literal> CreateR3WithLayout( 150 std::initializer_list< 151 std::initializer_list<std::initializer_list<NativeT>>> 152 values, 153 const Layout& layout); 154 template <typename NativeT> 155 static std::unique_ptr<Literal> CreateR4( 156 std::initializer_list<std::initializer_list< 157 std::initializer_list<std::initializer_list<NativeT>>>> 158 values); 159 template <typename NativeT> 160 static std::unique_ptr<Literal> CreateR4WithLayout( 161 std::initializer_list<std::initializer_list< 162 std::initializer_list<std::initializer_list<NativeT>>>> 163 values, 164 const Layout& layout); 165 166 // Returns this literal's data as a string. This literal must be a rank-1 U8 167 // array. 168 string GetR1U8AsString() const; 169 170 // Creates a literal with a sparse layout and the given indices and values. 171 // The shape is initialized from the given dimensions. The minor dimension of 172 // the indices array must equal the rank of the shape (i.e. size of the 173 // dimensions array). The major dimension of the indices array must equal the 174 // number of elements in the values array. The maximum number of elements in 175 // the array is taken from the max_indices() value of the index array. 176 // 177 // XLA assumes that sparse literals are in sorted order for all operations. If 178 // the `sort` argument is true, then the indices and values will be sorted 179 // while copying them into the literal. If you have ensured that the indices 180 // and values are already sorted, then you may set the `sort` argument to 181 // false to skip the sorting step. 182 // 183 // For example: 184 // 185 // CreateSparse( 186 // {12, 12, 12}, 187 // SparseIndexArray(10, 3, 188 // Array2D{ 189 // {0, 1, 2}, 190 // {3, 4, 5}, 191 // {6, 7, 8}, 192 // {9, 10, 11}, 193 // }), 194 // {1.0, 2.0 3.0, 4.0}) 195 // 196 // This creates an array with shape F64[12,12,12]sparse{10}, that has the 197 // following non-zero values: 198 // 199 // [0, 1, 2]: 1.0 200 // [3, 4, 5]: 2.0 201 // [6, 7, 8]: 3.0 202 // [9, 10, 11]: 4.0 203 // 204 template <typename NativeT> 205 static std::unique_ptr<Literal> CreateSparse( 206 tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices, 207 tensorflow::gtl::ArraySlice<NativeT> values, bool sort = true); 208 209 // Populates a literal with a sparse layout with the given indices and values. 210 // Each index in the indices array is CHECKed against the dimensions in the 211 // literal's shape. If sort is true, then the indices and values will be 212 // sorted. If sort is false, then the indices and values are assumed to 213 // already be in sorted order. See CreateSparse for an example of how data 214 // are populated. 215 template <typename NativeT> 216 void PopulateSparse(SparseIndexArray indices, 217 tensorflow::gtl::ArraySlice<NativeT> values, 218 bool sort = true); 219 220 // Creates a new Literal object with the shape specified as parameter. 221 // The content of the literal values is the default value of the primitive 222 // type of literal itself (0 for numeric types, and false for predicates). 223 static std::unique_ptr<Literal> CreateFromShape(const Shape& shape); 224 225 // Creates a new Literal object with its values havings the primitive_type 226 // type, and with dimensions defined by the dimensions parameter. 227 // The content of the literal values is the default value of the primitive 228 // type of literal itself (0 for numeric types, and false for predicates). 229 static std::unique_ptr<Literal> CreateFromDimensions( 230 PrimitiveType primitive_type, 231 tensorflow::gtl::ArraySlice<int64> dimensions); 232 233 // Copy values from 'src_literal' rooted at 'src_shape_index' into this 234 // literal rooted at 'dest_shape_index'. The subshape of this literal rooted 235 // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' 236 // rooted at 'src_shape_index', but need not be arrays. 237 Status CopyFrom(const Literal& src_literal, 238 const ShapeIndex& dest_shape_index = {}, 239 const ShapeIndex& src_shape_index = {}); 240 241 // Similar to CopyFrom, but with move semantincs. The subshape of this literal 242 // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' 243 // (layouts and shapes must match), but need not be arrays. The memory 244 // allocated in this literal for the subshape at dest_shape_index is 245 // deallocated, and the respective buffers are replaced with those in 246 // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). 247 Status MoveFrom(Literal&& src_literal, 248 const ShapeIndex& dest_shape_index = {}); 249 250 // Copies the values from src_literal, starting at src_base shape indexes, 251 // to this literal, starting at dest_base, where the copy size in each 252 // dimension is specified by copy_size. 253 // The src_literal and this literal must have the same primitive type, 254 // src_base+copy_size must fit the source literal dimensions, as well as 255 // dest_base+copy_size must fit the destination literal dimensions. 256 // Note: if either src_literal or this literal contains dimensions with zero 257 // element, then copy_size must be 0 in these dimensions while the 258 // corresponding base indices being 0. 259 // This literal and 'src_literal' must be arrays. 260 Status CopySliceFrom(const Literal& src_literal, 261 tensorflow::gtl::ArraySlice<int64> src_base, 262 tensorflow::gtl::ArraySlice<int64> dest_base, 263 tensorflow::gtl::ArraySlice<int64> copy_size); 264 265 // Returns a vector containing the tuple elements of this Literal as separate 266 // Literals. This Literal must be tuple-shaped and can be a nested tuple. The 267 // elements are moved into the new Literals; no data is copied. Upon return 268 // this Literal is set to a nil shape (empty tuple) 269 std::vector<Literal> DecomposeTuple(); 270 271 // This operation is the inverse of DecomposeTuple. The given elements are 272 // moved into the tuple elements of a new tuple-shaped Literal which is 273 // returned. Upon return, each of the Literals in 'elements' is set to a nil 274 // shape (empty tuple). 275 static Literal MoveIntoTuple( 276 tensorflow::gtl::MutableArraySlice<Literal> elements); 277 278 // Creates a new value that has the equivalent value as this literal, but 279 // conforms to new_layout; e.g. a literal matrix that was in {0, 1} 280 // minor-to-major dimension layout can be re-layed-out as {1, 0} 281 // minor-to-major dimension layout and the value in the cell at any given 282 // logical index (i0, i1) will be the same. 283 // 284 // For tuple shaped literals, shape_index should be used to select the inner 285 // array that the new layout applies to. 286 // 287 // Note: this is useful when the client wants to ensure that a value placed in 288 // the XLA allocation tracker has a particular layout; for efficiency 289 // purposes or avoiding unimplemented operation/layout combinations. 290 std::unique_ptr<Literal> Relayout(const Layout& new_layout, 291 const ShapeIndex& shape_index = {}) const; 292 293 // An overload of Relayout which changes the layout of the entire shape rather 294 // than being limited to a single array within the shape. 295 std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const; 296 297 // Creates a new literal by reshaping this literal to have the given 298 // dimensions. The total number of elements must not change; The 299 // implementation currently only supports monotonic dim0-major layouts. 300 // This literal must be an array. 301 StatusOr<std::unique_ptr<Literal>> Reshape( 302 tensorflow::gtl::ArraySlice<int64> dimensions) const; 303 304 // Creates a new literal by reordering the dimensions of this literal. 305 // The given `permutation` must be a permutation of the dimension numbers 306 // in the original literal, and it specifies the order of the new dimensions 307 // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). 308 // For example, a transpose call on a literal of shape [3 x 8 x 4] and 309 // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. 310 // This literal must be an array. 311 std::unique_ptr<Literal> Transpose( 312 tensorflow::gtl::ArraySlice<int64> permutation) const; 313 314 // Creates a sub-array from this literal by extracting the indices 315 // [start_index, limit_index) of each dimension. The result literal has the 316 // same rank and layout as for the given literal. The number of indices in 317 // start_indices and limit_indices must be the rank of the literal, and the 318 // indices follow the order of the dimensions. 319 // This literal must be an array. 320 std::unique_ptr<Literal> Slice( 321 tensorflow::gtl::ArraySlice<int64> start_indices, 322 tensorflow::gtl::ArraySlice<int64> limit_indices) const; 323 324 // Creates a literal with a prepended dimension with bound "times"; e.g. a 325 // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this 326 // literal replicated four times. 327 // This literal must be an array. 328 template <typename NativeT> 329 std::unique_ptr<Literal> Replicate(int64 times) const; 330 331 // Converts this literal to another primitive type. Returns an error if the 332 // conversion is not possible. This literal must be array-shaped. 333 StatusOr<std::unique_ptr<Literal>> Convert( 334 PrimitiveType primitive_dest_type) const; 335 336 // Creates a scalar literal value zero of the given primitive type. 337 static Literal Zero(PrimitiveType primitive_type); 338 339 // Creates a scalar literal value one of the given primitive type. 340 static Literal One(PrimitiveType primitive_type); 341 342 // Creates a scalar literal value containing the minimum value of the given 343 // primitive type. For floating-point types, returns -inf. 344 static Literal MinValue(PrimitiveType primitive_type); 345 346 // Creates a scalar literal value containing the maximum value of the given 347 // primitive type. For floating-point types, returns inf. 348 static Literal MaxValue(PrimitiveType primitive_type); 349 350 // Creates a literal of the given shape where each element is `value`. 351 template <typename NativeT> 352 static std::unique_ptr<Literal> CreateFullWithDescendingLayout( 353 tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value); 354 355 // Creates a new literal from an Array type. The variants not ending with 356 // WithLayout use the default XLA layout for the literal's linear 357 // representation in memory. 358 template <typename NativeT> 359 static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values); 360 template <typename NativeT> 361 static std::unique_ptr<Literal> CreateFromArrayWithLayout( 362 const Array<NativeT>& values, const Layout& layout); 363 template <typename NativeT> 364 static std::unique_ptr<Literal> CreateR2FromArray2D( 365 const Array2D<NativeT>& values); 366 template <typename NativeT> 367 static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout( 368 const Array2D<NativeT>& values, const Layout& layout); 369 template <typename NativeT> 370 static std::unique_ptr<Literal> CreateR3FromArray3D( 371 const Array3D<NativeT>& values); 372 template <typename NativeT> 373 static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout( 374 const Array3D<NativeT>& values, const Layout& layout); 375 template <typename NativeT> 376 static std::unique_ptr<Literal> CreateR4FromArray4D( 377 const Array4D<NativeT>& values); 378 template <typename NativeT> 379 static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout( 380 const Array4D<NativeT>& values, const Layout& layout); 381 382 // Creates a new vector of U8s literal value from a string. 383 static std::unique_ptr<Literal> CreateR1U8(tensorflow::StringPiece value); 384 385 // Creates a linspace-populated literal with the given number of rows and 386 // columns. 387 static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to, 388 int64 rows, int64 cols); 389 390 // Creates a literal that projects the (x, y) dimensions given in values into 391 // the z dimension given by "projection". 392 template <typename NativeT> 393 static std::unique_ptr<Literal> CreateR3Projected( 394 std::initializer_list<std::initializer_list<NativeT>> values, 395 int64 projection); 396 397 // Creates a literal that projects the (x, y) dimensions given in values into 398 // the z and p dimensions given. 399 template <typename NativeT> 400 static std::unique_ptr<Literal> CreateR4Projected( 401 std::initializer_list<std::initializer_list<NativeT>> values, 402 int64 projection_p, int64 projection_z); 403 404 // Clones this literal into a new Literal, or new std::unique_ptr<Literal>. 405 Literal Clone() const; 406 std::unique_ptr<Literal> CloneToUnique() const; 407 408 // Gets or sets an element in the literal at the given index. The multi_index 409 // is CHECKed against the dimension sizes. 410 template <typename NativeT> 411 NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index, 412 const ShapeIndex& shape_index) const; 413 template <typename NativeT> 414 void Set(tensorflow::gtl::ArraySlice<int64> multi_index, 415 const ShapeIndex& shape_index, NativeT value); 416 417 // Overloads of Get and Set for array literals. CHECKs if the literal is not 418 // array-shaped and dense. 419 template <typename NativeT> 420 NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const; 421 template <typename NativeT> 422 void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value); 423 424 // Returns the multi-index of the element in a sparse literal at the given 425 // sparse element number. The sparse element number is the position with in 426 // the sparse array's list of (index, value) pairs, and is checked against the 427 // total number of (index, value) pairs in the sparse array. 428 tensorflow::gtl::ArraySlice<int64> GetSparseIndex( 429 int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; 430 431 // Returns the value of the element in a sparse literal at the given sparse 432 // element number. The sparse element number is the position with in the 433 // sparse array's list of (index, value) pairs, and is checked against the 434 // total number of (index, value) pairs in the sparse array. 435 template <typename NativeT> 436 NativeT GetSparseElement(int64 sparse_element_number, 437 const ShapeIndex& shape_index = {}) const; 438 439 // Appends the given element to the literal. If the elements are not appended 440 // in sorted order, then SortSparseElements should be called before calling 441 // other methods. This literal must have a sparse layout. 442 template <typename NativeT> 443 void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index, 444 NativeT value, const ShapeIndex& shape_index = {}); 445 446 // Sorts the elements in a sparse array. 447 void SortSparseElements(const ShapeIndex& shape_index = {}); 448 449 // Returns the element value at index (0, ..., 0), however many zeroes are 450 // required for that index. 451 template <typename NativeT> 452 NativeT GetFirstElement() const; 453 454 // As Get(), but determines the correct type and converts the value 455 // into text. 456 string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index, 457 const ShapeIndex& shape_index = {}) const; 458 459 // As GetSparseElement(), but determines the correct type and converts the 460 // value into text. 461 string GetSparseElementAsString(int64 sparse_element_number, 462 const ShapeIndex& shape_index = {}) const; 463 464 // As Get(), but determines the correct type and converts the value into 465 // int64. This literal must be an array. 466 StatusOr<int64> GetIntegralAsS64( 467 tensorflow::gtl::ArraySlice<int64> multi_index) const; 468 469 // Returns an identity matrix (rank 2) with the given row and column count. 470 template <typename NativeT> 471 static std::unique_ptr<Literal> MakeIdentityR2(int64 size); 472 473 // Returns a tuple literal composed of given literals. Data is copied from the 474 // given elements into the returned literal. 475 static std::unique_ptr<Literal> MakeTuple( 476 tensorflow::gtl::ArraySlice<const Literal*> elements); 477 478 // As above, but intended to be invoked with move semantics; i.e. 479 // 480 // std::vector<std::unique_ptr<Literal>> elements = ...; 481 // auto result = Literal::MakeTupleOwned(std::move(elements)); 482 // 483 // This would have been declared as an overload, but there is ambiguity 484 // in invocation between the above signature and this one. 485 static std::unique_ptr<Literal> MakeTupleOwned( 486 std::vector<std::unique_ptr<Literal>> elements); 487 488 // This overload lets you pass a braced list of unique_ptr<Literal>s to 489 // MakeTupleOwned: 490 // 491 // Literal::MakeTupleOwned(Literal::CreateR1(...), ...). 492 // 493 // Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>) 494 // overload doesn't work because std::initializer_list's elements are always 495 // const. 496 // 497 // The arguments to this function must all be unique_ptr<Literal>. 498 template <typename... Ts> 499 static std::unique_ptr<Literal> MakeTupleOwned( 500 std::unique_ptr<Ts>... elements) { 501 std::array<std::unique_ptr<Literal>, sizeof...(Ts)> arr{ 502 std::move(elements)...}; 503 std::vector<std::unique_ptr<Literal>> v; 504 v.insert(v.begin(), std::make_move_iterator(arr.begin()), 505 std::make_move_iterator(arr.end())); 506 return MakeTupleOwned(std::move(v)); 507 } 508 509 // Returns a string representation of the literal value. 510 // Warning: this function can take minutes for multi-million element Literals. 511 string ToString(bool print_layout = false) const; 512 513 // Invokes the "per cell" callback for each element in the provided 514 // literal with the element's indices and a string representation of 515 // the element's value. 516 // 517 // This function is useful if you want a polymorphic representation 518 // of the tensor's elements (turning it to a string for something 519 // like representation in a protobuf). 520 // 521 // This literal must have a dense layout. 522 void EachCellAsString( 523 const std::function<void(tensorflow::gtl::ArraySlice<int64> indices, 524 const string& value)>& per_cell) const; 525 template <typename NativeT> 526 void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices, 527 NativeT value)> 528 per_cell) const; 529 530 // Populate this literal with the given values. Examples: 531 // 532 // // Populate with floats. 533 // Array2D<float> float_values = ... 534 // literal.PopulateR2FromArray2D(values); 535 // 536 // // Populate with int32s. 537 // literal.PopulateR2<int32>({{1, 2}, {3, 4}}); 538 // 539 // The shape and element type of this literal must match given values. For 540 // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 541 // array of S32. 542 template <typename NativeT> 543 void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values); 544 void PopulateR1(const tensorflow::core::Bitmap& values); 545 template <typename NativeT> 546 void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values); 547 template <typename NativeT> 548 void PopulateFromArray(const Array<NativeT>& values); 549 template <typename NativeT> 550 void PopulateR2FromArray2D(const Array2D<NativeT>& values); 551 template <typename NativeT> 552 void PopulateR3FromArray3D(const Array3D<NativeT>& values); 553 template <typename NativeT> 554 void PopulateR4FromArray4D(const Array4D<NativeT>& values); 555 556 // Populates literal values by calling the generator function for every cell 557 // in this literal object. 558 // 559 // generator must be a callable of the type 560 // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible. 561 // 562 // This literal must have a dense layout. 563 template <typename NativeT, typename FnType> 564 Status Populate(const FnType& generator); 565 566 // Fills this literal with the given value. 567 template <typename NativeT> 568 void PopulateWithValue(NativeT value); 569 570 // Returns whether every element in this literal is equal to value. 571 // 572 // value is an int8 because we expect this to be called with small 573 // compile-time constants (0, -1, etc.) and so that whatever value you pass 574 // can be represented exactly by floating-point types as small as 16 bits. 575 // 576 // If value doesn't fit in this literal's type, returns false. Values of 1/0 577 // are considered equal to true/false; other values are not considered equal 578 // to true. Also if this literal is not array-shaped false is returned. 579 bool IsAll(int8 value) const; 580 581 // Like IsAll(const Literal&, int8), except we check whether the literal is 582 // equal to a particular floating-point number. 583 // 584 // If the literal is not a floating-point value, this always returns false. 585 // 586 // This casts value to the type of literal, then compares using ==. The usual 587 // admonishments about floating-point equality checks apply. We expect you to 588 // use this to check for values that can be expressed precisely as a float, 589 // e.g. -0.5. Also if this literal is not array-shaped false is returned. 590 bool IsAllFloat(float value) const; 591 592 // Like IsAll(const Literal&, int8), except we check whether the literal is 593 // equal to a particular complex number. 594 // 595 // If the literal is not a complex value, this always returns false. 596 // 597 // This casts value to the type of literal, then compares using ==. The usual 598 // admonishments about floating-point equality checks apply. We expect you to 599 // use this to check for complex values that can be expressed precisely as 600 // float pairs e.g. (-0.5, 1.0). 601 // 602 // This literal must have a dense layout. 603 bool IsAllComplex(complex64 value) const; 604 605 // Returns whether this literal is zero at the specified index. This literal 606 // must be an array with a dense layout. 607 bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const; 608 609 // Return the count of the elements in the array at the given shape index in 610 // this literal. 611 int64 element_count(const ShapeIndex& index = {}) const { 612 return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); 613 } 614 615 // Return the count of the elements in the sparse array at the given shape 616 // index in this literal, which will be no larger than 617 // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). 618 int64 sparse_element_count() const; 619 620 protected: 621 // 'allocate_arrays' indicates whether to allocate memory for the arrays in 622 // the shape. If false, buffer pointers inside of the Literal::Pieces are set 623 // to nullptr. 624 Literal(const Shape& shape, bool allocate_arrays); 625 626 // Internal template helper for the Literal::CopySliceFrom(), matching its 627 // arguments one by one. 628 template <typename NativeT> 629 Status CopySliceFromInternal(const Literal& src_literal, 630 tensorflow::gtl::ArraySlice<int64> src_base, 631 tensorflow::gtl::ArraySlice<int64> dest_base, 632 tensorflow::gtl::ArraySlice<int64> copy_size); 633 634 // Utility structure which is used to create the optimal configuration for 635 // a ShapeUtil::ForEachIndex() scan across two literals. 636 struct StrideConfig { 637 StrideConfig(const Shape& source_shape, const Shape& dest_shape, 638 tensorflow::gtl::ArraySlice<int64> dimensions); 639 640 // The dimensions of the stride operation. Essentially every dimension 641 // will be iterated from base[i] to base[i]+dimensions[i], in step[i] 642 // steps. 643 tensorflow::gtl::ArraySlice<int64> dimensions; 644 DimensionVector base; 645 DimensionVector step; 646 int64 minor_dimension = 0; 647 // The size of the strides for source and destination. One of the two 648 // (the one looping through its most minor dimension) will be 1, while 649 // the other will be the stride size at the dimension matching the other 650 // shape most minor dimension being scanned. 651 int64 dest_stride = 1; 652 int64 source_stride = 1; 653 // The size of the inner loop on the most minor dimension. 654 int64 minor_loop_size = 1; 655 }; 656 657 // A data structure representing a subshape at a particular ShapeIndex within 658 // the literal. For array-shaped ShapeIndexes, this data structure holds the 659 // pointer to the memory allocated for the array data. 660 class Piece { 661 public: 662 // Return the buffer holding the array data for this piece as an array 663 // slice. This piece must be array-shaped. 664 template <typename NativeT> 665 tensorflow::gtl::ArraySlice<NativeT> data() const; 666 template <typename NativeT> 667 tensorflow::gtl::MutableArraySlice<NativeT> data(); 668 669 // Return the buffer holding the array data for this piece as a void*. This 670 // piece must be array-shaped. 671 void* untyped_data(); 672 const void* untyped_data() const; 673 674 // Gets or sets an element in the array at the given index. The multi_index 675 // is CHECKed against the dimension sizes of the array. This piece must be 676 // array-shaped. 677 template <typename NativeT> 678 NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const; 679 template <typename NativeT> 680 void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value); 681 682 // Gets/sets the buffer holding the array data. 683 char* buffer() const { return buffer_; } 684 void set_buffer(char* buffer) { buffer_ = buffer; } 685 686 // The array of multi-indices that provide the locations of non-zero 687 // elements in a sparse array. Only used if 688 // LayoutUtil::IsSparseArray(shape()) is true. 689 SparseIndexArray* sparse_indices() const { return sparse_indices_; } 690 void set_sparse_indices(SparseIndexArray* sparse_indices) { 691 sparse_indices_ = sparse_indices; 692 } 693 694 // Gets or sets the subshape of this piece. This reference points to a 695 // subshape within the shape in the containing Literal (Literal::shape_). 696 const Shape& subshape() const { return *subshape_; } 697 void set_subshape(const Shape* subshape) { subshape_ = subshape; } 698 699 // Returns the size in bytes of the buffer holding the array data. 700 int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } 701 702 // Returns the number of elements in this piece's array. 703 int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); } 704 705 // Copy the data from 'src' into this piece's buffer. Shapes of this piece 706 // and src must be compatible. 707 Status CopyFrom(const Piece& src); 708 709 // Returns true if this piece and 'other' contain the same data. This piece 710 // and 'other' must be array-shaped and compatible. 711 bool EqualElements(const Piece& other) const; 712 713 // Writes the shape and data (if array-shaped) into the given proto. 714 void WriteToProto(LiteralProto* proto) const; 715 716 // Copies the data from the given proto into this piece. The shape of this 717 // piece must be equal (not just compatible) to the shape of the proto. 718 Status CopyFromProto(const LiteralProto& proto); 719 720 // Sorts the elements in a sparse array. 721 void SortSparseElements(); 722 723 private: 724 // Recursive helper for EqualElements. 725 template <typename NativeT> 726 bool EqualElementsInternal(const Piece& other, 727 std::vector<int64>* multi_index) const; 728 729 // Helper for SortSparseElements that has the element type as a template 730 // parameter. 731 template <typename NativeT> 732 void SortSparseElementsInternal(); 733 734 // For array-shaped pieces, this is the buffer holding the literal data. 735 char* buffer_ = nullptr; 736 737 // For sparse arrays, this is the array of indices. 738 SparseIndexArray* sparse_indices_ = nullptr; 739 740 // The shape of piece. This points into the shape of the containing Literal 741 // (Literal::shape_). 742 const Shape* subshape_ = nullptr; 743 }; 744 745 // Returns the piece at the given ShapeIndex. 746 Piece& piece(const ShapeIndex& shape_index) { 747 return *pieces_.mutable_element(shape_index); 748 } 749 const Piece& piece(const ShapeIndex& shape_index) const { 750 return pieces_.element(shape_index); 751 } 752 753 // Returns the piece at the root of the shape (empty ShapeIndex). 754 Piece& root_piece() { return piece({}); } 755 const Piece& root_piece() const { return piece({}); } 756 757 // Deallocate the buffers held by this literal (if the literal owns the 758 // buffer). 759 void DeallocateBuffers(); 760 761 Shape shape_; 762 ShapeTree<Piece> pieces_; 763 764 // Whether the buffers held in pieces_ are owned by this Literal. 765 bool owns_buffers_; 766 767 // LiteralView must access and manipulate Pieces of other Literals. 768 friend class LiteralView; 769 }; // namespace xla 770 771 std::ostream& operator<<(std::ostream& out, const Literal& literal); 772 773 // A read-only view of a Literal. A LiteralView contains pointers to buffers 774 // owned by the viewed Literal. 775 // 776 // TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable 777 // and mutable) similar to (Mutable)ArraySlice. 778 class LiteralView : public Literal { 779 public: 780 // Create and return a view of the given literal rooted at the given shape 781 // index within the given literal. A factory is used rather than a public 782 // constructor because only const LiteralViews are supported. It's still 783 // possible to create non-const LiteralViews via the copy constructors, but 784 // the factory method makes it a bit less likely. Implementing literal slices 785 // will fix this undesirable situation (b/71550060). 786 static const LiteralView Create(const Literal& literal, 787 const ShapeIndex& view_root = {}); 788 789 LiteralView(const LiteralView& other); 790 LiteralView& operator=(const LiteralView& other); 791 792 virtual ~LiteralView(); 793 794 private: 795 LiteralView(const Literal& literal, const ShapeIndex& view_root); 796 797 // Helper for the copy constructor and copy assignment operator. 798 void CopyFrom(const LiteralView& other); 799 }; 800 801 template <typename NativeT> 802 tensorflow::gtl::ArraySlice<NativeT> Literal::Piece::data() const { 803 CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); 804 CHECK_EQ(subshape().element_type(), 805 primitive_util::NativeToPrimitiveType<NativeT>()) 806 << "Attempting to access " 807 << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>()) 808 << " type, but literal element type is " 809 << PrimitiveType_Name(subshape().element_type()); 810 return tensorflow::gtl::ArraySlice<NativeT>( 811 reinterpret_cast<const NativeT*>(buffer()), 812 ShapeUtil::ElementsIn(subshape())); 813 } 814 815 template <typename NativeT> 816 tensorflow::gtl::MutableArraySlice<NativeT> Literal::Piece::data() { 817 CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); 818 CHECK_EQ(subshape().element_type(), 819 primitive_util::NativeToPrimitiveType<NativeT>()) 820 << "Attempting to access " 821 << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>()) 822 << " type, but literal element type is " 823 << PrimitiveType_Name(subshape().element_type()); 824 return tensorflow::gtl::MutableArraySlice<NativeT>( 825 reinterpret_cast<NativeT*>(buffer()), ShapeUtil::ElementsIn(subshape())); 826 } 827 828 template <typename NativeT> 829 NativeT Literal::Piece::Get( 830 tensorflow::gtl::ArraySlice<int64> multi_index) const { 831 CHECK(LayoutUtil::IsDenseArray(subshape())); 832 return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex( 833 subshape(), multi_index)]; 834 } 835 836 template <typename NativeT> 837 void Literal::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index, 838 NativeT value) { 839 CHECK(LayoutUtil::IsDenseArray(subshape())); 840 data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex( 841 subshape(), multi_index)] = value; 842 } 843 844 template <typename NativeT> 845 tensorflow::gtl::ArraySlice<NativeT> Literal::data( 846 const ShapeIndex& shape_index) const { 847 return piece(shape_index).data<NativeT>(); 848 } 849 850 template <typename NativeT> 851 tensorflow::gtl::MutableArraySlice<NativeT> Literal::data( 852 const ShapeIndex& shape_index) { 853 return piece(shape_index).data<NativeT>(); 854 } 855 856 template <typename NativeT> 857 inline NativeT Literal::Get(tensorflow::gtl::ArraySlice<int64> multi_index, 858 const ShapeIndex& shape_index) const { 859 return piece(shape_index).Get<NativeT>(multi_index); 860 } 861 862 template <typename NativeT> 863 inline NativeT Literal::Get( 864 tensorflow::gtl::ArraySlice<int64> multi_index) const { 865 return root_piece().Get<NativeT>(multi_index); 866 } 867 868 template <typename NativeT> 869 inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index, 870 const ShapeIndex& shape_index, NativeT value) { 871 return piece(shape_index).Set<NativeT>(multi_index, value); 872 } 873 874 template <typename NativeT> 875 inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index, 876 NativeT value) { 877 return root_piece().Set<NativeT>(multi_index, value); 878 } 879 880 template <typename NativeT> 881 /* static */ std::unique_ptr<Literal> Literal::CreateR0(NativeT value) { 882 auto literal = MakeUnique<Literal>(ShapeUtil::MakeShape( 883 primitive_util::NativeToPrimitiveType<NativeT>(), {})); 884 literal->Set({}, value); 885 return literal; 886 } 887 888 template <typename NativeT> 889 /* static */ std::unique_ptr<Literal> Literal::CreateR1( 890 tensorflow::gtl::ArraySlice<NativeT> values) { 891 auto literal = MakeUnique<Literal>( 892 ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(), 893 {static_cast<int64>(values.size())})); 894 literal->PopulateR1(values); 895 return literal; 896 } 897 898 template <typename NativeT> 899 /* static */ std::unique_ptr<Literal> Literal::CreateR2WithLayout( 900 std::initializer_list<std::initializer_list<NativeT>> values, 901 const Layout& layout) { 902 auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout( 903 primitive_util::NativeToPrimitiveType<NativeT>(), 904 {static_cast<int64>(values.size()), 905 static_cast<int64>(values.begin()->size())}, 906 AsInt64Slice(layout.minor_to_major()))); 907 literal->PopulateR2(values); 908 return literal; 909 } 910 911 template <typename NativeT> 912 /* static */ std::unique_ptr<Literal> Literal::CreateR2( 913 std::initializer_list<std::initializer_list<NativeT>> values) { 914 return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); 915 } 916 917 template <typename NativeT> 918 /* static */ std::unique_ptr<Literal> Literal::CreateR3WithLayout( 919 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> 920 values, 921 const Layout& layout) { 922 const int64 d0 = values.size(); 923 const int64 d1 = values.begin()->size(); 924 const int64 d2 = values.begin()->begin()->size(); 925 Array3D<NativeT> tmp(d0, d1, d2); 926 int64 i0 = 0; 927 for (auto d1_values : values) { 928 int64 i1 = 0; 929 for (auto d2_values : d1_values) { 930 int64 i2 = 0; 931 for (auto value : d2_values) { 932 tmp(i0, i1, i2) = value; 933 ++i2; 934 } 935 ++i1; 936 } 937 ++i0; 938 } 939 return CreateR3FromArray3DWithLayout(tmp, layout); 940 } 941 942 template <typename NativeT> 943 /* static */ std::unique_ptr<Literal> Literal::CreateR3( 944 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> 945 values) { 946 return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); 947 } 948 949 template <typename NativeT> 950 /* static */ std::unique_ptr<Literal> Literal::CreateR4WithLayout( 951 std::initializer_list<std::initializer_list< 952 std::initializer_list<std::initializer_list<NativeT>>>> 953 values, 954 const Layout& layout) { 955 const int64 d0 = values.size(); 956 const int64 d1 = values.begin()->size(); 957 const int64 d2 = values.begin()->begin()->size(); 958 const int64 d3 = values.begin()->begin()->begin()->size(); 959 Array4D<NativeT> tmp(d0, d1, d2, d3); 960 int64 i0 = 0; 961 for (auto d1_values : values) { 962 int64 i1 = 0; 963 for (auto d2_values : d1_values) { 964 int64 i2 = 0; 965 for (auto d3_values : d2_values) { 966 int64 i3 = 0; 967 for (auto value : d3_values) { 968 tmp(i0, i1, i2, i3) = value; 969 ++i3; 970 } 971 ++i2; 972 } 973 ++i1; 974 } 975 ++i0; 976 } 977 return CreateR4FromArray4DWithLayout(tmp, layout); 978 } 979 980 template <typename NativeT> 981 /* static */ std::unique_ptr<Literal> Literal::CreateSparse( 982 tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices, 983 tensorflow::gtl::ArraySlice<NativeT> values, bool sort) { 984 int64 num_elements = values.size(); 985 int64 rank = dimensions.size(); 986 CHECK_EQ(num_elements, indices.index_count()); 987 CHECK_EQ(rank, indices.rank()); 988 auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithSparseLayout( 989 primitive_util::NativeToPrimitiveType<NativeT>(), dimensions, 990 indices.max_indices())); 991 literal->PopulateSparse(indices, values, sort); 992 return literal; 993 } 994 995 template <typename NativeT> 996 /* static */ std::unique_ptr<Literal> Literal::CreateR4( 997 std::initializer_list<std::initializer_list< 998 std::initializer_list<std::initializer_list<NativeT>>>> 999 values) { 1000 return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); 1001 } 1002 1003 template <typename NativeT> 1004 /* static */ std::unique_ptr<Literal> Literal::CreateFromArrayWithLayout( 1005 const Array<NativeT>& values, const Layout& layout) { 1006 auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout( 1007 primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(), 1008 AsInt64Slice(layout.minor_to_major()))); 1009 literal->PopulateFromArray(values); 1010 return literal; 1011 } 1012 1013 template <typename NativeT> 1014 /* static */ std::unique_ptr<Literal> Literal::CreateFromArray( 1015 const Array<NativeT>& values) { 1016 return CreateFromArrayWithLayout( 1017 values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); 1018 } 1019 1020 template <typename NativeT> 1021 /* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout( 1022 const Array2D<NativeT>& values, const Layout& layout) { 1023 return CreateFromArrayWithLayout(values, layout); 1024 } 1025 1026 template <typename NativeT> 1027 /* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2D( 1028 const Array2D<NativeT>& values) { 1029 return CreateFromArray(values); 1030 } 1031 1032 template <typename NativeT> 1033 /* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3DWithLayout( 1034 const Array3D<NativeT>& values, const Layout& layout) { 1035 return CreateFromArrayWithLayout(values, layout); 1036 } 1037 1038 template <typename NativeT> 1039 /* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3D( 1040 const Array3D<NativeT>& values) { 1041 return CreateFromArray(values); 1042 } 1043 1044 template <typename NativeT> 1045 /* static */ std::unique_ptr<Literal> Literal::CreateR3Projected( 1046 std::initializer_list<std::initializer_list<NativeT>> values, 1047 int64 projection) { 1048 int64 dim0_size = projection; 1049 int64 dim1_size = values.size(); 1050 int64 dim2_size = values.begin()->size(); 1051 1052 Array3D<NativeT> array(dim0_size, dim1_size, dim2_size); 1053 for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { 1054 int64 dim1 = 0; 1055 for (auto inner_list : values) { 1056 int64 dim2 = 0; 1057 for (auto value : inner_list) { 1058 array(dim0, dim1, dim2) = value; 1059 ++dim2; 1060 } 1061 CHECK_EQ(dim2_size, dim2); 1062 ++dim1; 1063 } 1064 CHECK_EQ(dim1_size, dim1); 1065 } 1066 return CreateR3FromArray3D(array); 1067 } 1068 1069 template <typename NativeT> 1070 /* static */ std::unique_ptr<Literal> Literal::CreateR4Projected( 1071 std::initializer_list<std::initializer_list<NativeT>> values, 1072 int64 projection_p, int64 projection_z) { 1073 int64 dim0_size = projection_p; 1074 int64 dim1_size = projection_z; 1075 int64 dim2_size = values.size(); 1076 int64 dim3_size = values.begin()->size(); 1077 1078 Array4D<NativeT> array(dim0_size, dim1_size, dim2_size, dim3_size); 1079 for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { 1080 for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) { 1081 int64 dim2 = 0; 1082 for (auto inner_list : values) { 1083 int64 dim3 = 0; 1084 for (auto value : inner_list) { 1085 array(dim0, dim1, dim2, dim3) = value; 1086 ++dim3; 1087 } 1088 CHECK_EQ(dim3_size, dim3); 1089 ++dim2; 1090 } 1091 CHECK_EQ(dim2_size, dim2); 1092 } 1093 } 1094 return CreateR4FromArray4D(array); 1095 } 1096 1097 template <typename NativeT> 1098 /* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4D( 1099 const Array4D<NativeT>& values) { 1100 return CreateFromArray(values); 1101 } 1102 1103 template <typename NativeT> 1104 /* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4DWithLayout( 1105 const Array4D<NativeT>& values, const Layout& layout) { 1106 return CreateFromArrayWithLayout(values, layout); 1107 } 1108 1109 template <typename NativeT> 1110 NativeT Literal::GetFirstElement() const { 1111 return data<NativeT>().at(0); 1112 } 1113 1114 template <typename NativeT> 1115 NativeT Literal::GetSparseElement(int64 sparse_element_number, 1116 const ShapeIndex& shape_index) const { 1117 CHECK( 1118 LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); 1119 return data<NativeT>(shape_index)[sparse_element_number]; 1120 } 1121 1122 template <typename NativeT> 1123 void Literal::AppendSparseElement( 1124 tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value, 1125 const ShapeIndex& shape_index) { 1126 Piece& p = piece(shape_index); 1127 const Shape& subshape = p.subshape(); 1128 CHECK(LayoutUtil::IsSparseArray(subshape)); 1129 int64 rank = ShapeUtil::Rank(subshape); 1130 CHECK_EQ(multi_index.size(), rank); 1131 int64 last_element = p.sparse_indices()->index_count(); 1132 CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout())); 1133 p.sparse_indices()->Append(multi_index); 1134 CHECK_LT(last_element, p.data<NativeT>().size()); 1135 p.data<NativeT>()[last_element] = value; 1136 } 1137 1138 // Returns an identity matrix (rank 2) with the given row and column count. 1139 template <typename NativeT> 1140 /* static */ std::unique_ptr<Literal> Literal::MakeIdentityR2(int64 size) { 1141 Array2D<NativeT> array(size, size, 0); 1142 for (int64 i = 0; i < size; ++i) { 1143 array(i, i) = 1; 1144 } 1145 return CreateR2FromArray2D(array); 1146 } 1147 1148 template <typename NativeT> 1149 void Literal::EachCell( 1150 std::function<void(tensorflow::gtl::ArraySlice<int64> indices, 1151 NativeT value)> 1152 per_cell) const { 1153 if (ShapeUtil::HasZeroElements(shape())) { 1154 return; 1155 } 1156 std::vector<int64> indices(ShapeUtil::Rank(shape()), 0); 1157 do { 1158 per_cell(indices, Get<NativeT>(indices)); 1159 } while (IndexUtil::BumpIndices(shape(), &indices)); 1160 } 1161 1162 template <typename NativeT> 1163 inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) { 1164 CHECK(ShapeUtil::IsArray(shape())); 1165 CHECK_EQ(ShapeUtil::Rank(shape()), 1); 1166 CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); 1167 CHECK_EQ(shape().element_type(), 1168 primitive_util::NativeToPrimitiveType<NativeT>()); 1169 for (int64 i = 0; i < values.size(); ++i) { 1170 Set({i}, values[i]); 1171 } 1172 } 1173 1174 template <typename NativeT> 1175 void Literal::PopulateR2( 1176 std::initializer_list<std::initializer_list<NativeT>> values) { 1177 CHECK(ShapeUtil::IsArray(shape())); 1178 CHECK_EQ(ShapeUtil::Rank(shape()), 2); 1179 CHECK_EQ(shape().element_type(), 1180 primitive_util::NativeToPrimitiveType<NativeT>()); 1181 1182 const int64 dim0_size = values.size(); 1183 const int64 dim1_size = values.begin()->size(); 1184 CHECK_EQ(dim0_size, shape().dimensions(0)); 1185 CHECK_EQ(dim1_size, shape().dimensions(1)); 1186 1187 int64 dim0 = 0; 1188 for (auto inner_list : values) { 1189 int64 dim1 = 0; 1190 for (auto value : inner_list) { 1191 Set({dim0, dim1}, value); 1192 ++dim1; 1193 } 1194 CHECK_EQ(dim1_size, dim1); 1195 ++dim0; 1196 } 1197 } 1198 1199 template <typename NativeT> 1200 void Literal::PopulateFromArray(const Array<NativeT>& values) { 1201 CHECK(ShapeUtil::IsArray(shape())); 1202 CHECK_EQ(shape().element_type(), 1203 primitive_util::NativeToPrimitiveType<NativeT>()); 1204 CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions()); 1205 for (int dim = 0; dim < values.num_dimensions(); ++dim) { 1206 CHECK_EQ(values.dim(dim), shape().dimensions(dim)); 1207 } 1208 values.Each([this](tensorflow::gtl::ArraySlice<int64> indices, 1209 NativeT value) { this->Set(indices, value); }); 1210 } 1211 1212 template <typename NativeT> 1213 void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) { 1214 PopulateFromArray(values); 1215 } 1216 1217 template <typename NativeT> 1218 void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) { 1219 PopulateFromArray(values); 1220 } 1221 1222 template <typename NativeT> 1223 void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) { 1224 PopulateFromArray(values); 1225 } 1226 1227 template <typename NativeT> 1228 void Literal::PopulateSparse(SparseIndexArray indices, 1229 tensorflow::gtl::ArraySlice<NativeT> values, 1230 bool sort) { 1231 CHECK(LayoutUtil::IsSparseArray(shape())); 1232 int rank = ShapeUtil::Rank(shape()); 1233 CHECK_EQ(indices.rank(), rank); 1234 int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout()); 1235 CHECK_LE(indices.max_indices(), max_elements); 1236 int64 num_elements = values.size(); 1237 CHECK_LE(num_elements, max_elements); 1238 CHECK_EQ(num_elements, indices.index_count()); 1239 auto root_data = root_piece().data<NativeT>(); 1240 root_data.remove_suffix(max_elements - values.size()); 1241 std::copy(values.begin(), values.end(), root_data.begin()); 1242 *this->root_piece().sparse_indices() = std::move(indices); 1243 if (sort) { 1244 auto root_data = this->root_piece().data<NativeT>(); 1245 root_data.remove_suffix(root_data.size() - num_elements); 1246 this->root_piece().sparse_indices()->SortWithValues(root_data); 1247 } 1248 DCHECK(this->root_piece().sparse_indices()->Validate(shape())); 1249 } 1250 1251 template <typename NativeT, typename FnType> 1252 Status Literal::Populate(const FnType& generator) { 1253 const Shape& this_shape = shape(); 1254 const int64 rank = ShapeUtil::Rank(this_shape); 1255 TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); 1256 TF_RET_CHECK(this_shape.element_type() == 1257 primitive_util::NativeToPrimitiveType<NativeT>()); 1258 tensorflow::gtl::MutableArraySlice<NativeT> literal_data = data<NativeT>(); 1259 if (rank > 0) { 1260 StrideConfig stride_config(this_shape, this_shape, 1261 AsInt64Slice(this_shape.dimensions())); 1262 DimensionVector minor_scan_indexes(rank, 0); 1263 int64 minor_dimension_size = 1264 ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); 1265 1266 auto init_function = [&](const std::vector<int64>& indexes) { 1267 const int64 index = 1268 IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); 1269 std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); 1270 for (int64 i = 0; i < minor_dimension_size; ++i) { 1271 minor_scan_indexes[stride_config.minor_dimension] = i; 1272 literal_data.at(index + i) = generator(minor_scan_indexes); 1273 } 1274 return true; 1275 }; 1276 ShapeUtil::ForEachIndex(this_shape, stride_config.base, 1277 stride_config.dimensions, stride_config.step, 1278 init_function); 1279 } else { 1280 // For scalars. 1281 literal_data.at(0) = generator({}); 1282 } 1283 return Status::OK(); 1284 } 1285 1286 template <typename NativeT> 1287 void Literal::PopulateWithValue(NativeT value) { 1288 CHECK(ShapeUtil::IsArray(shape())); 1289 CHECK_EQ(shape().element_type(), 1290 primitive_util::NativeToPrimitiveType<NativeT>()); 1291 for (NativeT& element : data<NativeT>()) { 1292 element = value; 1293 } 1294 } 1295 1296 template <typename NativeT> 1297 /* static */ std::unique_ptr<Literal> Literal::CreateFullWithDescendingLayout( 1298 tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) { 1299 auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout( 1300 primitive_util::NativeToPrimitiveType<NativeT>(), dimensions)); 1301 literal->PopulateWithValue(value); 1302 return literal; 1303 } 1304 1305 template <typename NativeT> 1306 std::unique_ptr<Literal> Literal::Replicate(int64 times) const { 1307 DimensionVector bounds = {times}; 1308 bounds.reserve(shape().dimensions_size() + 1); 1309 for (int64 bound : shape().dimensions()) { 1310 bounds.push_back(bound); 1311 } 1312 auto literal = 1313 MakeUnique<Literal>(ShapeUtil::MakeShape(shape().element_type(), bounds)); 1314 int64 elements = ShapeUtil::ElementsIn(literal->shape()); 1315 if (elements == 0) { 1316 return literal; 1317 } 1318 1319 DimensionVector output_indices(bounds.size(), 0); 1320 tensorflow::gtl::ArraySlice<int64> input_indices = output_indices; 1321 input_indices.remove_prefix(1); 1322 1323 bool done = false; 1324 while (!done) { 1325 const auto element = Get<NativeT>(input_indices); 1326 literal->Set<NativeT>(output_indices, element); 1327 1328 done = true; 1329 for (int n = 0; n < output_indices.size(); ++n) { 1330 ++output_indices[n]; 1331 if (output_indices[n] < bounds[n]) { 1332 done = false; 1333 break; 1334 } 1335 output_indices[n] = 0; 1336 } 1337 } 1338 return literal; 1339 } 1340 1341 } // namespace xla 1342 1343 #endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ 1344