Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Utilities for dealing with Literal protobufs.
     17 
     18 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
     19 #define TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
     20 
     21 #include <functional>
     22 #include <initializer_list>
     23 #include <iterator>
     24 #include <memory>
     25 #include <ostream>
     26 #include <string>
     27 #include <type_traits>
     28 #include <vector>
     29 
     30 #include "tensorflow/compiler/xla/array2d.h"
     31 #include "tensorflow/compiler/xla/array3d.h"
     32 #include "tensorflow/compiler/xla/array4d.h"
     33 #include "tensorflow/compiler/xla/index_util.h"
     34 #include "tensorflow/compiler/xla/layout_util.h"
     35 #include "tensorflow/compiler/xla/primitive_util.h"
     36 #include "tensorflow/compiler/xla/ptr_util.h"
     37 #include "tensorflow/compiler/xla/shape_tree.h"
     38 #include "tensorflow/compiler/xla/shape_util.h"
     39 #include "tensorflow/compiler/xla/sparse_index_array.h"
     40 #include "tensorflow/compiler/xla/status_macros.h"
     41 #include "tensorflow/compiler/xla/types.h"
     42 #include "tensorflow/compiler/xla/util.h"
     43 #include "tensorflow/compiler/xla/xla_data.pb.h"
     44 #include "tensorflow/core/lib/core/bitmap.h"
     45 #include "tensorflow/core/lib/core/status.h"
     46 #include "tensorflow/core/lib/core/stringpiece.h"
     47 #include "tensorflow/core/lib/gtl/array_slice.h"
     48 #include "tensorflow/core/platform/logging.h"
     49 #include "tensorflow/core/platform/macros.h"
     50 #include "tensorflow/core/platform/protobuf.h"
     51 #include "tensorflow/core/platform/types.h"
     52 
     53 namespace xla {
     54 
     55 // Class representing literal values in XLA.
     56 //
     57 // TODO(b/67651157): The methods in this class should be reduced to a minimal
     58 // set of methods which construct Literals and accessors methods. Other methods
     59 // which perform computation on Literals (Reshape, Slice, etc) should be moved
     60 // elsewhere, and perhaps combined with evaluator code which operates on
     61 // Literals.
     62 class Literal {
     63  public:
     64   Literal() : Literal(ShapeUtil::MakeNil()) {}
     65 
     66   // Create a literal of the given shape. The literal is allocated sufficient
     67   // memory to hold the shape. Memory is uninitialized.
     68   explicit Literal(const Shape& shape);
     69   virtual ~Literal();
     70 
     71   // Literals are moveable, but not copyable. To copy a literal use
     72   // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
     73   // of literals which can be expensive.
     74   Literal(const Literal& other) = delete;
     75   Literal& operator=(const Literal& other) = delete;
     76   Literal(Literal&& other);
     77   Literal& operator=(Literal&& other);
     78 
     79   // Literals are equal if they have compatible shapes and the same data
     80   // values. Layout is not compared.
     81   bool operator==(const Literal& other) const;
     82   bool operator!=(const Literal& other) const { return !(*this == other); }
     83 
     84   // Serialize to and from a proto.
     85   static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
     86       const LiteralProto& proto);
     87   LiteralProto ToProto() const;
     88 
     89   // Return the shape of the literal.
     90   const Shape& shape() const { return shape_; }
     91 
     92   // TODO(b/67651157): Remove this accessor. Literal users should not be able to
     93   // mutate the shape as this can produce malformed Literals.
     94   Shape* mutable_shape_do_not_use() { return &shape_; }
     95 
     96   // Returns a (Mutable)ArraySlice view of the array for this literal for the
     97   // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
     98   // given ShapeIndex is not array. See primitive_util.h for the mapping from
     99   // XLA type to native type.
    100   template <typename NativeT>
    101   tensorflow::gtl::ArraySlice<NativeT> data(
    102       const ShapeIndex& shape_index = {}) const;
    103   template <typename NativeT>
    104   tensorflow::gtl::MutableArraySlice<NativeT> data(
    105       const ShapeIndex& shape_index = {});
    106 
    107   // Returns a pointer to the sparse index array. Returns nullptr if the literal
    108   // is not a sparse array.
    109   const SparseIndexArray* sparse_indices(
    110       const ShapeIndex& shape_index = {}) const;
    111   SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
    112 
    113   // Returns a pointer to (or size of) the underlying buffer holding the array
    114   // at the given shape index. CHECKs if the subshape of the literal at the
    115   // given ShapeIndex is not array.
    116   const void* untyped_data(const ShapeIndex& shape_index = {}) const;
    117   void* untyped_data(const ShapeIndex& shape_index = {});
    118   int64 size_bytes(const ShapeIndex& shape_index = {}) const;
    119 
    120   // Creates a new literal of a given rank. To minimize ambiguity (for users
    121   // and the compiler) these CreateR[0-2] methods should explicitly specify the
    122   // native type. For example:
    123   //
    124   //  CreateR1<float>({1.0, 42.0});
    125   //  CreateR2<uint32>({{1, 2}, {3, 4}});
    126   //
    127   // The variants not ending with WithLayout use the default XLA layout for the
    128   // literal's linear representation in memory.
    129   template <typename NativeT>
    130   static std::unique_ptr<Literal> CreateR0(NativeT value);
    131   template <typename NativeT>
    132   static std::unique_ptr<Literal> CreateR1(
    133       tensorflow::gtl::ArraySlice<NativeT> values);
    134   static std::unique_ptr<Literal> CreateR1(
    135       const tensorflow::core::Bitmap& values);
    136   template <typename NativeT>
    137   static std::unique_ptr<Literal> CreateR2(
    138       std::initializer_list<std::initializer_list<NativeT>> values);
    139   template <typename NativeT>
    140   static std::unique_ptr<Literal> CreateR2WithLayout(
    141       std::initializer_list<std::initializer_list<NativeT>> values,
    142       const Layout& layout);
    143   template <typename NativeT>
    144   static std::unique_ptr<Literal> CreateR3(
    145       std::initializer_list<
    146           std::initializer_list<std::initializer_list<NativeT>>>
    147           values);
    148   template <typename NativeT>
    149   static std::unique_ptr<Literal> CreateR3WithLayout(
    150       std::initializer_list<
    151           std::initializer_list<std::initializer_list<NativeT>>>
    152           values,
    153       const Layout& layout);
    154   template <typename NativeT>
    155   static std::unique_ptr<Literal> CreateR4(
    156       std::initializer_list<std::initializer_list<
    157           std::initializer_list<std::initializer_list<NativeT>>>>
    158           values);
    159   template <typename NativeT>
    160   static std::unique_ptr<Literal> CreateR4WithLayout(
    161       std::initializer_list<std::initializer_list<
    162           std::initializer_list<std::initializer_list<NativeT>>>>
    163           values,
    164       const Layout& layout);
    165 
    166   // Returns this literal's data as a string. This literal must be a rank-1 U8
    167   // array.
    168   string GetR1U8AsString() const;
    169 
    170   // Creates a literal with a sparse layout and the given indices and values.
    171   // The shape is initialized from the given dimensions.  The minor dimension of
    172   // the indices array must equal the rank of the shape (i.e. size of the
    173   // dimensions array). The major dimension of the indices array must equal the
    174   // number of elements in the values array. The maximum number of elements in
    175   // the array is taken from the max_indices() value of the index array.
    176   //
    177   // XLA assumes that sparse literals are in sorted order for all operations. If
    178   // the `sort` argument is true, then the indices and values will be sorted
    179   // while copying them into the literal. If you have ensured that the indices
    180   // and values are already sorted, then you may set the `sort` argument to
    181   // false to skip the sorting step.
    182   //
    183   // For example:
    184   //
    185   //   CreateSparse(
    186   //     {12, 12, 12},
    187   //     SparseIndexArray(10, 3,
    188   //                      Array2D{
    189   //                        {0, 1, 2},
    190   //                        {3, 4, 5},
    191   //                        {6, 7, 8},
    192   //                        {9, 10, 11},
    193   //                      }),
    194   //     {1.0, 2.0 3.0, 4.0})
    195   //
    196   // This creates an array with shape F64[12,12,12]sparse{10}, that has the
    197   // following non-zero values:
    198   //
    199   //     [0,  1,  2]: 1.0
    200   //     [3,  4,  5]: 2.0
    201   //     [6,  7,  8]: 3.0
    202   //     [9, 10, 11]: 4.0
    203   //
    204   template <typename NativeT>
    205   static std::unique_ptr<Literal> CreateSparse(
    206       tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
    207       tensorflow::gtl::ArraySlice<NativeT> values, bool sort = true);
    208 
    209   // Populates a literal with a sparse layout with the given indices and values.
    210   // Each index in the indices array is CHECKed against the dimensions in the
    211   // literal's shape.  If sort is true, then the indices and values will be
    212   // sorted.  If sort is false, then the indices and values are assumed to
    213   // already be in sorted order.  See CreateSparse for an example of how data
    214   // are populated.
    215   template <typename NativeT>
    216   void PopulateSparse(SparseIndexArray indices,
    217                       tensorflow::gtl::ArraySlice<NativeT> values,
    218                       bool sort = true);
    219 
    220   // Creates a new Literal object with the shape specified as parameter.
    221   // The content of the literal values is the default value of the primitive
    222   // type of literal itself (0 for numeric types, and false for predicates).
    223   static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
    224 
    225   // Creates a new Literal object with its values havings the primitive_type
    226   // type, and with dimensions defined by the dimensions parameter.
    227   // The content of the literal values is the default value of the primitive
    228   // type of literal itself (0 for numeric types, and false for predicates).
    229   static std::unique_ptr<Literal> CreateFromDimensions(
    230       PrimitiveType primitive_type,
    231       tensorflow::gtl::ArraySlice<int64> dimensions);
    232 
    233   // Copy values from 'src_literal' rooted at 'src_shape_index' into this
    234   // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
    235   // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
    236   // rooted at 'src_shape_index', but need not be arrays.
    237   Status CopyFrom(const Literal& src_literal,
    238                   const ShapeIndex& dest_shape_index = {},
    239                   const ShapeIndex& src_shape_index = {});
    240 
    241   // Similar to CopyFrom, but with move semantincs. The subshape of this literal
    242   // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
    243   // (layouts and shapes must match), but need not be arrays. The memory
    244   // allocated in this literal for the subshape at dest_shape_index is
    245   // deallocated, and the respective buffers are replaced with those in
    246   // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
    247   Status MoveFrom(Literal&& src_literal,
    248                   const ShapeIndex& dest_shape_index = {});
    249 
    250   // Copies the values from src_literal, starting at src_base shape indexes,
    251   // to this literal, starting at dest_base, where the copy size in each
    252   // dimension is specified by copy_size.
    253   // The src_literal and this literal must have the same primitive type,
    254   // src_base+copy_size must fit the source literal dimensions, as well as
    255   // dest_base+copy_size must fit the destination literal dimensions.
    256   // Note: if either src_literal or this literal contains dimensions with zero
    257   // element, then copy_size must be 0 in these dimensions while the
    258   // corresponding base indices being 0.
    259   // This literal and 'src_literal' must be arrays.
    260   Status CopySliceFrom(const Literal& src_literal,
    261                        tensorflow::gtl::ArraySlice<int64> src_base,
    262                        tensorflow::gtl::ArraySlice<int64> dest_base,
    263                        tensorflow::gtl::ArraySlice<int64> copy_size);
    264 
    265   // Returns a vector containing the tuple elements of this Literal as separate
    266   // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
    267   // elements are moved into the new Literals; no data is copied. Upon return
    268   // this Literal is set to a nil shape (empty tuple)
    269   std::vector<Literal> DecomposeTuple();
    270 
    271   // This operation is the inverse of DecomposeTuple. The given elements are
    272   // moved into the tuple elements of a new tuple-shaped Literal which is
    273   // returned. Upon return, each of the Literals in 'elements' is set to a nil
    274   // shape (empty tuple).
    275   static Literal MoveIntoTuple(
    276       tensorflow::gtl::MutableArraySlice<Literal> elements);
    277 
    278   // Creates a new value that has the equivalent value as this literal, but
    279   // conforms to new_layout; e.g. a literal matrix that was in {0, 1}
    280   // minor-to-major dimension layout can be re-layed-out as {1, 0}
    281   // minor-to-major dimension layout and the value in the cell at any given
    282   // logical index (i0, i1) will be the same.
    283   //
    284   // For tuple shaped literals, shape_index should be used to select the inner
    285   // array that the new layout applies to.
    286   //
    287   // Note: this is useful when the client wants to ensure that a value placed in
    288   // the XLA allocation tracker has a particular layout; for efficiency
    289   // purposes or avoiding unimplemented operation/layout combinations.
    290   std::unique_ptr<Literal> Relayout(const Layout& new_layout,
    291                                     const ShapeIndex& shape_index = {}) const;
    292 
    293   // An overload of Relayout which changes the layout of the entire shape rather
    294   // than being limited to a single array within the shape.
    295   std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
    296 
    297   // Creates a new literal by reshaping this literal to have the given
    298   // dimensions. The total number of elements must not change; The
    299   // implementation currently only supports monotonic dim0-major layouts.
    300   // This literal must be an array.
    301   StatusOr<std::unique_ptr<Literal>> Reshape(
    302       tensorflow::gtl::ArraySlice<int64> dimensions) const;
    303 
    304   // Creates a new literal by reordering the dimensions of this literal.
    305   // The given `permutation` must be a permutation of the dimension numbers
    306   // in the original literal, and it specifies the order of the new dimensions
    307   // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
    308   // For example, a transpose call on a literal of shape [3 x 8 x 4] and
    309   // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
    310   // This literal must be an array.
    311   std::unique_ptr<Literal> Transpose(
    312       tensorflow::gtl::ArraySlice<int64> permutation) const;
    313 
    314   // Creates a sub-array from this literal by extracting the indices
    315   // [start_index, limit_index) of each dimension. The result literal has the
    316   // same rank and layout as for the given literal. The number of indices in
    317   // start_indices and limit_indices must be the rank of the literal, and the
    318   // indices follow the order of the dimensions.
    319   // This literal must be an array.
    320   std::unique_ptr<Literal> Slice(
    321       tensorflow::gtl::ArraySlice<int64> start_indices,
    322       tensorflow::gtl::ArraySlice<int64> limit_indices) const;
    323 
    324   // Creates a literal with a prepended dimension with bound "times"; e.g. a
    325   // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
    326   // literal replicated four times.
    327   // This literal must be an array.
    328   template <typename NativeT>
    329   std::unique_ptr<Literal> Replicate(int64 times) const;
    330 
    331   // Converts this literal to another primitive type. Returns an error if the
    332   // conversion is not possible. This literal must be array-shaped.
    333   StatusOr<std::unique_ptr<Literal>> Convert(
    334       PrimitiveType primitive_dest_type) const;
    335 
    336   // Creates a scalar literal value zero of the given primitive type.
    337   static Literal Zero(PrimitiveType primitive_type);
    338 
    339   // Creates a scalar literal value one of the given primitive type.
    340   static Literal One(PrimitiveType primitive_type);
    341 
    342   // Creates a scalar literal value containing the minimum value of the given
    343   // primitive type. For floating-point types, returns -inf.
    344   static Literal MinValue(PrimitiveType primitive_type);
    345 
    346   // Creates a scalar literal value containing the maximum value of the given
    347   // primitive type. For floating-point types, returns inf.
    348   static Literal MaxValue(PrimitiveType primitive_type);
    349 
    350   // Creates a literal of the given shape where each element is `value`.
    351   template <typename NativeT>
    352   static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
    353       tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value);
    354 
    355   // Creates a new literal from an Array type. The variants not ending with
    356   // WithLayout use the default XLA layout for the literal's linear
    357   // representation in memory.
    358   template <typename NativeT>
    359   static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
    360   template <typename NativeT>
    361   static std::unique_ptr<Literal> CreateFromArrayWithLayout(
    362       const Array<NativeT>& values, const Layout& layout);
    363   template <typename NativeT>
    364   static std::unique_ptr<Literal> CreateR2FromArray2D(
    365       const Array2D<NativeT>& values);
    366   template <typename NativeT>
    367   static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout(
    368       const Array2D<NativeT>& values, const Layout& layout);
    369   template <typename NativeT>
    370   static std::unique_ptr<Literal> CreateR3FromArray3D(
    371       const Array3D<NativeT>& values);
    372   template <typename NativeT>
    373   static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout(
    374       const Array3D<NativeT>& values, const Layout& layout);
    375   template <typename NativeT>
    376   static std::unique_ptr<Literal> CreateR4FromArray4D(
    377       const Array4D<NativeT>& values);
    378   template <typename NativeT>
    379   static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout(
    380       const Array4D<NativeT>& values, const Layout& layout);
    381 
    382   // Creates a new vector of U8s literal value from a string.
    383   static std::unique_ptr<Literal> CreateR1U8(tensorflow::StringPiece value);
    384 
    385   // Creates a linspace-populated literal with the given number of rows and
    386   // columns.
    387   static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to,
    388                                                       int64 rows, int64 cols);
    389 
    390   // Creates a literal that projects the (x, y) dimensions given in values into
    391   // the z dimension given by "projection".
    392   template <typename NativeT>
    393   static std::unique_ptr<Literal> CreateR3Projected(
    394       std::initializer_list<std::initializer_list<NativeT>> values,
    395       int64 projection);
    396 
    397   // Creates a literal that projects the (x, y) dimensions given in values into
    398   // the z and p dimensions given.
    399   template <typename NativeT>
    400   static std::unique_ptr<Literal> CreateR4Projected(
    401       std::initializer_list<std::initializer_list<NativeT>> values,
    402       int64 projection_p, int64 projection_z);
    403 
    404   // Clones this literal into a new Literal, or new std::unique_ptr<Literal>.
    405   Literal Clone() const;
    406   std::unique_ptr<Literal> CloneToUnique() const;
    407 
    408   // Gets or sets an element in the literal at the given index. The multi_index
    409   // is CHECKed against the dimension sizes.
    410   template <typename NativeT>
    411   NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
    412               const ShapeIndex& shape_index) const;
    413   template <typename NativeT>
    414   void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
    415            const ShapeIndex& shape_index, NativeT value);
    416 
    417   // Overloads of Get and Set for array literals. CHECKs if the literal is not
    418   // array-shaped and dense.
    419   template <typename NativeT>
    420   NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
    421   template <typename NativeT>
    422   void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
    423 
    424   // Returns the multi-index of the element in a sparse literal at the given
    425   // sparse element number.  The sparse element number is the position with in
    426   // the sparse array's list of (index, value) pairs, and is checked against the
    427   // total number of (index, value) pairs in the sparse array.
    428   tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
    429       int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
    430 
    431   // Returns the value of the element in a sparse literal at the given sparse
    432   // element number.  The sparse element number is the position with in the
    433   // sparse array's list of (index, value) pairs, and is checked against the
    434   // total number of (index, value) pairs in the sparse array.
    435   template <typename NativeT>
    436   NativeT GetSparseElement(int64 sparse_element_number,
    437                            const ShapeIndex& shape_index = {}) const;
    438 
    439   // Appends the given element to the literal.  If the elements are not appended
    440   // in sorted order, then SortSparseElements should be called before calling
    441   // other methods.  This literal must have a sparse layout.
    442   template <typename NativeT>
    443   void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
    444                            NativeT value, const ShapeIndex& shape_index = {});
    445 
    446   // Sorts the elements in a sparse array.
    447   void SortSparseElements(const ShapeIndex& shape_index = {});
    448 
    449   // Returns the element value at index (0, ..., 0), however many zeroes are
    450   // required for that index.
    451   template <typename NativeT>
    452   NativeT GetFirstElement() const;
    453 
    454   // As Get(), but determines the correct type and converts the value
    455   // into text.
    456   string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
    457                      const ShapeIndex& shape_index = {}) const;
    458 
    459   // As GetSparseElement(), but determines the correct type and converts the
    460   // value into text.
    461   string GetSparseElementAsString(int64 sparse_element_number,
    462                                   const ShapeIndex& shape_index = {}) const;
    463 
    464   // As Get(), but determines the correct type and converts the value into
    465   // int64.  This literal must be an array.
    466   StatusOr<int64> GetIntegralAsS64(
    467       tensorflow::gtl::ArraySlice<int64> multi_index) const;
    468 
    469   // Returns an identity matrix (rank 2) with the given row and column count.
    470   template <typename NativeT>
    471   static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
    472 
    473   // Returns a tuple literal composed of given literals. Data is copied from the
    474   // given elements into the returned literal.
    475   static std::unique_ptr<Literal> MakeTuple(
    476       tensorflow::gtl::ArraySlice<const Literal*> elements);
    477 
    478   // As above, but intended to be invoked with move semantics; i.e.
    479   //
    480   //  std::vector<std::unique_ptr<Literal>> elements = ...;
    481   //  auto result = Literal::MakeTupleOwned(std::move(elements));
    482   //
    483   // This would have been declared as an overload, but there is ambiguity
    484   // in invocation between the above signature and this one.
    485   static std::unique_ptr<Literal> MakeTupleOwned(
    486       std::vector<std::unique_ptr<Literal>> elements);
    487 
    488   // This overload lets you pass a braced list of unique_ptr<Literal>s to
    489   // MakeTupleOwned:
    490   //
    491   //   Literal::MakeTupleOwned(Literal::CreateR1(...), ...).
    492   //
    493   // Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>)
    494   // overload doesn't work because std::initializer_list's elements are always
    495   // const.
    496   //
    497   // The arguments to this function must all be unique_ptr<Literal>.
    498   template <typename... Ts>
    499   static std::unique_ptr<Literal> MakeTupleOwned(
    500       std::unique_ptr<Ts>... elements) {
    501     std::array<std::unique_ptr<Literal>, sizeof...(Ts)> arr{
    502         std::move(elements)...};
    503     std::vector<std::unique_ptr<Literal>> v;
    504     v.insert(v.begin(), std::make_move_iterator(arr.begin()),
    505              std::make_move_iterator(arr.end()));
    506     return MakeTupleOwned(std::move(v));
    507   }
    508 
    509   // Returns a string representation of the literal value.
    510   // Warning: this function can take minutes for multi-million element Literals.
    511   string ToString(bool print_layout = false) const;
    512 
    513   // Invokes the "per cell" callback for each element in the provided
    514   // literal with the element's indices and a string representation of
    515   // the element's value.
    516   //
    517   // This function is useful if you want a polymorphic representation
    518   // of the tensor's elements (turning it to a string for something
    519   // like representation in a protobuf).
    520   //
    521   // This literal must have a dense layout.
    522   void EachCellAsString(
    523       const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
    524                                const string& value)>& per_cell) const;
    525   template <typename NativeT>
    526   void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
    527                                    NativeT value)>
    528                     per_cell) const;
    529 
    530   // Populate this literal with the given values. Examples:
    531   //
    532   //   // Populate with floats.
    533   //   Array2D<float> float_values = ...
    534   //   literal.PopulateR2FromArray2D(values);
    535   //
    536   //   // Populate with int32s.
    537   //   literal.PopulateR2<int32>({{1, 2}, {3, 4}});
    538   //
    539   // The shape and element type of this literal must match given values. For
    540   // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
    541   // array of S32.
    542   template <typename NativeT>
    543   void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
    544   void PopulateR1(const tensorflow::core::Bitmap& values);
    545   template <typename NativeT>
    546   void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
    547   template <typename NativeT>
    548   void PopulateFromArray(const Array<NativeT>& values);
    549   template <typename NativeT>
    550   void PopulateR2FromArray2D(const Array2D<NativeT>& values);
    551   template <typename NativeT>
    552   void PopulateR3FromArray3D(const Array3D<NativeT>& values);
    553   template <typename NativeT>
    554   void PopulateR4FromArray4D(const Array4D<NativeT>& values);
    555 
    556   // Populates literal values by calling the generator function for every cell
    557   // in this literal object.
    558   //
    559   // generator must be a callable of the type
    560   // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
    561   //
    562   // This literal must have a dense layout.
    563   template <typename NativeT, typename FnType>
    564   Status Populate(const FnType& generator);
    565 
    566   // Fills this literal with the given value.
    567   template <typename NativeT>
    568   void PopulateWithValue(NativeT value);
    569 
    570   // Returns whether every element in this literal is equal to value.
    571   //
    572   // value is an int8 because we expect this to be called with small
    573   // compile-time constants (0, -1, etc.) and so that whatever value you pass
    574   // can be represented exactly by floating-point types as small as 16 bits.
    575   //
    576   // If value doesn't fit in this literal's type, returns false.  Values of 1/0
    577   // are considered equal to true/false; other values are not considered equal
    578   // to true. Also if this literal is not array-shaped false is returned.
    579   bool IsAll(int8 value) const;
    580 
    581   // Like IsAll(const Literal&, int8), except we check whether the literal is
    582   // equal to a particular floating-point number.
    583   //
    584   // If the literal is not a floating-point value, this always returns false.
    585   //
    586   // This casts value to the type of literal, then compares using ==.  The usual
    587   // admonishments about floating-point equality checks apply.  We expect you to
    588   // use this to check for values that can be expressed precisely as a float,
    589   // e.g. -0.5.  Also if this literal is not array-shaped false is returned.
    590   bool IsAllFloat(float value) const;
    591 
    592   // Like IsAll(const Literal&, int8), except we check whether the literal is
    593   // equal to a particular complex number.
    594   //
    595   // If the literal is not a complex value, this always returns false.
    596   //
    597   // This casts value to the type of literal, then compares using ==.  The usual
    598   // admonishments about floating-point equality checks apply.  We expect you to
    599   // use this to check for complex values that can be expressed precisely as
    600   // float pairs e.g. (-0.5, 1.0).
    601   //
    602   // This literal must have a dense layout.
    603   bool IsAllComplex(complex64 value) const;
    604 
    605   // Returns whether this literal is zero at the specified index. This literal
    606   // must be an array with a dense layout.
    607   bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
    608 
    609   // Return the count of the elements in the array at the given shape index in
    610   // this literal.
    611   int64 element_count(const ShapeIndex& index = {}) const {
    612     return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
    613   }
    614 
    615   // Return the count of the elements in the sparse array at the given shape
    616   // index in this literal, which will be no larger than
    617   // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
    618   int64 sparse_element_count() const;
    619 
    620  protected:
    621   // 'allocate_arrays' indicates whether to allocate memory for the arrays in
    622   // the shape. If false, buffer pointers inside of the Literal::Pieces are set
    623   // to nullptr.
    624   Literal(const Shape& shape, bool allocate_arrays);
    625 
    626   // Internal template helper for the Literal::CopySliceFrom(), matching its
    627   // arguments one by one.
    628   template <typename NativeT>
    629   Status CopySliceFromInternal(const Literal& src_literal,
    630                                tensorflow::gtl::ArraySlice<int64> src_base,
    631                                tensorflow::gtl::ArraySlice<int64> dest_base,
    632                                tensorflow::gtl::ArraySlice<int64> copy_size);
    633 
    634   // Utility structure which is used to create the optimal configuration for
    635   // a ShapeUtil::ForEachIndex() scan across two literals.
    636   struct StrideConfig {
    637     StrideConfig(const Shape& source_shape, const Shape& dest_shape,
    638                  tensorflow::gtl::ArraySlice<int64> dimensions);
    639 
    640     // The dimensions of the stride operation. Essentially every dimension
    641     // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
    642     // steps.
    643     tensorflow::gtl::ArraySlice<int64> dimensions;
    644     DimensionVector base;
    645     DimensionVector step;
    646     int64 minor_dimension = 0;
    647     // The size of the strides for source and destination. One of the two
    648     // (the one looping through its most minor dimension) will be 1, while
    649     // the other will be the stride size at the dimension matching the other
    650     // shape most minor dimension being scanned.
    651     int64 dest_stride = 1;
    652     int64 source_stride = 1;
    653     // The size of the inner loop on the most minor dimension.
    654     int64 minor_loop_size = 1;
    655   };
    656 
    657   // A data structure representing a subshape at a particular ShapeIndex within
    658   // the literal. For array-shaped ShapeIndexes, this data structure holds the
    659   // pointer to the memory allocated for the array data.
    660   class Piece {
    661    public:
    662     // Return the buffer holding the array data for this piece as an array
    663     // slice. This piece must be array-shaped.
    664     template <typename NativeT>
    665     tensorflow::gtl::ArraySlice<NativeT> data() const;
    666     template <typename NativeT>
    667     tensorflow::gtl::MutableArraySlice<NativeT> data();
    668 
    669     // Return the buffer holding the array data for this piece as a void*. This
    670     // piece must be array-shaped.
    671     void* untyped_data();
    672     const void* untyped_data() const;
    673 
    674     // Gets or sets an element in the array at the given index. The multi_index
    675     // is CHECKed against the dimension sizes of the array.  This piece must be
    676     // array-shaped.
    677     template <typename NativeT>
    678     NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
    679     template <typename NativeT>
    680     void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
    681 
    682     // Gets/sets the buffer holding the array data.
    683     char* buffer() const { return buffer_; }
    684     void set_buffer(char* buffer) { buffer_ = buffer; }
    685 
    686     // The array of multi-indices that provide the locations of non-zero
    687     // elements in a sparse array.  Only used if
    688     // LayoutUtil::IsSparseArray(shape()) is true.
    689     SparseIndexArray* sparse_indices() const { return sparse_indices_; }
    690     void set_sparse_indices(SparseIndexArray* sparse_indices) {
    691       sparse_indices_ = sparse_indices;
    692     }
    693 
    694     // Gets or sets the subshape of this piece. This reference points to a
    695     // subshape within the shape in the containing Literal (Literal::shape_).
    696     const Shape& subshape() const { return *subshape_; }
    697     void set_subshape(const Shape* subshape) { subshape_ = subshape; }
    698 
    699     // Returns the size in bytes of the buffer holding the array data.
    700     int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
    701 
    702     // Returns the number of elements in this piece's array.
    703     int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); }
    704 
    705     // Copy the data from 'src' into this piece's buffer. Shapes of this piece
    706     // and src must be compatible.
    707     Status CopyFrom(const Piece& src);
    708 
    709     // Returns true if this piece and 'other' contain the same data. This piece
    710     // and 'other' must be array-shaped and compatible.
    711     bool EqualElements(const Piece& other) const;
    712 
    713     // Writes the shape and data (if array-shaped) into the given proto.
    714     void WriteToProto(LiteralProto* proto) const;
    715 
    716     // Copies the data from the given proto into this piece. The shape of this
    717     // piece must be equal (not just compatible) to the shape of the proto.
    718     Status CopyFromProto(const LiteralProto& proto);
    719 
    720     // Sorts the elements in a sparse array.
    721     void SortSparseElements();
    722 
    723    private:
    724     // Recursive helper for EqualElements.
    725     template <typename NativeT>
    726     bool EqualElementsInternal(const Piece& other,
    727                                std::vector<int64>* multi_index) const;
    728 
    729     // Helper for SortSparseElements that has the element type as a template
    730     // parameter.
    731     template <typename NativeT>
    732     void SortSparseElementsInternal();
    733 
    734     // For array-shaped pieces, this is the buffer holding the literal data.
    735     char* buffer_ = nullptr;
    736 
    737     // For sparse arrays, this is the array of indices.
    738     SparseIndexArray* sparse_indices_ = nullptr;
    739 
    740     // The shape of piece. This points into the shape of the containing Literal
    741     // (Literal::shape_).
    742     const Shape* subshape_ = nullptr;
    743   };
    744 
    745   // Returns the piece at the given ShapeIndex.
    746   Piece& piece(const ShapeIndex& shape_index) {
    747     return *pieces_.mutable_element(shape_index);
    748   }
    749   const Piece& piece(const ShapeIndex& shape_index) const {
    750     return pieces_.element(shape_index);
    751   }
    752 
    753   // Returns the piece at the root of the shape (empty ShapeIndex).
    754   Piece& root_piece() { return piece({}); }
    755   const Piece& root_piece() const { return piece({}); }
    756 
    757   // Deallocate the buffers held by this literal (if the literal owns the
    758   // buffer).
    759   void DeallocateBuffers();
    760 
    761   Shape shape_;
    762   ShapeTree<Piece> pieces_;
    763 
    764   // Whether the buffers held in pieces_ are owned by this Literal.
    765   bool owns_buffers_;
    766 
    767   // LiteralView must access and manipulate Pieces of other Literals.
    768   friend class LiteralView;
    769 };  // namespace xla
    770 
    771 std::ostream& operator<<(std::ostream& out, const Literal& literal);
    772 
    773 // A read-only view of a Literal. A LiteralView contains pointers to buffers
    774 // owned by the viewed Literal.
    775 //
    776 // TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable
    777 // and mutable) similar to (Mutable)ArraySlice.
    778 class LiteralView : public Literal {
    779  public:
    780   // Create and return a view of the given literal rooted at the given shape
    781   // index within the given literal. A factory is used rather than a public
    782   // constructor because only const LiteralViews are supported. It's still
    783   // possible to create non-const LiteralViews via the copy constructors, but
    784   // the factory method makes it a bit less likely. Implementing literal slices
    785   // will fix this undesirable situation (b/71550060).
    786   static const LiteralView Create(const Literal& literal,
    787                                   const ShapeIndex& view_root = {});
    788 
    789   LiteralView(const LiteralView& other);
    790   LiteralView& operator=(const LiteralView& other);
    791 
    792   virtual ~LiteralView();
    793 
    794  private:
    795   LiteralView(const Literal& literal, const ShapeIndex& view_root);
    796 
    797   // Helper for the copy constructor and copy assignment operator.
    798   void CopyFrom(const LiteralView& other);
    799 };
    800 
    801 template <typename NativeT>
    802 tensorflow::gtl::ArraySlice<NativeT> Literal::Piece::data() const {
    803   CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
    804   CHECK_EQ(subshape().element_type(),
    805            primitive_util::NativeToPrimitiveType<NativeT>())
    806       << "Attempting to access "
    807       << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
    808       << " type, but literal element type is "
    809       << PrimitiveType_Name(subshape().element_type());
    810   return tensorflow::gtl::ArraySlice<NativeT>(
    811       reinterpret_cast<const NativeT*>(buffer()),
    812       ShapeUtil::ElementsIn(subshape()));
    813 }
    814 
    815 template <typename NativeT>
    816 tensorflow::gtl::MutableArraySlice<NativeT> Literal::Piece::data() {
    817   CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
    818   CHECK_EQ(subshape().element_type(),
    819            primitive_util::NativeToPrimitiveType<NativeT>())
    820       << "Attempting to access "
    821       << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
    822       << " type, but literal element type is "
    823       << PrimitiveType_Name(subshape().element_type());
    824   return tensorflow::gtl::MutableArraySlice<NativeT>(
    825       reinterpret_cast<NativeT*>(buffer()), ShapeUtil::ElementsIn(subshape()));
    826 }
    827 
    828 template <typename NativeT>
    829 NativeT Literal::Piece::Get(
    830     tensorflow::gtl::ArraySlice<int64> multi_index) const {
    831   CHECK(LayoutUtil::IsDenseArray(subshape()));
    832   return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
    833       subshape(), multi_index)];
    834 }
    835 
    836 template <typename NativeT>
    837 void Literal::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
    838                          NativeT value) {
    839   CHECK(LayoutUtil::IsDenseArray(subshape()));
    840   data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
    841       subshape(), multi_index)] = value;
    842 }
    843 
    844 template <typename NativeT>
    845 tensorflow::gtl::ArraySlice<NativeT> Literal::data(
    846     const ShapeIndex& shape_index) const {
    847   return piece(shape_index).data<NativeT>();
    848 }
    849 
    850 template <typename NativeT>
    851 tensorflow::gtl::MutableArraySlice<NativeT> Literal::data(
    852     const ShapeIndex& shape_index) {
    853   return piece(shape_index).data<NativeT>();
    854 }
    855 
    856 template <typename NativeT>
    857 inline NativeT Literal::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
    858                             const ShapeIndex& shape_index) const {
    859   return piece(shape_index).Get<NativeT>(multi_index);
    860 }
    861 
    862 template <typename NativeT>
    863 inline NativeT Literal::Get(
    864     tensorflow::gtl::ArraySlice<int64> multi_index) const {
    865   return root_piece().Get<NativeT>(multi_index);
    866 }
    867 
    868 template <typename NativeT>
    869 inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
    870                          const ShapeIndex& shape_index, NativeT value) {
    871   return piece(shape_index).Set<NativeT>(multi_index, value);
    872 }
    873 
    874 template <typename NativeT>
    875 inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
    876                          NativeT value) {
    877   return root_piece().Set<NativeT>(multi_index, value);
    878 }
    879 
    880 template <typename NativeT>
    881 /* static */ std::unique_ptr<Literal> Literal::CreateR0(NativeT value) {
    882   auto literal = MakeUnique<Literal>(ShapeUtil::MakeShape(
    883       primitive_util::NativeToPrimitiveType<NativeT>(), {}));
    884   literal->Set({}, value);
    885   return literal;
    886 }
    887 
    888 template <typename NativeT>
    889 /* static */ std::unique_ptr<Literal> Literal::CreateR1(
    890     tensorflow::gtl::ArraySlice<NativeT> values) {
    891   auto literal = MakeUnique<Literal>(
    892       ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
    893                            {static_cast<int64>(values.size())}));
    894   literal->PopulateR1(values);
    895   return literal;
    896 }
    897 
    898 template <typename NativeT>
    899 /* static */ std::unique_ptr<Literal> Literal::CreateR2WithLayout(
    900     std::initializer_list<std::initializer_list<NativeT>> values,
    901     const Layout& layout) {
    902   auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
    903       primitive_util::NativeToPrimitiveType<NativeT>(),
    904       {static_cast<int64>(values.size()),
    905        static_cast<int64>(values.begin()->size())},
    906       AsInt64Slice(layout.minor_to_major())));
    907   literal->PopulateR2(values);
    908   return literal;
    909 }
    910 
    911 template <typename NativeT>
    912 /* static */ std::unique_ptr<Literal> Literal::CreateR2(
    913     std::initializer_list<std::initializer_list<NativeT>> values) {
    914   return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
    915 }
    916 
    917 template <typename NativeT>
    918 /* static */ std::unique_ptr<Literal> Literal::CreateR3WithLayout(
    919     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
    920         values,
    921     const Layout& layout) {
    922   const int64 d0 = values.size();
    923   const int64 d1 = values.begin()->size();
    924   const int64 d2 = values.begin()->begin()->size();
    925   Array3D<NativeT> tmp(d0, d1, d2);
    926   int64 i0 = 0;
    927   for (auto d1_values : values) {
    928     int64 i1 = 0;
    929     for (auto d2_values : d1_values) {
    930       int64 i2 = 0;
    931       for (auto value : d2_values) {
    932         tmp(i0, i1, i2) = value;
    933         ++i2;
    934       }
    935       ++i1;
    936     }
    937     ++i0;
    938   }
    939   return CreateR3FromArray3DWithLayout(tmp, layout);
    940 }
    941 
    942 template <typename NativeT>
    943 /* static */ std::unique_ptr<Literal> Literal::CreateR3(
    944     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
    945         values) {
    946   return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
    947 }
    948 
    949 template <typename NativeT>
    950 /* static */ std::unique_ptr<Literal> Literal::CreateR4WithLayout(
    951     std::initializer_list<std::initializer_list<
    952         std::initializer_list<std::initializer_list<NativeT>>>>
    953         values,
    954     const Layout& layout) {
    955   const int64 d0 = values.size();
    956   const int64 d1 = values.begin()->size();
    957   const int64 d2 = values.begin()->begin()->size();
    958   const int64 d3 = values.begin()->begin()->begin()->size();
    959   Array4D<NativeT> tmp(d0, d1, d2, d3);
    960   int64 i0 = 0;
    961   for (auto d1_values : values) {
    962     int64 i1 = 0;
    963     for (auto d2_values : d1_values) {
    964       int64 i2 = 0;
    965       for (auto d3_values : d2_values) {
    966         int64 i3 = 0;
    967         for (auto value : d3_values) {
    968           tmp(i0, i1, i2, i3) = value;
    969           ++i3;
    970         }
    971         ++i2;
    972       }
    973       ++i1;
    974     }
    975     ++i0;
    976   }
    977   return CreateR4FromArray4DWithLayout(tmp, layout);
    978 }
    979 
    980 template <typename NativeT>
    981 /* static */ std::unique_ptr<Literal> Literal::CreateSparse(
    982     tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
    983     tensorflow::gtl::ArraySlice<NativeT> values, bool sort) {
    984   int64 num_elements = values.size();
    985   int64 rank = dimensions.size();
    986   CHECK_EQ(num_elements, indices.index_count());
    987   CHECK_EQ(rank, indices.rank());
    988   auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithSparseLayout(
    989       primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
    990       indices.max_indices()));
    991   literal->PopulateSparse(indices, values, sort);
    992   return literal;
    993 }
    994 
    995 template <typename NativeT>
    996 /* static */ std::unique_ptr<Literal> Literal::CreateR4(
    997     std::initializer_list<std::initializer_list<
    998         std::initializer_list<std::initializer_list<NativeT>>>>
    999         values) {
   1000   return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
   1001 }
   1002 
   1003 template <typename NativeT>
   1004 /* static */ std::unique_ptr<Literal> Literal::CreateFromArrayWithLayout(
   1005     const Array<NativeT>& values, const Layout& layout) {
   1006   auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
   1007       primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
   1008       AsInt64Slice(layout.minor_to_major())));
   1009   literal->PopulateFromArray(values);
   1010   return literal;
   1011 }
   1012 
   1013 template <typename NativeT>
   1014 /* static */ std::unique_ptr<Literal> Literal::CreateFromArray(
   1015     const Array<NativeT>& values) {
   1016   return CreateFromArrayWithLayout(
   1017       values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
   1018 }
   1019 
   1020 template <typename NativeT>
   1021 /* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout(
   1022     const Array2D<NativeT>& values, const Layout& layout) {
   1023   return CreateFromArrayWithLayout(values, layout);
   1024 }
   1025 
   1026 template <typename NativeT>
   1027 /* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2D(
   1028     const Array2D<NativeT>& values) {
   1029   return CreateFromArray(values);
   1030 }
   1031 
   1032 template <typename NativeT>
   1033 /* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3DWithLayout(
   1034     const Array3D<NativeT>& values, const Layout& layout) {
   1035   return CreateFromArrayWithLayout(values, layout);
   1036 }
   1037 
   1038 template <typename NativeT>
   1039 /* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3D(
   1040     const Array3D<NativeT>& values) {
   1041   return CreateFromArray(values);
   1042 }
   1043 
   1044 template <typename NativeT>
   1045 /* static */ std::unique_ptr<Literal> Literal::CreateR3Projected(
   1046     std::initializer_list<std::initializer_list<NativeT>> values,
   1047     int64 projection) {
   1048   int64 dim0_size = projection;
   1049   int64 dim1_size = values.size();
   1050   int64 dim2_size = values.begin()->size();
   1051 
   1052   Array3D<NativeT> array(dim0_size, dim1_size, dim2_size);
   1053   for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
   1054     int64 dim1 = 0;
   1055     for (auto inner_list : values) {
   1056       int64 dim2 = 0;
   1057       for (auto value : inner_list) {
   1058         array(dim0, dim1, dim2) = value;
   1059         ++dim2;
   1060       }
   1061       CHECK_EQ(dim2_size, dim2);
   1062       ++dim1;
   1063     }
   1064     CHECK_EQ(dim1_size, dim1);
   1065   }
   1066   return CreateR3FromArray3D(array);
   1067 }
   1068 
   1069 template <typename NativeT>
   1070 /* static */ std::unique_ptr<Literal> Literal::CreateR4Projected(
   1071     std::initializer_list<std::initializer_list<NativeT>> values,
   1072     int64 projection_p, int64 projection_z) {
   1073   int64 dim0_size = projection_p;
   1074   int64 dim1_size = projection_z;
   1075   int64 dim2_size = values.size();
   1076   int64 dim3_size = values.begin()->size();
   1077 
   1078   Array4D<NativeT> array(dim0_size, dim1_size, dim2_size, dim3_size);
   1079   for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
   1080     for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) {
   1081       int64 dim2 = 0;
   1082       for (auto inner_list : values) {
   1083         int64 dim3 = 0;
   1084         for (auto value : inner_list) {
   1085           array(dim0, dim1, dim2, dim3) = value;
   1086           ++dim3;
   1087         }
   1088         CHECK_EQ(dim3_size, dim3);
   1089         ++dim2;
   1090       }
   1091       CHECK_EQ(dim2_size, dim2);
   1092     }
   1093   }
   1094   return CreateR4FromArray4D(array);
   1095 }
   1096 
   1097 template <typename NativeT>
   1098 /* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4D(
   1099     const Array4D<NativeT>& values) {
   1100   return CreateFromArray(values);
   1101 }
   1102 
   1103 template <typename NativeT>
   1104 /* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4DWithLayout(
   1105     const Array4D<NativeT>& values, const Layout& layout) {
   1106   return CreateFromArrayWithLayout(values, layout);
   1107 }
   1108 
   1109 template <typename NativeT>
   1110 NativeT Literal::GetFirstElement() const {
   1111   return data<NativeT>().at(0);
   1112 }
   1113 
   1114 template <typename NativeT>
   1115 NativeT Literal::GetSparseElement(int64 sparse_element_number,
   1116                                   const ShapeIndex& shape_index) const {
   1117   CHECK(
   1118       LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index)));
   1119   return data<NativeT>(shape_index)[sparse_element_number];
   1120 }
   1121 
   1122 template <typename NativeT>
   1123 void Literal::AppendSparseElement(
   1124     tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value,
   1125     const ShapeIndex& shape_index) {
   1126   Piece& p = piece(shape_index);
   1127   const Shape& subshape = p.subshape();
   1128   CHECK(LayoutUtil::IsSparseArray(subshape));
   1129   int64 rank = ShapeUtil::Rank(subshape);
   1130   CHECK_EQ(multi_index.size(), rank);
   1131   int64 last_element = p.sparse_indices()->index_count();
   1132   CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout()));
   1133   p.sparse_indices()->Append(multi_index);
   1134   CHECK_LT(last_element, p.data<NativeT>().size());
   1135   p.data<NativeT>()[last_element] = value;
   1136 }
   1137 
   1138 // Returns an identity matrix (rank 2) with the given row and column count.
   1139 template <typename NativeT>
   1140 /* static */ std::unique_ptr<Literal> Literal::MakeIdentityR2(int64 size) {
   1141   Array2D<NativeT> array(size, size, 0);
   1142   for (int64 i = 0; i < size; ++i) {
   1143     array(i, i) = 1;
   1144   }
   1145   return CreateR2FromArray2D(array);
   1146 }
   1147 
   1148 template <typename NativeT>
   1149 void Literal::EachCell(
   1150     std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
   1151                        NativeT value)>
   1152         per_cell) const {
   1153   if (ShapeUtil::HasZeroElements(shape())) {
   1154     return;
   1155   }
   1156   std::vector<int64> indices(ShapeUtil::Rank(shape()), 0);
   1157   do {
   1158     per_cell(indices, Get<NativeT>(indices));
   1159   } while (IndexUtil::BumpIndices(shape(), &indices));
   1160 }
   1161 
   1162 template <typename NativeT>
   1163 inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
   1164   CHECK(ShapeUtil::IsArray(shape()));
   1165   CHECK_EQ(ShapeUtil::Rank(shape()), 1);
   1166   CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
   1167   CHECK_EQ(shape().element_type(),
   1168            primitive_util::NativeToPrimitiveType<NativeT>());
   1169   for (int64 i = 0; i < values.size(); ++i) {
   1170     Set({i}, values[i]);
   1171   }
   1172 }
   1173 
   1174 template <typename NativeT>
   1175 void Literal::PopulateR2(
   1176     std::initializer_list<std::initializer_list<NativeT>> values) {
   1177   CHECK(ShapeUtil::IsArray(shape()));
   1178   CHECK_EQ(ShapeUtil::Rank(shape()), 2);
   1179   CHECK_EQ(shape().element_type(),
   1180            primitive_util::NativeToPrimitiveType<NativeT>());
   1181 
   1182   const int64 dim0_size = values.size();
   1183   const int64 dim1_size = values.begin()->size();
   1184   CHECK_EQ(dim0_size, shape().dimensions(0));
   1185   CHECK_EQ(dim1_size, shape().dimensions(1));
   1186 
   1187   int64 dim0 = 0;
   1188   for (auto inner_list : values) {
   1189     int64 dim1 = 0;
   1190     for (auto value : inner_list) {
   1191       Set({dim0, dim1}, value);
   1192       ++dim1;
   1193     }
   1194     CHECK_EQ(dim1_size, dim1);
   1195     ++dim0;
   1196   }
   1197 }
   1198 
   1199 template <typename NativeT>
   1200 void Literal::PopulateFromArray(const Array<NativeT>& values) {
   1201   CHECK(ShapeUtil::IsArray(shape()));
   1202   CHECK_EQ(shape().element_type(),
   1203            primitive_util::NativeToPrimitiveType<NativeT>());
   1204   CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions());
   1205   for (int dim = 0; dim < values.num_dimensions(); ++dim) {
   1206     CHECK_EQ(values.dim(dim), shape().dimensions(dim));
   1207   }
   1208   values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
   1209                      NativeT value) { this->Set(indices, value); });
   1210 }
   1211 
   1212 template <typename NativeT>
   1213 void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
   1214   PopulateFromArray(values);
   1215 }
   1216 
   1217 template <typename NativeT>
   1218 void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
   1219   PopulateFromArray(values);
   1220 }
   1221 
   1222 template <typename NativeT>
   1223 void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
   1224   PopulateFromArray(values);
   1225 }
   1226 
   1227 template <typename NativeT>
   1228 void Literal::PopulateSparse(SparseIndexArray indices,
   1229                              tensorflow::gtl::ArraySlice<NativeT> values,
   1230                              bool sort) {
   1231   CHECK(LayoutUtil::IsSparseArray(shape()));
   1232   int rank = ShapeUtil::Rank(shape());
   1233   CHECK_EQ(indices.rank(), rank);
   1234   int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout());
   1235   CHECK_LE(indices.max_indices(), max_elements);
   1236   int64 num_elements = values.size();
   1237   CHECK_LE(num_elements, max_elements);
   1238   CHECK_EQ(num_elements, indices.index_count());
   1239   auto root_data = root_piece().data<NativeT>();
   1240   root_data.remove_suffix(max_elements - values.size());
   1241   std::copy(values.begin(), values.end(), root_data.begin());
   1242   *this->root_piece().sparse_indices() = std::move(indices);
   1243   if (sort) {
   1244     auto root_data = this->root_piece().data<NativeT>();
   1245     root_data.remove_suffix(root_data.size() - num_elements);
   1246     this->root_piece().sparse_indices()->SortWithValues(root_data);
   1247   }
   1248   DCHECK(this->root_piece().sparse_indices()->Validate(shape()));
   1249 }
   1250 
   1251 template <typename NativeT, typename FnType>
   1252 Status Literal::Populate(const FnType& generator) {
   1253   const Shape& this_shape = shape();
   1254   const int64 rank = ShapeUtil::Rank(this_shape);
   1255   TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
   1256   TF_RET_CHECK(this_shape.element_type() ==
   1257                primitive_util::NativeToPrimitiveType<NativeT>());
   1258   tensorflow::gtl::MutableArraySlice<NativeT> literal_data = data<NativeT>();
   1259   if (rank > 0) {
   1260     StrideConfig stride_config(this_shape, this_shape,
   1261                                AsInt64Slice(this_shape.dimensions()));
   1262     DimensionVector minor_scan_indexes(rank, 0);
   1263     int64 minor_dimension_size =
   1264         ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
   1265 
   1266     auto init_function = [&](const std::vector<int64>& indexes) {
   1267       const int64 index =
   1268           IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
   1269       std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
   1270       for (int64 i = 0; i < minor_dimension_size; ++i) {
   1271         minor_scan_indexes[stride_config.minor_dimension] = i;
   1272         literal_data.at(index + i) = generator(minor_scan_indexes);
   1273       }
   1274       return true;
   1275     };
   1276     ShapeUtil::ForEachIndex(this_shape, stride_config.base,
   1277                             stride_config.dimensions, stride_config.step,
   1278                             init_function);
   1279   } else {
   1280     // For scalars.
   1281     literal_data.at(0) = generator({});
   1282   }
   1283   return Status::OK();
   1284 }
   1285 
   1286 template <typename NativeT>
   1287 void Literal::PopulateWithValue(NativeT value) {
   1288   CHECK(ShapeUtil::IsArray(shape()));
   1289   CHECK_EQ(shape().element_type(),
   1290            primitive_util::NativeToPrimitiveType<NativeT>());
   1291   for (NativeT& element : data<NativeT>()) {
   1292     element = value;
   1293   }
   1294 }
   1295 
   1296 template <typename NativeT>
   1297 /* static */ std::unique_ptr<Literal> Literal::CreateFullWithDescendingLayout(
   1298     tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
   1299   auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
   1300       primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
   1301   literal->PopulateWithValue(value);
   1302   return literal;
   1303 }
   1304 
   1305 template <typename NativeT>
   1306 std::unique_ptr<Literal> Literal::Replicate(int64 times) const {
   1307   DimensionVector bounds = {times};
   1308   bounds.reserve(shape().dimensions_size() + 1);
   1309   for (int64 bound : shape().dimensions()) {
   1310     bounds.push_back(bound);
   1311   }
   1312   auto literal =
   1313       MakeUnique<Literal>(ShapeUtil::MakeShape(shape().element_type(), bounds));
   1314   int64 elements = ShapeUtil::ElementsIn(literal->shape());
   1315   if (elements == 0) {
   1316     return literal;
   1317   }
   1318 
   1319   DimensionVector output_indices(bounds.size(), 0);
   1320   tensorflow::gtl::ArraySlice<int64> input_indices = output_indices;
   1321   input_indices.remove_prefix(1);
   1322 
   1323   bool done = false;
   1324   while (!done) {
   1325     const auto element = Get<NativeT>(input_indices);
   1326     literal->Set<NativeT>(output_indices, element);
   1327 
   1328     done = true;
   1329     for (int n = 0; n < output_indices.size(); ++n) {
   1330       ++output_indices[n];
   1331       if (output_indices[n] < bounds[n]) {
   1332         done = false;
   1333         break;
   1334       }
   1335       output_indices[n] = 0;
   1336     }
   1337   }
   1338   return literal;
   1339 }
   1340 
   1341 }  // namespace xla
   1342 
   1343 #endif  // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
   1344