Home | History | Annotate | Download | only in xla
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/literal_util.h"
     17 
     18 #include <algorithm>
     19 #include <cstring>
     20 #include <functional>
     21 #include <limits>
     22 #include <numeric>
     23 #include <vector>
     24 
     25 #include "tensorflow/compiler/xla/index_util.h"
     26 #include "tensorflow/compiler/xla/shape_util.h"
     27 #include "tensorflow/compiler/xla/status_macros.h"
     28 #include "tensorflow/compiler/xla/types.h"
     29 #include "tensorflow/compiler/xla/util.h"
     30 #include "tensorflow/core/lib/core/casts.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/lib/strings/str_util.h"
     33 #include "tensorflow/core/lib/strings/strcat.h"
     34 #include "tensorflow/core/lib/strings/stringprintf.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 #include "tensorflow/core/platform/types.h"
     37 
     38 using tensorflow::strings::Printf;
     39 using tensorflow::strings::StrCat;
     40 
     41 namespace xla {
     42 
     43 namespace {
     44 
     45 constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
     46 
     47 // Converts between little and big endian, assuming elements in the array are 16
     48 // bits long.
     49 void ConvertEndianShort(char* bytes, int64 size) {
     50   CHECK_EQ(size / 2, 0);
     51   for (int64 i = 0; i < size; i += 2) {
     52     std::swap(bytes[i], bytes[i + 1]);
     53   }
     54 }
     55 
     56 }  // namespace
     57 
     58 std::ostream& operator<<(std::ostream& out, const Literal& literal) {
     59   out << literal.ToString();
     60   return out;
     61 }
     62 
     63 Literal::StrideConfig::StrideConfig(
     64     const Shape& source_shape, const Shape& dest_shape,
     65     tensorflow::gtl::ArraySlice<int64> dimensions)
     66     : dimensions(dimensions),
     67       base(dimensions.size(), 0),
     68       step(dimensions.size(), 1) {
     69   if (!dimensions.empty()) {
     70     // Selects the shape with the largest minor dimension as the one upon
     71     // which to run the tight stride loop.
     72     if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >=
     73         dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) {
     74       minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0);
     75       dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension);
     76     } else {
     77       minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0);
     78       source_stride =
     79           IndexUtil::GetDimensionStride(source_shape, minor_dimension);
     80     }
     81     minor_loop_size = dimensions[minor_dimension];
     82     step[minor_dimension] = minor_loop_size;
     83   }
     84 }
     85 
     86 Literal::Literal(const Shape& shape)
     87     : Literal(shape, /*allocate_arrays=*/true) {}
     88 
     89 Literal::Literal(const Shape& shape, bool allocate_arrays)
     90     : shape_(shape), pieces_(shape), owns_buffers_(true) {
     91   CHECK(LayoutUtil::HasLayout(shape));
     92   for (auto& pair : pieces_) {
     93     const ShapeIndex& index = pair.first;
     94     Piece& piece = pair.second;
     95 
     96     piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
     97     const Shape& subshape = piece.subshape();
     98     if (ShapeUtil::IsArray(subshape)) {
     99       if (allocate_arrays) {
    100         piece.set_buffer(new char[piece.size_bytes()]);
    101         if (LayoutUtil::IsSparseArray(subshape)) {
    102           piece.set_sparse_indices(new SparseIndexArray(
    103               LayoutUtil::MaxSparseElements(subshape.layout()),
    104               ShapeUtil::Rank(subshape)));
    105         }
    106       } else {
    107         piece.set_buffer(nullptr);
    108       }
    109     }
    110   }
    111 }
    112 
    113 Literal::~Literal() { DeallocateBuffers(); }
    114 
    115 void Literal::DeallocateBuffers() {
    116   if (owns_buffers_) {
    117     for (auto& pair : pieces_) {
    118       Piece& piece = pair.second;
    119       if (piece.buffer() != nullptr) {
    120         delete[] piece.buffer();
    121         delete piece.sparse_indices();
    122       }
    123     }
    124   }
    125 }
    126 
    127 Literal::Literal(Literal&& other) {
    128   shape_ = std::move(other.shape_);
    129   pieces_ = std::move(other.pieces_);
    130   // We need to iterate through the pieces to set the subshape pointer
    131   // properly. It must refer to subshapes within shape_.
    132   for (auto& pair : pieces_) {
    133     const ShapeIndex& index = pair.first;
    134     Piece& piece = pair.second;
    135     piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
    136   }
    137   owns_buffers_ = other.owns_buffers_;
    138 
    139   other.shape_ = ShapeUtil::MakeNil();
    140   other.pieces_ = ShapeTree<Piece>(other.shape_);
    141   other.piece({}).set_subshape(&other.shape_);
    142 }
    143 
    144 Literal& Literal::operator=(Literal&& other) {
    145   DeallocateBuffers();
    146   shape_ = std::move(other.shape_);
    147   pieces_ = std::move(other.pieces_);
    148   // We need to iterate through the pieces to set the subshape pointer
    149   // properly. It must refer to subshapes within shape_.
    150   for (auto& pair : pieces_) {
    151     const ShapeIndex& index = pair.first;
    152     Piece& piece = pair.second;
    153     piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
    154   }
    155   owns_buffers_ = other.owns_buffers_;
    156 
    157   other.shape_ = ShapeUtil::MakeNil();
    158   other.pieces_ = ShapeTree<Piece>(other.shape_);
    159   other.piece({}).set_subshape(&other.shape_);
    160   return *this;
    161 }
    162 
    163 std::unique_ptr<Literal> Literal::CreateFromShape(const Shape& shape) {
    164   auto literal = MakeUnique<Literal>(shape);
    165   for (auto& pair : literal->pieces_) {
    166     Piece& piece = pair.second;
    167     if (ShapeUtil::IsArray(piece.subshape())) {
    168       memset(piece.untyped_data(), 0, piece.size_bytes());
    169     }
    170   }
    171   return literal;
    172 }
    173 
    174 const SparseIndexArray* Literal::sparse_indices(
    175     const ShapeIndex& shape_index) const {
    176   return piece(shape_index).sparse_indices();
    177 }
    178 
    179 SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) {
    180   return piece(shape_index).sparse_indices();
    181 }
    182 
    183 /* static */ std::unique_ptr<Literal> Literal::CreateFromDimensions(
    184     PrimitiveType primitive_type,
    185     tensorflow::gtl::ArraySlice<int64> dimensions) {
    186   return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions));
    187 }
    188 
    189 template <typename NativeT>
    190 Status Literal::CopySliceFromInternal(
    191     const Literal& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
    192     tensorflow::gtl::ArraySlice<int64> dest_base,
    193     tensorflow::gtl::ArraySlice<int64> copy_size) {
    194   TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size());
    195   TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size());
    196 
    197   auto linear_index = [](const Shape& shape,
    198                          tensorflow::gtl::ArraySlice<int64> multi_index) {
    199     return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
    200   };
    201 
    202   if (ShapeUtil::Rank(src_literal.shape()) == 0 ||
    203       ShapeUtil::Rank(shape()) == 0) {
    204     // If any of the two shapes are scalars, we can just call the StridedCopy()
    205     // directly, and we know we will be copying only one value.
    206     TF_RET_CHECK(copy_size.empty());
    207     StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
    208                 src_literal.data<NativeT>(),
    209                 linear_index(src_literal.shape(), src_base), 0, 1);
    210   } else if (!ShapeUtil::HasZeroElements(shape()) &&
    211              !ShapeUtil::HasZeroElements(src_literal.shape())) {
    212     // Perform copy if neither src nor dest has dimensions with zero element,
    213     // otherwise it's a no-op.
    214     TF_RET_CHECK(src_base.size() == dest_base.size());
    215     TF_RET_CHECK(src_base.size() == copy_size.size());
    216 
    217     // Scan the source from minor, stepping in copy size blocks, then within
    218     // the index enumaration functor, do a strided copy advancing source index
    219     // by one (walking through the minor dimension), and destination index by
    220     // proper stride size at the matching dimension.
    221     DimensionVector src_indexes(src_base.size(), 0);
    222     DimensionVector dest_indexes(dest_base.size(), 0);
    223     Literal::StrideConfig stride_config(src_literal.shape(), shape(),
    224                                         copy_size);
    225 
    226     auto copy_proc = [&](const std::vector<int64>& indexes) {
    227       // Map from multi-dimensional index, to source index.
    228       std::transform(indexes.begin(), indexes.end(), src_base.begin(),
    229                      src_indexes.begin(), std::plus<int64>());
    230       // Map from multi-dimensional index, to destination index.
    231       std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
    232                      dest_indexes.begin(), std::plus<int64>());
    233 
    234       int64 src_index = linear_index(src_literal.shape(), src_indexes);
    235       int64 dest_index = linear_index(shape(), dest_indexes);
    236 
    237       // `this->` is needed to workaround MSVC bug: #16882
    238       StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
    239                   src_literal.data<NativeT>(), src_index,
    240                   stride_config.source_stride, stride_config.minor_loop_size);
    241       return true;
    242     };
    243 
    244     ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
    245                             stride_config.dimensions, stride_config.step,
    246                             copy_proc);
    247   }
    248   return Status::OK();
    249 }
    250 
    251 std::vector<Literal> Literal::DecomposeTuple() {
    252   CHECK(ShapeUtil::IsTuple(shape()));
    253   std::vector<Literal> elements;
    254   for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
    255     elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
    256                                /*allocate_arrays=*/false));
    257     Literal& element = elements.back();
    258     for (auto& pair : element.pieces_) {
    259       const ShapeIndex& index = pair.first;
    260       Piece& dest_piece = pair.second;
    261       ShapeIndex src_index = {i};
    262       for (int64 j : index) {
    263         src_index.push_back(j);
    264       }
    265       Piece& src_piece = piece(src_index);
    266 
    267       // Move the respective buffer and sparse indices over to the element
    268       // Literal.
    269       dest_piece.set_buffer(src_piece.buffer());
    270       src_piece.set_buffer(nullptr);
    271       dest_piece.set_sparse_indices(src_piece.sparse_indices());
    272       src_piece.set_sparse_indices(nullptr);
    273     }
    274   }
    275   // Set this literal to be nil-shaped.
    276   *this = Literal();
    277   return elements;
    278 }
    279 
    280 /* static */ Literal Literal::MoveIntoTuple(
    281     tensorflow::gtl::MutableArraySlice<Literal> elements) {
    282   std::vector<Shape> element_shapes;
    283   for (const Literal& element : elements) {
    284     element_shapes.push_back(element.shape());
    285   }
    286   Literal literal(ShapeUtil::MakeTupleShape(element_shapes),
    287                   /*allocate_arrays=*/false);
    288   for (int i = 0; i < elements.size(); ++i) {
    289     TF_CHECK_OK(
    290         literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
    291   }
    292   return literal;
    293 }
    294 
    295 namespace {
    296 
    297 // Copies the elements in 'src' to 'dest'. The shape and layout of the data in
    298 // the array slices are indicated by dest_shape and src_shape respectively.
    299 template <typename NativeT>
    300 void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
    301                          tensorflow::gtl::ArraySlice<NativeT> src,
    302                          const Shape& dest_shape, const Shape& src_shape) {
    303   CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
    304   if (ShapeUtil::HasZeroElements(dest_shape)) {
    305     return;
    306   }
    307   std::vector<int64> index(ShapeUtil::Rank(dest_shape));
    308   do {
    309     dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
    310         src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
    311   } while (IndexUtil::BumpIndices(dest_shape, &index));
    312 }
    313 
    314 }  // namespace
    315 
    316 Status Literal::Piece::CopyFrom(const Literal::Piece& src) {
    317   if (ShapeUtil::Equal(subshape(), src.subshape())) {
    318     // If the layouts are equal it's faster just to memcpy.
    319     memcpy(buffer(), src.buffer(), src.size_bytes());
    320   } else {
    321     TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape()));
    322     std::vector<int64> origin(ShapeUtil::Rank(subshape()), 0);
    323     switch (subshape().element_type()) {
    324 #define COPY_ELEMENTS(XLA_T, NATIVE_T)                                    \
    325   case (XLA_T):                                                           \
    326     CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
    327                                   subshape(), src.subshape());            \
    328     break;
    329       COPY_ELEMENTS(U8, uint8);
    330       COPY_ELEMENTS(U16, uint16);
    331       COPY_ELEMENTS(U32, uint32);
    332       COPY_ELEMENTS(U64, uint64);
    333       COPY_ELEMENTS(S8, int8);
    334       COPY_ELEMENTS(S16, int16);
    335       COPY_ELEMENTS(S32, int32);
    336       COPY_ELEMENTS(S64, int64);
    337       COPY_ELEMENTS(F16, half);
    338       COPY_ELEMENTS(BF16, bfloat16);
    339       COPY_ELEMENTS(F32, float);
    340       COPY_ELEMENTS(F64, double);
    341       COPY_ELEMENTS(C64, complex64);
    342       COPY_ELEMENTS(PRED, bool);
    343 #undef COPY_ELEMENTS
    344       default:
    345         return Unimplemented(
    346             "Unhandled primitive type %s",
    347             PrimitiveType_Name(subshape().element_type()).c_str());
    348     }
    349   }
    350   return Status::OK();
    351 }
    352 
    353 Status Literal::CopyFrom(const Literal& src_literal,
    354                          const ShapeIndex& dest_shape_index,
    355                          const ShapeIndex& src_shape_index) {
    356   const Shape& dest_subshape =
    357       ShapeUtil::GetSubshape(shape(), dest_shape_index);
    358   const Shape& src_subshape =
    359       ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
    360   if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
    361     return InvalidArgument(
    362         "Destination subshape incompatible with source subshape: %s vs %s",
    363         ShapeUtil::HumanString(dest_subshape).c_str(),
    364         ShapeUtil::HumanString(src_subshape).c_str());
    365   }
    366 
    367   for (auto& pair : pieces_) {
    368     const ShapeIndex& index = pair.first;
    369     Piece& piece = pair.second;
    370     if (!ShapeUtil::IsArray(piece.subshape())) {
    371       continue;
    372     }
    373 
    374     // Determine if this index is in the part of this literal that we want to
    375     // copy over from src_literal.
    376     bool in_subtree_to_copy = true;
    377     for (int i = 0; i < dest_shape_index.size(); ++i) {
    378       if (index[i] != dest_shape_index[i]) {
    379         in_subtree_to_copy = false;
    380         break;
    381       }
    382     }
    383     if (!in_subtree_to_copy) {
    384       continue;
    385     }
    386 
    387     // Construct the index of the corresponding piece in the source literal.
    388     ShapeIndex src_piece_index = src_shape_index;
    389     for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
    390       src_piece_index.push_back(index[i]);
    391     }
    392 
    393     TF_RETURN_IF_ERROR(piece.CopyFrom(src_literal.piece(src_piece_index)));
    394   }
    395   return Status::OK();
    396 }
    397 
    398 Status Literal::MoveFrom(Literal&& src_literal,
    399                          const ShapeIndex& dest_shape_index) {
    400   const Shape& dest_subshape =
    401       ShapeUtil::GetSubshape(shape(), dest_shape_index);
    402   if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
    403     return InvalidArgument(
    404         "Destination subshape not equal to source shape: %s vs %s",
    405         ShapeUtil::HumanString(dest_subshape).c_str(),
    406         ShapeUtil::HumanString(src_literal.shape()).c_str());
    407   }
    408 
    409   if (!(owns_buffers_ && src_literal.owns_buffers_)) {
    410     return InvalidArgument(
    411         "Source and destination literals must both own their buffers (ie, not "
    412         "be views)");
    413   }
    414 
    415   for (auto& pair : src_literal.pieces_) {
    416     const ShapeIndex& src_index = pair.first;
    417     Piece& src_piece = pair.second;
    418     if (!ShapeUtil::IsArray(src_piece.subshape())) {
    419       continue;
    420     }
    421 
    422     ShapeIndex dest_index = dest_shape_index;
    423     for (int64 i : src_index) {
    424       dest_index.push_back(i);
    425     }
    426     Piece& dest_piece = piece(dest_index);
    427     delete[] dest_piece.buffer();
    428     dest_piece.set_buffer(src_piece.buffer());
    429     delete dest_piece.sparse_indices();
    430     dest_piece.set_sparse_indices(src_piece.sparse_indices());
    431   }
    432 
    433   src_literal.shape_ = ShapeUtil::MakeNil();
    434   src_literal.pieces_ = ShapeTree<Piece>(src_literal.shape_);
    435   src_literal.piece({}).set_subshape(&src_literal.shape_);
    436   return Status::OK();
    437 }
    438 
    439 Status Literal::CopySliceFrom(const Literal& src_literal,
    440                               tensorflow::gtl::ArraySlice<int64> src_base,
    441                               tensorflow::gtl::ArraySlice<int64> dest_base,
    442                               tensorflow::gtl::ArraySlice<int64> copy_size) {
    443   TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape());
    444   TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape()))
    445       << ShapeUtil::HumanString(src_literal.shape());
    446   TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
    447 
    448   switch (shape().element_type()) {
    449     case U8:
    450       return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base,
    451                                           copy_size);
    452     case U16:
    453       return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base,
    454                                            copy_size);
    455     case U32:
    456       return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base,
    457                                            copy_size);
    458     case U64:
    459       return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base,
    460                                            copy_size);
    461     case S8:
    462       return CopySliceFromInternal<int8>(src_literal, src_base, dest_base,
    463                                          copy_size);
    464     case S16:
    465       return CopySliceFromInternal<int16>(src_literal, src_base, dest_base,
    466                                           copy_size);
    467     case S32:
    468       return CopySliceFromInternal<int32>(src_literal, src_base, dest_base,
    469                                           copy_size);
    470     case S64:
    471       return CopySliceFromInternal<int64>(src_literal, src_base, dest_base,
    472                                           copy_size);
    473     case F16:
    474       return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
    475                                          copy_size);
    476     case BF16:
    477       return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
    478                                              copy_size);
    479     case F32:
    480       return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
    481                                           copy_size);
    482     case F64:
    483       return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
    484                                            copy_size);
    485     case C64:
    486       return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
    487                                               copy_size);
    488     case PRED:
    489       return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
    490                                          copy_size);
    491     default:
    492       break;
    493   }
    494   return Unimplemented("Unhandled primitive type %d", shape().element_type());
    495 }
    496 
    497 /* static */ Literal Literal::Zero(PrimitiveType primitive_type) {
    498   switch (primitive_type) {
    499     case U8:
    500       return std::move(*Literal::CreateR0<uint8>(0));
    501     case U32:
    502       return std::move(*Literal::CreateR0<uint32>(0));
    503     case U64:
    504       return std::move(*Literal::CreateR0<uint64>(0));
    505     case S8:
    506       return std::move(*Literal::CreateR0<int8>(0));
    507     case S32:
    508       return std::move(*Literal::CreateR0<int32>(0));
    509     case S64:
    510       return std::move(*Literal::CreateR0<int64>(0));
    511     case F16:
    512       return std::move(*Literal::CreateR0<half>(static_cast<half>(0.0f)));
    513     case BF16:
    514       return std::move(
    515           *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f)));
    516     case F32:
    517       return std::move(*Literal::CreateR0<float>(0));
    518     case F64:
    519       return std::move(*Literal::CreateR0<double>(0));
    520     case C64:
    521       return std::move(*Literal::CreateR0<complex64>(0));
    522     case PRED:
    523       return std::move(*Literal::CreateR0<bool>(false));
    524     case S16:
    525     case U16:
    526       LOG(FATAL) << "u16/s16 literals not yet implemented";
    527     case TUPLE:
    528       LOG(FATAL) << "tuple element type cannot take on value of 0";
    529     case OPAQUE:
    530       LOG(FATAL) << "opaque element type cannot take on value of 0";
    531     default:
    532       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
    533   }
    534 }
    535 
    536 /* static */ Literal Literal::One(PrimitiveType primitive_type) {
    537   switch (primitive_type) {
    538     case U8:
    539       return std::move(*Literal::CreateR0<uint8>(1));
    540     case U32:
    541       return std::move(*Literal::CreateR0<uint32>(1));
    542     case U64:
    543       return std::move(*Literal::CreateR0<uint64>(1));
    544     case S8:
    545       return std::move(*Literal::CreateR0<int8>(1));
    546     case S32:
    547       return std::move(*Literal::CreateR0<int32>(1));
    548     case S64:
    549       return std::move(*Literal::CreateR0<int64>(1));
    550     case F16:
    551       return std::move(*Literal::CreateR0<half>(static_cast<half>(1.0f)));
    552     case BF16:
    553       return std::move(
    554           *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f)));
    555     case F32:
    556       return std::move(*Literal::CreateR0<float>(1));
    557     case F64:
    558       return std::move(*Literal::CreateR0<double>(1));
    559     case C64:
    560       return std::move(*Literal::CreateR0<complex64>(1));
    561     case PRED:
    562       return std::move(*Literal::CreateR0<bool>(true));
    563     case S16:
    564     case U16:
    565       LOG(FATAL) << "u16/s16 literals not yet implemented";
    566     case TUPLE:
    567       LOG(FATAL) << "tuple element type cannot take on value of 1";
    568     case OPAQUE:
    569       LOG(FATAL) << "opaque element type cannot take on value of 1";
    570     default:
    571       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
    572   }
    573 }
    574 
    575 /* static */ Literal Literal::MinValue(PrimitiveType primitive_type) {
    576   switch (primitive_type) {
    577     case U8:
    578       return std::move(
    579           *Literal::CreateR0<uint8>(std::numeric_limits<uint8>::min()));
    580     case U32:
    581       return std::move(
    582           *Literal::CreateR0<uint32>(std::numeric_limits<uint32>::min()));
    583     case U64:
    584       return std::move(
    585           *Literal::CreateR0<uint64>(std::numeric_limits<uint64>::min()));
    586     case S8:
    587       return std::move(
    588           *Literal::CreateR0<int8>(std::numeric_limits<int8>::min()));
    589     case S32:
    590       return std::move(
    591           *Literal::CreateR0<int32>(std::numeric_limits<int32>::min()));
    592     case S64:
    593       return std::move(
    594           *Literal::CreateR0<int64>(std::numeric_limits<int64>::min()));
    595     case F32:
    596       return std::move(
    597           *Literal::CreateR0<float>(-std::numeric_limits<float>::infinity()));
    598     case F64:
    599       return std::move(
    600           *Literal::CreateR0<double>(-std::numeric_limits<double>::infinity()));
    601     case C64:
    602       LOG(FATAL) << "C64 element type has no minimum value";
    603     case PRED:
    604       return std::move(*Literal::CreateR0<bool>(false));
    605     case S16:
    606     case U16:
    607       LOG(FATAL) << "u16/s16 literals not yet implemented";
    608     case F16:
    609       return std::move(*Literal::CreateR0<half>(
    610           static_cast<half>(-std::numeric_limits<float>::infinity())));
    611     case BF16:
    612       return std::move(*Literal::CreateR0<bfloat16>(
    613           static_cast<bfloat16>(-std::numeric_limits<float>::infinity())));
    614     case TUPLE:
    615       LOG(FATAL) << "tuple element type has no minimum value";
    616     case OPAQUE:
    617       LOG(FATAL) << "opaque element type has no minimum value";
    618     default:
    619       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
    620   }
    621 }
    622 
    623 /* static */ Literal Literal::MaxValue(PrimitiveType primitive_type) {
    624   switch (primitive_type) {
    625     case U8:
    626       return std::move(
    627           *Literal::CreateR0<uint8>(std::numeric_limits<uint8>::max()));
    628     case U32:
    629       return std::move(
    630           *Literal::CreateR0<uint32>(std::numeric_limits<uint32>::max()));
    631     case U64:
    632       return std::move(
    633           *Literal::CreateR0<uint64>(std::numeric_limits<uint64>::max()));
    634     case S8:
    635       return std::move(
    636           *Literal::CreateR0<int8>(std::numeric_limits<int8>::max()));
    637     case S32:
    638       return std::move(
    639           *Literal::CreateR0<int32>(std::numeric_limits<int32>::max()));
    640     case S64:
    641       return std::move(
    642           *Literal::CreateR0<int64>(std::numeric_limits<int64>::max()));
    643     case F32:
    644       return std::move(
    645           *Literal::CreateR0<float>(std::numeric_limits<float>::infinity()));
    646     case F64:
    647       return std::move(
    648           *Literal::CreateR0<double>(std::numeric_limits<double>::infinity()));
    649     case PRED:
    650       return std::move(*Literal::CreateR0<bool>(true));
    651     case S16:
    652     case U16:
    653       LOG(FATAL) << "u16/s16 literals not yet implemented";
    654     case F16:
    655       return std::move(*Literal::CreateR0<half>(
    656           static_cast<half>(std::numeric_limits<float>::infinity())));
    657     case BF16:
    658       return std::move(*Literal::CreateR0<bfloat16>(
    659           static_cast<bfloat16>(std::numeric_limits<float>::infinity())));
    660     case TUPLE:
    661       LOG(FATAL) << "tuple element type has no maximum value";
    662     case OPAQUE:
    663       LOG(FATAL) << "opaque element type has no maximum value";
    664     default:
    665       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
    666   }
    667 }
    668 
    669 /* static */ std::unique_ptr<Literal> Literal::CreateR1(
    670     const tensorflow::core::Bitmap& values) {
    671   auto literal = MakeUnique<Literal>(
    672       ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
    673   literal->PopulateR1(values);
    674   return literal;
    675 }
    676 
    677 void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
    678   CHECK(ShapeUtil::IsArray(shape()));
    679   CHECK_EQ(ShapeUtil::Rank(shape()), 1);
    680   CHECK_EQ(element_count(), values.bits());
    681   CHECK_EQ(shape().element_type(), PRED);
    682   for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
    683     Set({i}, values.get(i));
    684   }
    685 }
    686 
    687 /* static */ std::unique_ptr<Literal> Literal::CreateR1U8(
    688     tensorflow::StringPiece value) {
    689   auto literal = MakeUnique<Literal>(
    690       ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
    691   for (int i = 0; i < value.size(); ++i) {
    692     literal->Set<uint8>({i}, value[i]);
    693   }
    694   return literal;
    695 }
    696 
    697 /* static */ std::unique_ptr<Literal> Literal::CreateR2F32Linspace(float from,
    698                                                                    float to,
    699                                                                    int64 rows,
    700                                                                    int64 cols) {
    701   auto value = MakeLinspaceArray2D(from, to, rows, cols);
    702   return CreateR2FromArray2D(*value);
    703 }
    704 
    705 std::unique_ptr<Literal> Literal::Relayout(
    706     const Layout& new_layout, const ShapeIndex& shape_index) const {
    707   // Create new shape with 'new_layout' set at the given shape index.
    708   Shape new_shape = shape();
    709   Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
    710   TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
    711   *subshape->mutable_layout() = new_layout;
    712   auto result = MakeUnique<Literal>(new_shape);
    713   TF_CHECK_OK(result->CopyFrom(*this));
    714   return result;
    715 }
    716 
    717 std::unique_ptr<Literal> Literal::Relayout(
    718     const Shape& shape_with_layout) const {
    719   CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
    720       << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
    721       << " not compatible with literal shape "
    722       << ShapeUtil::HumanString(shape());
    723   std::unique_ptr<Literal> result = CreateFromShape(shape_with_layout);
    724   ShapeUtil::ForEachSubshape(
    725       result->shape(),
    726       [this, &result](const Shape& subshape, const ShapeIndex& index) {
    727         if (ShapeUtil::IsArray(subshape)) {
    728           TF_CHECK_OK(result->CopyFrom(*this,
    729                                        /*dest_shape_index=*/index,
    730                                        /*src_shape_index=*/index));
    731         }
    732       });
    733   return result;
    734 }
    735 
    736 StatusOr<std::unique_ptr<Literal>> Literal::Reshape(
    737     tensorflow::gtl::ArraySlice<int64> dimensions) const {
    738   if (!ShapeUtil::IsArray(shape())) {
    739     return InvalidArgument("Reshape does not support tuples.");
    740   }
    741   std::unique_ptr<Literal> output;
    742   if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
    743     output =
    744         Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape())));
    745   } else {
    746     output = CloneToUnique();
    747   }
    748   // Because the layout is monotonic, we can simply reuse the same sequence of
    749   // values without changing their order.
    750   output->shape_ = ShapeUtil::MakeShape(shape().element_type(), dimensions);
    751 
    752   int64 elements_before = ShapeUtil::ElementsIn(shape());
    753   int64 elements_after = ShapeUtil::ElementsIn(output->shape());
    754   if (elements_before != elements_after) {
    755     return InvalidArgument(
    756         "Shapes before and after Literal::Reshape have different numbers "
    757         "of elements: %s vs %s.",
    758         ShapeUtil::HumanString(shape()).c_str(),
    759         ShapeUtil::HumanString(output->shape()).c_str());
    760   }
    761   return std::move(output);
    762 }
    763 
    764 std::unique_ptr<Literal> Literal::Transpose(
    765     tensorflow::gtl::ArraySlice<int64> permutation) const {
    766   CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
    767   CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
    768       << "Given permutation is not a permutation of dimension numbers";
    769   // To transpose the array, we just permute the dimensions and layout, and
    770   // do a straight memory copy of the raw data set.
    771   // This is considerably faster than iterating over every array element using
    772   // the EachCell<>() and Set<>() APIs.
    773   std::vector<int64> inverse_permutation = InversePermutation(permutation);
    774   Shape permuted_shape =
    775       ShapeUtil::PermuteDimensions(inverse_permutation, shape());
    776   // Replace the layout with one affine to this shape, such that a
    777   // transpose operation can be performed by leaving the flat values
    778   // representation intact.
    779   // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
    780   // The shape with affine layout resulting from that operation will be
    781   // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
    782   // most minor.
    783   //
    784   // Essentially, given MinMaj(Di) the position of the Di dimension within the
    785   // minor to major vector, and given T(Di) the index that the original Di
    786   // dimension has within the transposed array, a layout is affine if
    787   // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
    788   // vector of the affine layout.
    789   CHECK(LayoutUtil::IsDenseArray(permuted_shape));
    790   Layout* layout = permuted_shape.mutable_layout();
    791   layout->clear_minor_to_major();
    792   for (auto index : LayoutUtil::MinorToMajor(shape())) {
    793     layout->add_minor_to_major(inverse_permutation[index]);
    794   }
    795   std::unique_ptr<Literal> new_literal = CreateFromShape(permuted_shape);
    796   DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()),
    797             ShapeUtil::ByteSizeOf(shape()));
    798   std::memcpy(new_literal->root_piece().buffer(), root_piece().buffer(),
    799               root_piece().size_bytes());
    800   return new_literal;
    801 }
    802 
    803 std::unique_ptr<Literal> Literal::Slice(
    804     tensorflow::gtl::ArraySlice<int64> start_indices,
    805     tensorflow::gtl::ArraySlice<int64> limit_indices) const {
    806   CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
    807 
    808   DimensionVector result_dimensions;
    809   for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) {
    810     CHECK_GE(start_indices[dnum], 0);
    811     CHECK_LE(limit_indices[dnum], shape().dimensions(dnum));
    812     int64 dimension = limit_indices[dnum] - start_indices[dnum];
    813     CHECK_GT(dimension, 0);
    814     result_dimensions.push_back(dimension);
    815   }
    816   const auto result_shape =
    817       ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
    818                                      LayoutUtil::MinorToMajor(shape()));
    819 
    820   auto result_literal = MakeUnique<Literal>(result_shape);
    821 
    822   DimensionVector new_indices(ShapeUtil::Rank(result_shape));
    823   switch (result_shape.element_type()) {
    824     case F32:
    825       result_literal->EachCell<float>(
    826           [&](tensorflow::gtl::ArraySlice<int64> indices, float /*value*/) {
    827             for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
    828               new_indices[i] = indices[i] + start_indices[i];
    829             }
    830             float value = Get<float>(new_indices);
    831             result_literal->Set<float>(indices, value);
    832           });
    833       return result_literal;
    834     case C64:
    835       result_literal->EachCell<complex64>(
    836           [&](tensorflow::gtl::ArraySlice<int64> indices, complex64 /*value*/) {
    837             for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
    838               new_indices[i] = indices[i] + start_indices[i];
    839             }
    840             complex64 value = Get<complex64>(new_indices);
    841             result_literal->Set<complex64>(indices, value);
    842           });
    843       return result_literal;
    844     case S32:
    845       result_literal->EachCell<int32>(
    846           [&](tensorflow::gtl::ArraySlice<int64> indices, int32 /*value*/) {
    847             for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
    848               new_indices[i] = indices[i] + start_indices[i];
    849             }
    850             int32 value = Get<int32>(new_indices);
    851             result_literal->Set<int32>(indices, value);
    852           });
    853       return result_literal;
    854     case U32:
    855       result_literal->EachCell<uint32>(
    856           [&](tensorflow::gtl::ArraySlice<int64> indices, uint32 /*value*/) {
    857             for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
    858               new_indices[i] = indices[i] + start_indices[i];
    859             }
    860             uint32 value = Get<uint32>(new_indices);
    861             result_literal->Set<uint32>(indices, value);
    862           });
    863       return result_literal;
    864     default:
    865       LOG(FATAL) << "not yet implemented: "
    866                  << PrimitiveType_Name(result_shape.element_type());
    867   }
    868 }
    869 
    870 Literal Literal::Clone() const {
    871   Literal result(shape());
    872   TF_CHECK_OK(result.CopyFrom(*this));
    873   return result;
    874 }
    875 
    876 std::unique_ptr<Literal> Literal::CloneToUnique() const {
    877   auto result = MakeUnique<Literal>(shape());
    878   TF_CHECK_OK(result->CopyFrom(*this));
    879   return result;
    880 }
    881 
    882 string Literal::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
    883                             const ShapeIndex& shape_index) const {
    884   const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
    885   CHECK(LayoutUtil::IsDenseArray(subshape));
    886   switch (subshape.element_type()) {
    887     case PRED:
    888       return Get<bool>(multi_index, shape_index) ? "true" : "false";
    889     case S8:
    890       return StrCat(Get<int8>(multi_index, shape_index));
    891     case S16:
    892       return StrCat(Get<int16>(multi_index, shape_index));
    893     case S32:
    894       return StrCat(Get<int32>(multi_index, shape_index));
    895     case S64:
    896       return StrCat(Get<int64>(multi_index, shape_index));
    897     case U8:
    898       return StrCat(Get<uint8>(multi_index, shape_index));
    899     case U16:
    900       return StrCat(Get<uint16>(multi_index, shape_index));
    901     case U32:
    902       return StrCat(Get<uint32>(multi_index, shape_index));
    903     case U64:
    904       return StrCat(Get<uint64>(multi_index, shape_index));
    905     case F16:
    906       return StrCat(Get<half>(multi_index, shape_index));
    907     case F32:
    908       return StrCat(Get<float>(multi_index, shape_index));
    909     case BF16:
    910       return StrCat(
    911           static_cast<float>(Get<bfloat16>(multi_index, shape_index)));
    912     case F64:
    913       return StrCat(Get<double>(multi_index, shape_index));
    914     case C64: {
    915       complex64 c = Get<complex64>(multi_index, shape_index);
    916       return StrCat("(", c.real(), ", ", c.imag(), ")");
    917     }
    918     default:
    919       LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
    920   }
    921 }
    922 
    923 string Literal::GetSparseElementAsString(int64 sparse_element_number,
    924                                          const ShapeIndex& shape_index) const {
    925   const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
    926   CHECK(LayoutUtil::IsSparseArray(subshape));
    927   switch (subshape.element_type()) {
    928     case PRED:
    929       return GetSparseElement<bool>(sparse_element_number, shape_index)
    930                  ? "true"
    931                  : "false";
    932     case S8:
    933       return StrCat(GetSparseElement<int8>(sparse_element_number, shape_index));
    934     case S16:
    935       return StrCat(
    936           GetSparseElement<int16>(sparse_element_number, shape_index));
    937     case S32:
    938       return StrCat(
    939           GetSparseElement<int32>(sparse_element_number, shape_index));
    940     case S64:
    941       return StrCat(
    942           GetSparseElement<int64>(sparse_element_number, shape_index));
    943     case U8:
    944       return StrCat(
    945           GetSparseElement<uint8>(sparse_element_number, shape_index));
    946     case U16:
    947       return StrCat(
    948           GetSparseElement<uint16>(sparse_element_number, shape_index));
    949     case U32:
    950       return StrCat(
    951           GetSparseElement<uint32>(sparse_element_number, shape_index));
    952     case U64:
    953       return StrCat(
    954           GetSparseElement<uint64>(sparse_element_number, shape_index));
    955     case F16:
    956       return StrCat(GetSparseElement<half>(sparse_element_number, shape_index));
    957     case F32:
    958       return StrCat(
    959           GetSparseElement<float>(sparse_element_number, shape_index));
    960     case BF16:
    961       return StrCat(static_cast<float>(
    962           GetSparseElement<bfloat16>(sparse_element_number, shape_index)));
    963     case F64:
    964       return StrCat(
    965           GetSparseElement<double>(sparse_element_number, shape_index));
    966     case C64: {
    967       complex64 c =
    968           GetSparseElement<complex64>(sparse_element_number, shape_index);
    969       return StrCat("(", c.real(), ", ", c.imag(), ")");
    970     }
    971     default:
    972       LOG(FATAL) << "Invalid element type for sparse arrays: "
    973                  << PrimitiveType_Name(subshape.element_type());
    974   }
    975 }
    976 
    977 StatusOr<int64> Literal::GetIntegralAsS64(
    978     tensorflow::gtl::ArraySlice<int64> multi_index) const {
    979   CHECK(LayoutUtil::IsDenseArray(shape()));
    980   switch (shape().element_type()) {
    981     case PRED:
    982       return Get<bool>(multi_index);
    983     case U8:
    984       return Get<uint8>(multi_index);
    985     case S32:
    986       return Get<int32>(multi_index);
    987     case S64:
    988       return Get<int64>(multi_index);
    989     case U32:
    990       return Get<uint32>(multi_index);
    991     case U64:
    992       return Get<uint64>(multi_index);
    993     default:
    994       return FailedPrecondition(
    995           "Array element type is not integral: %s",
    996           PrimitiveType_Name(shape().element_type()).c_str());
    997   }
    998 }
    999 
   1000 tensorflow::gtl::ArraySlice<int64> Literal::GetSparseIndex(
   1001     int64 sparse_element_number, const ShapeIndex& shape_index) const {
   1002   const Piece& p = piece(shape_index);
   1003   CHECK_GE(sparse_element_number, 0);
   1004   CHECK_LT(sparse_element_number, p.sparse_indices()->index_count());
   1005   return p.sparse_indices()->At(sparse_element_number);
   1006 }
   1007 
   1008 void Literal::SortSparseElements(const ShapeIndex& shape_index) {
   1009   piece(shape_index).SortSparseElements();
   1010 }
   1011 
   1012 void Literal::Piece::SortSparseElements() {
   1013   switch (subshape().element_type()) {
   1014     case PRED:
   1015       SortSparseElementsInternal<bool>();
   1016       break;
   1017     case S8:
   1018       SortSparseElementsInternal<int8>();
   1019       break;
   1020     case U8:
   1021       SortSparseElementsInternal<uint8>();
   1022       break;
   1023     case S16:
   1024       SortSparseElementsInternal<int16>();
   1025       break;
   1026     case U16:
   1027       SortSparseElementsInternal<uint16>();
   1028       break;
   1029     case S32:
   1030       SortSparseElementsInternal<int32>();
   1031       break;
   1032     case U32:
   1033       SortSparseElementsInternal<uint32>();
   1034       break;
   1035     case S64:
   1036       SortSparseElementsInternal<int64>();
   1037       break;
   1038     case U64:
   1039       SortSparseElementsInternal<uint64>();
   1040       break;
   1041     case F32:
   1042       SortSparseElementsInternal<float>();
   1043       break;
   1044     case F64:
   1045       SortSparseElementsInternal<double>();
   1046       break;
   1047     case C64:
   1048       SortSparseElementsInternal<complex64>();
   1049       break;
   1050     case F16:
   1051       SortSparseElementsInternal<half>();
   1052       break;
   1053     case BF16:
   1054       SortSparseElementsInternal<bfloat16>();
   1055       break;
   1056     default:
   1057       LOG(FATAL) << "Element type not valid for sparse array: "
   1058                  << PrimitiveType_Name(subshape().element_type());
   1059   }
   1060 }
   1061 
   1062 template <typename NativeT>
   1063 void Literal::Piece::SortSparseElementsInternal() {
   1064   CHECK(LayoutUtil::IsSparseArray(subshape()));
   1065   int64 num_elements = sparse_indices()->index_count();
   1066   auto values = data<NativeT>();
   1067   CHECK_LE(num_elements, values.size());
   1068   sparse_indices()->SortWithValues(
   1069       tensorflow::gtl::MutableArraySlice<NativeT>(values.data(), num_elements));
   1070 }
   1071 
   1072 namespace {
   1073 
   1074 void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index,
   1075                     bool print_layout, std::vector<string>* pieces) {
   1076   const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
   1077 
   1078   auto shape_to_string = [print_layout](const Shape& shape) {
   1079     if (print_layout) {
   1080       return ShapeUtil::HumanStringWithLayout(shape);
   1081     } else {
   1082       return ShapeUtil::HumanString(shape);
   1083     }
   1084   };
   1085 
   1086   // TODO(b/32894291): refactor this code to reduce code duplication.
   1087   if (ShapeUtil::IsTuple(subshape)) {
   1088     pieces->push_back(shape_to_string(subshape));
   1089     pieces->push_back(" (\n");
   1090     std::vector<string> tuple_pieces;
   1091     for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
   1092       ShapeIndex element_index = shape_index;
   1093       element_index.push_back(i);
   1094       std::vector<string> element_pieces;
   1095       ToStringHelper(literal, element_index, print_layout, &element_pieces);
   1096       tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, ""));
   1097     }
   1098     pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n"));
   1099     pieces->push_back("\n)");
   1100     return;
   1101   }
   1102 
   1103   if (LayoutUtil::IsSparseArray(subshape)) {
   1104     pieces->push_back(shape_to_string(subshape));
   1105     pieces->push_back("{");
   1106     int64 rank = ShapeUtil::Rank(subshape);
   1107     int64 num_elements = literal.sparse_element_count();
   1108     for (int64 i = 0; i < num_elements; ++i) {
   1109       if (i > 0) {
   1110         pieces->push_back(", ");
   1111       }
   1112       if (rank == 1) {
   1113         pieces->push_back(StrCat(literal.GetSparseIndex(i)[0]));
   1114         pieces->push_back(": ");
   1115       } else {
   1116         pieces->push_back("[");
   1117         pieces->push_back(
   1118             tensorflow::str_util::Join(literal.GetSparseIndex(i), ", "));
   1119         pieces->push_back("]: ");
   1120       }
   1121       pieces->push_back(literal.GetSparseElementAsString(i));
   1122     }
   1123     pieces->push_back("}");
   1124     return;
   1125   }
   1126 
   1127   CHECK(LayoutUtil::IsDenseArray(subshape));
   1128 
   1129   auto element_to_string =
   1130       [&](tensorflow::gtl::ArraySlice<int64> indices) -> string {
   1131     PrimitiveType element_type = subshape.element_type();
   1132     if (element_type == PRED) {
   1133       // We display predicates in a densely packed form.
   1134       return literal.Get<bool>(indices, shape_index) ? "1" : "0";
   1135     }
   1136     return ((!indices.empty() && indices.back() > 0) ? ", " : "") +
   1137            literal.GetAsString(indices, shape_index);
   1138   };
   1139 
   1140   if (ShapeUtil::Rank(subshape) == 0) {
   1141     pieces->push_back(literal.GetAsString({}, shape_index));
   1142   } else if (ShapeUtil::Rank(subshape) == 1) {
   1143     pieces->push_back("{");
   1144     for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
   1145       pieces->push_back(element_to_string({i0}));
   1146     }
   1147     pieces->push_back("}");
   1148   } else if (ShapeUtil::Rank(subshape) == 2) {
   1149     pieces->push_back(shape_to_string(subshape));
   1150     pieces->push_back(" {\n");
   1151     for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
   1152       pieces->push_back("  { ");
   1153       for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
   1154         pieces->push_back(element_to_string({i0, i1}));
   1155       }
   1156       pieces->push_back(" ");
   1157       pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n");
   1158     }
   1159     pieces->push_back("}");
   1160   } else if (ShapeUtil::Rank(subshape) == 3) {
   1161     pieces->push_back(shape_to_string(subshape));
   1162     pieces->push_back(" {\n");
   1163     for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
   1164       pieces->push_back(i0 > 0 ? ",\n{" : "{");
   1165       for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
   1166         pieces->push_back(i1 > 0 ? ",\n  { " : " { ");
   1167         for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
   1168           pieces->push_back(element_to_string({i0, i1, i2}));
   1169         }
   1170         pieces->push_back(" }");
   1171       }
   1172       pieces->push_back(" }");
   1173     }
   1174     pieces->push_back("\n}");
   1175   } else if (ShapeUtil::Rank(subshape) == 4) {
   1176     pieces->push_back(shape_to_string(subshape));
   1177     pieces->push_back(" {\n");
   1178     for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
   1179       pieces->push_back(Printf("  {  /*i0=%lld*/\n", i0));
   1180       for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
   1181         pieces->push_back(Printf("    {  /*i1=%lld*/\n", i1));
   1182         for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
   1183           pieces->push_back("      {");
   1184           for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
   1185             pieces->push_back(element_to_string({i0, i1, i2, i3}));
   1186           }
   1187           pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n");
   1188         }
   1189         pieces->push_back(i1 == subshape.dimensions(1) - 1 ? "    }\n"
   1190                                                            : "    },\n");
   1191       }
   1192       pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "  }\n" : "  },\n");
   1193     }
   1194     pieces->push_back("}");
   1195   } else if (ShapeUtil::Rank(subshape) == 5) {
   1196     pieces->push_back(shape_to_string(subshape));
   1197     pieces->push_back(" {\n");
   1198     for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
   1199       pieces->push_back(Printf("  {  /*i0=%lld*/\n", i0));
   1200       for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
   1201         pieces->push_back(Printf("    {  /*i1=%lld*/\n", i1));
   1202         for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
   1203           pieces->push_back(Printf("      {  /*i2=%lld*/\n", i2));
   1204           for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
   1205             pieces->push_back("        {");
   1206             for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) {
   1207               pieces->push_back(element_to_string({i0, i1, i2, i3, i4}));
   1208             }
   1209             pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n"
   1210                                                                : "},\n");
   1211           }
   1212           pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "      }\n"
   1213                                                              : "      },\n");
   1214         }
   1215         pieces->push_back(i1 == subshape.dimensions(1) - 1 ? "    }\n"
   1216                                                            : "    },\n");
   1217       }
   1218       pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "  }\n" : "  },\n");
   1219     }
   1220     pieces->push_back("}");
   1221   } else {
   1222     pieces->push_back(shape_to_string(subshape));
   1223     pieces->push_back(" {");
   1224     literal.EachCellAsString(
   1225         [&](tensorflow::gtl::ArraySlice<int64> indices, const string& value) {
   1226           pieces->push_back(" ");
   1227           pieces->push_back(value);
   1228         });
   1229     pieces->push_back("}");
   1230   }
   1231 }
   1232 
   1233 }  // namespace
   1234 
   1235 int64 Literal::sparse_element_count() const {
   1236   CHECK(LayoutUtil::IsSparseArray(shape()));
   1237   return sparse_indices()->index_count();
   1238 }
   1239 
   1240 string Literal::ToString(bool print_layout) const {
   1241   std::vector<string> pieces;
   1242   ToStringHelper(*this, {}, print_layout, &pieces);
   1243   return tensorflow::str_util::Join(pieces, "");
   1244 }
   1245 
   1246 /* static */ std::unique_ptr<Literal> Literal::MakeTuple(
   1247     tensorflow::gtl::ArraySlice<const Literal*> elements) {
   1248   std::vector<Shape> element_shapes;
   1249   for (const Literal* element : elements) {
   1250     element_shapes.push_back(element->shape());
   1251   }
   1252   auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
   1253   for (int i = 0; i < elements.size(); ++i) {
   1254     TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
   1255   }
   1256   return literal;
   1257 }
   1258 
   1259 /* static */ std::unique_ptr<Literal> Literal::MakeTupleOwned(
   1260     std::vector<std::unique_ptr<Literal>> elements) {
   1261   std::vector<Shape> element_shapes;
   1262   element_shapes.reserve(elements.size());
   1263   for (const auto& element : elements) {
   1264     element_shapes.push_back(element->shape());
   1265   }
   1266   auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
   1267   for (int64 i = 0; i < elements.size(); ++i) {
   1268     TF_CHECK_OK(
   1269         literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i}));
   1270   }
   1271   return literal;
   1272 }
   1273 
   1274 void Literal::EachCellAsString(
   1275     const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
   1276                              const string& value)>& per_cell) const {
   1277   if (ShapeUtil::HasZeroElements(shape())) {
   1278     return;
   1279   }
   1280   std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
   1281       shape(), /*linear_index=*/0);
   1282   do {
   1283     per_cell(indices, GetAsString(indices));
   1284   } while (IndexUtil::BumpIndices(shape(), &indices));
   1285 }
   1286 
   1287 namespace {
   1288 template <typename NativeSrcT, typename NativeDestT>
   1289 std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
   1290   CHECK(ShapeUtil::IsArray(src_literal.shape()));
   1291   auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType(
   1292       src_literal.shape(),
   1293       primitive_util::NativeToPrimitiveType<NativeDestT>()));
   1294   auto src_data = src_literal.data<NativeSrcT>();
   1295   auto dest_data = result_literal->template data<NativeDestT>();
   1296   int64 num_elements = src_literal.element_count();
   1297 
   1298   for (int64 i = 0; i < num_elements; ++i) {
   1299     dest_data[i] = static_cast<NativeDestT>(src_data[i]);
   1300   }
   1301   return result_literal;
   1302 }
   1303 
   1304 template <PrimitiveType primitive_src_type>
   1305 std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) {
   1306   CHECK(ShapeUtil::IsArray(src_literal.shape()));
   1307   auto result_literal = MakeUnique<Literal>(
   1308       ShapeUtil::ChangeElementType(src_literal.shape(), C64));
   1309   using NativeSrcT =
   1310       typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
   1311   tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
   1312       src_literal.data<NativeSrcT>();
   1313   tensorflow::gtl::MutableArraySlice<complex64> dest_data =
   1314       result_literal->data<complex64>();
   1315   int64 num_elements = src_literal.element_count();
   1316   for (int64 i = 0; i < num_elements; ++i) {
   1317     dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
   1318   }
   1319   return result_literal;
   1320 }
   1321 
   1322 template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
   1323 std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal) {
   1324   CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
   1325   return ConvertBetweenNativeTypes<
   1326       typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type,
   1327       typename primitive_util::PrimitiveTypeToNative<
   1328           primitive_dest_type>::type>(src_literal);
   1329 }
   1330 
   1331 template <PrimitiveType primitive_src_type>
   1332 StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
   1333     const Literal& src_literal, PrimitiveType primitive_dest_type) {
   1334   switch (primitive_dest_type) {
   1335 #define CONVERT_IF_TYPES_MATCH(type) \
   1336   case (type):                       \
   1337     return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal);
   1338     CONVERT_IF_TYPES_MATCH(PRED)
   1339     CONVERT_IF_TYPES_MATCH(S8)
   1340     CONVERT_IF_TYPES_MATCH(S32)
   1341     CONVERT_IF_TYPES_MATCH(S64)
   1342     CONVERT_IF_TYPES_MATCH(U8)
   1343     CONVERT_IF_TYPES_MATCH(U32)
   1344     CONVERT_IF_TYPES_MATCH(U64)
   1345     CONVERT_IF_TYPES_MATCH(F16)
   1346     CONVERT_IF_TYPES_MATCH(F32)
   1347     CONVERT_IF_TYPES_MATCH(F64)
   1348     CONVERT_IF_TYPES_MATCH(BF16)
   1349 #undef CONVERT_IF_TYPES_MATCH
   1350     case C64:
   1351       return ConvertToC64<primitive_src_type>(src_literal);
   1352     // Other types are not yet supported.
   1353     default:
   1354       return InvalidArgument(
   1355           "Unimplemented: Convert from type %s to type %s",
   1356           PrimitiveType_Name(src_literal.shape().element_type()).c_str(),
   1357           PrimitiveType_Name(primitive_dest_type).c_str());
   1358   }
   1359 }
   1360 
   1361 }  // namespace
   1362 
   1363 StatusOr<std::unique_ptr<Literal>> Literal::Convert(
   1364     PrimitiveType primitive_dest_type) const {
   1365   TF_RET_CHECK(ShapeUtil::IsArray(shape()));
   1366   switch (shape().element_type()) {
   1367 #define CONVERT_IF_DEST_TYPE_MATCHES(type) \
   1368   case (type):                             \
   1369     return ConvertIfDestTypeMatches<(type)>(*this, primitive_dest_type);
   1370     CONVERT_IF_DEST_TYPE_MATCHES(PRED)
   1371     CONVERT_IF_DEST_TYPE_MATCHES(S8)
   1372     CONVERT_IF_DEST_TYPE_MATCHES(S32)
   1373     CONVERT_IF_DEST_TYPE_MATCHES(S64)
   1374     CONVERT_IF_DEST_TYPE_MATCHES(U8)
   1375     CONVERT_IF_DEST_TYPE_MATCHES(U32)
   1376     CONVERT_IF_DEST_TYPE_MATCHES(U64)
   1377     CONVERT_IF_DEST_TYPE_MATCHES(F16)
   1378     CONVERT_IF_DEST_TYPE_MATCHES(F32)
   1379     CONVERT_IF_DEST_TYPE_MATCHES(F64)
   1380     CONVERT_IF_DEST_TYPE_MATCHES(BF16)
   1381 #undef CONVERT_IF_DEST_TYPE_MATCHES
   1382       // Other types are not yet supported.
   1383     default:
   1384       return InvalidArgument("Unimplemented: Convert from type %s to type %s",
   1385                              PrimitiveType_Name(shape().element_type()).c_str(),
   1386                              PrimitiveType_Name(primitive_dest_type).c_str());
   1387   }
   1388 }
   1389 
   1390 template <typename NativeT>
   1391 bool Literal::Piece::EqualElementsInternal(
   1392     const Literal::Piece& other, std::vector<int64>* multi_index) const {
   1393   if (multi_index->size() == ShapeUtil::Rank(subshape())) {
   1394     return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
   1395   }
   1396   for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) {
   1397     multi_index->push_back(i);
   1398     if (!EqualElementsInternal<NativeT>(other, multi_index)) {
   1399       return false;
   1400     }
   1401     multi_index->pop_back();
   1402   }
   1403   return true;
   1404 }
   1405 
   1406 bool Literal::Piece::EqualElements(const Literal::Piece& other) const {
   1407   DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
   1408 
   1409   std::vector<int64> multi_index;
   1410   switch (subshape().element_type()) {
   1411     case PRED:
   1412       return EqualElementsInternal<bool>(other, &multi_index);
   1413     case U8:
   1414       return EqualElementsInternal<uint8>(other, &multi_index);
   1415     case S32:
   1416       return EqualElementsInternal<int32>(other, &multi_index);
   1417     case S64:
   1418       return EqualElementsInternal<int64>(other, &multi_index);
   1419     case U32:
   1420       return EqualElementsInternal<uint32>(other, &multi_index);
   1421     case U64:
   1422       return EqualElementsInternal<uint64>(other, &multi_index);
   1423     case F32:
   1424       return EqualElementsInternal<float>(other, &multi_index);
   1425     case F64:
   1426       return EqualElementsInternal<double>(other, &multi_index);
   1427     case F16:
   1428       return EqualElementsInternal<half>(other, &multi_index);
   1429     case BF16:
   1430       return EqualElementsInternal<bfloat16>(other, &multi_index);
   1431     case C64:
   1432       return EqualElementsInternal<complex64>(other, &multi_index);
   1433     default:
   1434       LOG(FATAL) << "Unimplemented: Literal::Piece::EqualElements for type "
   1435                  << PrimitiveType_Name(subshape().element_type());
   1436   }
   1437 }
   1438 
   1439 bool Literal::operator==(const Literal& other) const {
   1440   if (!ShapeUtil::Compatible(shape(), other.shape())) {
   1441     return false;
   1442   }
   1443   for (const auto& pair : pieces_) {
   1444     const ShapeIndex& index = pair.first;
   1445     const Piece& piece = pair.second;
   1446     if (!ShapeUtil::IsArray(piece.subshape())) {
   1447       continue;
   1448     }
   1449 
   1450     const Piece& other_piece = other.piece(index);
   1451     if (!piece.EqualElements(other_piece)) {
   1452       return false;
   1453     }
   1454   }
   1455   return true;
   1456 }
   1457 
   1458 namespace {
   1459 
   1460 template <typename NativeT>
   1461 static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
   1462                                   NativeT value) {
   1463   for (int64 i = 0; i < data.size(); ++i) {
   1464     if (data[i] != value) {
   1465       return false;
   1466     }
   1467   }
   1468   return true;
   1469 }
   1470 
   1471 }  // namespace
   1472 
   1473 bool Literal::IsAll(int8 value) const {
   1474   for (const auto& pair : pieces_) {
   1475     const Piece& piece = pair.second;
   1476     if (!ShapeUtil::IsArray(piece.subshape())) {
   1477       continue;
   1478     }
   1479 
   1480     auto piece_is_all = [&]() {
   1481       switch (shape().element_type()) {
   1482         case U8:
   1483           if (value >= 0) {
   1484             return AllElementsEqualValue<uint8>(piece.data<uint8>(), value);
   1485           }
   1486           return false;
   1487         case U32:
   1488           if (value >= 0) {
   1489             return AllElementsEqualValue<uint32>(piece.data<uint32>(), value);
   1490           }
   1491           return false;
   1492         case U64:
   1493           if (value >= 0) {
   1494             return AllElementsEqualValue<uint64>(piece.data<uint64>(), value);
   1495           }
   1496           return false;
   1497         case S8:
   1498           return AllElementsEqualValue<int8>(piece.data<int8>(), value);
   1499         case S32:
   1500           return AllElementsEqualValue<int32>(piece.data<int32>(), value);
   1501         case S64:
   1502           return AllElementsEqualValue<int64>(piece.data<int64>(), value);
   1503         case F32:
   1504           return AllElementsEqualValue<float>(piece.data<float>(), value);
   1505         case F64:
   1506           return AllElementsEqualValue<double>(piece.data<double>(), value);
   1507         case F16:
   1508           return AllElementsEqualValue<half>(piece.data<half>(),
   1509                                              static_cast<half>(value));
   1510         case BF16:
   1511           return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
   1512                                                  static_cast<bfloat16>(value));
   1513         case PRED:
   1514           if (value == 0) {
   1515             return AllElementsEqualValue<bool>(piece.data<bool>(), false);
   1516           }
   1517           if (value == 1) {
   1518             return AllElementsEqualValue<bool>(piece.data<bool>(), true);
   1519           }
   1520           return false;
   1521         default:
   1522           return false;
   1523       }
   1524       return false;
   1525     };
   1526 
   1527     if (!piece_is_all()) {
   1528       return false;
   1529     }
   1530   }
   1531   return true;
   1532 }
   1533 
   1534 bool Literal::IsAllFloat(float value) const {
   1535   for (const auto& pair : pieces_) {
   1536     const Piece& piece = pair.second;
   1537     if (!ShapeUtil::IsArray(piece.subshape())) {
   1538       continue;
   1539     }
   1540 
   1541     auto piece_is_all = [&]() {
   1542       switch (shape().element_type()) {
   1543         case F32:
   1544           return AllElementsEqualValue<float>(piece.data<float>(), value);
   1545         case F64:
   1546           return AllElementsEqualValue<double>(piece.data<double>(), value);
   1547         case F16:
   1548           return AllElementsEqualValue<half>(piece.data<half>(),
   1549                                              static_cast<half>(value));
   1550         case BF16:
   1551           return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
   1552                                                  static_cast<bfloat16>(value));
   1553         default:
   1554           return false;
   1555       }
   1556     };
   1557     if (!piece_is_all()) {
   1558       return false;
   1559     }
   1560   }
   1561   return true;
   1562 }
   1563 
   1564 bool Literal::IsAllComplex(complex64 value) const {
   1565   switch (shape().element_type()) {
   1566     case C64:
   1567       return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
   1568                                               value);
   1569     default:
   1570       return false;
   1571   }
   1572 }
   1573 
   1574 bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
   1575   CHECK(ShapeUtil::IsArray(shape()));
   1576   switch (shape().element_type()) {
   1577     case U8:
   1578       return Get<uint8>(indices) == 0;
   1579     case U32:
   1580       return Get<uint32>(indices) == 0;
   1581     case U64:
   1582       return Get<uint64>(indices) == 0;
   1583     case S8:
   1584       return Get<int8>(indices) == 0;
   1585     case S32:
   1586       return Get<int32>(indices) == 0;
   1587     case S64:
   1588       return Get<int64>(indices) == 0;
   1589     case F32:
   1590       return Get<float>(indices) == 0.0f;
   1591     case F64:
   1592       return Get<double>(indices) == 0.0;
   1593     case C64:
   1594       return Get<complex64>(indices) == complex64(0.0f, 0.0f);
   1595     case F16:
   1596       return Get<half>(indices) == static_cast<half>(0.0f);
   1597     case BF16:
   1598       return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
   1599     case PRED:
   1600       return Get<bool>(indices) == false;
   1601     default:
   1602       LOG(FATAL) << "Input literal must be an array.";
   1603   }
   1604 }
   1605 
   1606 namespace {
   1607 
   1608 template <typename RepeatedFieldT, typename NativeT>
   1609 void CopyToRepeatedField(RepeatedFieldT* dest,
   1610                          const tensorflow::gtl::ArraySlice<NativeT> src) {
   1611   *dest = RepeatedFieldT(src.begin(), src.end());
   1612 }
   1613 
   1614 }  // namespace
   1615 
   1616 void Literal::Piece::WriteToProto(LiteralProto* proto) const {
   1617   *proto->mutable_shape() = subshape();
   1618   switch (subshape().element_type()) {
   1619     case PRED:
   1620       CopyToRepeatedField(proto->mutable_preds(), data<bool>());
   1621       break;
   1622     case U8:
   1623       proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
   1624                      element_count());
   1625       break;
   1626     case U32:
   1627       CopyToRepeatedField(proto->mutable_u32s(), data<uint32>());
   1628       break;
   1629     case U64:
   1630       CopyToRepeatedField(proto->mutable_u64s(), data<uint64>());
   1631       break;
   1632     case S32:
   1633       CopyToRepeatedField(proto->mutable_s32s(), data<int32>());
   1634       break;
   1635     case S64:
   1636       CopyToRepeatedField(proto->mutable_s64s(), data<int64>());
   1637       break;
   1638     case F16:
   1639       *proto->mutable_f16s() = string(
   1640           reinterpret_cast<const char*>(data<half>().data()), size_bytes());
   1641       if (!kLittleEndian) {
   1642         ConvertEndianShort(const_cast<char*>(proto->mutable_f16s()->data()),
   1643                            proto->f16s().size());
   1644       }
   1645       break;
   1646     case BF16:
   1647       *proto->mutable_bf16s() = string(
   1648           reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
   1649       if (!kLittleEndian) {
   1650         ConvertEndianShort(const_cast<char*>(proto->mutable_bf16s()->data()),
   1651                            proto->bf16s().size());
   1652       }
   1653       break;
   1654     case F32:
   1655       CopyToRepeatedField(proto->mutable_f32s(), data<float>());
   1656       break;
   1657     case F64:
   1658       CopyToRepeatedField(proto->mutable_f64s(), data<double>());
   1659       break;
   1660     case C64:
   1661       for (complex64 value : data<complex64>()) {
   1662         proto->add_c64s(value.real());
   1663         proto->add_c64s(value.imag());
   1664       }
   1665       break;
   1666     case TUPLE:
   1667       // Nothing to do but assign the shape which is done above.
   1668       return;
   1669     default:
   1670       LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
   1671   }
   1672 }
   1673 
   1674 const void* Literal::Piece::untyped_data() const {
   1675   CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
   1676   return buffer();
   1677 }
   1678 
   1679 void* Literal::Piece::untyped_data() {
   1680   CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
   1681   return buffer();
   1682 }
   1683 
   1684 namespace {
   1685 
   1686 template <typename RepeatedFieldT, typename NativeT>
   1687 Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
   1688                              const RepeatedFieldT& src) {
   1689   if (dest.size() != src.size()) {
   1690     return InvalidArgument(
   1691         "Expected %lu elements in LiteralProto repeated field, has %d",
   1692         dest.size(), src.size());
   1693   }
   1694   std::copy(src.begin(), src.end(), dest.begin());
   1695   return Status::OK();
   1696 }
   1697 
   1698 }  // namespace
   1699 
   1700 Status Literal::Piece::CopyFromProto(const LiteralProto& proto) {
   1701   // These conditions should have been checked in Literal::CreateFromProto.
   1702   TF_RET_CHECK(proto.has_shape());
   1703   TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
   1704   TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape()));
   1705 
   1706   switch (subshape().element_type()) {
   1707     case PRED:
   1708       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
   1709       break;
   1710     case U8: {
   1711       auto u8_data = data<uint8>();
   1712       TF_RET_CHECK(proto.u8s().size() == u8_data.size());
   1713       std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
   1714     } break;
   1715     case S32:
   1716       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s()));
   1717       break;
   1718     case S64:
   1719       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s()));
   1720       break;
   1721     case U32:
   1722       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s()));
   1723       break;
   1724     case U64:
   1725       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s()));
   1726       break;
   1727     case F16: {
   1728       const string& s(proto.f16s());
   1729       TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
   1730       memcpy(untyped_data(), s.data(), s.size());
   1731       if (!kLittleEndian) {
   1732         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
   1733       }
   1734     } break;
   1735 
   1736     case BF16: {
   1737       const string& s(proto.bf16s());
   1738       TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
   1739       memcpy(untyped_data(), s.data(), s.size());
   1740       if (!kLittleEndian) {
   1741         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
   1742       }
   1743     } break;
   1744     case F32:
   1745       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
   1746       break;
   1747     case F64:
   1748       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
   1749       break;
   1750     case C64: {
   1751       auto complex_data = data<complex64>();
   1752       TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
   1753       for (int64 i = 0; i < complex_data.size(); ++i) {
   1754         complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
   1755       }
   1756     } break;
   1757     case TUPLE:
   1758       LOG(FATAL) << "Should not be called on tuple shapes: "
   1759                  << ShapeUtil::HumanString(subshape());
   1760       break;
   1761     default:
   1762       LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
   1763   }
   1764   return Status::OK();
   1765 }
   1766 
   1767 LiteralProto Literal::ToProto() const {
   1768   LiteralProto proto;
   1769   for (const auto& pair : pieces_) {
   1770     const ShapeIndex& index = pair.first;
   1771     const Piece& piece = pair.second;
   1772 
   1773     LiteralProto* proto_piece = &proto;
   1774     for (int64 i : index) {
   1775       while (proto_piece->tuple_literals_size() <= i) {
   1776         proto_piece->add_tuple_literals();
   1777       }
   1778       proto_piece = proto_piece->mutable_tuple_literals(i);
   1779     }
   1780     piece.WriteToProto(proto_piece);
   1781   }
   1782 
   1783   if (LayoutUtil::IsSparseArray(shape())) {
   1784     CopyToRepeatedField(proto.mutable_sparse_indices(),
   1785                         sparse_indices()->data());
   1786   }
   1787 
   1788   return proto;
   1789 }
   1790 
   1791 /* static */
   1792 StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
   1793     const LiteralProto& proto) {
   1794   if (!proto.has_shape()) {
   1795     return InvalidArgument("LiteralProto has no shape");
   1796   }
   1797   if (!LayoutUtil::HasLayout(proto.shape())) {
   1798     return InvalidArgument("LiteralProto has no layout");
   1799   }
   1800 
   1801   auto literal = MakeUnique<Literal>(proto.shape());
   1802 
   1803   for (auto& pair : literal->pieces_) {
   1804     const ShapeIndex& index = pair.first;
   1805     Piece& piece = pair.second;
   1806     const LiteralProto* proto_element = &proto;
   1807     for (int64 i : index) {
   1808       TF_RET_CHECK(i < proto_element->tuple_literals_size());
   1809       proto_element = &proto_element->tuple_literals(i);
   1810     }
   1811 
   1812     if (ShapeUtil::IsTuple(piece.subshape())) {
   1813       if (proto_element->tuple_literals_size() !=
   1814           ShapeUtil::TupleElementCount(piece.subshape())) {
   1815         return InvalidArgument(
   1816             "Expected %lld tuple elements in LiteralProto, has %d",
   1817             ShapeUtil::TupleElementCount(piece.subshape()),
   1818             proto_element->tuple_literals_size());
   1819       }
   1820       continue;
   1821     }
   1822 
   1823     TF_RET_CHECK(ShapeUtil::IsArray(piece.subshape()));
   1824     TF_RETURN_IF_ERROR(piece.CopyFromProto(*proto_element));
   1825   }
   1826   return std::move(literal);
   1827 }
   1828 
   1829 const void* Literal::untyped_data(const ShapeIndex& shape_index) const {
   1830   return piece(shape_index).untyped_data();
   1831 }
   1832 
   1833 void* Literal::untyped_data(const ShapeIndex& shape_index) {
   1834   return piece(shape_index).untyped_data();
   1835 }
   1836 
   1837 int64 Literal::size_bytes(const ShapeIndex& shape_index) const {
   1838   return piece(shape_index).size_bytes();
   1839 }
   1840 
   1841 string Literal::GetR1U8AsString() const {
   1842   CHECK(ShapeUtil::IsArray(shape()));
   1843   CHECK_EQ(ShapeUtil::Rank(shape()), 1);
   1844   CHECK_EQ(shape().element_type(), U8);
   1845   return string(tensorflow::bit_cast<const char*>(data<uint8>().data()),
   1846                 ShapeUtil::ElementsIn(shape()));
   1847 }
   1848 
   1849 /* static */ const LiteralView LiteralView::Create(
   1850     const Literal& literal, const ShapeIndex& view_root) {
   1851   return LiteralView(literal, view_root);
   1852 }
   1853 
   1854 LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) {
   1855   shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root);
   1856   pieces_ = ShapeTree<Piece>(shape_);
   1857   owns_buffers_ = false;
   1858   for (auto& pair : pieces_) {
   1859     const ShapeIndex& index = pair.first;
   1860     Piece& piece = pair.second;
   1861 
   1862     ShapeIndex src_index = view_root;
   1863     for (int64 i : index) {
   1864       src_index.push_back(i);
   1865     }
   1866     const Piece& src_piece = literal.piece(src_index);
   1867     piece.set_buffer(src_piece.buffer());
   1868     piece.set_sparse_indices(src_piece.sparse_indices());
   1869     piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
   1870   }
   1871 }
   1872 
   1873 LiteralView::~LiteralView() {}
   1874 
   1875 LiteralView::LiteralView(const LiteralView& other) { CopyFrom(other); }
   1876 
   1877 LiteralView& LiteralView::operator=(const LiteralView& other) {
   1878   CopyFrom(other);
   1879   return *this;
   1880 }
   1881 
   1882 void LiteralView::CopyFrom(const LiteralView& other) {
   1883   // We can't use the default copy-constructor/copy-assignment because
   1884   // Piece::subshape_ points to subshapes within the Shape of the owning
   1885   // Literal/LiteralView.
   1886   shape_ = other.shape();
   1887   pieces_ = other.pieces_;
   1888   for (auto& pair : pieces_) {
   1889     const ShapeIndex& index = pair.first;
   1890     Piece& piece = pair.second;
   1891     piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
   1892   }
   1893   owns_buffers_ = false;
   1894 }
   1895 
   1896 }  // namespace xla
   1897