Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/literal_util.h"
     17 
     18 #include <vector>
     19 
     20 #include "tensorflow/compiler/xla/array3d.h"
     21 #include "tensorflow/compiler/xla/array4d.h"
     22 #include "tensorflow/compiler/xla/layout_util.h"
     23 #include "tensorflow/compiler/xla/shape_util.h"
     24 #include "tensorflow/compiler/xla/test.h"
     25 #include "tensorflow/compiler/xla/types.h"
     26 #include "tensorflow/core/lib/core/status_test_util.h"
     27 #include "tensorflow/core/platform/macros.h"
     28 #include "tensorflow/core/platform/types.h"
     29 
     30 namespace xla {
     31 namespace {
     32 
     33 using ::testing::ElementsAre;
     34 using ::testing::HasSubstr;
     35 
     36 class LiteralUtilTest : public ::testing::Test {
     37  protected:
     38   LiteralUtilTest() {
     39     Array4D<float> arr4d({
     40         // clang-format off
     41       {  // i0=0
     42           {  // i1=0
     43               {1, 2, 3},  // i2=0
     44               {4, 5, 6},  // i2=1
     45               {7, 8, 9},  // i2=2
     46           },
     47           {  // i1=1
     48               {11, 12, 13},
     49               {14, 15, 16},
     50               {17, 18, 19},
     51           },
     52       },
     53       {  // i0=1
     54           {  // i1=0
     55               {101, 102, 103},
     56               {104, 105, 106},
     57               {107, 108, 109},
     58           },
     59           {  // i1=1
     60               {201, 202, 203},  // i2=0
     61               {204, 205, 206},  // i2=1
     62               {207, 208, 209},  // i2=2
     63           },
     64       },
     65         // clang-format on
     66     });
     67 
     68     layout_r2_dim0major_ = LayoutUtil::MakeLayout({1, 0});
     69     layout_r2_dim0minor_ = LayoutUtil::MakeLayout({0, 1});
     70     layout_r3_dim0major_ = LayoutUtil::MakeLayout({2, 1, 0});
     71     layout_r3_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2});
     72     layout_r4_dim0major_ = LayoutUtil::MakeLayout({3, 2, 1, 0});
     73     layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3});
     74 
     75     literal_r4_2x2x3x3_dim0major_ =
     76         Literal::CreateR4FromArray4DWithLayout<float>(arr4d,
     77                                                       layout_r4_dim0major_);
     78     literal_r4_2x2x3x3_dim0minor_ =
     79         Literal::CreateR4FromArray4DWithLayout<float>(arr4d,
     80                                                       layout_r4_dim0minor_);
     81   }
     82 
     83   Layout layout_r2_dim0major_;
     84   Layout layout_r2_dim0minor_;
     85   Layout layout_r3_dim0major_;
     86   Layout layout_r3_dim0minor_;
     87   Layout layout_r4_dim0major_;
     88   Layout layout_r4_dim0minor_;
     89   std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0major_;
     90   std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0minor_;
     91 };
     92 
     93 TEST_F(LiteralUtilTest, LiteralScalarToString) {
     94   auto true_lit = Literal::CreateR0<bool>(true);
     95   ASSERT_EQ("true", true_lit->ToString());
     96 
     97   auto false_lit = Literal::CreateR0<bool>(false);
     98   ASSERT_EQ("false", false_lit->ToString());
     99 
    100   auto u32_lit = Literal::CreateR0<uint32>(42);
    101   ASSERT_EQ("42", u32_lit->ToString());
    102 
    103   auto s32_lit = Literal::CreateR0<int32>(-999);
    104   ASSERT_EQ("-999", s32_lit->ToString());
    105 
    106   auto f32_lit = Literal::CreateR0<float>(3.14f);
    107   ASSERT_EQ("3.14", f32_lit->ToString());
    108 
    109   auto f16_lit = Literal::CreateR0<half>(static_cast<half>(0.5f));
    110   ASSERT_EQ("0.5", f16_lit->ToString());
    111 
    112   auto c64_lit = Literal::CreateR0<complex64>({3.14f, 2.78f});
    113   ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString());
    114 
    115   auto bf16_lit = Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
    116   ASSERT_EQ("0.5", bf16_lit->ToString());
    117 
    118   // 3.14 will be truncated to 3.125 in bfloat16 format.
    119   auto bf16_lit_truncated =
    120       Literal::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
    121   ASSERT_EQ("3.125", bf16_lit_truncated->ToString());
    122 
    123   auto bf16_lit_truncated2 =
    124       Literal::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
    125   ASSERT_EQ("9", bf16_lit_truncated2->ToString());
    126 }
    127 
    128 TEST_F(LiteralUtilTest, LiteralVectorToString) {
    129   auto pred_vec = Literal::CreateR1<bool>({true, false, true});
    130   ASSERT_EQ("{101}", pred_vec->ToString());
    131 }
    132 
    133 TEST_F(LiteralUtilTest, R2ToString) {
    134   const auto literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}});
    135   const string expected = R"(s32[3,2] {
    136   { 1, 2 },
    137   { 3, 4 },
    138   { 5, 6 }
    139 })";
    140   ASSERT_EQ(expected, literal->ToString());
    141 }
    142 
    143 TEST_F(LiteralUtilTest, R3ToString) {
    144   const auto literal = Literal::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
    145   const string expected = R"(s32[3,2,1] {
    146 { { 1 },
    147   { 2 } },
    148 { { 3 },
    149   { 4 } },
    150 { { 5 },
    151   { 6 } }
    152 })";
    153   ASSERT_EQ(expected, literal->ToString());
    154 }
    155 
    156 TEST_F(LiteralUtilTest, TupleToString) {
    157   auto scalar = Literal::CreateR0<float>(1.0);
    158   auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
    159   auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
    160   const string expected = R"((f32[], f32[2,2]) (
    161 1,
    162 f32[2,2] {
    163   { 1, 2 },
    164   { 3, 4 }
    165 }
    166 ))";
    167   ASSERT_EQ(expected, tuple->ToString());
    168 }
    169 
    170 TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
    171   // clang-format off
    172   Array3D<float> array_3d({
    173     {{1.0f, 2.0f},
    174      {3.0f, 4.0f},
    175      {5.0f, 6.0f}},
    176     {{7.0f, 8.0f},
    177      {9.0f, 10.0f},
    178      {11.0f, 12.0f}},
    179   });
    180   // clang-format on
    181 
    182   auto literal = Literal::CreateR3FromArray3D(array_3d);
    183   EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2));
    184   string result = literal->ToString();
    185   const string expected = R"(f32[2,3,2] {
    186 { { 1, 2 },
    187   { 3, 4 },
    188   { 5, 6 } },
    189 { { 7, 8 },
    190   { 9, 10 },
    191   { 11, 12 } }
    192 })";
    193   ASSERT_EQ(expected, result);
    194 }
    195 
    196 TEST_F(LiteralUtilTest, CreateSparse) {
    197   std::vector<int64> dimensions = {8, 8, 8};
    198   Array2D<int64> indices = {
    199       {3, 4, 5},
    200       {1, 2, 3},
    201       {2, 3, 4},
    202       {3, 5, 6},
    203   };
    204   std::vector<int64> values = {7, 8, 9, 10};
    205   auto literal = Literal::CreateSparse<int64>(
    206       dimensions, SparseIndexArray(indices.n1() + 3, indices), values);
    207 
    208   Array2D<int64> expected_indices = {
    209       {1, 2, 3},
    210       {2, 3, 4},
    211       {3, 4, 5},
    212       {3, 5, 6},
    213   };
    214   std::vector<int64> expected_values = {8, 9, 7, 10};
    215 
    216   EXPECT_EQ(literal->sparse_indices()->data(),
    217             tensorflow::gtl::ArraySlice<int64>(
    218                 expected_indices.data(), expected_indices.num_elements()));
    219   EXPECT_EQ(tensorflow::gtl::ArraySlice<int64>(literal->data<int64>().data(),
    220                                                expected_values.size()),
    221             tensorflow::gtl::ArraySlice<int64>(expected_values));
    222 }
    223 
    224 TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
    225   // clang-format off
    226   auto literal = Literal::CreateR4Projected<float>({
    227     {1, 2},
    228     {1001, 1002},
    229     {2001, 2002},
    230   }, /*projection_p=*/1, /*projection_z=*/2);
    231   // clang-format on
    232   EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2));
    233   string result = literal->ToString();
    234   const string expected = R"(f32[1,2,3,2] {
    235   {  /*i0=0*/
    236     {  /*i1=0*/
    237       {1, 2},
    238       {1001, 1002},
    239       {2001, 2002}
    240     },
    241     {  /*i1=1*/
    242       {1, 2},
    243       {1001, 1002},
    244       {2001, 2002}
    245     }
    246   }
    247 })";
    248   ASSERT_EQ(expected, result);
    249 }
    250 
    251 TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
    252   EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(),
    253               ElementsAre(2, 2, 3, 3));
    254   string result = literal_r4_2x2x3x3_dim0major_->ToString();
    255   const string expected = R"(f32[2,2,3,3] {
    256   {  /*i0=0*/
    257     {  /*i1=0*/
    258       {1, 2, 3},
    259       {4, 5, 6},
    260       {7, 8, 9}
    261     },
    262     {  /*i1=1*/
    263       {11, 12, 13},
    264       {14, 15, 16},
    265       {17, 18, 19}
    266     }
    267   },
    268   {  /*i0=1*/
    269     {  /*i1=0*/
    270       {101, 102, 103},
    271       {104, 105, 106},
    272       {107, 108, 109}
    273     },
    274     {  /*i1=1*/
    275       {201, 202, 203},
    276       {204, 205, 206},
    277       {207, 208, 209}
    278     }
    279   }
    280 })";
    281   ASSERT_EQ(expected, result);
    282 }
    283 
    284 TEST_F(LiteralUtilTest, EachCellR2F32) {
    285   // clang-format off
    286   auto literal = Literal::CreateR2<float>({
    287     {3.1f, 4.2f},
    288     {9.3f, 12.4f},
    289   });
    290   // clang-format on
    291   std::vector<std::tuple<int64, int64, string>> seen;
    292   literal->EachCellAsString(
    293       [&seen](tensorflow::gtl::ArraySlice<int64> indices, const string& value) {
    294         seen.emplace_back(indices[0], indices[1], value);
    295       });
    296 
    297   using Elem = std::tuple<int64, int64, string>;
    298   std::vector<Elem> expected = {Elem(0, 0, "3.1"), Elem(0, 1, "4.2"),
    299                                 Elem(1, 0, "9.3"), Elem(1, 1, "12.4")};
    300   EXPECT_EQ(expected, seen);
    301 }
    302 
    303 TEST_F(LiteralUtilTest, ScalarEquality) {
    304   // Test equality with scalars.
    305   auto f32_42 = Literal::CreateR0<float>(42.0);
    306   auto f32_42_clone = Literal::CreateR0<float>(42.0);
    307 
    308   EXPECT_EQ(*f32_42, *f32_42);
    309   EXPECT_EQ(*f32_42, *f32_42_clone);
    310 
    311   auto f32_123 = Literal::CreateR0<float>(123.0);
    312   EXPECT_NE(*f32_42, *f32_123);
    313 
    314   auto f64_42 = Literal::CreateR0<double>(42.0);
    315   EXPECT_NE(*f32_42, *f64_42);
    316 }
    317 
    318 TEST_F(LiteralUtilTest, NonScalarEquality) {
    319   // Test equality with nonscalars.
    320   auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
    321   auto matrix_clone = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
    322   auto matrix_different = Literal::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}});
    323   auto vector_literal = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
    324   auto scalar = Literal::CreateR0<float>(1.0);
    325   Literal nil(ShapeUtil::MakeNil());
    326 
    327   EXPECT_EQ(*matrix, *matrix);
    328   EXPECT_EQ(*matrix, *matrix_clone);
    329   EXPECT_NE(*matrix, *matrix_different);
    330   EXPECT_NE(*matrix, *vector_literal);
    331   EXPECT_NE(*matrix, *scalar);
    332   EXPECT_NE(*matrix, nil);
    333   EXPECT_EQ(nil, nil);
    334 }
    335 
    336 TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
    337   // Test equality with literals which have different layouts.
    338   auto colmajor =
    339       MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
    340   colmajor->Set<float>({0, 0}, 1.0);
    341   colmajor->Set<float>({0, 1}, 2.0);
    342   colmajor->Set<float>({1, 0}, 3.0);
    343   colmajor->Set<float>({1, 1}, 4.0);
    344 
    345   auto rowmajor =
    346       MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
    347   rowmajor->Set<float>({0, 0}, 1.0);
    348   rowmajor->Set<float>({0, 1}, 2.0);
    349   rowmajor->Set<float>({1, 0}, 3.0);
    350   rowmajor->Set<float>({1, 1}, 4.0);
    351 
    352   EXPECT_EQ(*rowmajor, *colmajor);
    353 }
    354 
    355 TEST_F(LiteralUtilTest, TupleEquality) {
    356   // Test equality with tuples.
    357   auto scalar = Literal::CreateR0<float>(1.0);
    358   auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
    359   auto tuple1 = Literal::MakeTuple({scalar.get(), matrix.get()});
    360 
    361   // Tuple with the same elements. One element is shared with the original
    362   // tuple, the other is a clone of the element in the original tuple.
    363   auto scalar_clone = Literal::CreateR0<float>(1.0);
    364   auto tuple2 = Literal::MakeTuple({scalar_clone.get(), matrix.get()});
    365   EXPECT_EQ(*tuple1, *tuple2);
    366 
    367   // Tuple with elements reversed.
    368   auto reversed_tuple = Literal::MakeTuple({matrix.get(), scalar.get()});
    369   EXPECT_NE(*tuple1, *reversed_tuple);
    370 
    371   // Tuple with different value.
    372   auto scalar_42 = Literal::CreateR0<float>(42.0);
    373   auto different_tuple = Literal::MakeTuple({scalar_42.get(), matrix.get()});
    374   EXPECT_NE(*tuple1, *different_tuple);
    375 }
    376 
    377 TEST_F(LiteralUtilTest, C64Equality) {
    378   // Test equality with tuples.
    379   auto vector = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
    380 
    381   // Tuple with the same elements. One element is shared with the original
    382   // tuple, the other is a clone of the element in the original tuple.
    383   auto vector_clone = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
    384   EXPECT_EQ(*vector, *vector_clone);
    385 
    386   auto vector_reversed = Literal::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
    387   EXPECT_NE(*vector, *vector_reversed);
    388 }
    389 
    390 TEST_F(LiteralUtilTest, IsAllTuple) {
    391   auto element1 = Literal::CreateR0<float>(0.0);
    392   auto element2 = Literal::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
    393   auto tuple = Literal::MakeTuple({element1.get(), element1.get()});
    394 
    395   // Tuples should always return false for IsAll.
    396   EXPECT_FALSE(tuple->IsAll(0));
    397   EXPECT_FALSE(tuple->IsAll(1));
    398 }
    399 
    400 // Verifies that CreateFromShape works for tuples.
    401 TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
    402   auto scalar = Literal::CreateR0<float>(0.0);
    403   auto matrix = Literal::CreateR2<int32>({{0, 0}, {0, 0}});
    404   auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
    405 
    406   auto x = Literal::CreateFromShape(tuple->shape());
    407   EXPECT_EQ(*tuple, *x);
    408 }
    409 
    410 TEST_F(LiteralUtilTest, IsAll) {
    411   EXPECT_TRUE(Literal::CreateR0<bool>(false)->IsAll(0));
    412   EXPECT_TRUE(Literal::CreateR0<bool>(true)->IsAll(1));
    413   EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAll(1));
    414   EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAll(2));
    415   EXPECT_FALSE(Literal::CreateR0<bool>(true)->IsAll(0));
    416   EXPECT_FALSE(Literal::CreateR0<bool>(true)->IsAll(2));
    417   EXPECT_FALSE(Literal::CreateR0<bool>(true)->IsAll(-1));
    418 
    419   // We shouldn't reinterpret int8_min as an unsigned type and then decide that
    420   // it is equal to 255.
    421   auto int8_min = std::numeric_limits<int8>::min();
    422   EXPECT_FALSE(Literal::CreateR0<uint8>(255)->IsAll(int8_min));
    423 
    424   EXPECT_TRUE(Literal::CreateR0<float>(42.0)->IsAll(42));
    425   EXPECT_FALSE(Literal::CreateR0<float>(42.0001)->IsAll(42));
    426 
    427   EXPECT_TRUE(Literal::CreateR1<int>({100, 100, 100})->IsAll(100));
    428   EXPECT_FALSE(Literal::CreateR1<double>({100, 100, 100.001})->IsAll(100));
    429 
    430   EXPECT_TRUE(Literal::CreateR2<uint64>({{8, 8}, {8, 8}})->IsAll(8));
    431   EXPECT_FALSE(Literal::CreateR2<uint64>({{8, 8}, {8, 9}})->IsAll(8));
    432   EXPECT_FALSE(Literal::CreateR2<uint64>({{9, 8}, {8, 8}})->IsAll(8));
    433 
    434   half h8(8.0f);
    435   half h9(9.0f);
    436   EXPECT_TRUE(Literal::CreateR2<half>({{h8}, {h8}})->IsAll(8));
    437   EXPECT_FALSE(Literal::CreateR2<half>({{h8}, {h9}})->IsAll(8));
    438   EXPECT_FALSE(Literal::CreateR2<half>({{h9}, {h8}})->IsAll(8));
    439 
    440   bfloat16 b8(8.0f);
    441   bfloat16 b9(9.0f);
    442 
    443   EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8));
    444   EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8));
    445   EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8));
    446 
    447   // 9.001 will be truncated to 9.0
    448   bfloat16 b91(9.001f);
    449   bfloat16 b90(9.00f);
    450   EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0));
    451 
    452   complex64 c8_9 = {8, 9};
    453   EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
    454 
    455   auto uint64_max = std::numeric_limits<uint64>::max();
    456   EXPECT_FALSE(Literal::CreateR2<uint64>(
    457                    {{uint64_max, uint64_max}, {uint64_max, uint64_max}})
    458                    ->IsAll(-1));
    459 }
    460 
    461 TEST_F(LiteralUtilTest, IsAllFloat) {
    462   // IsAllFloat always returns false when the literal is not floating-point.
    463   EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAllFloat(0));
    464   EXPECT_FALSE(Literal::CreateR0<int8>(0)->IsAllFloat(0));
    465   EXPECT_FALSE(Literal::CreateR0<uint8>(0)->IsAllFloat(0));
    466   EXPECT_FALSE(Literal::CreateR0<int>(0)->IsAllFloat(0));
    467 
    468   EXPECT_TRUE(Literal::CreateR0<float>(0)->IsAllFloat(0));
    469   EXPECT_TRUE(Literal::CreateR0<float>(.5)->IsAllFloat(.5));
    470   EXPECT_TRUE(Literal::CreateR0<float>(-.5)->IsAllFloat(-.5));
    471   EXPECT_FALSE(Literal::CreateR0<float>(-.5)->IsAllFloat(-.49));
    472   EXPECT_FALSE(
    473       Literal::CreateR2<float>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
    474   EXPECT_TRUE(
    475       Literal::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})->IsAllFloat(.5));
    476 
    477   EXPECT_TRUE(Literal::CreateR0<double>(0)->IsAllFloat(0));
    478   EXPECT_TRUE(Literal::CreateR0<double>(.5)->IsAllFloat(.5));
    479   EXPECT_TRUE(Literal::CreateR0<double>(-.5)->IsAllFloat(-.5));
    480   EXPECT_FALSE(Literal::CreateR0<double>(-.5)->IsAllFloat(-.49));
    481   EXPECT_FALSE(
    482       Literal::CreateR2<double>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
    483 }
    484 
    485 TEST_F(LiteralUtilTest, IsAllComplex) {
    486   // IsAllComplex always returns false when the literal is not complex.
    487   EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAllComplex(0));
    488   EXPECT_FALSE(Literal::CreateR0<int8>(0)->IsAllComplex(0));
    489   EXPECT_FALSE(Literal::CreateR0<uint8>(0)->IsAllComplex(0));
    490   EXPECT_FALSE(Literal::CreateR0<int>(0)->IsAllComplex(0));
    491   EXPECT_FALSE(Literal::CreateR0<float>(0)->IsAllComplex(0));
    492   EXPECT_FALSE(Literal::CreateR0<double>(0)->IsAllComplex(0));
    493 
    494   complex64 c8_9 = {8, 9};
    495   complex64 c7_9 = {7, 9};
    496   EXPECT_TRUE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})
    497                   ->IsAllComplex({8.0f, 9.0f}));
    498   EXPECT_FALSE(Literal::CreateR2<complex64>({{c7_9}, {c8_9}})
    499                    ->IsAllComplex({8.0f, 9.0f}));
    500   EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c7_9}})
    501                    ->IsAllComplex({8.0f, 9.0f}));
    502 }
    503 
    504 TEST_F(LiteralUtilTest, IsZero) {
    505   auto scalar_zero = Literal::CreateR0<float>(0.0f);
    506   auto scalar_one = Literal::CreateR0<float>(1.0f);
    507   EXPECT_TRUE(scalar_zero->IsZero({}));
    508   EXPECT_FALSE(scalar_one->IsZero({}));
    509 
    510   auto array = Literal::CreateR2<uint32>({{1, 2, 0, 3}, {1, 0, 1, 2}});
    511   EXPECT_FALSE(array->IsZero({0, 1}));
    512   EXPECT_TRUE(array->IsZero({0, 2}));
    513   EXPECT_TRUE(array->IsZero({1, 1}));
    514   EXPECT_FALSE(array->IsZero({1, 2}));
    515 
    516   auto complex_zero = Literal::CreateR0<complex64>(0.0f);
    517   auto complex_nonzero = Literal::CreateR0<complex64>(0.5f);
    518   EXPECT_TRUE(complex_zero->IsZero({}));
    519   EXPECT_FALSE(complex_nonzero->IsZero({}));
    520 }
    521 
    522 template <typename T>
    523 class LiteralUtilTestTemplated : public ::testing::Test {};
    524 
    525 using TestedTypes = ::testing::Types<float, int32, uint32, complex64>;
    526 TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes);
    527 
    528 TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
    529   // Make a non-integer for floating point types.
    530   TypeParam half = TypeParam(1) / TypeParam(2);
    531   auto data = Literal::CreateR2<TypeParam>({{half, 2}, {3, 4}});
    532   const Layout layout01 = LayoutUtil::MakeLayout({0, 1});
    533   const Layout layout10 = LayoutUtil::MakeLayout({1, 0});
    534 
    535   auto data01 = data->Relayout(layout01);
    536   EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01));
    537   EXPECT_EQ(*data, *data01);
    538 
    539   auto data10 = data->Relayout(layout10);
    540   EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10));
    541   EXPECT_EQ(*data, *data10);
    542 }
    543 
    544 TEST_F(LiteralUtilTest, ReshapeR0) {
    545   auto original = Literal::CreateR0<float>(1.7f);
    546   auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
    547   EXPECT_EQ(*original, *reshape);
    548 }
    549 
    550 TEST_F(LiteralUtilTest, ReshapeR4) {
    551   // clang-format off
    552   // F32[1x3x2x4]
    553   auto original = Literal::CreateR4WithLayout<float>({{
    554      {{10, 11, 12, 13}, {14, 15, 16, 17}},
    555      {{18, 19, 20, 21}, {22, 23, 24, 25}},
    556      {{26, 27, 28, 29}, {30, 31, 32, 33}},
    557   }}, layout_r4_dim0major_);
    558   // F32[1x3x4x2]
    559   auto expected = Literal::CreateR3WithLayout<float>({
    560     {{10, 11}, {12, 13}, {14, 15}, {16, 17}},
    561     {{18, 19}, {20, 21}, {22, 23}, {24, 25}},
    562     {{26, 27}, {28, 29}, {30, 31}, {32, 33}},
    563   }, layout_r3_dim0major_);
    564   // clang-format on
    565   auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
    566 
    567   EXPECT_EQ(*expected, *reshape);
    568 }
    569 
    570 TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
    571   // clang-format off
    572   // F32[1x3x2x4]
    573   auto original = Literal::CreateR4WithLayout<float>({{
    574      {{10, 11, 12, 13}, {14, 15, 16, 17}},
    575      {{18, 19, 20, 21}, {22, 23, 24, 25}},
    576      {{26, 27, 28, 29}, {30, 31, 32, 33}},
    577   }}, layout_r4_dim0minor_);
    578   // F32[1x3x4x2]
    579   auto expected = Literal::CreateR3WithLayout<float>({
    580     {{10, 11}, {12, 13}, {14, 15}, {16, 17}},
    581     {{18, 19}, {20, 21}, {22, 23}, {24, 25}},
    582     {{26, 27}, {28, 29}, {30, 31}, {32, 33}},
    583   }, layout_r3_dim0major_);
    584   // clang-format on
    585   auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
    586 
    587   EXPECT_EQ(*expected, *reshape);
    588 }
    589 
    590 TEST_F(LiteralUtilTest, TransposeR0) {
    591   auto original = Literal::CreateR0<float>(1.7f);
    592   auto reshape = original->Transpose(/*permutation=*/{});
    593   EXPECT_EQ(*original, *reshape);
    594 }
    595 
    596 TEST_F(LiteralUtilTest, TransposeR4) {
    597   // clang-format off
    598   // F32[1x3x2x4]
    599   auto original = Literal::CreateR4<float>({{
    600      {{10, 11, 12, 13}, {14, 15, 16, 17}},
    601      {{18, 19, 20, 21}, {22, 23, 24, 25}},
    602      {{26, 27, 28, 29}, {30, 31, 32, 33}},
    603   }});
    604   // clang-format on
    605   auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1});
    606 
    607   reshape->EachCell<float>(
    608       [&](tensorflow::gtl::ArraySlice<int64> indices, float value) {
    609         EXPECT_EQ(value, original->Get<float>(
    610                              {indices[2], indices[3], indices[0], indices[1]}));
    611       });
    612 }
    613 
    614 TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
    615   // Tests that using Relayout on an array is equivalent to creating it in the
    616   // target layout in the first place.
    617   auto dim0minor_relaid_to_dim0major =
    618       literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_);
    619   EXPECT_EQ(*literal_r4_2x2x3x3_dim0major_, *dim0minor_relaid_to_dim0major);
    620 
    621   auto dim0major_relaid_to_dim0minor =
    622       literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_);
    623   EXPECT_EQ(*literal_r4_2x2x3x3_dim0minor_, *dim0major_relaid_to_dim0minor);
    624 }
    625 
    626 TEST_F(LiteralUtilTest, TestR2LinearLayout) {
    627   // Test expected memory layout of R2 dim0-minor (column-major) literal.
    628   auto mat_dim0minor = Literal::CreateR2WithLayout<int32>(
    629       {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
    630   EXPECT_EQ(mat_dim0minor->element_count(), 6);
    631   EXPECT_THAT(mat_dim0minor->data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
    632 
    633   // Test expected memory layout when using Relayout to row major.
    634   auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_);
    635   EXPECT_THAT(relaid_mat_to_dim0major->data<int32>(),
    636               ElementsAre(1, 2, 3, 4, 5, 6));
    637 
    638   // Test expected memory layout of R2 created with dim0-major (row-major).
    639   auto mat_dim0major = Literal::CreateR2WithLayout<int32>(
    640       {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
    641   EXPECT_EQ(mat_dim0major->element_count(), 6);
    642   EXPECT_THAT(mat_dim0major->data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
    643 
    644   // Test expected memory layout when using Relayout to column major.
    645   auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_);
    646   EXPECT_THAT(relaid_mat_to_dim0minor->data<int32>(),
    647               ElementsAre(1, 4, 2, 5, 3, 6));
    648 }
    649 
    650 TEST_F(LiteralUtilTest, TestR3LinearLayout) {
    651   // Test expected memory layout of R3 dim0-minor (column-major) literal.
    652   Array3D<int> arr3d(
    653       // clang-format off
    654         {
    655           {
    656             {1, 2, 3},
    657             {4, 5, 6},
    658           },
    659           {
    660             {7, 8, 9},
    661             {10, 11, 12},
    662           },
    663       });  // clang-format on
    664   auto lit_dim0minor =
    665       Literal::CreateR3FromArray3DWithLayout<int>(arr3d, layout_r3_dim0minor_);
    666 
    667   EXPECT_EQ(lit_dim0minor->element_count(), 12);
    668   std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
    669   EXPECT_THAT(lit_dim0minor->data<int32>(),
    670               testing::ElementsAreArray(expected_dim0minor));
    671 
    672   // Test expected memory layout when using Relayout to row major.
    673   auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_);
    674   std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
    675   EXPECT_THAT(relaid_lit_to_dim0major->data<int32>(),
    676               testing::ElementsAreArray(expected_dim0major));
    677 
    678   // Test expected memory layout of R3 created with dim0-major (row-major).
    679   auto lit_dim0major =
    680       Literal::CreateR3FromArray3DWithLayout<int>(arr3d, layout_r3_dim0major_);
    681   EXPECT_EQ(lit_dim0major->element_count(), 12);
    682   EXPECT_THAT(lit_dim0major->data<int32>(),
    683               testing::ElementsAreArray(expected_dim0major));
    684 
    685   // Test expected memory layout when using Relayout to column major.
    686   auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_);
    687   EXPECT_THAT(relaid_lit_to_dim0minor->data<int32>(),
    688               testing::ElementsAreArray(expected_dim0minor));
    689 }
    690 
    691 TEST_F(LiteralUtilTest, SliceR0S32) {
    692   auto input = Literal::CreateR0<int32>(1);
    693   auto result = input->Slice({}, {});
    694   EXPECT_EQ(*input, *result);
    695 }
    696 
    697 TEST_F(LiteralUtilTest, SliceR1F32) {
    698   auto input = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
    699   auto result = input->Slice({3}, {4});
    700   auto expected = Literal::CreateR1<float>({4.0});
    701   EXPECT_EQ(*expected, *result);
    702 }
    703 
    704 TEST_F(LiteralUtilTest, SliceR2U32) {
    705   auto input_3x4 =
    706       Literal::CreateR2<uint32>({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
    707   auto result = input_3x4->Slice({0, 2}, {2, 4});
    708   auto expected = Literal::CreateR2<uint32>({{3, 4}, {7, 8}});
    709   EXPECT_EQ(*expected, *result);
    710 }
    711 
    712 TEST_F(LiteralUtilTest, SliceR3U32Full) {
    713   auto input_2x3x2 = Literal::CreateR3<uint32>(
    714       {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
    715   auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2});
    716   EXPECT_EQ(*input_2x3x2, *result);
    717 }
    718 
    719 TEST_F(LiteralUtilTest, PopulateR1S64) {
    720   Literal output(ShapeUtil::MakeShape(S64, {1}));
    721   output.PopulateR1<int64>({77});
    722   auto expected = Literal::CreateR1<int64>({77});
    723   EXPECT_EQ(output, *expected);
    724 }
    725 
    726 TEST_F(LiteralUtilTest, PopulateR1U64) {
    727   Literal output(ShapeUtil::MakeShape(U64, {2}));
    728   output.PopulateR1<uint64>({{77, 88}});
    729   auto expected = Literal::CreateR1<uint64>({{77, 88}});
    730   EXPECT_EQ(output, *expected);
    731 }
    732 
    733 TEST_F(LiteralUtilTest, PopulateR1C64) {
    734   Literal output(ShapeUtil::MakeShape(C64, {1}));
    735   output.PopulateR1<complex64>({{77, 88}});
    736   auto expected = Literal::CreateR1<complex64>({{77, 88}});
    737   EXPECT_EQ(output, *expected);
    738 }
    739 
    740 TEST_F(LiteralUtilTest, PopulateR2C64) {
    741   Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
    742   output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
    743   auto expected =
    744       Literal::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
    745   EXPECT_EQ(output, *expected);
    746 }
    747 
    748 TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
    749   Literal output(ShapeUtil::MakeShape(BF16, {}));
    750   bfloat16 h(0.25f);
    751   output.PopulateWithValue<bfloat16>(h);
    752   auto expected = Literal::CreateR0<bfloat16>(h);
    753   EXPECT_EQ(output, *expected);
    754 }
    755 
    756 TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
    757   Literal output(ShapeUtil::MakeShape(BF16, {3}));
    758   bfloat16 h(0.5f);
    759   output.PopulateWithValue<bfloat16>(h);
    760   auto expected = Literal::CreateR1<bfloat16>({h, h, h});
    761   EXPECT_EQ(output, *expected);
    762 }
    763 
    764 TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
    765   Literal output(ShapeUtil::MakeShape(BF16, {2, 2}));
    766   bfloat16 h(2.0f);
    767   output.PopulateWithValue<bfloat16>(h);
    768   auto expected = Literal::CreateR2<bfloat16>({{h, h}, {h, h}});
    769   EXPECT_EQ(output, *expected);
    770 }
    771 
    772 TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
    773   Literal output(ShapeUtil::MakeShape(F32, {}));
    774   output.PopulateWithValue<float>(2.5f);
    775   auto expected = Literal::CreateR0<float>(2.5f);
    776   EXPECT_EQ(output, *expected);
    777 }
    778 
    779 TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
    780   Literal output(ShapeUtil::MakeShape(S64, {3}));
    781   output.PopulateWithValue<int64>(-7);
    782   auto expected = Literal::CreateR1<int64>({-7, -7, -7});
    783   EXPECT_EQ(output, *expected);
    784 }
    785 
    786 TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
    787   Literal output(ShapeUtil::MakeShape(U64, {2, 2}));
    788   output.PopulateWithValue<uint64>(42);
    789   auto expected = Literal::CreateR2<uint64>({{42, 42}, {42, 42}});
    790   EXPECT_EQ(output, *expected);
    791 }
    792 
    793 TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
    794   Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
    795   output.PopulateWithValue<complex64>({4, 2});
    796   auto expected =
    797       Literal::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
    798   EXPECT_EQ(output, *expected);
    799 }
    800 
    801 TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
    802   Literal output(ShapeUtil::MakeShape(F16, {}));
    803   half h(0.25f);
    804   output.PopulateWithValue<half>(h);
    805   auto expected = Literal::CreateR0<half>(h);
    806   EXPECT_EQ(output, *expected);
    807 }
    808 
    809 TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
    810   Literal output(ShapeUtil::MakeShape(F16, {3}));
    811   half h(0.5f);
    812   output.PopulateWithValue<half>(h);
    813   auto expected = Literal::CreateR1<half>({h, h, h});
    814   EXPECT_EQ(output, *expected);
    815 }
    816 
    817 TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
    818   Literal output(ShapeUtil::MakeShape(F16, {2, 2}));
    819   half h(2.0f);
    820   output.PopulateWithValue<half>(h);
    821   auto expected = Literal::CreateR2<half>({{h, h}, {h, h}});
    822   EXPECT_EQ(output, *expected);
    823 }
    824 
    825 TEST_F(LiteralUtilTest, ReplicateR2U32) {
    826   auto input =
    827       Literal::CreateR2<uint32>({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
    828   auto output = input->Replicate<uint32>(3);
    829   auto expected = Literal::CreateR3<uint32>(
    830       {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
    831        {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
    832        {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}});
    833   EXPECT_EQ(*output, *expected);
    834 }
    835 
    836 TEST_F(LiteralUtilTest, CopySliceFrom) {
    837   const int64 dimensions[] = {17, 15, 34, 21};
    838   const int64 layouts[][4] = {
    839       {3, 2, 1, 0}, {0, 2, 1, 3}, {0, 1, 2, 3}, {2, 0, 3, 1}, {1, 3, 0, 2}};
    840   for (const auto& layout : layouts) {
    841     Shape shape = ShapeUtil::MakeShapeWithLayout(
    842         primitive_util::NativeToPrimitiveType<uint32>(), dimensions, layout);
    843 
    844     auto source = Literal::CreateFromShape(shape);
    845     const int64 zero_base[] = {0, 0, 0, 0};
    846     const int64 step[] = {1, 1, 1, 1};
    847     uint32 seqnr = 0;
    848     auto init_proc = [&](const std::vector<int64>& indexes) {
    849       source->Set(indexes, ++seqnr);
    850       return true;
    851     };
    852     ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step,
    853                             init_proc);
    854 
    855     auto blank = Literal::CreateFromShape(shape);
    856     const int64 src_base[] = {3, 1, 5, 7};
    857     const int64 dest_base[] = {6, 4, 12, 2};
    858     const int64 copy_size[] = {7, 8, 11, 9};
    859     TF_EXPECT_OK(blank->CopySliceFrom(*source, src_base, dest_base, copy_size));
    860 
    861     std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
    862     std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
    863     bool matched = true;
    864     auto check_proc = [&](const std::vector<int64>& indexes) {
    865       std::copy(indexes.begin(), indexes.end(), source_indexes.begin());
    866       std::transform(source_indexes.begin(), source_indexes.end(), src_base,
    867                      source_indexes.begin(), std::plus<int64>());
    868       std::copy(indexes.begin(), indexes.end(), blank_indexes.begin());
    869       std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base,
    870                      blank_indexes.begin(), std::plus<int64>());
    871       auto bval = blank->Get<uint32>(blank_indexes);
    872       matched = (bval != 0 && bval == source->Get<uint32>(source_indexes));
    873       return matched;
    874     };
    875 
    876     ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step,
    877                             check_proc);
    878     EXPECT_TRUE(matched);
    879   }
    880 }
    881 
    882 TEST_F(LiteralUtilTest, CopyFromScalars) {
    883   auto zero = Literal::CreateR0<uint32>(0);
    884   auto nine = Literal::CreateR0<uint32>(9);
    885   TF_EXPECT_OK(zero->CopyFrom(*nine));
    886   EXPECT_EQ(*zero, *nine);
    887 
    888   auto vect = Literal::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
    889   TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {}));
    890   EXPECT_EQ(zero->Get<uint32>({}), 17);
    891   TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {}));
    892   EXPECT_EQ(vect->Get<uint32>({4}), 17);
    893 }
    894 
    895 TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
    896   const Shape empty_r1_shape = ShapeUtil::MakeShape(F32, {0});
    897   const auto const_nine = Literal::CreateR1<float>({9});
    898   const auto const_empty = Literal::CreateFromShape(empty_r1_shape);
    899 
    900   {
    901     // Source contains dimension with zero elements.
    902     const auto empty = Literal::CreateFromShape(empty_r1_shape);
    903     auto nine = Literal::CreateR1<float>({9});
    904 
    905     TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0}));
    906     EXPECT_EQ(*nine, *const_nine);
    907   }
    908 
    909   {
    910     // Copy 0 element to destination with zero elements.
    911     const auto empty = Literal::CreateFromShape(empty_r1_shape);
    912     auto nine = Literal::CreateR1<float>({9});
    913 
    914     TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0}));
    915     EXPECT_EQ(*empty, *const_empty);
    916   }
    917 }
    918 
    919 TEST_F(LiteralUtilTest, CopyFromNilShape) {
    920   Literal nil_literal0(ShapeUtil::MakeNil());
    921   Literal nil_literal1(ShapeUtil::MakeNil());
    922   // This doesn't actually do any copying, but it should succeed.
    923   TF_ASSERT_OK(nil_literal0.CopyFrom(nil_literal1));
    924 }
    925 
    926 TEST_F(LiteralUtilTest, CopyFromArrays) {
    927   auto scalar_42 = Literal::CreateR0<float>(42.0);
    928   auto scalar_123 = Literal::CreateR0<float>(123.0);
    929   EXPECT_NE(*scalar_42, *scalar_123);
    930   TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{},
    931                                    /*src_shape_index=*/{}));
    932   EXPECT_EQ(*scalar_42, *scalar_123);
    933   EXPECT_EQ(scalar_42->Get<float>({}), 123.0f);
    934 
    935   auto matrix_1234 = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
    936   auto matrix_5678 = Literal::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
    937   EXPECT_NE(*matrix_1234, *matrix_5678);
    938   EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 1.0f);
    939   TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{},
    940                                      /*src_shape_index=*/{}));
    941   EXPECT_EQ(*matrix_1234, *matrix_5678);
    942   EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 5.0f);
    943 }
    944 
    945 TEST_F(LiteralUtilTest, CopyFromTuples) {
    946   auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
    947   Literal nil_literal(ShapeUtil::MakeNil());
    948   auto nested_tuple = Literal::MakeTuple(
    949       {matrix.get(),
    950        Literal::MakeTuple({Literal::CreateR0<int32>(42).get(),
    951                            Literal::CreateR1<double>({23.0, 44.0}).get(),
    952                            &nil_literal})
    953            .get()});
    954   // Create a tuple the same shape as the inner tuple of nested_tuple but with
    955   // different values..
    956   auto tuple = Literal::MakeTuple({Literal::CreateR0<int32>(-5).get(),
    957                                    Literal::CreateR1<double>({2.0, 4.0}).get(),
    958                                    &nil_literal});
    959 
    960   EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0}));
    961   EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), 42);
    962   EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 23.0);
    963   EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 44.0);
    964 
    965   // Overwrite the inner tuple element of nested_tuple with the contents of
    966   // 'tuple'.
    967   TF_ASSERT_OK(nested_tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
    968                                       /*src_shape_index=*/{}));
    969 
    970   // The matrix element should be unchanged.
    971   EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0}));
    972 
    973   // The tuple element should have been copied from 'tuple'.
    974   EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), -5);
    975   EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 2.0);
    976   EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 4.0);
    977 }
    978 TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
    979   auto tuple = Literal::MakeTuple(
    980       {Literal::CreateR0<int32>(-2).get(), Literal::CreateR0<int32>(4).get()});
    981 
    982   EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
    983   EXPECT_EQ(tuple->Get<int32>({}, {1}), 4);
    984 
    985   // Copy from one element to the other.
    986   TF_ASSERT_OK(tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
    987                                /*src_shape_index=*/{0}));
    988 
    989   EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
    990   EXPECT_EQ(tuple->Get<int32>({}, {1}), -2);
    991 }
    992 
    993 TEST_F(LiteralUtilTest, CopyFromDifferentShapes) {
    994   auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
    995   auto vector = Literal::CreateR1<float>({5.0, 7.0});
    996   Status status = matrix->CopyFrom(*vector);
    997   ASSERT_FALSE(status.ok());
    998   ASSERT_THAT(status.error_message(),
    999               HasSubstr("Destination subshape incompatible"));
   1000 }
   1001 
   1002 TEST_F(LiteralUtilTest, F16) {
   1003   // Verify that the internal data views are consistent and that they
   1004   // are in little endian format
   1005   // TODO - modify if we make the data format machine endianess dependent
   1006   auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
   1007   Literal* l1 = m1.get();
   1008   const char* d1 = reinterpret_cast<const char*>(l1->data<half>().data());
   1009   EXPECT_EQ(d1[0], 0);
   1010   EXPECT_EQ(d1[1], 0);
   1011   EXPECT_EQ(d1[2], 0);
   1012   EXPECT_EQ(d1[3], 0);
   1013   EXPECT_EQ(d1[4], 0);
   1014   EXPECT_EQ(d1[5], 0);
   1015   EXPECT_EQ(d1[6], 0);
   1016   EXPECT_EQ(d1[7], 0);
   1017 
   1018   half h1(1.0f);
   1019   half h2(2.0f);
   1020   auto m2 = Literal::CreateR2<half>({{h1, h2}, {h2, h1}});
   1021   Literal* l2 = m2.get();
   1022   const char* d2 = reinterpret_cast<const char*>(l2->data<half>().data());
   1023   EXPECT_EQ(d2[0], 0);
   1024   EXPECT_EQ(d2[1], 0x3C);
   1025   EXPECT_EQ(d2[2], 0);
   1026   EXPECT_EQ(d2[3], 0x40);
   1027   EXPECT_EQ(d2[4], 0);
   1028   EXPECT_EQ(d2[5], 0x40);
   1029   EXPECT_EQ(d2[6], 0);
   1030   EXPECT_EQ(d2[7], 0x3C);
   1031 }
   1032 
   1033 TEST_F(LiteralUtilTest, Populate) {
   1034   struct PopulateData {
   1035     std::vector<int64> dimensions;
   1036     std::vector<int64> layout;
   1037   } populate_data[] = {
   1038       {{}, {}},
   1039       {{0}, {0}},
   1040       {{16}, {0}},
   1041       {{2, 0}, {1, 0}},
   1042       {{4, 16}, {1, 0}},
   1043       {{21, 12}, {0, 1}},
   1044       {{6, 11, 17}, {2, 0, 1}},
   1045       {{6, 11, 5, 17}, {3, 2, 0, 1}},
   1046   };
   1047   for (const auto& data : populate_data) {
   1048     Shape shape = ShapeUtil::MakeShapeWithLayout(
   1049         primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
   1050         data.layout);
   1051     auto literal = Literal::CreateFromShape(shape);
   1052     auto generator = [&](tensorflow::gtl::ArraySlice<int64> indexes) -> uint32 {
   1053       // Offsets from linear index just to avoid R0 literals to be initialized
   1054       // with zero.
   1055       return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
   1056                                                            indexes) +
   1057              17;
   1058     };
   1059     TF_EXPECT_OK(literal->Populate<uint32>(generator));
   1060 
   1061     std::vector<int64> zero_base(data.dimensions.size(), 0);
   1062     std::vector<int64> step(data.dimensions.size(), 1);
   1063     bool matched = true;
   1064     auto check_function = [&](const std::vector<int64>& indexes) {
   1065       auto value = literal->Get<uint32>(indexes);
   1066       matched = matched && (value == generator(indexes));
   1067       return matched;
   1068     };
   1069     ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step,
   1070                             check_function);
   1071     EXPECT_TRUE(matched);
   1072   }
   1073 }
   1074 
   1075 TEST_F(LiteralUtilTest, ConvertR4) {
   1076   // clang-format off
   1077   auto original = Literal::CreateR4WithLayout<int8>({{
   1078      {{10, 11, 12, 13}, {14, 15, 16, 17}},
   1079      {{18, 19, 20, 21}, {22, 23, 24, 25}},
   1080      {{26, 27, 28, 29}, {30, 31, 32, 33}},
   1081   }}, layout_r4_dim0major_);
   1082   auto expected = Literal::CreateR4WithLayout<uint32>({{
   1083      {{10, 11, 12, 13}, {14, 15, 16, 17}},
   1084      {{18, 19, 20, 21}, {22, 23, 24, 25}},
   1085      {{26, 27, 28, 29}, {30, 31, 32, 33}},
   1086   }}, layout_r4_dim0major_);
   1087   // clang-format on
   1088   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
   1089                           original->Convert(U32));
   1090 
   1091   EXPECT_EQ(*expected, *converted);
   1092 }
   1093 
   1094 TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
   1095   // clang-format off
   1096   auto s8 = Literal::CreateR4WithLayout<int8>({{
   1097     {{10, 0, 12, 0}, {0, 15, 0, 17}},
   1098     {{0, 19, 0, 21}, {22, 0, 24, 0}},
   1099     {{26, 0, 28, 0}, {0, 31, 0, 33}},
   1100   }}, layout_r4_dim0major_);
   1101   auto s32 = Literal::CreateR4WithLayout<int32>({{
   1102     {{10, 0, 12, 0}, {0, 15, 0, 17}},
   1103     {{0, 19, 0, 21}, {22, 0, 24, 0}},
   1104     {{26, 0, 28, 0}, {0, 31, 0, 33}},
   1105   }}, layout_r4_dim0major_);
   1106   auto u32 = Literal::CreateR4WithLayout<uint32>({{
   1107     {{10, 0, 12, 0}, {0, 15, 0, 17}},
   1108     {{0, 19, 0, 21}, {22, 0, 24, 0}},
   1109     {{26, 0, 28, 0}, {0, 31, 0, 33}},
   1110   }}, layout_r4_dim0major_);
   1111   auto s64 = Literal::CreateR4WithLayout<int64>({{
   1112     {{10, 0, 12, 0}, {0, 15, 0, 17}},
   1113     {{0, 19, 0, 21}, {22, 0, 24, 0}},
   1114     {{26, 0, 28, 0}, {0, 31, 0, 33}},
   1115   }}, layout_r4_dim0major_);
   1116   auto u64 = Literal::CreateR4WithLayout<uint64>({{
   1117     {{10, 0, 12, 0}, {0, 15, 0, 17}},
   1118     {{0, 19, 0, 21}, {22, 0, 24, 0}},
   1119     {{26, 0, 28, 0}, {0, 31, 0, 33}},
   1120   }}, layout_r4_dim0major_);
   1121   auto pred = Literal::CreateR4WithLayout<bool>({{
   1122     {{true, false, true, false}, {false, true, false, true}},
   1123     {{false, true, false, true}, {true, false, true, false}},
   1124     {{true, false, true, false}, {false, true, false, true}},
   1125   }}, layout_r4_dim0major_);
   1126   auto int32_pred = Literal::CreateR4WithLayout<int32>({{
   1127     {{1, 0, 1, 0}, {0, 1, 0, 1}},
   1128     {{0, 1, 0, 1}, {1, 0, 1, 0}},
   1129     {{1, 0, 1, 0}, {0, 1, 0, 1}},
   1130   }}, layout_r4_dim0major_);
   1131   auto f16 = Literal::CreateR4WithLayout<half>({{
   1132     {{half(10.0), half(0.0), half(12.0), half(0.0)},
   1133      {half(0.0), half(15.0), half(0.0), half(17.0)}},
   1134     {{half(0.0), half(19.0), half(0.0), half(21.0)},
   1135      {half(22.0), half(0.0), half(24.0), half(0.0)}},
   1136     {{half(26.0), half(0.0), half(28.0), half(0.0)},
   1137      {half(0.0), half(31.0), half(0.0), half(33.0)}},
   1138   }}, layout_r4_dim0major_);
   1139   auto bf16 = Literal::CreateR4WithLayout<bfloat16>({{
   1140     {{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)},
   1141      {bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}},
   1142     {{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)},
   1143      {bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}},
   1144     {{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)},
   1145      {bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}},
   1146   }}, layout_r4_dim0major_);
   1147   auto f32 = Literal::CreateR4WithLayout<float>({{
   1148     {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
   1149     {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
   1150     {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
   1151   }}, layout_r4_dim0major_);
   1152   auto f64 = Literal::CreateR4WithLayout<double>({{
   1153     {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}},
   1154     {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
   1155     {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
   1156   }}, layout_r4_dim0major_);
   1157   auto c64 = Literal::CreateR4WithLayout<complex64>({{
   1158     {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
   1159     {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
   1160     {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
   1161   }}, layout_r4_dim0major_);
   1162   // clang-format on
   1163   std::unique_ptr<Literal> conv;
   1164 
   1165   conv = s8->Convert(U32).ConsumeValueOrDie();
   1166   EXPECT_EQ(*conv, *u32);
   1167 
   1168   conv = s8->Convert(S32).ConsumeValueOrDie();
   1169   EXPECT_EQ(*conv, *s32);
   1170 
   1171   conv = s8->Convert(U64).ConsumeValueOrDie();
   1172   EXPECT_EQ(*conv, *u64);
   1173 
   1174   conv = s8->Convert(S64).ConsumeValueOrDie();
   1175   EXPECT_EQ(*conv, *s64);
   1176 
   1177   conv = s8->Convert(PRED).ConsumeValueOrDie();
   1178   EXPECT_EQ(*conv, *pred);
   1179 
   1180   conv = bf16->Convert(S32).ConsumeValueOrDie();
   1181   EXPECT_EQ(*conv, *s32);
   1182 
   1183   conv = bf16->Convert(F32).ConsumeValueOrDie();
   1184   EXPECT_EQ(*conv, *f32);
   1185 
   1186   conv = pred->Convert(S32).ConsumeValueOrDie();
   1187   EXPECT_EQ(*conv, *int32_pred);
   1188 
   1189   conv = f32->Convert(S32).ConsumeValueOrDie();
   1190   EXPECT_EQ(*conv, *s32);
   1191 
   1192   conv = f64->Convert(S32).ConsumeValueOrDie();
   1193   EXPECT_EQ(*conv, *s32);
   1194 
   1195   conv = s32->Convert(F32).ConsumeValueOrDie();
   1196   EXPECT_EQ(*conv, *f32);
   1197 
   1198   conv = f32->Convert(F16).ConsumeValueOrDie();
   1199   EXPECT_EQ(*conv, *f16);
   1200 
   1201   conv = f64->Convert(F16).ConsumeValueOrDie();
   1202   EXPECT_EQ(*conv, *f16);
   1203 
   1204   conv = s32->Convert(F16).ConsumeValueOrDie();
   1205   EXPECT_EQ(*conv, *f16);
   1206 
   1207   conv = u32->Convert(F16).ConsumeValueOrDie();
   1208   EXPECT_EQ(*conv, *f16);
   1209 
   1210   conv = s32->Convert(C64).ConsumeValueOrDie();
   1211   EXPECT_EQ(*conv, *c64);
   1212 
   1213   conv = f16->Convert(C64).ConsumeValueOrDie();
   1214   EXPECT_EQ(*conv, *c64);
   1215 
   1216   EXPECT_EQ(s32->Convert(TUPLE).status().code(),
   1217             tensorflow::error::INVALID_ARGUMENT);
   1218   EXPECT_EQ(s32->Convert(S16).status().code(),
   1219             tensorflow::error::INVALID_ARGUMENT);
   1220   EXPECT_EQ(s32->Convert(U16).status().code(),
   1221             tensorflow::error::INVALID_ARGUMENT);
   1222   EXPECT_EQ(c64->Convert(F32).status().code(),
   1223             tensorflow::error::INVALID_ARGUMENT);
   1224   EXPECT_EQ(c64->Convert(S32).status().code(),
   1225             tensorflow::error::INVALID_ARGUMENT);
   1226 }
   1227 
   1228 TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
   1229   LiteralProto p;
   1230   p.mutable_shape()->set_element_type(PRED);
   1231   for (int len = 0; len < 25; ++len) {
   1232     p.mutable_shape()->clear_dimensions();
   1233     p.mutable_shape()->add_dimensions(len);
   1234     LayoutUtil::SetToDefaultLayout(p.mutable_shape());
   1235     p.clear_preds();
   1236     for (int i = 0; i < len; ++i) {
   1237       p.add_preds((i % 2) == (len % 2));
   1238     }
   1239 
   1240     TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
   1241                             Literal::CreateFromProto(p));
   1242     ASSERT_EQ(len, literal->data<bool>().size());
   1243     int i = 0;
   1244     for (bool value : literal->data<bool>()) {
   1245       EXPECT_EQ((i % 2) == (len % 2), value);
   1246       ++i;
   1247     }
   1248   }
   1249 }
   1250 
   1251 // Note that f16 is currently stored in a byte array in little endian byte order
   1252 TEST_F(LiteralUtilTest, ToProto_f16) {
   1253   half h1(1.0f);
   1254   half h2(2.0f);
   1255 
   1256   auto m = Literal::CreateR2<half>({{h1, h2}, {h2, h1}});
   1257   Literal* l = m.get();
   1258   EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape()));
   1259   EXPECT_EQ(4, l->data<half>().size());
   1260 
   1261   LiteralProto p = l->ToProto();
   1262   EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape()));
   1263   EXPECT_EQ(8, p.f16s().size());
   1264   const char* d = p.f16s().data();
   1265   EXPECT_EQ(d[0], 0);
   1266   EXPECT_EQ(d[1], 0x3C);
   1267   EXPECT_EQ(d[2], 0);
   1268   EXPECT_EQ(d[3], 0x40);
   1269   EXPECT_EQ(d[4], 0);
   1270   EXPECT_EQ(d[5], 0x40);
   1271   EXPECT_EQ(d[6], 0);
   1272   EXPECT_EQ(d[7], 0x3C);
   1273 }
   1274 
   1275 // Note that f16 is currently stored in a byte array in little endian byte order
   1276 TEST_F(LiteralUtilTest, CopyFromProto_f16) {
   1277   half h1(1.0f);
   1278   half h2(2.0f);
   1279 
   1280   const char half_vals[8] = {0x00, 0x3C, 0x00, 0x40, 0x00, 0x40, 0x00, 0x3C};
   1281   LiteralProto p;
   1282   p.mutable_shape()->set_element_type(F16);
   1283   p.mutable_shape()->clear_dimensions();
   1284   p.mutable_shape()->add_dimensions(4);
   1285   LayoutUtil::SetToDefaultLayout(p.mutable_shape());
   1286   p.clear_f16s();
   1287   p.set_f16s(half_vals, 8);
   1288   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
   1289                           Literal::CreateFromProto(p));
   1290   auto r = literal->data<half>();
   1291   ASSERT_EQ(4, r.size());
   1292   ASSERT_EQ(h1, r[0]);
   1293   ASSERT_EQ(h2, r[1]);
   1294   ASSERT_EQ(h2, r[2]);
   1295   ASSERT_EQ(h1, r[3]);
   1296 }
   1297 
   1298 TEST_F(LiteralUtilTest, LiteralViewTest) {
   1299   auto scalar = Literal::CreateR0<float>(1.0);
   1300   auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   1301   auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
   1302   auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
   1303   Literal nil(ShapeUtil::MakeNil());
   1304 
   1305   EXPECT_EQ(LiteralView::Create(*scalar, {}), *scalar);
   1306   EXPECT_EQ(LiteralView::Create(*matrix, {}), *matrix);
   1307   EXPECT_EQ(LiteralView::Create(*tuple, {}), *tuple);
   1308   EXPECT_EQ(LiteralView::Create(*nested_tuple, {}), *nested_tuple);
   1309   EXPECT_EQ(LiteralView::Create(nil, {}), nil);
   1310 
   1311   EXPECT_EQ(LiteralView::Create(*tuple, {0}), *scalar);
   1312   EXPECT_EQ(LiteralView::Create(*tuple, {1}), *matrix);
   1313 
   1314   EXPECT_EQ(LiteralView::Create(*nested_tuple, {0}), *tuple);
   1315   EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 0}), *scalar);
   1316   EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 1}), *matrix);
   1317   EXPECT_EQ(LiteralView::Create(*nested_tuple, {1}), *scalar);
   1318 }
   1319 
   1320 TEST_F(LiteralUtilTest, MutatingLiteralView) {
   1321   auto scalar = Literal::CreateR0<float>(1.0);
   1322   auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   1323   auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
   1324   auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
   1325   // Verify that changing the underlying data beneath the view changes the
   1326   // data of the view itself.
   1327   const auto nested_tuple_view = LiteralView::Create(*nested_tuple);
   1328   EXPECT_EQ(
   1329       nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
   1330       1.0f);
   1331   EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
   1332                                          /*shape_index=*/{0, 0}),
   1333             1.0f);
   1334   nested_tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
   1335   EXPECT_EQ(
   1336       nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
   1337       555.0f);
   1338   EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
   1339                                          /*shape_index=*/{0, 0}),
   1340             555.0f);
   1341 }
   1342 
   1343 TEST_F(LiteralUtilTest, LiteralViewOfALiteralView) {
   1344   auto scalar = Literal::CreateR0<float>(1.0);
   1345   auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   1346   auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
   1347   auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
   1348 
   1349   const auto nested_tuple_view = LiteralView::Create(*nested_tuple);
   1350   const auto tuple_view =
   1351       LiteralView::Create(nested_tuple_view, /*view_root=*/{0});
   1352   const auto matrix_view = LiteralView::Create(tuple_view, /*view_root=*/{1});
   1353   EXPECT_EQ(matrix_view, *Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
   1354 }
   1355 
   1356 TEST_F(LiteralUtilTest, LiteralMove) {
   1357   std::unique_ptr<Literal> matrix =
   1358       Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   1359   Literal literal(std::move(*matrix));
   1360 
   1361   EXPECT_TRUE(
   1362       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
   1363   EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
   1364   EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
   1365   EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
   1366   EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
   1367 }
   1368 
   1369 TEST_F(LiteralUtilTest, DecomposeTuple) {
   1370   Literal nil_literal(ShapeUtil::MakeNil());
   1371   auto nested_tuple = Literal::MakeTuple(
   1372       {Literal::CreateR2<int32>({{1, 2}, {3, 4}}).get(),
   1373        Literal::MakeTuple({Literal::CreateR0<int32>(42).get(),
   1374                            Literal::CreateR1<double>({23.0, 44.0}).get(),
   1375                            &nil_literal})
   1376            .get(),
   1377        &nil_literal});
   1378 
   1379   EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple->shape()));
   1380   std::vector<Literal> elements = nested_tuple->DecomposeTuple();
   1381   EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple->shape()));
   1382 
   1383   ASSERT_EQ(elements.size(), 3);
   1384 
   1385   EXPECT_TRUE(ShapeUtil::Compatible(elements[0].shape(),
   1386                                     ShapeUtil::MakeShape(S32, {2, 2})));
   1387   EXPECT_EQ(elements[0].Get<int32>({0, 0}), 1);
   1388   EXPECT_EQ(elements[0].Get<int32>({0, 1}), 2);
   1389   EXPECT_EQ(elements[0].Get<int32>({1, 0}), 3);
   1390   EXPECT_EQ(elements[0].Get<int32>({1, 1}), 4);
   1391 
   1392   EXPECT_TRUE(ShapeUtil::Compatible(
   1393       elements[1].shape(),
   1394       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}),
   1395                                  ShapeUtil::MakeShape(F64, {2}),
   1396                                  ShapeUtil::MakeNil()})));
   1397   EXPECT_EQ(elements[1].Get<int32>({}, /*shape_index=*/{0}), 42);
   1398   EXPECT_EQ(elements[1].Get<double>({0}, /*shape_index=*/{1}), 23.0);
   1399   EXPECT_EQ(elements[1].Get<double>({1}, /*shape_index=*/{1}), 44.0);
   1400 
   1401   EXPECT_TRUE(ShapeUtil::Compatible(elements[2].shape(), ShapeUtil::MakeNil()));
   1402 }
   1403 
   1404 TEST_F(LiteralUtilTest, DecomposeEmptyTuple) {
   1405   Literal nil_literal(ShapeUtil::MakeNil());
   1406   std::vector<Literal> elements = nil_literal.DecomposeTuple();
   1407   EXPECT_EQ(elements.size(), 0);
   1408 }
   1409 
   1410 TEST_F(LiteralUtilTest, MoveIntoTuple) {
   1411   std::vector<Literal> elements;
   1412   elements.push_back(std::move(*Literal::CreateR0<float>(1.0)));
   1413   elements.push_back(std::move(*Literal::CreateR1<int32>({4, 8})));
   1414   elements.push_back(std::move(
   1415       *Literal::MakeTuple({Literal::CreateR0<int32>(42).get(),
   1416                            Literal::CreateR1<double>({23.0, 44.0}).get()})
   1417 
   1418           ));
   1419 
   1420   Literal literal = Literal::MoveIntoTuple(&elements);
   1421   ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape()));
   1422   ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3);
   1423 
   1424   EXPECT_EQ(literal.Get<float>({}, /*shape_index=*/{0}), 1.0);
   1425   EXPECT_EQ(literal.Get<int32>({0}, /*shape_index=*/{1}), 4);
   1426   EXPECT_EQ(literal.Get<int32>({1}, /*shape_index=*/{1}), 8);
   1427   EXPECT_EQ(literal.Get<int32>({}, /*shape_index=*/{2, 0}), 42);
   1428   EXPECT_EQ(literal.Get<double>({0}, /*shape_index=*/{2, 1}), 23.0);
   1429   EXPECT_EQ(literal.Get<double>({1}, /*shape_index=*/{2, 1}), 44.0);
   1430 
   1431   for (const Literal& element : elements) {
   1432     EXPECT_TRUE(ShapeUtil::IsNil(element.shape()));
   1433   }
   1434 }
   1435 
   1436 TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) {
   1437   Literal literal = Literal::MoveIntoTuple({});
   1438   ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape()));
   1439   ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0);
   1440 }
   1441 
   1442 TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
   1443   Literal literal;
   1444   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape()));
   1445 
   1446   std::unique_ptr<Literal> matrix =
   1447       Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   1448   literal = std::move(*matrix);
   1449 
   1450   EXPECT_TRUE(
   1451       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
   1452   EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
   1453   EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
   1454   EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
   1455   EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
   1456 }
   1457 
   1458 TEST_F(LiteralUtilTest, LiteralViewCopy) {
   1459   std::unique_ptr<Literal> matrix =
   1460       Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   1461   const auto matrix_view = LiteralView::Create(*matrix);
   1462   LiteralView matrix_view_copy(matrix_view);
   1463 
   1464   EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
   1465   EXPECT_EQ(matrix_view_copy.Get<float>({0, 1}), 2.0);
   1466   EXPECT_EQ(matrix_view_copy.Get<float>({1, 0}), 3.0);
   1467   EXPECT_EQ(matrix_view_copy.Get<float>({1, 1}), 4.0);
   1468 }
   1469 
   1470 TEST_F(LiteralUtilTest, GetSetTuple) {
   1471   auto tuple = Literal::MakeTuple(
   1472       {Literal::CreateR0<float>(42.0).get(),
   1473        Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get()});
   1474   EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
   1475   tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
   1476   EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
   1477 
   1478   EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
   1479             3.0);
   1480   tuple->Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
   1481   EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
   1482             -4.0);
   1483 }
   1484 
   1485 TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
   1486   // Literals constructed using CreateFromShape should be zero initialized.
   1487   std::unique_ptr<Literal> scalar_f32 =
   1488       Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
   1489   EXPECT_EQ(scalar_f32->Get<float>({}), 0.0);
   1490   EXPECT_TRUE(scalar_f32->IsAll(0));
   1491 
   1492   std::unique_ptr<Literal> vector_s32 =
   1493       Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
   1494   EXPECT_EQ(vector_s32->Get<int32>({0}), 0);
   1495   EXPECT_EQ(vector_s32->Get<int32>({1}), 0);
   1496   EXPECT_EQ(vector_s32->Get<int32>({2}), 0);
   1497   EXPECT_TRUE(vector_s32->IsAll(0));
   1498 
   1499   std::unique_ptr<Literal> tuple =
   1500       Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
   1501           {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
   1502            ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})}));
   1503 
   1504   EXPECT_EQ(tuple->Get<double>({}, {0}), 0.0);
   1505   EXPECT_EQ(tuple->Get<bool>({0}, {1}), false);
   1506   EXPECT_EQ(tuple->Get<bool>({1}, {1}), false);
   1507   EXPECT_EQ(tuple->Get<uint64>({0, 0}, {2}), 0);
   1508   EXPECT_EQ(tuple->Get<uint64>({1, 0}, {2}), 0);
   1509   EXPECT_EQ(tuple->Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
   1510 }
   1511 
   1512 TEST_F(LiteralUtilTest, ProtoRoundTrip) {
   1513   // Test serializing then deserializing a Literal through a proto.
   1514   auto one_f32 = Literal::CreateR0<float>(1.0);
   1515   auto two_f32 = Literal::CreateR0<float>(2.0);
   1516   auto vector_int8 = Literal::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
   1517   auto vector_c64 = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
   1518   auto vector_bfloat16 = Literal::CreateR1<bfloat16>(
   1519       {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
   1520   auto vector_half =
   1521       Literal::CreateR1<half>({half{10.0}, half{20.0}, half{-30.0}});
   1522   auto matrix_pred =
   1523       Literal::CreateR2<bool>({{true, false, true}, {false, false, true}});
   1524   auto tuple = Literal::MakeTuple(
   1525       {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()});
   1526   Literal nil_literal(ShapeUtil::MakeNil());
   1527   auto nested_tuple = Literal::MakeTuple(
   1528       {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal});
   1529 
   1530   auto to_from_proto = [](const Literal& literal) -> Literal {
   1531     return std::move(*Literal::CreateFromProto(literal.ToProto()).ValueOrDie());
   1532   };
   1533 
   1534   EXPECT_EQ(*one_f32, to_from_proto(*one_f32));
   1535   EXPECT_EQ(*vector_c64, to_from_proto(*vector_c64));
   1536   EXPECT_EQ(*vector_bfloat16, to_from_proto(*vector_bfloat16));
   1537   EXPECT_EQ(*matrix_pred, to_from_proto(*matrix_pred));
   1538   EXPECT_EQ(*tuple, to_from_proto(*tuple));
   1539   EXPECT_EQ(*nested_tuple, to_from_proto(*nested_tuple));
   1540   EXPECT_EQ(nil_literal, to_from_proto(nil_literal));
   1541 
   1542   EXPECT_NE(*one_f32, *two_f32);
   1543   EXPECT_NE(*one_f32, to_from_proto(*two_f32));
   1544 }
   1545 
   1546 TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
   1547   // Proto contains a shape, but no values.
   1548   LiteralProto proto;
   1549   *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3});
   1550   Status status = Literal::CreateFromProto(proto).status();
   1551   ASSERT_FALSE(status.ok());
   1552   ASSERT_THAT(status.error_message(),
   1553               HasSubstr("Expected 3 elements in LiteralProto"));
   1554 }
   1555 
   1556 TEST_F(LiteralUtilTest, InvalidProtoNoShape) {
   1557   // Proto contains values, but no shape.
   1558   LiteralProto proto;
   1559   proto.add_preds(false);
   1560   proto.add_preds(true);
   1561   proto.add_preds(false);
   1562   Status status = Literal::CreateFromProto(proto).status();
   1563   ASSERT_FALSE(status.ok());
   1564   ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape"));
   1565 }
   1566 
   1567 TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) {
   1568   // Proto contains values in wrong container.
   1569   LiteralProto proto;
   1570   *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3});
   1571   proto.add_preds(false);
   1572   proto.add_preds(true);
   1573   proto.add_preds(false);
   1574   Status status = Literal::CreateFromProto(proto).status();
   1575   ASSERT_FALSE(status.ok());
   1576   ASSERT_THAT(status.error_message(),
   1577               HasSubstr("Expected 3 elements in LiteralProto"));
   1578 }
   1579 
   1580 TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) {
   1581   // Proto contains too few values.
   1582   LiteralProto proto;
   1583   *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2});
   1584   proto.add_f32s(1.0);
   1585   proto.add_f32s(2.0);
   1586   proto.add_f32s(3.0);
   1587   Status status = Literal::CreateFromProto(proto).status();
   1588   ASSERT_FALSE(status.ok());
   1589   ASSERT_THAT(status.error_message(),
   1590               HasSubstr("Expected 84 elements in LiteralProto"));
   1591 }
   1592 
   1593 TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) {
   1594   // Proto contains too many values.
   1595   LiteralProto proto;
   1596   *proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2});
   1597   proto.add_s32s(42);
   1598   proto.add_s32s(-10);
   1599   proto.add_s32s(100);
   1600   Status status = Literal::CreateFromProto(proto).status();
   1601   ASSERT_FALSE(status.ok());
   1602   ASSERT_THAT(status.error_message(),
   1603               HasSubstr("Expected 2 elements in LiteralProto"));
   1604 }
   1605 
   1606 TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) {
   1607   // Proto shape missing layout.
   1608   LiteralProto proto;
   1609   *proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2});
   1610   LayoutUtil::ClearLayout(proto.mutable_shape());
   1611   proto.add_preds(true);
   1612   proto.add_preds(false);
   1613   proto.add_preds(true);
   1614   proto.add_preds(false);
   1615   Status status = Literal::CreateFromProto(proto).status();
   1616   ASSERT_FALSE(status.ok());
   1617   ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout"));
   1618 }
   1619 
   1620 TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) {
   1621   // Proto has the too few tuple elements.
   1622   LiteralProto proto;
   1623   *proto.mutable_shape() = ShapeUtil::MakeTupleShape(
   1624       {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})});
   1625   LiteralProto* element0 = proto.add_tuple_literals();
   1626   *element0->mutable_shape() =
   1627       ShapeUtil::GetTupleElementShape(proto.shape(), 0);
   1628   element0->add_preds(false);
   1629   element0->add_preds(true);
   1630 
   1631   Status status = Literal::CreateFromProto(proto).status();
   1632   ASSERT_FALSE(status.ok());
   1633   ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
   1634 }
   1635 
   1636 TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
   1637   // Proto has the too many tuple elements.
   1638   LiteralProto proto;
   1639   *proto.mutable_shape() = ShapeUtil::MakeTupleShape(
   1640       {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})});
   1641   LiteralProto* element0 = proto.add_tuple_literals();
   1642   *element0->mutable_shape() =
   1643       ShapeUtil::GetTupleElementShape(proto.shape(), 0);
   1644   element0->add_preds(false);
   1645   element0->add_preds(true);
   1646   LiteralProto* element1 = proto.add_tuple_literals();
   1647   *element1->mutable_shape() =
   1648       ShapeUtil::GetTupleElementShape(proto.shape(), 1);
   1649   element1->add_f32s(42.0);
   1650   LiteralProto* element2 = proto.add_tuple_literals();
   1651   *element2->mutable_shape() = ShapeUtil::MakeShape(F32, {});
   1652   element2->add_f32s(123.0);
   1653 
   1654   Status status = Literal::CreateFromProto(proto).status();
   1655   ASSERT_FALSE(status.ok());
   1656   ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
   1657 }
   1658 
   1659 TEST_F(LiteralUtilTest, SortSparseElements) {
   1660   auto literal =
   1661       Literal::CreateSparse<float>({10, 10, 10}, SparseIndexArray(10, 3), {});
   1662   literal->AppendSparseElement<float>({2, 3, 4}, 2.0);
   1663   literal->AppendSparseElement<float>({3, 4, 5}, 3.0);
   1664   literal->AppendSparseElement<float>({1, 2, 3}, 1.0);
   1665   literal->SortSparseElements();
   1666   ASSERT_EQ(literal->ToString(false),
   1667             "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}");
   1668 }
   1669 
   1670 TEST_F(LiteralUtilTest, GetSparseElementAsString) {
   1671   std::vector<int64> dimensions = {10, 10, 10};
   1672   SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}});
   1673 
   1674   ASSERT_EQ(
   1675       Literal::CreateSparse<bool>(dimensions, indices, {true, false, true})
   1676           ->GetSparseElementAsString(1),
   1677       "false");
   1678   ASSERT_EQ(Literal::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
   1679                 ->GetSparseElementAsString(1),
   1680             tensorflow::strings::StrCat(int64{2}));
   1681   ASSERT_EQ(Literal::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
   1682                 ->GetSparseElementAsString(1),
   1683             tensorflow::strings::StrCat(double{2.0}));
   1684   ASSERT_EQ(Literal::CreateSparse<half>(dimensions, indices,
   1685                                         {half{1.0}, half{2.0}, half{3.0}})
   1686                 ->GetSparseElementAsString(1),
   1687             tensorflow::strings::StrCat(half{2.0}));
   1688   ASSERT_EQ(
   1689       Literal::CreateSparse<complex64>(
   1690           dimensions, indices,
   1691           std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
   1692           ->GetSparseElementAsString(1),
   1693       tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
   1694 }
   1695 
   1696 }  // namespace
   1697 }  // namespace xla
   1698