Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/literal.h"
     17 
     18 #include <vector>
     19 
     20 #include "absl/base/casts.h"
     21 #include "absl/memory/memory.h"
     22 #include "absl/strings/match.h"
     23 #include "absl/strings/str_cat.h"
     24 #include "tensorflow/compiler/tf2xla/shape_util.h"
     25 #include "tensorflow/compiler/xla/array3d.h"
     26 #include "tensorflow/compiler/xla/array4d.h"
     27 #include "tensorflow/compiler/xla/layout_util.h"
     28 #include "tensorflow/compiler/xla/literal_util.h"
     29 #include "tensorflow/compiler/xla/shape_util.h"
     30 #include "tensorflow/compiler/xla/test.h"
     31 #include "tensorflow/compiler/xla/types.h"
     32 #include "tensorflow/core/lib/core/status_test_util.h"
     33 #include "tensorflow/core/platform/macros.h"
     34 #include "tensorflow/core/platform/types.h"
     35 
     36 namespace xla {
     37 namespace {
     38 
     39 using ::testing::ElementsAre;
     40 using ::testing::HasSubstr;
     41 
     42 class LiteralUtilTest : public ::testing::Test {
     43  protected:
     44   LiteralUtilTest() {
     45     Array4D<float> arr4d({
     46         // clang-format off
     47       {  // i0=0
     48           {  // i1=0
     49               {1, 2, 3},  // i2=0
     50               {4, 5, 6},  // i2=1
     51               {7, 8, 9},  // i2=2
     52           },
     53           {  // i1=1
     54               {11, 12, 13},
     55               {14, 15, 16},
     56               {17, 18, 19},
     57           },
     58       },
     59       {  // i0=1
     60           {  // i1=0
     61               {101, 102, 103},
     62               {104, 105, 106},
     63               {107, 108, 109},
     64           },
     65           {  // i1=1
     66               {201, 202, 203},  // i2=0
     67               {204, 205, 206},  // i2=1
     68               {207, 208, 209},  // i2=2
     69           },
     70       },
     71         // clang-format on
     72     });
     73 
     74     layout_r2_dim0major_ = LayoutUtil::MakeLayout({1, 0});
     75     layout_r2_dim0minor_ = LayoutUtil::MakeLayout({0, 1});
     76     layout_r3_dim0major_ = LayoutUtil::MakeLayout({2, 1, 0});
     77     layout_r3_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2});
     78     layout_r4_dim0major_ = LayoutUtil::MakeLayout({3, 2, 1, 0});
     79     layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3});
     80 
     81     literal_r4_2x2x3x3_dim0major_ =
     82         LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
     83                                                           layout_r4_dim0major_);
     84     literal_r4_2x2x3x3_dim0minor_ =
     85         LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
     86                                                           layout_r4_dim0minor_);
     87   }
     88 
     89   Layout layout_r2_dim0major_;
     90   Layout layout_r2_dim0minor_;
     91   Layout layout_r3_dim0major_;
     92   Layout layout_r3_dim0minor_;
     93   Layout layout_r4_dim0major_;
     94   Layout layout_r4_dim0minor_;
     95   Literal literal_r4_2x2x3x3_dim0major_;
     96   Literal literal_r4_2x2x3x3_dim0minor_;
     97 };
     98 
     99 TEST_F(LiteralUtilTest, LiteralScalarToString) {
    100   auto true_lit = LiteralUtil::CreateR0<bool>(true);
    101   EXPECT_EQ("pred[] true", true_lit.ToString());
    102 
    103   auto false_lit = LiteralUtil::CreateR0<bool>(false);
    104   EXPECT_EQ("pred[] false", false_lit.ToString());
    105 
    106   auto u32_lit = LiteralUtil::CreateR0<uint32>(42);
    107   EXPECT_EQ("u32[] 42", u32_lit.ToString());
    108 
    109   auto s32_lit = LiteralUtil::CreateR0<int32>(-999);
    110   EXPECT_EQ("s32[] -999", s32_lit.ToString());
    111 
    112   auto f32_lit = LiteralUtil::CreateR0<float>(3.14f);
    113   EXPECT_EQ("f32[] 3.14", f32_lit.ToString());
    114 
    115   auto f16_lit = LiteralUtil::CreateR0<half>(static_cast<half>(0.5f));
    116   EXPECT_EQ("f16[] 0.5", f16_lit.ToString());
    117 
    118   auto c64_lit = LiteralUtil::CreateR0<complex64>({3.14f, 2.78f});
    119   EXPECT_EQ("c64[] (3.14, 2.78)", c64_lit.ToString());
    120 
    121   auto c128_lit = LiteralUtil::CreateR0<complex128>({3.14f, 2.78f});
    122   EXPECT_EQ("c128[] (3.14, 2.78)", c128_lit.ToString());
    123 
    124   auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
    125   EXPECT_EQ("bf16[] 0.5", bf16_lit.ToString());
    126 
    127   // 3.14 will be rounded to 3.14062 in bfloat16 format.
    128   auto bf16_lit_truncated =
    129       LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
    130   ASSERT_EQ("bf16[] 3.14062", bf16_lit_truncated.ToString());
    131 
    132   auto bf16_lit_truncated2 =
    133       LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
    134   EXPECT_EQ("bf16[] 9", bf16_lit_truncated2.ToString());
    135 }
    136 
    137 TEST_F(LiteralUtilTest, LiteralVectorToString) {
    138   auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
    139   EXPECT_EQ("pred[3] {1, 0, 1}", pred_vec.ToString());
    140 }
    141 
    142 TEST_F(LiteralUtilTest, R2ToString) {
    143   const auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}});
    144   const string expected = R"(s32[3,2] {
    145   { 1, 2 },
    146   { 3, 4 },
    147   { 5, 6 }
    148 })";
    149   EXPECT_EQ(expected, literal.ToString());
    150 }
    151 
    152 TEST_F(LiteralUtilTest, R3ToString) {
    153   const auto literal =
    154       LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
    155   const string expected = R"(s32[3,2,1] {
    156 {
    157   {1},
    158   {2}
    159 },
    160 {
    161   {3},
    162   {4}
    163 },
    164 {
    165   {5},
    166   {6}
    167 }
    168 })";
    169   EXPECT_EQ(expected, literal.ToString());
    170 }
    171 
    172 TEST_F(LiteralUtilTest, R6ToString) {
    173   const auto literal =
    174       LiteralUtil::CreateFromDimensions(S32, {2, 2, 1, 1, 1, 2});
    175   const string expected = R"(s32[2,2,1,1,1,2] {
    176 { /*i0=0*/
    177 { /*i1=0*/
    178 { /*i2=0*/
    179 { /*i3=0*/
    180   { 0, 0 }
    181 }
    182 }
    183 },
    184 { /*i1=1*/
    185 { /*i2=0*/
    186 { /*i3=0*/
    187   { 0, 0 }
    188 }
    189 }
    190 }
    191 },
    192 { /*i0=1*/
    193 { /*i1=0*/
    194 { /*i2=0*/
    195 { /*i3=0*/
    196   { 0, 0 }
    197 }
    198 }
    199 },
    200 { /*i1=1*/
    201 { /*i2=0*/
    202 { /*i3=0*/
    203   { 0, 0 }
    204 }
    205 }
    206 }
    207 }
    208 })";
    209   EXPECT_EQ(expected, literal.ToString());
    210 }
    211 
    212 TEST_F(LiteralUtilTest, TupleToString) {
    213   auto scalar = LiteralUtil::CreateR0<float>(1.0);
    214   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
    215   auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
    216   const string expected = R"((
    217 f32[] 1,
    218 f32[2,2] {
    219   { 1, 2 },
    220   { 3, 4 }
    221 }
    222 ))";
    223   EXPECT_EQ(expected, tuple.ToString());
    224 }
    225 
    226 TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
    227   // clang-format off
    228   Array3D<float> array_3d({
    229     {{1.0f, 2.0f},
    230      {3.0f, 4.0f},
    231      {5.0f, 6.0f}},
    232     {{7.0f, 8.0f},
    233      {9.0f, 10.0f},
    234      {11.0f, 12.0f}},
    235   });
    236   // clang-format on
    237 
    238   auto literal = LiteralUtil::CreateR3FromArray3D(array_3d);
    239   EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2));
    240   string result = literal.ToString();
    241   const string expected = R"(f32[2,3,2] {
    242 {
    243   { 1, 2 },
    244   { 3, 4 },
    245   { 5, 6 }
    246 },
    247 {
    248   { 7, 8 },
    249   { 9, 10 },
    250   { 11, 12 }
    251 }
    252 })";
    253   EXPECT_EQ(expected, result);
    254 }
    255 
    256 TEST_F(LiteralUtilTest, CreateSparse) {
    257   std::vector<int64> dimensions = {8, 8, 8};
    258   Array2D<int64> indices = {
    259       {3, 4, 5},
    260       {1, 2, 3},
    261       {2, 3, 4},
    262       {3, 5, 6},
    263   };
    264   std::vector<int64> values = {7, 8, 9, 10};
    265   auto literal = LiteralUtil::CreateSparse<int64>(
    266       dimensions, SparseIndexArray(indices.n1() + 3, indices), values);
    267 
    268   Array2D<int64> expected_indices = {
    269       {1, 2, 3},
    270       {2, 3, 4},
    271       {3, 4, 5},
    272       {3, 5, 6},
    273   };
    274   std::vector<int64> expected_values = {8, 9, 7, 10};
    275 
    276   EXPECT_EQ(literal.sparse_indices()->data(),
    277             absl::Span<const int64>(expected_indices.data(),
    278                                     expected_indices.num_elements()));
    279   EXPECT_EQ(literal.data<int64>(), absl::Span<const int64>(expected_values));
    280 
    281   // Serialize then deserialize and verify the resulting literal.
    282   TF_ASSERT_OK_AND_ASSIGN(Literal literal_from_proto,
    283                           Literal::CreateFromProto(literal.ToProto()));
    284 
    285   EXPECT_EQ(literal_from_proto.sparse_indices()->data(),
    286             absl::Span<const int64>(expected_indices.data(),
    287                                     expected_indices.num_elements()));
    288   EXPECT_EQ(literal_from_proto.data<int64>(),
    289             absl::Span<const int64>(expected_values));
    290 }
    291 
    292 TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
    293   // clang-format off
    294   auto literal = LiteralUtil::CreateR4Projected<float>({
    295     {1, 2},
    296     {1001, 1002},
    297     {2001, 2002},
    298   }, /*projection_p=*/1, /*projection_z=*/2);
    299   // clang-format on
    300   EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2));
    301   string result = literal.ToString();
    302   const string expected = R"(f32[1,2,3,2] {
    303 { /*i0=0*/
    304 { /*i1=0*/
    305   { 1, 2 },
    306   { 1001, 1002 },
    307   { 2001, 2002 }
    308 },
    309 { /*i1=1*/
    310   { 1, 2 },
    311   { 1001, 1002 },
    312   { 2001, 2002 }
    313 }
    314 }
    315 })";
    316   EXPECT_EQ(expected, result);
    317 }
    318 
    319 TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
    320   EXPECT_THAT(literal_r4_2x2x3x3_dim0major_.shape().dimensions(),
    321               ElementsAre(2, 2, 3, 3));
    322   string result = literal_r4_2x2x3x3_dim0major_.ToString();
    323   const string expected = R"(f32[2,2,3,3] {
    324 { /*i0=0*/
    325 { /*i1=0*/
    326   { 1, 2, 3 },
    327   { 4, 5, 6 },
    328   { 7, 8, 9 }
    329 },
    330 { /*i1=1*/
    331   { 11, 12, 13 },
    332   { 14, 15, 16 },
    333   { 17, 18, 19 }
    334 }
    335 },
    336 { /*i0=1*/
    337 { /*i1=0*/
    338   { 101, 102, 103 },
    339   { 104, 105, 106 },
    340   { 107, 108, 109 }
    341 },
    342 { /*i1=1*/
    343   { 201, 202, 203 },
    344   { 204, 205, 206 },
    345   { 207, 208, 209 }
    346 }
    347 }
    348 })";
    349   EXPECT_EQ(expected, result);
    350 }
    351 
    352 TEST_F(LiteralUtilTest, EachCellR2F32) {
    353   // clang-format off
    354   auto literal = LiteralUtil::CreateR2<float>({
    355     {3.1f, 4.2f},
    356     {9.3f, 12.4f},
    357   });
    358   // clang-format on
    359   std::vector<std::tuple<int64, int64, string>> seen;
    360   literal.EachCellAsString(
    361       [&seen](absl::Span<const int64> indices, const string& value) {
    362         seen.emplace_back(indices[0], indices[1], value);
    363       });
    364 
    365   using Elem = std::tuple<int64, int64, string>;
    366   std::vector<Elem> expected = {Elem(0, 0, "3.1"), Elem(0, 1, "4.2"),
    367                                 Elem(1, 0, "9.3"), Elem(1, 1, "12.4")};
    368   EXPECT_EQ(expected, seen);
    369 }
    370 
    371 TEST_F(LiteralUtilTest, ScalarEquality) {
    372   // Test equality with scalars.
    373   auto f32_42 = LiteralUtil::CreateR0<float>(42.0);
    374   auto f32_42_clone = LiteralUtil::CreateR0<float>(42.0);
    375 
    376   EXPECT_EQ(f32_42, f32_42);
    377   EXPECT_EQ(f32_42, f32_42_clone);
    378 
    379   auto f32_123 = LiteralUtil::CreateR0<float>(123.0);
    380   EXPECT_NE(f32_42, f32_123);
    381 
    382   auto f64_42 = LiteralUtil::CreateR0<double>(42.0);
    383   EXPECT_NE(f32_42, f64_42);
    384 }
    385 
    386 TEST_F(LiteralUtilTest, NonScalarEquality) {
    387   // Test equality with nonscalars.
    388   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
    389   auto matrix_clone = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
    390   auto matrix_different =
    391       LiteralUtil::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}});
    392   auto vector_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
    393   auto scalar = LiteralUtil::CreateR0<float>(1.0);
    394   Literal nil(ShapeUtil::MakeNil());
    395 
    396   EXPECT_EQ(matrix, matrix);
    397   EXPECT_EQ(matrix, matrix_clone);
    398   EXPECT_NE(matrix, matrix_different);
    399   EXPECT_NE(matrix, vector_literal);
    400   EXPECT_NE(matrix, scalar);
    401   EXPECT_NE(matrix, nil);
    402   EXPECT_EQ(nil, nil);
    403 }
    404 
    405 TEST_F(LiteralUtilTest, TokenEquality) {
    406   auto token0 = LiteralUtil::CreateToken();
    407   auto token1 = LiteralUtil::CreateToken();
    408   auto scalar = LiteralUtil::CreateR0<float>(1.0);
    409 
    410   EXPECT_EQ(token0, token1);
    411   EXPECT_NE(token0, scalar);
    412 
    413   EXPECT_EQ(LiteralUtil::MakeTuple({&token0}),
    414             LiteralUtil::MakeTuple({&token0}));
    415   EXPECT_EQ(LiteralUtil::MakeTuple({&token0, &scalar}),
    416             LiteralUtil::MakeTuple({&token1, &scalar}));
    417   EXPECT_NE(LiteralUtil::MakeTuple({&token0, &scalar}),
    418             LiteralUtil::MakeTuple({&scalar, &token1}));
    419 }
    420 
    421 TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
    422   // Test equality with literals which have different layouts.
    423   Literal colmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
    424   colmajor.Set<float>({0, 0}, 1.0);
    425   colmajor.Set<float>({0, 1}, 2.0);
    426   colmajor.Set<float>({1, 0}, 3.0);
    427   colmajor.Set<float>({1, 1}, 4.0);
    428 
    429   Literal rowmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
    430   rowmajor.Set<float>({0, 0}, 1.0);
    431   rowmajor.Set<float>({0, 1}, 2.0);
    432   rowmajor.Set<float>({1, 0}, 3.0);
    433   rowmajor.Set<float>({1, 1}, 4.0);
    434 
    435   EXPECT_EQ(rowmajor, colmajor);
    436 }
    437 
    438 TEST_F(LiteralUtilTest, TupleEquality) {
    439   // Test equality with tuples.
    440   auto scalar = LiteralUtil::CreateR0<float>(1.0);
    441   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
    442   auto tuple1 = LiteralUtil::MakeTuple({&scalar, &matrix});
    443 
    444   // Tuple with the same elements. One element is shared with the original
    445   // tuple, the other is a clone of the element in the original tuple.
    446   auto scalar_clone = LiteralUtil::CreateR0<float>(1.0);
    447   auto tuple2 = LiteralUtil::MakeTuple({&scalar_clone, &matrix});
    448   EXPECT_EQ(tuple1, tuple2);
    449 
    450   // Tuple with elements reversed.
    451   auto reversed_tuple = LiteralUtil::MakeTuple({&matrix, &scalar});
    452   EXPECT_NE(tuple1, reversed_tuple);
    453 
    454   // Tuple with different value.
    455   auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
    456   auto different_tuple = LiteralUtil::MakeTuple({&scalar_42, &matrix});
    457   EXPECT_NE(tuple1, different_tuple);
    458 }
    459 
    460 TEST_F(LiteralUtilTest, C64Equality) {
    461   // Test equality with tuples.
    462   auto vector = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
    463 
    464   // Tuple with the same elements. One element is shared with the original
    465   // tuple, the other is a clone of the element in the original tuple.
    466   auto vector_clone =
    467       LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
    468   EXPECT_EQ(vector, vector_clone);
    469 
    470   auto vector_reversed =
    471       LiteralUtil::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
    472   EXPECT_NE(vector, vector_reversed);
    473 }
    474 
    475 TEST_F(LiteralUtilTest, C128Equality) {
    476   // Test equality with tuples.
    477   auto vector = LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
    478 
    479   // Tuple with the same elements. One element is shared with the original
    480   // tuple, the other is a clone of the element in the original tuple.
    481   auto vector_clone =
    482       LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
    483   EXPECT_EQ(vector, vector_clone);
    484 
    485   auto vector_reversed =
    486       LiteralUtil::CreateR1<complex128>({{3.0, 4.0}, {1.0, 2.0}});
    487   EXPECT_NE(vector, vector_reversed);
    488 }
    489 
    490 TEST_F(LiteralUtilTest, IsAllTuple) {
    491   auto element1 = LiteralUtil::CreateR0<float>(0.0);
    492   auto element2 = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
    493   auto tuple = LiteralUtil::MakeTuple({&element1, &element1});
    494 
    495   // Tuples should always return false for IsAll.
    496   EXPECT_FALSE(tuple.IsAll(0));
    497   EXPECT_FALSE(tuple.IsAll(1));
    498 }
    499 
    500 // Verifies that CreateFromShape works for tuples.
    501 TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
    502   auto scalar = LiteralUtil::CreateR0<float>(0.0);
    503   auto matrix = LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}});
    504   auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
    505 
    506   auto x = Literal::CreateFromShape(tuple.shape());
    507   EXPECT_EQ(tuple, x);
    508 }
    509 
    510 TEST_F(LiteralUtilTest, IsAll) {
    511   EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false).IsAll(0));
    512   EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true).IsAll(1));
    513   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(1));
    514   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(2));
    515   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(0));
    516   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(2));
    517   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(-1));
    518 
    519   // We shouldn't reinterpret int8_min as an unsigned type and then decide that
    520   // it is equal to 255.
    521   auto int8_min = std::numeric_limits<int8>::min();
    522   EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255).IsAll(int8_min));
    523 
    524   EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0).IsAll(42));
    525   EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001).IsAll(42));
    526 
    527   EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100}).IsAll(100));
    528   EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001}).IsAll(100));
    529 
    530   EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}}).IsAll(8));
    531   EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}}).IsAll(8));
    532   EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}}).IsAll(8));
    533 
    534   half h8(8.0f);
    535   half h9(9.0f);
    536   EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}}).IsAll(8));
    537   EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}}).IsAll(8));
    538   EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}}).IsAll(8));
    539 
    540   bfloat16 b8(8.0f);
    541   bfloat16 b9(9.0f);
    542 
    543   EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}}).IsAll(8));
    544   EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}}).IsAll(8));
    545   EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}}).IsAll(8));
    546 
    547   // 9.001 will be truncated to 9.0
    548   bfloat16 b91(9.001f);
    549   bfloat16 b90(9.00f);
    550   EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}}).IsAll(9.0));
    551 
    552   complex64 c8_9 = {8, 9};
    553   EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAll(8));
    554 
    555   auto uint64_max = std::numeric_limits<uint64>::max();
    556   EXPECT_FALSE(LiteralUtil::CreateR2<uint64>(
    557                    {{uint64_max, uint64_max}, {uint64_max, uint64_max}})
    558                    .IsAll(-1));
    559 }
    560 
    561 TEST_F(LiteralUtilTest, IsAllFloat) {
    562   // IsAllFloat always returns false when the literal is not floating-point.
    563   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllFloat(0));
    564   EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllFloat(0));
    565   EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllFloat(0));
    566   EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllFloat(0));
    567 
    568   EXPECT_TRUE(LiteralUtil::CreateR0<float>(0).IsAllFloat(0));
    569   EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5).IsAllFloat(.5));
    570   EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.5));
    571   EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.49));
    572   EXPECT_FALSE(
    573       LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
    574   EXPECT_TRUE(LiteralUtil::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})
    575                   .IsAllFloat(.5));
    576 
    577   EXPECT_TRUE(LiteralUtil::CreateR0<double>(0).IsAllFloat(0));
    578   EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5).IsAllFloat(.5));
    579   EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.5));
    580   EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.49));
    581   EXPECT_FALSE(
    582       LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
    583 }
    584 
    585 TEST_F(LiteralUtilTest, IsAllComplex) {
    586   // IsAllComplex always returns false when the literal is not complex.
    587   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllComplex(0));
    588   EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllComplex(0));
    589   EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllComplex(0));
    590   EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllComplex(0));
    591   EXPECT_FALSE(LiteralUtil::CreateR0<float>(0).IsAllComplex(0));
    592   EXPECT_FALSE(LiteralUtil::CreateR0<double>(0).IsAllComplex(0));
    593 
    594   complex64 c8_9 = {8, 9};
    595   complex64 c7_9 = {7, 9};
    596   EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})
    597                   .IsAllComplex({8.0f, 9.0f}));
    598   EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})
    599                    .IsAllComplex({8.0f, 9.0f}));
    600   EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c7_9}})
    601                    .IsAllComplex({8.0f, 9.0f}));
    602 }
    603 
    604 TEST_F(LiteralUtilTest, IsAllFirst) {
    605   // IsAllComplex always returns false when the literal is not complex.
    606   EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true}).IsAllFirst());
    607   EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false}).IsAllFirst());
    608   EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2}).IsAllFirst());
    609   EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5}).IsAllFirst());
    610   EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2}).IsAllFirst());
    611   EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5}).IsAllFirst());
    612   EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2}).IsAllFirst());
    613   EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5}).IsAllFirst());
    614   EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2}).IsAllFirst());
    615 
    616   complex64 c8_9 = {8, 9};
    617   complex64 c7_9 = {7, 9};
    618   EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAllFirst());
    619   EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}}).IsAllFirst());
    620 }
    621 
    622 TEST_F(LiteralUtilTest, IsZero) {
    623   auto scalar_zero = LiteralUtil::CreateR0<float>(0.0f);
    624   auto scalar_one = LiteralUtil::CreateR0<float>(1.0f);
    625   EXPECT_TRUE(scalar_zero.IsZero({}));
    626   EXPECT_FALSE(scalar_one.IsZero({}));
    627 
    628   auto array = LiteralUtil::CreateR2<uint32>({{1, 2, 0, 3}, {1, 0, 1, 2}});
    629   EXPECT_FALSE(array.IsZero({0, 1}));
    630   EXPECT_TRUE(array.IsZero({0, 2}));
    631   EXPECT_TRUE(array.IsZero({1, 1}));
    632   EXPECT_FALSE(array.IsZero({1, 2}));
    633 
    634   auto complex_zero = LiteralUtil::CreateR0<complex64>(0.0f);
    635   auto complex_nonzero = LiteralUtil::CreateR0<complex64>(0.5f);
    636   EXPECT_TRUE(complex_zero.IsZero({}));
    637   EXPECT_FALSE(complex_nonzero.IsZero({}));
    638 }
    639 
    640 template <typename T>
    641 class LiteralUtilTestTemplated : public ::testing::Test {};
    642 
    643 using TestedTypes = ::testing::Types<float, int32, uint32, complex64>;
    644 TYPED_TEST_SUITE(LiteralUtilTestTemplated, TestedTypes);
    645 
    646 TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
    647   // Make a non-integer for floating point types.
    648   TypeParam half = TypeParam(1) / TypeParam(2);
    649   auto data = LiteralUtil::CreateR2<TypeParam>({{half, 2}, {3, 4}});
    650   const Layout layout01 = LayoutUtil::MakeLayout({0, 1});
    651   const Layout layout10 = LayoutUtil::MakeLayout({1, 0});
    652 
    653   auto data01 = data.Relayout(layout01);
    654   EXPECT_TRUE(LayoutUtil::Equal(data01.shape().layout(), layout01));
    655   EXPECT_EQ(data, data01);
    656 
    657   auto data10 = data.Relayout(layout10);
    658   EXPECT_TRUE(LayoutUtil::Equal(data10.shape().layout(), layout10));
    659   EXPECT_EQ(data, data10);
    660 }
    661 
    662 TEST_F(LiteralUtilTest, ReshapeR0) {
    663   auto original = LiteralUtil::CreateR0<float>(1.7f);
    664   auto reshape = original.Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
    665   EXPECT_EQ(original, reshape);
    666 }
    667 
    668 TEST_F(LiteralUtilTest, ReshapeR4) {
    669   // clang-format off
    670   // F32[1x3x2x4]
    671   auto original = LiteralUtil::CreateR4WithLayout<float>({{
    672      {{10, 11, 12, 13}, {14, 15, 16, 17}},
    673      {{18, 19, 20, 21}, {22, 23, 24, 25}},
    674      {{26, 27, 28, 29}, {30, 31, 32, 33}},
    675   }}, layout_r4_dim0major_);
    676   // F32[1x3x4x2]
    677   auto expected = LiteralUtil::CreateR3WithLayout<float>({
    678     {{10, 11}, {12, 13}, {14, 15}, {16, 17}},
    679     {{18, 19}, {20, 21}, {22, 23}, {24, 25}},
    680     {{26, 27}, {28, 29}, {30, 31}, {32, 33}},
    681   }, layout_r3_dim0major_);
    682   // clang-format on
    683   auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
    684 
    685   EXPECT_EQ(expected, reshape);
    686 }
    687 
    688 TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
    689   // clang-format off
    690   // F32[1x3x2x4]
    691   auto original = LiteralUtil::CreateR4WithLayout<float>({{
    692      {{10, 11, 12, 13}, {14, 15, 16, 17}},
    693      {{18, 19, 20, 21}, {22, 23, 24, 25}},
    694      {{26, 27, 28, 29}, {30, 31, 32, 33}},
    695   }}, layout_r4_dim0minor_);
    696   // F32[1x3x4x2]
    697   auto expected = LiteralUtil::CreateR3WithLayout<float>({
    698     {{10, 11}, {12, 13}, {14, 15}, {16, 17}},
    699     {{18, 19}, {20, 21}, {22, 23}, {24, 25}},
    700     {{26, 27}, {28, 29}, {30, 31}, {32, 33}},
    701   }, layout_r3_dim0major_);
    702   // clang-format on
    703   auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
    704 
    705   EXPECT_EQ(expected, reshape);
    706 }
    707 
    708 TEST_F(LiteralUtilTest, TransposeR0) {
    709   auto original = LiteralUtil::CreateR0<float>(1.7f);
    710   auto reshape = original.Transpose(/*permutation=*/{});
    711   EXPECT_EQ(original, reshape);
    712 }
    713 
    714 TEST_F(LiteralUtilTest, TransposeR4) {
    715   // clang-format off
    716   // F32[1x3x2x4]
    717   auto original = LiteralUtil::CreateR4<float>({{
    718      {{10, 11, 12, 13}, {14, 15, 16, 17}},
    719      {{18, 19, 20, 21}, {22, 23, 24, 25}},
    720      {{26, 27, 28, 29}, {30, 31, 32, 33}},
    721   }});
    722   // clang-format on
    723   auto reshape = original.Transpose(/*permutation=*/{2, 3, 0, 1});
    724 
    725   reshape.EachCell<float>([&](absl::Span<const int64> indices, float value) {
    726     EXPECT_EQ(value, original.Get<float>(
    727                          {indices[2], indices[3], indices[0], indices[1]}));
    728   });
    729 }
    730 
    731 TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
    732   // Tests that using Relayout on an array is equivalent to creating it in the
    733   // target layout in the first place.
    734   auto dim0minor_relaid_to_dim0major =
    735       literal_r4_2x2x3x3_dim0minor_.Relayout(layout_r4_dim0major_);
    736   EXPECT_EQ(literal_r4_2x2x3x3_dim0major_, dim0minor_relaid_to_dim0major);
    737 
    738   auto dim0major_relaid_to_dim0minor =
    739       literal_r4_2x2x3x3_dim0major_.Relayout(layout_r4_dim0minor_);
    740   EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor);
    741 }
    742 
    743 TEST_F(LiteralUtilTest, TestR2LinearLayout) {
    744   // Test expected memory layout of R2 dim0-minor (column-major) literal.
    745   auto mat_dim0minor = LiteralUtil::CreateR2WithLayout<int32>(
    746       {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
    747   EXPECT_EQ(mat_dim0minor.element_count(), 6);
    748   EXPECT_THAT(mat_dim0minor.data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
    749 
    750   // Test expected memory layout when using Relayout to row major.
    751   auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_);
    752   EXPECT_THAT(relaid_mat_to_dim0major.data<int32>(),
    753               ElementsAre(1, 2, 3, 4, 5, 6));
    754 
    755   // Test expected memory layout of R2 created with dim0-major (row-major).
    756   auto mat_dim0major = LiteralUtil::CreateR2WithLayout<int32>(
    757       {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
    758   EXPECT_EQ(mat_dim0major.element_count(), 6);
    759   EXPECT_THAT(mat_dim0major.data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
    760 
    761   // Test expected memory layout when using Relayout to column major.
    762   auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_);
    763   EXPECT_THAT(relaid_mat_to_dim0minor.data<int32>(),
    764               ElementsAre(1, 4, 2, 5, 3, 6));
    765 }
    766 
    767 TEST_F(LiteralUtilTest, TestR3LinearLayout) {
    768   // Test expected memory layout of R3 dim0-minor (column-major) literal.
    769   Array3D<int> arr3d(
    770       // clang-format off
    771         {
    772           {
    773             {1, 2, 3},
    774             {4, 5, 6},
    775           },
    776           {
    777             {7, 8, 9},
    778             {10, 11, 12},
    779           },
    780       });  // clang-format on
    781   auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
    782       arr3d, layout_r3_dim0minor_);
    783 
    784   EXPECT_EQ(lit_dim0minor.element_count(), 12);
    785   std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
    786   EXPECT_THAT(lit_dim0minor.data<int32>(),
    787               testing::ElementsAreArray(expected_dim0minor));
    788 
    789   // Test expected memory layout when using Relayout to row major.
    790   auto relaid_lit_to_dim0major = lit_dim0minor.Relayout(layout_r3_dim0major_);
    791   std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
    792   EXPECT_THAT(relaid_lit_to_dim0major.data<int32>(),
    793               testing::ElementsAreArray(expected_dim0major));
    794 
    795   // Test expected memory layout of R3 created with dim0-major (row-major).
    796   auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
    797       arr3d, layout_r3_dim0major_);
    798   EXPECT_EQ(lit_dim0major.element_count(), 12);
    799   EXPECT_THAT(lit_dim0major.data<int32>(),
    800               testing::ElementsAreArray(expected_dim0major));
    801 
    802   // Test expected memory layout when using Relayout to column major.
    803   auto relaid_lit_to_dim0minor = lit_dim0major.Relayout(layout_r3_dim0minor_);
    804   EXPECT_THAT(relaid_lit_to_dim0minor.data<int32>(),
    805               testing::ElementsAreArray(expected_dim0minor));
    806 }
    807 
    808 TEST_F(LiteralUtilTest, SliceR0S32) {
    809   auto input = LiteralUtil::CreateR0<int32>(1);
    810   auto result = input.Slice({}, {});
    811   EXPECT_EQ(input, result);
    812 }
    813 
    814 TEST_F(LiteralUtilTest, SliceR1F32) {
    815   auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
    816   auto result = input.Slice({3}, {4});
    817   auto expected = LiteralUtil::CreateR1<float>({4.0});
    818   EXPECT_EQ(expected, result);
    819 }
    820 
    821 TEST_F(LiteralUtilTest, SliceR2U32) {
    822   auto input_3x4 = LiteralUtil::CreateR2<uint32>(
    823       {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
    824   auto result = input_3x4.Slice({0, 2}, {2, 4});
    825   auto expected = LiteralUtil::CreateR2<uint32>({{3, 4}, {7, 8}});
    826   EXPECT_EQ(expected, result);
    827 }
    828 
    829 TEST_F(LiteralUtilTest, SliceR3U32Full) {
    830   auto input_2x3x2 = LiteralUtil::CreateR3<uint32>(
    831       {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
    832   auto result = input_2x3x2.Slice({0, 0, 0}, {2, 3, 2});
    833   EXPECT_EQ(input_2x3x2, result);
    834 }
    835 
    836 TEST_F(LiteralUtilTest, PopulateR1S64) {
    837   Literal output(ShapeUtil::MakeShape(S64, {1}));
    838   output.PopulateR1<int64>({77});
    839   auto expected = LiteralUtil::CreateR1<int64>({77});
    840   EXPECT_EQ(output, expected);
    841 }
    842 
    843 TEST_F(LiteralUtilTest, PopulateR1U64) {
    844   Literal output(ShapeUtil::MakeShape(U64, {2}));
    845   output.PopulateR1<uint64>({{77, 88}});
    846   auto expected = LiteralUtil::CreateR1<uint64>({{77, 88}});
    847   EXPECT_EQ(output, expected);
    848 }
    849 
    850 TEST_F(LiteralUtilTest, PopulateR1C64) {
    851   Literal output(ShapeUtil::MakeShape(C64, {1}));
    852   output.PopulateR1<complex64>({{77, 88}});
    853   auto expected = LiteralUtil::CreateR1<complex64>({{77, 88}});
    854   EXPECT_EQ(output, expected);
    855 }
    856 
    857 TEST_F(LiteralUtilTest, PopulateR1C128) {
    858   Literal output(ShapeUtil::MakeShape(C128, {1}));
    859   output.PopulateR1<complex128>({{77, 88}});
    860   auto expected = LiteralUtil::CreateR1<complex128>({{77, 88}});
    861   EXPECT_EQ(output, expected);
    862 }
    863 
    864 TEST_F(LiteralUtilTest, PopulateR2C64) {
    865   Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
    866   output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
    867   auto expected =
    868       LiteralUtil::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
    869   EXPECT_EQ(output, expected);
    870 }
    871 
    872 TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
    873   Literal output(ShapeUtil::MakeShape(BF16, {}));
    874   bfloat16 h(0.25f);
    875   output.PopulateWithValue<bfloat16>(h);
    876   auto expected = LiteralUtil::CreateR0<bfloat16>(h);
    877   EXPECT_EQ(output, expected);
    878 }
    879 
    880 TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
    881   Literal output(ShapeUtil::MakeShape(BF16, {3}));
    882   bfloat16 h(0.5f);
    883   output.PopulateWithValue<bfloat16>(h);
    884   auto expected = LiteralUtil::CreateR1<bfloat16>({h, h, h});
    885   EXPECT_EQ(output, expected);
    886 }
    887 
    888 TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
    889   Literal output(ShapeUtil::MakeShape(BF16, {2, 2}));
    890   bfloat16 h(2.0f);
    891   output.PopulateWithValue<bfloat16>(h);
    892   auto expected = LiteralUtil::CreateR2<bfloat16>({{h, h}, {h, h}});
    893   EXPECT_EQ(output, expected);
    894 }
    895 
    896 TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
    897   Literal output(ShapeUtil::MakeShape(F32, {}));
    898   output.PopulateWithValue<float>(2.5f);
    899   auto expected = LiteralUtil::CreateR0<float>(2.5f);
    900   EXPECT_EQ(output, expected);
    901 }
    902 
    903 TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
    904   Literal output(ShapeUtil::MakeShape(S64, {3}));
    905   output.PopulateWithValue<int64>(-7);
    906   auto expected = LiteralUtil::CreateR1<int64>({-7, -7, -7});
    907   EXPECT_EQ(output, expected);
    908 }
    909 
    910 TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
    911   Literal output(ShapeUtil::MakeShape(U64, {2, 2}));
    912   output.PopulateWithValue<uint64>(42);
    913   auto expected = LiteralUtil::CreateR2<uint64>({{42, 42}, {42, 42}});
    914   EXPECT_EQ(output, expected);
    915 }
    916 
    917 TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
    918   Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
    919   output.PopulateWithValue<complex64>({4, 2});
    920   auto expected =
    921       LiteralUtil::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
    922   EXPECT_EQ(output, expected);
    923 }
    924 
    925 TEST_F(LiteralUtilTest, PopulateWithValueR2C128) {
    926   Literal output(ShapeUtil::MakeShape(C128, {2, 2}));
    927   output.PopulateWithValue<complex128>({4, 2});
    928   auto expected =
    929       LiteralUtil::CreateR2<complex128>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
    930   EXPECT_EQ(output, expected);
    931 }
    932 
    933 TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
    934   Literal output(ShapeUtil::MakeShape(F16, {}));
    935   half h(0.25f);
    936   output.PopulateWithValue<half>(h);
    937   auto expected = LiteralUtil::CreateR0<half>(h);
    938   EXPECT_EQ(output, expected);
    939 }
    940 
    941 TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
    942   Literal output(ShapeUtil::MakeShape(F16, {3}));
    943   half h(0.5f);
    944   output.PopulateWithValue<half>(h);
    945   auto expected = LiteralUtil::CreateR1<half>({h, h, h});
    946   EXPECT_EQ(output, expected);
    947 }
    948 
    949 TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
    950   Literal output(ShapeUtil::MakeShape(F16, {2, 2}));
    951   half h(2.0f);
    952   output.PopulateWithValue<half>(h);
    953   auto expected = LiteralUtil::CreateR2<half>({{h, h}, {h, h}});
    954   EXPECT_EQ(output, expected);
    955 }
    956 
    957 TEST_F(LiteralUtilTest, ReplicateR2U32) {
    958   auto input = LiteralUtil::CreateR2<uint32>(
    959       {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
    960   auto output = input.Replicate<uint32>(3);
    961   auto expected = LiteralUtil::CreateR3<uint32>(
    962       {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
    963        {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
    964        {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}});
    965   EXPECT_EQ(output, expected);
    966 }
    967 
    968 TEST_F(LiteralUtilTest, CopySliceFrom) {
    969   const int64 dimensions[] = {17, 15, 34, 21};
    970   const int64 layouts[][4] = {
    971       {3, 2, 1, 0}, {0, 2, 1, 3}, {0, 1, 2, 3}, {2, 0, 3, 1}, {1, 3, 0, 2}};
    972   for (const auto& layout : layouts) {
    973     Shape shape = ShapeUtil::MakeShapeWithLayout(
    974         primitive_util::NativeToPrimitiveType<uint32>(), dimensions, layout);
    975 
    976     auto source = Literal::CreateFromShape(shape);
    977     const int64 zero_base[] = {0, 0, 0, 0};
    978     const int64 step[] = {1, 1, 1, 1};
    979     uint32 seqnr = 0;
    980     auto init_proc = [&](absl::Span<const int64> indexes) {
    981       source.Set(indexes, ++seqnr);
    982       return true;
    983     };
    984     ShapeUtil::ForEachIndex(source.shape(), zero_base, dimensions, step,
    985                             init_proc);
    986 
    987     auto blank = Literal::CreateFromShape(shape);
    988     const int64 src_base[] = {3, 1, 5, 7};
    989     const int64 dest_base[] = {6, 4, 12, 2};
    990     const int64 copy_size[] = {7, 8, 11, 9};
    991     TF_EXPECT_OK(blank.CopySliceFrom(source, src_base, dest_base, copy_size));
    992 
    993     std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
    994     std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
    995     bool matched = true;
    996     auto check_proc = [&](absl::Span<const int64> indexes) {
    997       std::copy(indexes.begin(), indexes.end(), source_indexes.begin());
    998       std::transform(source_indexes.begin(), source_indexes.end(), src_base,
    999                      source_indexes.begin(), std::plus<int64>());
   1000       std::copy(indexes.begin(), indexes.end(), blank_indexes.begin());
   1001       std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base,
   1002                      blank_indexes.begin(), std::plus<int64>());
   1003       auto bval = blank.Get<uint32>(blank_indexes);
   1004       matched = (bval != 0 && bval == source.Get<uint32>(source_indexes));
   1005       return matched;
   1006     };
   1007 
   1008     ShapeUtil::ForEachIndex(source.shape(), zero_base, copy_size, step,
   1009                             check_proc);
   1010     EXPECT_TRUE(matched);
   1011   }
   1012 }
   1013 
   1014 TEST_F(LiteralUtilTest, CopyFromScalars) {
   1015   auto zero = LiteralUtil::CreateR0<uint32>(0);
   1016   auto nine = LiteralUtil::CreateR0<uint32>(9);
   1017   TF_EXPECT_OK(zero.CopyFrom(nine));
   1018   EXPECT_EQ(zero, nine);
   1019 
   1020   auto vect = LiteralUtil::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
   1021   TF_EXPECT_OK(zero.CopySliceFrom(vect, {5}, {}, {}));
   1022   EXPECT_EQ(zero.Get<uint32>({}), 17);
   1023   TF_EXPECT_OK(vect.CopySliceFrom(zero, {}, {4}, {}));
   1024   EXPECT_EQ(vect.Get<uint32>({4}), 17);
   1025 }
   1026 
   1027 TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
   1028   const Shape empty_r1_shape = ShapeUtil::MakeShape(F32, {0});
   1029   const auto const_nine = LiteralUtil::CreateR1<float>({9});
   1030   const auto const_empty = Literal::CreateFromShape(empty_r1_shape);
   1031 
   1032   {
   1033     // Source contains dimension with zero elements.
   1034     const auto empty = Literal::CreateFromShape(empty_r1_shape);
   1035     auto nine = LiteralUtil::CreateR1<float>({9});
   1036 
   1037     TF_EXPECT_OK(nine.CopySliceFrom(empty, {0}, {0}, {0}));
   1038     EXPECT_EQ(nine, const_nine);
   1039   }
   1040 
   1041   {
   1042     // Copy 0 element to destination with zero elements.
   1043     auto empty = Literal::CreateFromShape(empty_r1_shape);
   1044     auto nine = LiteralUtil::CreateR1<float>({9});
   1045 
   1046     TF_EXPECT_OK(empty.CopySliceFrom(nine, {0}, {0}, {0}));
   1047     EXPECT_EQ(empty, const_empty);
   1048   }
   1049 }
   1050 
   1051 TEST_F(LiteralUtilTest, CopyFromNilShape) {
   1052   Literal nil_literal0(ShapeUtil::MakeNil());
   1053   Literal nil_literal1(ShapeUtil::MakeNil());
   1054   // This doesn't actually do any copying, but it should succeed.
   1055   TF_ASSERT_OK(nil_literal0.CopyFrom(nil_literal1));
   1056 }
   1057 
   1058 TEST_F(LiteralUtilTest, CopyFromArrays) {
   1059   auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
   1060   auto scalar_123 = LiteralUtil::CreateR0<float>(123.0);
   1061   EXPECT_NE(scalar_42, scalar_123);
   1062   TF_ASSERT_OK(scalar_42.CopyFrom(scalar_123, /*dest_shape_index=*/{},
   1063                                   /*src_shape_index=*/{}));
   1064   EXPECT_EQ(scalar_42, scalar_123);
   1065   EXPECT_EQ(scalar_42.Get<float>({}), 123.0f);
   1066 
   1067   auto matrix_1234 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   1068   auto matrix_5678 = LiteralUtil::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
   1069   EXPECT_NE(matrix_1234, matrix_5678);
   1070   EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 1.0f);
   1071   TF_ASSERT_OK(matrix_1234.CopyFrom(matrix_5678, /*dest_shape_index=*/{},
   1072                                     /*src_shape_index=*/{}));
   1073   EXPECT_EQ(matrix_1234, matrix_5678);
   1074   EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 5.0f);
   1075 }
   1076 
   1077 TEST_F(LiteralUtilTest, CopyFromTuples) {
   1078   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   1079   Literal nil_literal(ShapeUtil::MakeNil());
   1080   Literal inner_elements[] = {LiteralUtil::CreateR0<int32>(42),
   1081                               LiteralUtil::CreateR1<double>({23.0, 44.0})};
   1082   Literal inner_tuple = LiteralUtil::MakeTuple(
   1083       {&inner_elements[0], &inner_elements[1], &nil_literal});
   1084   Literal nested_tuple = LiteralUtil::MakeTuple({&matrix, &inner_tuple});
   1085   // Create a tuple the same shape as the inner tuple of nested_tuple but with
   1086   // different values..
   1087   Literal int32_minus5 = LiteralUtil::CreateR0<int32>(-5);
   1088   Literal double_2_4 = LiteralUtil::CreateR1<double>({2.0, 4.0});
   1089   Literal tuple =
   1090       LiteralUtil::MakeTuple({&int32_minus5, &double_2_4, &nil_literal});
   1091 
   1092   EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
   1093   EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), 42);
   1094   EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 23.0);
   1095   EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 44.0);
   1096 
   1097   // Overwrite the inner tuple element of nested_tuple with the contents of
   1098   // 'tuple'.
   1099   TF_ASSERT_OK(nested_tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
   1100                                      /*src_shape_index=*/{}));
   1101 
   1102   // The matrix element should be unchanged.
   1103   EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
   1104 
   1105   // The tuple element should have been copied from 'tuple'.
   1106   EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), -5);
   1107   EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 2.0);
   1108   EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 4.0);
   1109 }
   1110 TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
   1111   Literal elements[] = {LiteralUtil::CreateR0<int32>(-2),
   1112                         LiteralUtil::CreateR0<int32>(4)};
   1113   Literal tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
   1114 
   1115   EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
   1116   EXPECT_EQ(tuple.Get<int32>({}, {1}), 4);
   1117 
   1118   // Copy from one element to the other.
   1119   TF_ASSERT_OK(tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
   1120                               /*src_shape_index=*/{0}));
   1121 
   1122   EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
   1123   EXPECT_EQ(tuple.Get<int32>({}, {1}), -2);
   1124 }
   1125 
   1126 TEST_F(LiteralUtilTest, CopyFromDifferentShapes) {
   1127   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   1128   auto vector = LiteralUtil::CreateR1<float>({5.0, 7.0});
   1129   Status status = matrix.CopyFrom(vector);
   1130   ASSERT_FALSE(status.ok());
   1131   EXPECT_THAT(status.error_message(),
   1132               HasSubstr("Destination subshape incompatible"));
   1133 }
   1134 
   1135 TEST_F(LiteralUtilTest, F16) {
   1136   // Verify that the internal data views are consistent and that they
   1137   // are in little endian format
   1138   // TODO - modify if we make the data format machine endianess dependent
   1139   Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
   1140   const char* d1 = reinterpret_cast<const char*>(m1.data<half>().data());
   1141   EXPECT_EQ(d1[0], 0);
   1142   EXPECT_EQ(d1[1], 0);
   1143   EXPECT_EQ(d1[2], 0);
   1144   EXPECT_EQ(d1[3], 0);
   1145   EXPECT_EQ(d1[4], 0);
   1146   EXPECT_EQ(d1[5], 0);
   1147   EXPECT_EQ(d1[6], 0);
   1148   EXPECT_EQ(d1[7], 0);
   1149 
   1150   half h1(1.0f);
   1151   half h2(2.0f);
   1152   auto m2 = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
   1153   const char* d2 = reinterpret_cast<const char*>(m2.data<half>().data());
   1154   EXPECT_EQ(d2[0], 0);
   1155   EXPECT_EQ(d2[1], 0x3C);
   1156   EXPECT_EQ(d2[2], 0);
   1157   EXPECT_EQ(d2[3], 0x40);
   1158   EXPECT_EQ(d2[4], 0);
   1159   EXPECT_EQ(d2[5], 0x40);
   1160   EXPECT_EQ(d2[6], 0);
   1161   EXPECT_EQ(d2[7], 0x3C);
   1162 }
   1163 
   1164 TEST_F(LiteralUtilTest, Populate) {
   1165   struct PopulateData {
   1166     std::vector<int64> dimensions;
   1167     std::vector<int64> layout;
   1168   } populate_data[] = {
   1169       {{}, {}},
   1170       {{0}, {0}},
   1171       {{16}, {0}},
   1172       {{2, 0}, {1, 0}},
   1173       {{4, 16}, {1, 0}},
   1174       {{21, 12}, {0, 1}},
   1175       {{6, 11, 17}, {2, 0, 1}},
   1176       {{6, 11, 5, 17}, {3, 2, 0, 1}},
   1177   };
   1178   for (const auto& data : populate_data) {
   1179     Shape shape = ShapeUtil::MakeShapeWithLayout(
   1180         primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
   1181         data.layout);
   1182     Literal literal(shape);
   1183     auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
   1184       // Offsets from linear index just to avoid R0 literals to be initialized
   1185       // with zero.
   1186       return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
   1187                                                            indexes) +
   1188              17;
   1189     };
   1190     TF_EXPECT_OK(literal.Populate<uint32>(generator));
   1191 
   1192     std::vector<int64> zero_base(data.dimensions.size(), 0);
   1193     std::vector<int64> step(data.dimensions.size(), 1);
   1194     bool matched = true;
   1195     auto check_function = [&](absl::Span<const int64> indexes) {
   1196       auto value = literal.Get<uint32>(indexes);
   1197       matched = matched && (value == generator(indexes));
   1198       return matched;
   1199     };
   1200     ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
   1201                             check_function);
   1202     EXPECT_TRUE(matched);
   1203   }
   1204 }
   1205 
   1206 TEST_F(LiteralUtilTest, PopulateParallel) {
   1207   struct PopulateData {
   1208     std::vector<int64> dimensions;
   1209     std::vector<int64> layout;
   1210   } populate_data[] = {
   1211       {{}, {}},
   1212       {{0}, {0}},
   1213       {{16}, {0}},
   1214       {{2, 0}, {1, 0}},
   1215       {{4, 16}, {1, 0}},
   1216       {{21, 12}, {0, 1}},
   1217       {{6, 11, 17}, {2, 0, 1}},
   1218       {{6, 11, 5, 17}, {3, 2, 0, 1}},
   1219   };
   1220   for (const auto& data : populate_data) {
   1221     Shape shape = ShapeUtil::MakeShapeWithLayout(
   1222         primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
   1223         data.layout);
   1224     Literal literal(shape);
   1225     auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
   1226       // Offsets from linear index just to avoid R0 literals to be initialized
   1227       // with zero.
   1228       return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
   1229                                                            indexes) +
   1230              17;
   1231     };
   1232     TF_EXPECT_OK(literal.PopulateParallel<uint32>(generator));
   1233 
   1234     std::vector<int64> zero_base(data.dimensions.size(), 0);
   1235     std::vector<int64> step(data.dimensions.size(), 1);
   1236     bool matched = true;
   1237     auto check_function = [&](absl::Span<const int64> indexes) {
   1238       auto value = literal.Get<uint32>(indexes);
   1239       matched = matched && (value == generator(indexes));
   1240       return matched;
   1241     };
   1242     ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
   1243                             check_function);
   1244     EXPECT_TRUE(matched);
   1245   }
   1246 }
   1247 
   1248 TEST_F(LiteralUtilTest, ConvertR4) {
   1249   // clang-format off
   1250   auto original = LiteralUtil::CreateR4WithLayout<int8>({{
   1251      {{10, 11, 12, 13}, {14, 15, 16, 17}},
   1252      {{18, 19, 20, 21}, {22, 23, 24, 25}},
   1253      {{26, 27, 28, 29}, {30, 31, 32, 33}},
   1254   }}, layout_r4_dim0major_);
   1255   auto expected = LiteralUtil::CreateR4WithLayout<uint32>({{
   1256      {{10, 11, 12, 13}, {14, 15, 16, 17}},
   1257      {{18, 19, 20, 21}, {22, 23, 24, 25}},
   1258      {{26, 27, 28, 29}, {30, 31, 32, 33}},
   1259   }}, layout_r4_dim0major_);
   1260   // clang-format on
   1261   TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.Convert(U32));
   1262 
   1263   EXPECT_EQ(expected, converted);
   1264 }
   1265 
   1266 TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
   1267   // clang-format off
   1268   auto s8 = LiteralUtil::CreateR4WithLayout<int8>({{
   1269     {{10, 0, 12, 0}, {0, 15, 0, 17}},
   1270     {{0, 19, 0, 21}, {22, 0, 24, 0}},
   1271     {{26, 0, 28, 0}, {0, 31, 0, 33}},
   1272   }}, layout_r4_dim0major_);
   1273   auto s16 = LiteralUtil::CreateR4WithLayout<int16>({{
   1274     {{10, 0, 12, 0}, {0, 15, 0, 17}},
   1275     {{0, 19, 0, 21}, {22, 0, 24, 0}},
   1276     {{26, 0, 28, 0}, {0, 31, 0, 33}},
   1277   }}, layout_r4_dim0major_);
   1278   auto s32 = LiteralUtil::CreateR4WithLayout<int32>({{
   1279     {{10, 0, 12, 0}, {0, 15, 0, 17}},
   1280     {{0, 19, 0, 21}, {22, 0, 24, 0}},
   1281     {{26, 0, 28, 0}, {0, 31, 0, 33}},
   1282   }}, layout_r4_dim0major_);
   1283   auto u16 = LiteralUtil::CreateR4WithLayout<uint16>({{
   1284     {{10, 0, 12, 0}, {0, 15, 0, 17}},
   1285     {{0, 19, 0, 21}, {22, 0, 24, 0}},
   1286     {{26, 0, 28, 0}, {0, 31, 0, 33}},
   1287   }}, layout_r4_dim0major_);
   1288   auto u32 = LiteralUtil::CreateR4WithLayout<uint32>({{
   1289     {{10, 0, 12, 0}, {0, 15, 0, 17}},
   1290     {{0, 19, 0, 21}, {22, 0, 24, 0}},
   1291     {{26, 0, 28, 0}, {0, 31, 0, 33}},
   1292   }}, layout_r4_dim0major_);
   1293   auto s64 = LiteralUtil::CreateR4WithLayout<int64>({{
   1294     {{10, 0, 12, 0}, {0, 15, 0, 17}},
   1295     {{0, 19, 0, 21}, {22, 0, 24, 0}},
   1296     {{26, 0, 28, 0}, {0, 31, 0, 33}},
   1297   }}, layout_r4_dim0major_);
   1298   auto u64 = LiteralUtil::CreateR4WithLayout<uint64>({{
   1299     {{10, 0, 12, 0}, {0, 15, 0, 17}},
   1300     {{0, 19, 0, 21}, {22, 0, 24, 0}},
   1301     {{26, 0, 28, 0}, {0, 31, 0, 33}},
   1302   }}, layout_r4_dim0major_);
   1303   auto pred = LiteralUtil::CreateR4WithLayout<bool>({{
   1304     {{true, false, true, false}, {false, true, false, true}},
   1305     {{false, true, false, true}, {true, false, true, false}},
   1306     {{true, false, true, false}, {false, true, false, true}},
   1307   }}, layout_r4_dim0major_);
   1308   auto int32_pred = LiteralUtil::CreateR4WithLayout<int32>({{
   1309     {{1, 0, 1, 0}, {0, 1, 0, 1}},
   1310     {{0, 1, 0, 1}, {1, 0, 1, 0}},
   1311     {{1, 0, 1, 0}, {0, 1, 0, 1}},
   1312   }}, layout_r4_dim0major_);
   1313   auto f16 = LiteralUtil::CreateR4WithLayout<half>({{
   1314     {{half(10.0), half(0.0), half(12.0), half(0.0)},
   1315      {half(0.0), half(15.0), half(0.0), half(17.0)}},
   1316     {{half(0.0), half(19.0), half(0.0), half(21.0)},
   1317      {half(22.0), half(0.0), half(24.0), half(0.0)}},
   1318     {{half(26.0), half(0.0), half(28.0), half(0.0)},
   1319      {half(0.0), half(31.0), half(0.0), half(33.0)}},
   1320   }}, layout_r4_dim0major_);
   1321   auto bf16 = LiteralUtil::CreateR4WithLayout<bfloat16>({{
   1322     {{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)},
   1323      {bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}},
   1324     {{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)},
   1325      {bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}},
   1326     {{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)},
   1327      {bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}},
   1328   }}, layout_r4_dim0major_);
   1329   auto f32 = LiteralUtil::CreateR4WithLayout<float>({{
   1330     {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
   1331     {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
   1332     {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
   1333   }}, layout_r4_dim0major_);
   1334   auto f64 = LiteralUtil::CreateR4WithLayout<double>({{
   1335     {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}},
   1336     {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
   1337     {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
   1338   }}, layout_r4_dim0major_);
   1339   auto c64 = LiteralUtil::CreateR4WithLayout<complex64>({{
   1340     {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
   1341     {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
   1342     {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
   1343   }}, layout_r4_dim0major_);
   1344   auto c128 = LiteralUtil::CreateR4WithLayout<complex128>({{
   1345     {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}},
   1346     {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
   1347     {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
   1348   }}, layout_r4_dim0major_);  // clang-format on
   1349   Literal conv;
   1350 
   1351   conv = s8.Convert(U16).ConsumeValueOrDie();
   1352   EXPECT_EQ(conv, u16);
   1353 
   1354   conv = s8.Convert(S16).ConsumeValueOrDie();
   1355   EXPECT_EQ(conv, s16);
   1356 
   1357   conv = s8.Convert(U32).ConsumeValueOrDie();
   1358   EXPECT_EQ(conv, u32);
   1359 
   1360   conv = s8.Convert(S32).ConsumeValueOrDie();
   1361   EXPECT_EQ(conv, s32);
   1362 
   1363   conv = s8.Convert(U64).ConsumeValueOrDie();
   1364   EXPECT_EQ(conv, u64);
   1365 
   1366   conv = s8.Convert(S64).ConsumeValueOrDie();
   1367   EXPECT_EQ(conv, s64);
   1368 
   1369   conv = s8.Convert(PRED).ConsumeValueOrDie();
   1370   EXPECT_EQ(conv, pred);
   1371 
   1372   conv = bf16.Convert(S32).ConsumeValueOrDie();
   1373   EXPECT_EQ(conv, s32);
   1374 
   1375   conv = bf16.Convert(F32).ConsumeValueOrDie();
   1376   EXPECT_EQ(conv, f32);
   1377 
   1378   conv = pred.Convert(S32).ConsumeValueOrDie();
   1379   EXPECT_EQ(conv, int32_pred);
   1380 
   1381   conv = f32.Convert(S32).ConsumeValueOrDie();
   1382   EXPECT_EQ(conv, s32);
   1383 
   1384   conv = f64.Convert(S32).ConsumeValueOrDie();
   1385   EXPECT_EQ(conv, s32);
   1386 
   1387   conv = s32.Convert(F32).ConsumeValueOrDie();
   1388   EXPECT_EQ(conv, f32);
   1389 
   1390   conv = f32.Convert(F16).ConsumeValueOrDie();
   1391   EXPECT_EQ(conv, f16);
   1392 
   1393   conv = f64.Convert(F16).ConsumeValueOrDie();
   1394   EXPECT_EQ(conv, f16);
   1395 
   1396   conv = s32.Convert(F16).ConsumeValueOrDie();
   1397   EXPECT_EQ(conv, f16);
   1398 
   1399   conv = u32.Convert(F16).ConsumeValueOrDie();
   1400   EXPECT_EQ(conv, f16);
   1401 
   1402   conv = s32.Convert(C64).ConsumeValueOrDie();
   1403   EXPECT_EQ(conv, c64);
   1404 
   1405   conv = f16.Convert(C64).ConsumeValueOrDie();
   1406   EXPECT_EQ(conv, c64);
   1407 
   1408   conv = s32.Convert(S16).ConsumeValueOrDie();
   1409   EXPECT_EQ(conv, s16);
   1410 
   1411   conv = s32.Convert(U16).ConsumeValueOrDie();
   1412   EXPECT_EQ(conv, u16);
   1413 
   1414   conv = s32.Convert(C128).ConsumeValueOrDie();
   1415   EXPECT_EQ(conv, c128);
   1416 
   1417   conv = f16.Convert(C128).ConsumeValueOrDie();
   1418   EXPECT_EQ(conv, c128);
   1419 
   1420   EXPECT_EQ(s32.Convert(TUPLE).status().code(),
   1421             tensorflow::error::UNIMPLEMENTED);
   1422   EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED);
   1423   EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED);
   1424   EXPECT_EQ(c128.Convert(F32).status().code(),
   1425             tensorflow::error::UNIMPLEMENTED);
   1426   EXPECT_EQ(c128.Convert(S32).status().code(),
   1427             tensorflow::error::UNIMPLEMENTED);
   1428 }
   1429 
   1430 TEST_F(LiteralUtilTest, BitcastConvert) {
   1431   auto original = LiteralUtil::CreateR1<uint32>(
   1432       {absl::bit_cast<uint32>(2.5f), absl::bit_cast<uint32>(-42.25f),
   1433        absl::bit_cast<uint32>(100.f), 0xbeef});
   1434   auto expected = LiteralUtil::CreateR1<float>(
   1435       {2.5f, -42.25f, 100.0f, absl::bit_cast<float>(0xbeef)});
   1436   TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.BitcastConvert(F32));
   1437 }
   1438 
   1439 TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
   1440   auto literal = LiteralUtil::CreateR0<uint32>(1234);
   1441   Status status = literal.BitcastConvert(F64).status();
   1442   EXPECT_NE(Status::OK(), status);
   1443   EXPECT_TRUE(
   1444       absl::StrContains(status.error_message(), "bit widths are different"));
   1445 }
   1446 
   1447 // Sets the layout of the given ShapeProto to the default.
   1448 void SetDefaultLayoutOnProto(ShapeProto* shape_proto) {
   1449   CHECK(ShapeUtil::IsArrayPrimitiveType(shape_proto->element_type()));
   1450   shape_proto->mutable_layout()->set_format(DENSE);
   1451   auto* minor_to_major =
   1452       shape_proto->mutable_layout()->mutable_minor_to_major();
   1453   minor_to_major->Resize(shape_proto->dimensions_size(), 0);
   1454   const int64 size = minor_to_major->size();
   1455   for (int64 i = 0; i < size; ++i) {
   1456     minor_to_major->Set(i, size - 1 - i);
   1457   }
   1458 }
   1459 
   1460 TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
   1461   LiteralProto p;
   1462   p.mutable_shape()->set_element_type(PRED);
   1463   for (int len = 0; len < 25; ++len) {
   1464     p.mutable_shape()->clear_dimensions();
   1465     p.mutable_shape()->add_dimensions(len);
   1466     SetDefaultLayoutOnProto(p.mutable_shape());
   1467     p.clear_preds();
   1468     for (int i = 0; i < len; ++i) {
   1469       p.add_preds((i % 2) == (len % 2));
   1470     }
   1471 
   1472     TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
   1473     ASSERT_EQ(len, literal.data<bool>().size());
   1474     int i = 0;
   1475     for (bool value : literal.data<bool>()) {
   1476       EXPECT_EQ((i % 2) == (len % 2), value);
   1477       ++i;
   1478     }
   1479   }
   1480 }
   1481 
   1482 // Note that f16 is currently stored in a byte array in little endian byte order
   1483 TEST_F(LiteralUtilTest, ToProto_f16) {
   1484   half h1(1.0f);
   1485   half h2(2.0f);
   1486 
   1487   auto m = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
   1488   EXPECT_EQ(4, ShapeUtil::ElementsIn(m.shape()));
   1489   EXPECT_EQ(4, m.data<half>().size());
   1490 
   1491   LiteralProto p = m.ToProto();
   1492   EXPECT_EQ(4, ShapeUtil::ElementsIn(Shape(p.shape())));
   1493   EXPECT_EQ(8, p.f16s().size());
   1494   const char* d = p.f16s().data();
   1495   EXPECT_EQ(d[0], 0);
   1496   EXPECT_EQ(d[1], 0x3C);
   1497   EXPECT_EQ(d[2], 0);
   1498   EXPECT_EQ(d[3], 0x40);
   1499   EXPECT_EQ(d[4], 0);
   1500   EXPECT_EQ(d[5], 0x40);
   1501   EXPECT_EQ(d[6], 0);
   1502   EXPECT_EQ(d[7], 0x3C);
   1503 }
   1504 
   1505 // Note that f16 is currently stored in a byte array in little endian byte order
   1506 TEST_F(LiteralUtilTest, CopyFromProto_f16) {
   1507   half h1(1.0f);
   1508   half h2(2.0f);
   1509 
   1510   const char half_vals[8] = {0x00, 0x3C, 0x00, 0x40, 0x00, 0x40, 0x00, 0x3C};
   1511   LiteralProto p;
   1512   p.mutable_shape()->set_element_type(F16);
   1513   p.mutable_shape()->clear_dimensions();
   1514   p.mutable_shape()->add_dimensions(4);
   1515   SetDefaultLayoutOnProto(p.mutable_shape());
   1516   p.clear_f16s();
   1517   p.set_f16s(half_vals, 8);
   1518   TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
   1519   auto r = literal.data<half>();
   1520   ASSERT_EQ(4, r.size());
   1521   EXPECT_EQ(h1, r[0]);
   1522   EXPECT_EQ(h2, r[1]);
   1523   EXPECT_EQ(h2, r[2]);
   1524   EXPECT_EQ(h1, r[3]);
   1525 }
   1526 
   1527 TEST_F(LiteralUtilTest, CopyFromProto_u16) {
   1528   uint16 u1(0xabcd);
   1529   uint16 u2(0x1234);
   1530 
   1531   const unsigned char uint16_vals[8] = {0xcd, 0xab, 0x34, 0x12,
   1532                                         0x34, 0x12, 0xcd, 0xab};
   1533   LiteralProto p;
   1534   p.mutable_shape()->set_element_type(U16);
   1535   p.mutable_shape()->clear_dimensions();
   1536   p.mutable_shape()->add_dimensions(4);
   1537   SetDefaultLayoutOnProto(p.mutable_shape());
   1538   p.clear_u16s();
   1539   p.set_u16s(uint16_vals, 8);
   1540   TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
   1541   auto r = literal.data<uint16>();
   1542   ASSERT_EQ(4, r.size());
   1543   EXPECT_EQ(u1, r[0]);
   1544   EXPECT_EQ(u2, r[1]);
   1545   EXPECT_EQ(u2, r[2]);
   1546   EXPECT_EQ(u1, r[3]);
   1547 }
   1548 
   1549 TEST_F(LiteralUtilTest, LiteralSliceTest) {
   1550   auto scalar = LiteralUtil::CreateR0<float>(1.0);
   1551   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   1552   auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
   1553   auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
   1554   Literal nil(ShapeUtil::MakeNil());
   1555 
   1556   EXPECT_EQ(LiteralSlice(scalar, {}), scalar);
   1557   EXPECT_EQ(LiteralSlice(matrix, {}), matrix);
   1558   EXPECT_EQ(LiteralSlice(tuple, {}), tuple);
   1559   EXPECT_EQ(LiteralSlice(nested_tuple, {}), nested_tuple);
   1560   EXPECT_EQ(LiteralSlice(nil, {}), nil);
   1561 
   1562   EXPECT_EQ(LiteralSlice(tuple, {0}), scalar);
   1563   EXPECT_EQ(LiteralSlice(tuple, {1}), matrix);
   1564 
   1565   EXPECT_EQ(LiteralSlice(nested_tuple, {0}), tuple);
   1566   EXPECT_EQ(LiteralSlice(nested_tuple, {0, 0}), scalar);
   1567   EXPECT_EQ(LiteralSlice(nested_tuple, {0, 1}), matrix);
   1568   EXPECT_EQ(LiteralSlice(nested_tuple, {1}), scalar);
   1569 }
   1570 
   1571 TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
   1572   auto scalar = LiteralUtil::CreateR0<float>(1.0);
   1573   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   1574   auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
   1575   auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
   1576   // Verify that changing the underlying data beneath the view changes the
   1577   // data of the view itself.
   1578   const auto nested_tuple_view = LiteralSlice(nested_tuple);
   1579   EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
   1580             1.0f);
   1581   EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
   1582                                          /*shape_index=*/{0, 0}),
   1583             1.0f);
   1584   nested_tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
   1585   EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
   1586             555.0f);
   1587   EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
   1588                                          /*shape_index=*/{0, 0}),
   1589             555.0f);
   1590 }
   1591 
   1592 TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
   1593   auto scalar = LiteralUtil::CreateR0<float>(1.0);
   1594   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   1595   auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
   1596   auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
   1597 
   1598   const auto nested_tuple_view = LiteralSlice(nested_tuple);
   1599   const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0});
   1600   const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1});
   1601   EXPECT_EQ(matrix_view,
   1602             LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
   1603 }
   1604 
   1605 TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) {
   1606   std::vector<int64> int64_values = {1, 2, 3};
   1607   const Shape literal_shape = ShapeUtil::MakeShape(S64, {3});
   1608 
   1609   BorrowingLiteral literal(reinterpret_cast<const char*>(int64_values.data()),
   1610                            literal_shape);
   1611 
   1612   EXPECT_EQ(literal.Get<int64>({0}), 1);
   1613   EXPECT_EQ(literal.Get<int64>({1}), 2);
   1614   EXPECT_EQ(literal.Get<int64>({2}), 3);
   1615 }
   1616 
   1617 TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) {
   1618   std::vector<int64> one_two_three = {1, 2, 3};
   1619   const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3});
   1620 
   1621   std::vector<int64> hundred = {100};
   1622   const Shape hundred_shape = ShapeUtil::MakeShape(S64, {1});
   1623 
   1624   std::vector<const char*> src_buf_ptrs;
   1625   src_buf_ptrs.emplace_back(
   1626       reinterpret_cast<const char*>(one_two_three.data()));
   1627   src_buf_ptrs.emplace_back(reinterpret_cast<const char*>(hundred.data()));
   1628   auto literal_tuple = BorrowingLiteral(
   1629       src_buf_ptrs,
   1630       ShapeUtil::MakeTupleShape({one_two_three_shape, hundred_shape}));
   1631 
   1632   EXPECT_EQ(literal_tuple.Get<int64>(/*multi_index=*/{0}, /*shape_index=*/{0}),
   1633             1);
   1634   EXPECT_EQ(literal_tuple.Get<int64>(/*multi_index=*/{0}, /*shape_index=*/{1}),
   1635             100);
   1636 
   1637   EXPECT_EQ(literal_tuple.Get<int64>(/*multi_index=*/{1}, /*shape_index=*/{0}),
   1638             2);
   1639 
   1640   EXPECT_EQ(literal_tuple.Get<int64>(/*multi_index=*/{2}, /*shape_index=*/{0}),
   1641             3);
   1642 }
   1643 
   1644 TEST_F(LiteralUtilTest, LiteralMove) {
   1645   Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   1646   Literal literal(std::move(matrix));
   1647 
   1648   EXPECT_TRUE(
   1649       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
   1650   EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
   1651   EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
   1652   EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
   1653   EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
   1654 }
   1655 
   1656 TEST_F(LiteralUtilTest, DecomposeTuple) {
   1657   Literal nil_literal(ShapeUtil::MakeNil());
   1658   Literal inner_elements[] = {
   1659       LiteralUtil::CreateR0<int32>(42),
   1660       LiteralUtil::CreateR1<double>({23.0, 44.0}),
   1661   };
   1662   Literal tuple_elements[] = {
   1663       LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}),
   1664       LiteralUtil::MakeTuple(
   1665           {&inner_elements[0], &inner_elements[1], &nil_literal}),
   1666   };
   1667   Literal nested_tuple = LiteralUtil::MakeTuple(
   1668       {&tuple_elements[0], &tuple_elements[1], &nil_literal});
   1669 
   1670   EXPECT_FALSE(ShapeUtil::IsEmptyTuple(nested_tuple.shape()));
   1671   std::vector<Literal> elements = nested_tuple.DecomposeTuple();
   1672   EXPECT_TRUE(ShapeUtil::IsEmptyTuple(nested_tuple.shape()));
   1673 
   1674   ASSERT_EQ(elements.size(), 3);
   1675 
   1676   EXPECT_TRUE(ShapeUtil::Compatible(elements[0].shape(),
   1677                                     ShapeUtil::MakeShape(S32, {2, 2})));
   1678   EXPECT_EQ(elements[0].Get<int32>({0, 0}), 1);
   1679   EXPECT_EQ(elements[0].Get<int32>({0, 1}), 2);
   1680   EXPECT_EQ(elements[0].Get<int32>({1, 0}), 3);
   1681   EXPECT_EQ(elements[0].Get<int32>({1, 1}), 4);
   1682 
   1683   EXPECT_TRUE(ShapeUtil::Compatible(
   1684       elements[1].shape(),
   1685       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}),
   1686                                  ShapeUtil::MakeShape(F64, {2}),
   1687                                  ShapeUtil::MakeNil()})));
   1688   EXPECT_EQ(elements[1].Get<int32>({}, /*shape_index=*/{0}), 42);
   1689   EXPECT_EQ(elements[1].Get<double>({0}, /*shape_index=*/{1}), 23.0);
   1690   EXPECT_EQ(elements[1].Get<double>({1}, /*shape_index=*/{1}), 44.0);
   1691 
   1692   EXPECT_TRUE(ShapeUtil::Compatible(elements[2].shape(), ShapeUtil::MakeNil()));
   1693 }
   1694 
   1695 TEST_F(LiteralUtilTest, DecomposeEmptyTuple) {
   1696   Literal nil_literal(ShapeUtil::MakeNil());
   1697   std::vector<Literal> elements = nil_literal.DecomposeTuple();
   1698   EXPECT_EQ(elements.size(), 0);
   1699 }
   1700 
   1701 TEST_F(LiteralUtilTest, MoveIntoTuple) {
   1702   std::vector<Literal> elements;
   1703   elements.push_back(LiteralUtil::CreateR0<float>(1.0));
   1704   elements.push_back(LiteralUtil::CreateR1<int32>({4, 8}));
   1705   std::vector<Literal> inner_elements;
   1706   inner_elements.push_back(LiteralUtil::CreateR0<int32>(42));
   1707   inner_elements.push_back(LiteralUtil::CreateR1<double>({23.0, 44.0}));
   1708   elements.push_back(
   1709       LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]}));
   1710 
   1711   Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements));
   1712   ASSERT_TRUE(literal.shape().IsTuple());
   1713   ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3);
   1714 
   1715   EXPECT_EQ(literal.Get<float>({}, /*shape_index=*/{0}), 1.0);
   1716   EXPECT_EQ(literal.Get<int32>({0}, /*shape_index=*/{1}), 4);
   1717   EXPECT_EQ(literal.Get<int32>({1}, /*shape_index=*/{1}), 8);
   1718   EXPECT_EQ(literal.Get<int32>({}, /*shape_index=*/{2, 0}), 42);
   1719   EXPECT_EQ(literal.Get<double>({0}, /*shape_index=*/{2, 1}), 23.0);
   1720   EXPECT_EQ(literal.Get<double>({1}, /*shape_index=*/{2, 1}), 44.0);
   1721 
   1722   for (const Literal& element : elements) {
   1723     EXPECT_TRUE(ShapeUtil::IsEmptyTuple(element.shape()));
   1724   }
   1725 }
   1726 
   1727 TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) {
   1728   Literal literal = Literal::MoveIntoTuple({});
   1729   ASSERT_TRUE(literal.shape().IsTuple());
   1730   EXPECT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0);
   1731 }
   1732 
   1733 TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
   1734   Literal literal;
   1735   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape()));
   1736 
   1737   Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   1738   literal = std::move(matrix);
   1739 
   1740   EXPECT_TRUE(
   1741       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
   1742   EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
   1743   EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
   1744   EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
   1745   EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
   1746 }
   1747 
   1748 TEST_F(LiteralUtilTest, LiteralSliceCopy) {
   1749   Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   1750   const auto matrix_view = LiteralSlice(matrix);
   1751   LiteralSlice matrix_view_copy(matrix_view);
   1752 
   1753   EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
   1754   EXPECT_EQ(matrix_view_copy.Get<float>({0, 1}), 2.0);
   1755   EXPECT_EQ(matrix_view_copy.Get<float>({1, 0}), 3.0);
   1756   EXPECT_EQ(matrix_view_copy.Get<float>({1, 1}), 4.0);
   1757 }
   1758 
   1759 TEST_F(LiteralUtilTest, GetSetTuple) {
   1760   Literal elements[] = {
   1761       LiteralUtil::CreateR0<float>(42.0),
   1762       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
   1763   };
   1764   auto tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
   1765   EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
   1766   tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
   1767   EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
   1768 
   1769   EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 3.0);
   1770   tuple.Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
   1771   EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
   1772             -4.0);
   1773 }
   1774 
   1775 TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
   1776   // Literals constructed using CreateFromShape should be zero initialized.
   1777   Literal scalar_f32 = Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
   1778   EXPECT_EQ(scalar_f32.Get<float>({}), 0.0);
   1779   EXPECT_TRUE(scalar_f32.IsAll(0));
   1780 
   1781   Literal vector_s32 = Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
   1782   EXPECT_EQ(vector_s32.Get<int32>({0}), 0);
   1783   EXPECT_EQ(vector_s32.Get<int32>({1}), 0);
   1784   EXPECT_EQ(vector_s32.Get<int32>({2}), 0);
   1785   EXPECT_TRUE(vector_s32.IsAll(0));
   1786 
   1787   Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
   1788       {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
   1789        ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {}),
   1790        ShapeUtil::MakeShape(C128, {})}));
   1791 
   1792   EXPECT_EQ(tuple.Get<double>({}, {0}), 0.0);
   1793   EXPECT_EQ(tuple.Get<bool>({0}, {1}), false);
   1794   EXPECT_EQ(tuple.Get<bool>({1}, {1}), false);
   1795   EXPECT_EQ(tuple.Get<uint64>({0, 0}, {2}), 0);
   1796   EXPECT_EQ(tuple.Get<uint64>({1, 0}, {2}), 0);
   1797   EXPECT_EQ(tuple.Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
   1798   EXPECT_EQ(tuple.Get<complex128>({}, {4}), complex128(0.0, 0.0));
   1799 }
   1800 
   1801 TEST_F(LiteralUtilTest, ProtoRoundTrip) {
   1802   // Test serializing then deserializing a Literal through a proto.
   1803   auto one_f32 = LiteralUtil::CreateR0<float>(1.0);
   1804   auto two_f32 = LiteralUtil::CreateR0<float>(2.0);
   1805   auto vector_int8 = LiteralUtil::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
   1806   auto vector_uint8 = LiteralUtil::CreateR1<uint8>({128, 0, 2, 56, 127, 255});
   1807   auto vector_c64 = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
   1808   auto vector_c128 =
   1809       LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
   1810   auto vector_bfloat16 = LiteralUtil::CreateR1<bfloat16>(
   1811       {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
   1812   auto vector_half =
   1813       LiteralUtil::CreateR1<half>({half{10.0}, half{20.0}, half{-30.0}});
   1814   auto matrix_pred =
   1815       LiteralUtil::CreateR2<bool>({{true, false, true}, {false, false, true}});
   1816   auto tuple = LiteralUtil::MakeTuple(
   1817       {&one_f32, &vector_half, &matrix_pred, &matrix_pred});
   1818   Literal nil_literal(ShapeUtil::MakeNil());
   1819   auto nested_tuple =
   1820       LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal});
   1821 
   1822   auto to_from_proto = [](const Literal& literal) -> Literal {
   1823     return Literal::CreateFromProto(literal.ToProto()).ValueOrDie();
   1824   };
   1825 
   1826   EXPECT_EQ(one_f32, to_from_proto(one_f32));
   1827   EXPECT_EQ(vector_int8, to_from_proto(vector_int8));
   1828   EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8));
   1829   EXPECT_EQ(vector_c64, to_from_proto(vector_c64));
   1830   EXPECT_EQ(vector_c128, to_from_proto(vector_c128));
   1831   EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16));
   1832   EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred));
   1833   EXPECT_EQ(tuple, to_from_proto(tuple));
   1834   EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple));
   1835   EXPECT_EQ(nil_literal, to_from_proto(nil_literal));
   1836 
   1837   EXPECT_NE(one_f32, two_f32);
   1838   EXPECT_NE(one_f32, to_from_proto(two_f32));
   1839 }
   1840 
   1841 TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
   1842   // Proto contains a shape, but no values.
   1843   LiteralProto proto;
   1844   *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
   1845   Status status = Literal::CreateFromProto(proto).status();
   1846   ASSERT_FALSE(status.ok());
   1847   EXPECT_THAT(status.error_message(),
   1848               HasSubstr("Expected 3 elements in LiteralProto"));
   1849 }
   1850 
   1851 TEST_F(LiteralUtilTest, InvalidProtoNoShape) {
   1852   // Proto contains values, but no shape.
   1853   LiteralProto proto;
   1854   proto.add_preds(false);
   1855   proto.add_preds(true);
   1856   proto.add_preds(false);
   1857   Status status = Literal::CreateFromProto(proto).status();
   1858   ASSERT_FALSE(status.ok());
   1859   EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape"));
   1860 }
   1861 
   1862 TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) {
   1863   // Proto contains values in wrong container.
   1864   LiteralProto proto;
   1865   *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
   1866   proto.add_preds(false);
   1867   proto.add_preds(true);
   1868   proto.add_preds(false);
   1869   Status status = Literal::CreateFromProto(proto).status();
   1870   ASSERT_FALSE(status.ok());
   1871   EXPECT_THAT(status.error_message(),
   1872               HasSubstr("Expected 3 elements in LiteralProto"));
   1873 }
   1874 
   1875 TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) {
   1876   // Proto contains too few values.
   1877   LiteralProto proto;
   1878   *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2}).ToProto();
   1879   proto.add_f32s(1.0);
   1880   proto.add_f32s(2.0);
   1881   proto.add_f32s(3.0);
   1882   Status status = Literal::CreateFromProto(proto).status();
   1883   ASSERT_FALSE(status.ok());
   1884   EXPECT_THAT(status.error_message(),
   1885               HasSubstr("Expected 84 elements in LiteralProto"));
   1886 }
   1887 
   1888 TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) {
   1889   // Proto contains too many values.
   1890   LiteralProto proto;
   1891   *proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2}).ToProto();
   1892   proto.add_s32s(42);
   1893   proto.add_s32s(-10);
   1894   proto.add_s32s(100);
   1895   Status status = Literal::CreateFromProto(proto).status();
   1896   ASSERT_FALSE(status.ok());
   1897   EXPECT_THAT(status.error_message(),
   1898               HasSubstr("Expected 2 elements in LiteralProto"));
   1899 }
   1900 
   1901 TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) {
   1902   // Proto shape missing layout.
   1903   LiteralProto proto;
   1904   *proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2}).ToProto();
   1905   proto.mutable_shape()->clear_layout();
   1906   proto.add_preds(true);
   1907   proto.add_preds(false);
   1908   proto.add_preds(true);
   1909   proto.add_preds(false);
   1910   Status status = Literal::CreateFromProto(proto).status();
   1911   ASSERT_FALSE(status.ok());
   1912   EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout"));
   1913 }
   1914 
   1915 TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) {
   1916   // Proto has the too few tuple elements.
   1917   LiteralProto proto;
   1918   *proto.mutable_shape() =
   1919       ShapeUtil::MakeTupleShape(
   1920           {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})})
   1921           .ToProto();
   1922   LiteralProto* element0 = proto.add_tuple_literals();
   1923   *element0->mutable_shape() =
   1924       ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto();
   1925   element0->add_preds(false);
   1926   element0->add_preds(true);
   1927 
   1928   Status status = Literal::CreateFromProto(proto).status();
   1929   ASSERT_FALSE(status.ok());
   1930   EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
   1931 }
   1932 
   1933 TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
   1934   // Proto has the too many tuple elements.
   1935   LiteralProto proto;
   1936   *proto.mutable_shape() =
   1937       ShapeUtil::MakeTupleShape(
   1938           {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})})
   1939           .ToProto();
   1940   LiteralProto* element0 = proto.add_tuple_literals();
   1941   *element0->mutable_shape() =
   1942       ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto();
   1943   element0->add_preds(false);
   1944   element0->add_preds(true);
   1945   LiteralProto* element1 = proto.add_tuple_literals();
   1946   *element1->mutable_shape() =
   1947       ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 1).ToProto();
   1948   element1->add_f32s(42.0);
   1949   LiteralProto* element2 = proto.add_tuple_literals();
   1950   *element2->mutable_shape() = ShapeUtil::MakeShape(F32, {}).ToProto();
   1951   element2->add_f32s(123.0);
   1952 
   1953   Status status = Literal::CreateFromProto(proto).status();
   1954   ASSERT_FALSE(status.ok());
   1955   EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
   1956 }
   1957 
   1958 TEST_F(LiteralUtilTest, SortSparseElements) {
   1959   auto literal = LiteralUtil::CreateSparse<float>({10, 10, 10},
   1960                                                   SparseIndexArray(10, 3), {});
   1961   literal.AppendSparseElement<float>({2, 3, 4}, 2.0);
   1962   literal.AppendSparseElement<float>({3, 4, 5}, 3.0);
   1963   literal.AppendSparseElement<float>({1, 2, 3}, 1.0);
   1964   literal.SortSparseElements();
   1965   EXPECT_EQ(literal.ToString(),
   1966             "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}");
   1967 }
   1968 
   1969 TEST_F(LiteralUtilTest, GetSparseElementAsString) {
   1970   std::vector<int64> dimensions = {10, 10, 10};
   1971   SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}});
   1972 
   1973   EXPECT_EQ(
   1974       LiteralUtil::CreateSparse<bool>(dimensions, indices, {true, false, true})
   1975           .GetSparseElementAsString(1),
   1976       "false");
   1977   EXPECT_EQ(LiteralUtil::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
   1978                 .GetSparseElementAsString(1),
   1979             absl::StrCat(int64{2}));
   1980   EXPECT_EQ(
   1981       LiteralUtil::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
   1982           .GetSparseElementAsString(1),
   1983       absl::StrCat(double{2.0}));
   1984   EXPECT_EQ(LiteralUtil::CreateSparse<half>(dimensions, indices,
   1985                                             {half{1.0}, half{2.0}, half{3.0}})
   1986                 .GetSparseElementAsString(1),
   1987             absl::StrCat(static_cast<float>(half{2.0})));
   1988   EXPECT_EQ(LiteralUtil::CreateSparse<complex64>(
   1989                 dimensions, indices,
   1990                 std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
   1991                 .GetSparseElementAsString(1),
   1992             absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
   1993 }
   1994 
   1995 TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
   1996   Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
   1997   TF_ASSERT_OK_AND_ASSIGN(
   1998       Literal broadcasted_literal,
   1999       literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
   2000                         /*dimensions=*/{0}));
   2001   EXPECT_EQ(broadcasted_literal,
   2002             LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
   2003 }
   2004 
   2005 TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) {
   2006   Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
   2007   TF_ASSERT_OK_AND_ASSIGN(
   2008       Literal broadcasted_literal,
   2009       literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
   2010                         /*dimensions=*/{1}));
   2011   EXPECT_EQ(broadcasted_literal,
   2012             LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
   2013 }
   2014 
   2015 TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
   2016   Literal literal = LiteralUtil::CreateR0<int32>(9);
   2017   TF_ASSERT_OK_AND_ASSIGN(
   2018       Literal broadcasted_literal,
   2019       literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
   2020                         /*dimensions=*/{}));
   2021   EXPECT_EQ(broadcasted_literal,
   2022             LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
   2023 }
   2024 
   2025 }  // namespace
   2026 }  // namespace xla
   2027