1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/literal_util.h" 17 18 #include <vector> 19 20 #include "tensorflow/compiler/xla/array3d.h" 21 #include "tensorflow/compiler/xla/array4d.h" 22 #include "tensorflow/compiler/xla/layout_util.h" 23 #include "tensorflow/compiler/xla/shape_util.h" 24 #include "tensorflow/compiler/xla/test.h" 25 #include "tensorflow/compiler/xla/types.h" 26 #include "tensorflow/core/lib/core/status_test_util.h" 27 #include "tensorflow/core/platform/macros.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace xla { 31 namespace { 32 33 using ::testing::ElementsAre; 34 using ::testing::HasSubstr; 35 36 class LiteralUtilTest : public ::testing::Test { 37 protected: 38 LiteralUtilTest() { 39 Array4D<float> arr4d({ 40 // clang-format off 41 { // i0=0 42 { // i1=0 43 {1, 2, 3}, // i2=0 44 {4, 5, 6}, // i2=1 45 {7, 8, 9}, // i2=2 46 }, 47 { // i1=1 48 {11, 12, 13}, 49 {14, 15, 16}, 50 {17, 18, 19}, 51 }, 52 }, 53 { // i0=1 54 { // i1=0 55 {101, 102, 103}, 56 {104, 105, 106}, 57 {107, 108, 109}, 58 }, 59 { // i1=1 60 {201, 202, 203}, // i2=0 61 {204, 205, 206}, // i2=1 62 {207, 208, 209}, // i2=2 63 }, 64 }, 65 // clang-format on 66 }); 67 68 layout_r2_dim0major_ = LayoutUtil::MakeLayout({1, 0}); 69 layout_r2_dim0minor_ = LayoutUtil::MakeLayout({0, 1}); 70 layout_r3_dim0major_ = LayoutUtil::MakeLayout({2, 1, 0}); 71 layout_r3_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2}); 72 layout_r4_dim0major_ = LayoutUtil::MakeLayout({3, 2, 1, 0}); 73 layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3}); 74 75 literal_r4_2x2x3x3_dim0major_ = 76 Literal::CreateR4FromArray4DWithLayout<float>(arr4d, 77 layout_r4_dim0major_); 78 literal_r4_2x2x3x3_dim0minor_ = 79 Literal::CreateR4FromArray4DWithLayout<float>(arr4d, 80 layout_r4_dim0minor_); 81 } 82 83 Layout layout_r2_dim0major_; 84 Layout layout_r2_dim0minor_; 85 Layout layout_r3_dim0major_; 86 Layout layout_r3_dim0minor_; 87 Layout layout_r4_dim0major_; 88 Layout layout_r4_dim0minor_; 89 std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0major_; 90 std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0minor_; 91 }; 92 93 TEST_F(LiteralUtilTest, LiteralScalarToString) { 94 auto true_lit = Literal::CreateR0<bool>(true); 95 ASSERT_EQ("true", true_lit->ToString()); 96 97 auto false_lit = Literal::CreateR0<bool>(false); 98 ASSERT_EQ("false", false_lit->ToString()); 99 100 auto u32_lit = Literal::CreateR0<uint32>(42); 101 ASSERT_EQ("42", u32_lit->ToString()); 102 103 auto s32_lit = Literal::CreateR0<int32>(-999); 104 ASSERT_EQ("-999", s32_lit->ToString()); 105 106 auto f32_lit = Literal::CreateR0<float>(3.14f); 107 ASSERT_EQ("3.14", f32_lit->ToString()); 108 109 auto f16_lit = Literal::CreateR0<half>(static_cast<half>(0.5f)); 110 ASSERT_EQ("0.5", f16_lit->ToString()); 111 112 auto c64_lit = Literal::CreateR0<complex64>({3.14f, 2.78f}); 113 ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); 114 115 auto bf16_lit = Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f)); 116 ASSERT_EQ("0.5", bf16_lit->ToString()); 117 118 // 3.14 will be truncated to 3.125 in bfloat16 format. 119 auto bf16_lit_truncated = 120 Literal::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f)); 121 ASSERT_EQ("3.125", bf16_lit_truncated->ToString()); 122 123 auto bf16_lit_truncated2 = 124 Literal::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f)); 125 ASSERT_EQ("9", bf16_lit_truncated2->ToString()); 126 } 127 128 TEST_F(LiteralUtilTest, LiteralVectorToString) { 129 auto pred_vec = Literal::CreateR1<bool>({true, false, true}); 130 ASSERT_EQ("{101}", pred_vec->ToString()); 131 } 132 133 TEST_F(LiteralUtilTest, R2ToString) { 134 const auto literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); 135 const string expected = R"(s32[3,2] { 136 { 1, 2 }, 137 { 3, 4 }, 138 { 5, 6 } 139 })"; 140 ASSERT_EQ(expected, literal->ToString()); 141 } 142 143 TEST_F(LiteralUtilTest, R3ToString) { 144 const auto literal = Literal::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); 145 const string expected = R"(s32[3,2,1] { 146 { { 1 }, 147 { 2 } }, 148 { { 3 }, 149 { 4 } }, 150 { { 5 }, 151 { 6 } } 152 })"; 153 ASSERT_EQ(expected, literal->ToString()); 154 } 155 156 TEST_F(LiteralUtilTest, TupleToString) { 157 auto scalar = Literal::CreateR0<float>(1.0); 158 auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); 159 auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); 160 const string expected = R"((f32[], f32[2,2]) ( 161 1, 162 f32[2,2] { 163 { 1, 2 }, 164 { 3, 4 } 165 } 166 ))"; 167 ASSERT_EQ(expected, tuple->ToString()); 168 } 169 170 TEST_F(LiteralUtilTest, CreateR3FromArray3d) { 171 // clang-format off 172 Array3D<float> array_3d({ 173 {{1.0f, 2.0f}, 174 {3.0f, 4.0f}, 175 {5.0f, 6.0f}}, 176 {{7.0f, 8.0f}, 177 {9.0f, 10.0f}, 178 {11.0f, 12.0f}}, 179 }); 180 // clang-format on 181 182 auto literal = Literal::CreateR3FromArray3D(array_3d); 183 EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2)); 184 string result = literal->ToString(); 185 const string expected = R"(f32[2,3,2] { 186 { { 1, 2 }, 187 { 3, 4 }, 188 { 5, 6 } }, 189 { { 7, 8 }, 190 { 9, 10 }, 191 { 11, 12 } } 192 })"; 193 ASSERT_EQ(expected, result); 194 } 195 196 TEST_F(LiteralUtilTest, CreateSparse) { 197 std::vector<int64> dimensions = {8, 8, 8}; 198 Array2D<int64> indices = { 199 {3, 4, 5}, 200 {1, 2, 3}, 201 {2, 3, 4}, 202 {3, 5, 6}, 203 }; 204 std::vector<int64> values = {7, 8, 9, 10}; 205 auto literal = Literal::CreateSparse<int64>( 206 dimensions, SparseIndexArray(indices.n1() + 3, indices), values); 207 208 Array2D<int64> expected_indices = { 209 {1, 2, 3}, 210 {2, 3, 4}, 211 {3, 4, 5}, 212 {3, 5, 6}, 213 }; 214 std::vector<int64> expected_values = {8, 9, 7, 10}; 215 216 EXPECT_EQ(literal->sparse_indices()->data(), 217 tensorflow::gtl::ArraySlice<int64>( 218 expected_indices.data(), expected_indices.num_elements())); 219 EXPECT_EQ(tensorflow::gtl::ArraySlice<int64>(literal->data<int64>().data(), 220 expected_values.size()), 221 tensorflow::gtl::ArraySlice<int64>(expected_values)); 222 } 223 224 TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { 225 // clang-format off 226 auto literal = Literal::CreateR4Projected<float>({ 227 {1, 2}, 228 {1001, 1002}, 229 {2001, 2002}, 230 }, /*projection_p=*/1, /*projection_z=*/2); 231 // clang-format on 232 EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2)); 233 string result = literal->ToString(); 234 const string expected = R"(f32[1,2,3,2] { 235 { /*i0=0*/ 236 { /*i1=0*/ 237 {1, 2}, 238 {1001, 1002}, 239 {2001, 2002} 240 }, 241 { /*i1=1*/ 242 {1, 2}, 243 {1001, 1002}, 244 {2001, 2002} 245 } 246 } 247 })"; 248 ASSERT_EQ(expected, result); 249 } 250 251 TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { 252 EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(), 253 ElementsAre(2, 2, 3, 3)); 254 string result = literal_r4_2x2x3x3_dim0major_->ToString(); 255 const string expected = R"(f32[2,2,3,3] { 256 { /*i0=0*/ 257 { /*i1=0*/ 258 {1, 2, 3}, 259 {4, 5, 6}, 260 {7, 8, 9} 261 }, 262 { /*i1=1*/ 263 {11, 12, 13}, 264 {14, 15, 16}, 265 {17, 18, 19} 266 } 267 }, 268 { /*i0=1*/ 269 { /*i1=0*/ 270 {101, 102, 103}, 271 {104, 105, 106}, 272 {107, 108, 109} 273 }, 274 { /*i1=1*/ 275 {201, 202, 203}, 276 {204, 205, 206}, 277 {207, 208, 209} 278 } 279 } 280 })"; 281 ASSERT_EQ(expected, result); 282 } 283 284 TEST_F(LiteralUtilTest, EachCellR2F32) { 285 // clang-format off 286 auto literal = Literal::CreateR2<float>({ 287 {3.1f, 4.2f}, 288 {9.3f, 12.4f}, 289 }); 290 // clang-format on 291 std::vector<std::tuple<int64, int64, string>> seen; 292 literal->EachCellAsString( 293 [&seen](tensorflow::gtl::ArraySlice<int64> indices, const string& value) { 294 seen.emplace_back(indices[0], indices[1], value); 295 }); 296 297 using Elem = std::tuple<int64, int64, string>; 298 std::vector<Elem> expected = {Elem(0, 0, "3.1"), Elem(0, 1, "4.2"), 299 Elem(1, 0, "9.3"), Elem(1, 1, "12.4")}; 300 EXPECT_EQ(expected, seen); 301 } 302 303 TEST_F(LiteralUtilTest, ScalarEquality) { 304 // Test equality with scalars. 305 auto f32_42 = Literal::CreateR0<float>(42.0); 306 auto f32_42_clone = Literal::CreateR0<float>(42.0); 307 308 EXPECT_EQ(*f32_42, *f32_42); 309 EXPECT_EQ(*f32_42, *f32_42_clone); 310 311 auto f32_123 = Literal::CreateR0<float>(123.0); 312 EXPECT_NE(*f32_42, *f32_123); 313 314 auto f64_42 = Literal::CreateR0<double>(42.0); 315 EXPECT_NE(*f32_42, *f64_42); 316 } 317 318 TEST_F(LiteralUtilTest, NonScalarEquality) { 319 // Test equality with nonscalars. 320 auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); 321 auto matrix_clone = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); 322 auto matrix_different = Literal::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}}); 323 auto vector_literal = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0}); 324 auto scalar = Literal::CreateR0<float>(1.0); 325 Literal nil(ShapeUtil::MakeNil()); 326 327 EXPECT_EQ(*matrix, *matrix); 328 EXPECT_EQ(*matrix, *matrix_clone); 329 EXPECT_NE(*matrix, *matrix_different); 330 EXPECT_NE(*matrix, *vector_literal); 331 EXPECT_NE(*matrix, *scalar); 332 EXPECT_NE(*matrix, nil); 333 EXPECT_EQ(nil, nil); 334 } 335 336 TEST_F(LiteralUtilTest, DifferentLayoutEquality) { 337 // Test equality with literals which have different layouts. 338 auto colmajor = 339 MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); 340 colmajor->Set<float>({0, 0}, 1.0); 341 colmajor->Set<float>({0, 1}, 2.0); 342 colmajor->Set<float>({1, 0}, 3.0); 343 colmajor->Set<float>({1, 1}, 4.0); 344 345 auto rowmajor = 346 MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); 347 rowmajor->Set<float>({0, 0}, 1.0); 348 rowmajor->Set<float>({0, 1}, 2.0); 349 rowmajor->Set<float>({1, 0}, 3.0); 350 rowmajor->Set<float>({1, 1}, 4.0); 351 352 EXPECT_EQ(*rowmajor, *colmajor); 353 } 354 355 TEST_F(LiteralUtilTest, TupleEquality) { 356 // Test equality with tuples. 357 auto scalar = Literal::CreateR0<float>(1.0); 358 auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); 359 auto tuple1 = Literal::MakeTuple({scalar.get(), matrix.get()}); 360 361 // Tuple with the same elements. One element is shared with the original 362 // tuple, the other is a clone of the element in the original tuple. 363 auto scalar_clone = Literal::CreateR0<float>(1.0); 364 auto tuple2 = Literal::MakeTuple({scalar_clone.get(), matrix.get()}); 365 EXPECT_EQ(*tuple1, *tuple2); 366 367 // Tuple with elements reversed. 368 auto reversed_tuple = Literal::MakeTuple({matrix.get(), scalar.get()}); 369 EXPECT_NE(*tuple1, *reversed_tuple); 370 371 // Tuple with different value. 372 auto scalar_42 = Literal::CreateR0<float>(42.0); 373 auto different_tuple = Literal::MakeTuple({scalar_42.get(), matrix.get()}); 374 EXPECT_NE(*tuple1, *different_tuple); 375 } 376 377 TEST_F(LiteralUtilTest, C64Equality) { 378 // Test equality with tuples. 379 auto vector = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}}); 380 381 // Tuple with the same elements. One element is shared with the original 382 // tuple, the other is a clone of the element in the original tuple. 383 auto vector_clone = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}}); 384 EXPECT_EQ(*vector, *vector_clone); 385 386 auto vector_reversed = Literal::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}}); 387 EXPECT_NE(*vector, *vector_reversed); 388 } 389 390 TEST_F(LiteralUtilTest, IsAllTuple) { 391 auto element1 = Literal::CreateR0<float>(0.0); 392 auto element2 = Literal::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}}); 393 auto tuple = Literal::MakeTuple({element1.get(), element1.get()}); 394 395 // Tuples should always return false for IsAll. 396 EXPECT_FALSE(tuple->IsAll(0)); 397 EXPECT_FALSE(tuple->IsAll(1)); 398 } 399 400 // Verifies that CreateFromShape works for tuples. 401 TEST_F(LiteralUtilTest, CreateFromShapeTuple) { 402 auto scalar = Literal::CreateR0<float>(0.0); 403 auto matrix = Literal::CreateR2<int32>({{0, 0}, {0, 0}}); 404 auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); 405 406 auto x = Literal::CreateFromShape(tuple->shape()); 407 EXPECT_EQ(*tuple, *x); 408 } 409 410 TEST_F(LiteralUtilTest, IsAll) { 411 EXPECT_TRUE(Literal::CreateR0<bool>(false)->IsAll(0)); 412 EXPECT_TRUE(Literal::CreateR0<bool>(true)->IsAll(1)); 413 EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAll(1)); 414 EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAll(2)); 415 EXPECT_FALSE(Literal::CreateR0<bool>(true)->IsAll(0)); 416 EXPECT_FALSE(Literal::CreateR0<bool>(true)->IsAll(2)); 417 EXPECT_FALSE(Literal::CreateR0<bool>(true)->IsAll(-1)); 418 419 // We shouldn't reinterpret int8_min as an unsigned type and then decide that 420 // it is equal to 255. 421 auto int8_min = std::numeric_limits<int8>::min(); 422 EXPECT_FALSE(Literal::CreateR0<uint8>(255)->IsAll(int8_min)); 423 424 EXPECT_TRUE(Literal::CreateR0<float>(42.0)->IsAll(42)); 425 EXPECT_FALSE(Literal::CreateR0<float>(42.0001)->IsAll(42)); 426 427 EXPECT_TRUE(Literal::CreateR1<int>({100, 100, 100})->IsAll(100)); 428 EXPECT_FALSE(Literal::CreateR1<double>({100, 100, 100.001})->IsAll(100)); 429 430 EXPECT_TRUE(Literal::CreateR2<uint64>({{8, 8}, {8, 8}})->IsAll(8)); 431 EXPECT_FALSE(Literal::CreateR2<uint64>({{8, 8}, {8, 9}})->IsAll(8)); 432 EXPECT_FALSE(Literal::CreateR2<uint64>({{9, 8}, {8, 8}})->IsAll(8)); 433 434 half h8(8.0f); 435 half h9(9.0f); 436 EXPECT_TRUE(Literal::CreateR2<half>({{h8}, {h8}})->IsAll(8)); 437 EXPECT_FALSE(Literal::CreateR2<half>({{h8}, {h9}})->IsAll(8)); 438 EXPECT_FALSE(Literal::CreateR2<half>({{h9}, {h8}})->IsAll(8)); 439 440 bfloat16 b8(8.0f); 441 bfloat16 b9(9.0f); 442 443 EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8)); 444 EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8)); 445 EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8)); 446 447 // 9.001 will be truncated to 9.0 448 bfloat16 b91(9.001f); 449 bfloat16 b90(9.00f); 450 EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0)); 451 452 complex64 c8_9 = {8, 9}; 453 EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8)); 454 455 auto uint64_max = std::numeric_limits<uint64>::max(); 456 EXPECT_FALSE(Literal::CreateR2<uint64>( 457 {{uint64_max, uint64_max}, {uint64_max, uint64_max}}) 458 ->IsAll(-1)); 459 } 460 461 TEST_F(LiteralUtilTest, IsAllFloat) { 462 // IsAllFloat always returns false when the literal is not floating-point. 463 EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAllFloat(0)); 464 EXPECT_FALSE(Literal::CreateR0<int8>(0)->IsAllFloat(0)); 465 EXPECT_FALSE(Literal::CreateR0<uint8>(0)->IsAllFloat(0)); 466 EXPECT_FALSE(Literal::CreateR0<int>(0)->IsAllFloat(0)); 467 468 EXPECT_TRUE(Literal::CreateR0<float>(0)->IsAllFloat(0)); 469 EXPECT_TRUE(Literal::CreateR0<float>(.5)->IsAllFloat(.5)); 470 EXPECT_TRUE(Literal::CreateR0<float>(-.5)->IsAllFloat(-.5)); 471 EXPECT_FALSE(Literal::CreateR0<float>(-.5)->IsAllFloat(-.49)); 472 EXPECT_FALSE( 473 Literal::CreateR2<float>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); 474 EXPECT_TRUE( 475 Literal::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})->IsAllFloat(.5)); 476 477 EXPECT_TRUE(Literal::CreateR0<double>(0)->IsAllFloat(0)); 478 EXPECT_TRUE(Literal::CreateR0<double>(.5)->IsAllFloat(.5)); 479 EXPECT_TRUE(Literal::CreateR0<double>(-.5)->IsAllFloat(-.5)); 480 EXPECT_FALSE(Literal::CreateR0<double>(-.5)->IsAllFloat(-.49)); 481 EXPECT_FALSE( 482 Literal::CreateR2<double>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); 483 } 484 485 TEST_F(LiteralUtilTest, IsAllComplex) { 486 // IsAllComplex always returns false when the literal is not complex. 487 EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAllComplex(0)); 488 EXPECT_FALSE(Literal::CreateR0<int8>(0)->IsAllComplex(0)); 489 EXPECT_FALSE(Literal::CreateR0<uint8>(0)->IsAllComplex(0)); 490 EXPECT_FALSE(Literal::CreateR0<int>(0)->IsAllComplex(0)); 491 EXPECT_FALSE(Literal::CreateR0<float>(0)->IsAllComplex(0)); 492 EXPECT_FALSE(Literal::CreateR0<double>(0)->IsAllComplex(0)); 493 494 complex64 c8_9 = {8, 9}; 495 complex64 c7_9 = {7, 9}; 496 EXPECT_TRUE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}}) 497 ->IsAllComplex({8.0f, 9.0f})); 498 EXPECT_FALSE(Literal::CreateR2<complex64>({{c7_9}, {c8_9}}) 499 ->IsAllComplex({8.0f, 9.0f})); 500 EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c7_9}}) 501 ->IsAllComplex({8.0f, 9.0f})); 502 } 503 504 TEST_F(LiteralUtilTest, IsZero) { 505 auto scalar_zero = Literal::CreateR0<float>(0.0f); 506 auto scalar_one = Literal::CreateR0<float>(1.0f); 507 EXPECT_TRUE(scalar_zero->IsZero({})); 508 EXPECT_FALSE(scalar_one->IsZero({})); 509 510 auto array = Literal::CreateR2<uint32>({{1, 2, 0, 3}, {1, 0, 1, 2}}); 511 EXPECT_FALSE(array->IsZero({0, 1})); 512 EXPECT_TRUE(array->IsZero({0, 2})); 513 EXPECT_TRUE(array->IsZero({1, 1})); 514 EXPECT_FALSE(array->IsZero({1, 2})); 515 516 auto complex_zero = Literal::CreateR0<complex64>(0.0f); 517 auto complex_nonzero = Literal::CreateR0<complex64>(0.5f); 518 EXPECT_TRUE(complex_zero->IsZero({})); 519 EXPECT_FALSE(complex_nonzero->IsZero({})); 520 } 521 522 template <typename T> 523 class LiteralUtilTestTemplated : public ::testing::Test {}; 524 525 using TestedTypes = ::testing::Types<float, int32, uint32, complex64>; 526 TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes); 527 528 TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { 529 // Make a non-integer for floating point types. 530 TypeParam half = TypeParam(1) / TypeParam(2); 531 auto data = Literal::CreateR2<TypeParam>({{half, 2}, {3, 4}}); 532 const Layout layout01 = LayoutUtil::MakeLayout({0, 1}); 533 const Layout layout10 = LayoutUtil::MakeLayout({1, 0}); 534 535 auto data01 = data->Relayout(layout01); 536 EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01)); 537 EXPECT_EQ(*data, *data01); 538 539 auto data10 = data->Relayout(layout10); 540 EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10)); 541 EXPECT_EQ(*data, *data10); 542 } 543 544 TEST_F(LiteralUtilTest, ReshapeR0) { 545 auto original = Literal::CreateR0<float>(1.7f); 546 auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); 547 EXPECT_EQ(*original, *reshape); 548 } 549 550 TEST_F(LiteralUtilTest, ReshapeR4) { 551 // clang-format off 552 // F32[1x3x2x4] 553 auto original = Literal::CreateR4WithLayout<float>({{ 554 {{10, 11, 12, 13}, {14, 15, 16, 17}}, 555 {{18, 19, 20, 21}, {22, 23, 24, 25}}, 556 {{26, 27, 28, 29}, {30, 31, 32, 33}}, 557 }}, layout_r4_dim0major_); 558 // F32[1x3x4x2] 559 auto expected = Literal::CreateR3WithLayout<float>({ 560 {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, 561 {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, 562 {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, 563 }, layout_r3_dim0major_); 564 // clang-format on 565 auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); 566 567 EXPECT_EQ(*expected, *reshape); 568 } 569 570 TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { 571 // clang-format off 572 // F32[1x3x2x4] 573 auto original = Literal::CreateR4WithLayout<float>({{ 574 {{10, 11, 12, 13}, {14, 15, 16, 17}}, 575 {{18, 19, 20, 21}, {22, 23, 24, 25}}, 576 {{26, 27, 28, 29}, {30, 31, 32, 33}}, 577 }}, layout_r4_dim0minor_); 578 // F32[1x3x4x2] 579 auto expected = Literal::CreateR3WithLayout<float>({ 580 {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, 581 {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, 582 {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, 583 }, layout_r3_dim0major_); 584 // clang-format on 585 auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); 586 587 EXPECT_EQ(*expected, *reshape); 588 } 589 590 TEST_F(LiteralUtilTest, TransposeR0) { 591 auto original = Literal::CreateR0<float>(1.7f); 592 auto reshape = original->Transpose(/*permutation=*/{}); 593 EXPECT_EQ(*original, *reshape); 594 } 595 596 TEST_F(LiteralUtilTest, TransposeR4) { 597 // clang-format off 598 // F32[1x3x2x4] 599 auto original = Literal::CreateR4<float>({{ 600 {{10, 11, 12, 13}, {14, 15, 16, 17}}, 601 {{18, 19, 20, 21}, {22, 23, 24, 25}}, 602 {{26, 27, 28, 29}, {30, 31, 32, 33}}, 603 }}); 604 // clang-format on 605 auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1}); 606 607 reshape->EachCell<float>( 608 [&](tensorflow::gtl::ArraySlice<int64> indices, float value) { 609 EXPECT_EQ(value, original->Get<float>( 610 {indices[2], indices[3], indices[0], indices[1]})); 611 }); 612 } 613 614 TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { 615 // Tests that using Relayout on an array is equivalent to creating it in the 616 // target layout in the first place. 617 auto dim0minor_relaid_to_dim0major = 618 literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_); 619 EXPECT_EQ(*literal_r4_2x2x3x3_dim0major_, *dim0minor_relaid_to_dim0major); 620 621 auto dim0major_relaid_to_dim0minor = 622 literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_); 623 EXPECT_EQ(*literal_r4_2x2x3x3_dim0minor_, *dim0major_relaid_to_dim0minor); 624 } 625 626 TEST_F(LiteralUtilTest, TestR2LinearLayout) { 627 // Test expected memory layout of R2 dim0-minor (column-major) literal. 628 auto mat_dim0minor = Literal::CreateR2WithLayout<int32>( 629 {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_); 630 EXPECT_EQ(mat_dim0minor->element_count(), 6); 631 EXPECT_THAT(mat_dim0minor->data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6)); 632 633 // Test expected memory layout when using Relayout to row major. 634 auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_); 635 EXPECT_THAT(relaid_mat_to_dim0major->data<int32>(), 636 ElementsAre(1, 2, 3, 4, 5, 6)); 637 638 // Test expected memory layout of R2 created with dim0-major (row-major). 639 auto mat_dim0major = Literal::CreateR2WithLayout<int32>( 640 {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_); 641 EXPECT_EQ(mat_dim0major->element_count(), 6); 642 EXPECT_THAT(mat_dim0major->data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6)); 643 644 // Test expected memory layout when using Relayout to column major. 645 auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_); 646 EXPECT_THAT(relaid_mat_to_dim0minor->data<int32>(), 647 ElementsAre(1, 4, 2, 5, 3, 6)); 648 } 649 650 TEST_F(LiteralUtilTest, TestR3LinearLayout) { 651 // Test expected memory layout of R3 dim0-minor (column-major) literal. 652 Array3D<int> arr3d( 653 // clang-format off 654 { 655 { 656 {1, 2, 3}, 657 {4, 5, 6}, 658 }, 659 { 660 {7, 8, 9}, 661 {10, 11, 12}, 662 }, 663 }); // clang-format on 664 auto lit_dim0minor = 665 Literal::CreateR3FromArray3DWithLayout<int>(arr3d, layout_r3_dim0minor_); 666 667 EXPECT_EQ(lit_dim0minor->element_count(), 12); 668 std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}; 669 EXPECT_THAT(lit_dim0minor->data<int32>(), 670 testing::ElementsAreArray(expected_dim0minor)); 671 672 // Test expected memory layout when using Relayout to row major. 673 auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_); 674 std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; 675 EXPECT_THAT(relaid_lit_to_dim0major->data<int32>(), 676 testing::ElementsAreArray(expected_dim0major)); 677 678 // Test expected memory layout of R3 created with dim0-major (row-major). 679 auto lit_dim0major = 680 Literal::CreateR3FromArray3DWithLayout<int>(arr3d, layout_r3_dim0major_); 681 EXPECT_EQ(lit_dim0major->element_count(), 12); 682 EXPECT_THAT(lit_dim0major->data<int32>(), 683 testing::ElementsAreArray(expected_dim0major)); 684 685 // Test expected memory layout when using Relayout to column major. 686 auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_); 687 EXPECT_THAT(relaid_lit_to_dim0minor->data<int32>(), 688 testing::ElementsAreArray(expected_dim0minor)); 689 } 690 691 TEST_F(LiteralUtilTest, SliceR0S32) { 692 auto input = Literal::CreateR0<int32>(1); 693 auto result = input->Slice({}, {}); 694 EXPECT_EQ(*input, *result); 695 } 696 697 TEST_F(LiteralUtilTest, SliceR1F32) { 698 auto input = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0}); 699 auto result = input->Slice({3}, {4}); 700 auto expected = Literal::CreateR1<float>({4.0}); 701 EXPECT_EQ(*expected, *result); 702 } 703 704 TEST_F(LiteralUtilTest, SliceR2U32) { 705 auto input_3x4 = 706 Literal::CreateR2<uint32>({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); 707 auto result = input_3x4->Slice({0, 2}, {2, 4}); 708 auto expected = Literal::CreateR2<uint32>({{3, 4}, {7, 8}}); 709 EXPECT_EQ(*expected, *result); 710 } 711 712 TEST_F(LiteralUtilTest, SliceR3U32Full) { 713 auto input_2x3x2 = Literal::CreateR3<uint32>( 714 {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); 715 auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2}); 716 EXPECT_EQ(*input_2x3x2, *result); 717 } 718 719 TEST_F(LiteralUtilTest, PopulateR1S64) { 720 Literal output(ShapeUtil::MakeShape(S64, {1})); 721 output.PopulateR1<int64>({77}); 722 auto expected = Literal::CreateR1<int64>({77}); 723 EXPECT_EQ(output, *expected); 724 } 725 726 TEST_F(LiteralUtilTest, PopulateR1U64) { 727 Literal output(ShapeUtil::MakeShape(U64, {2})); 728 output.PopulateR1<uint64>({{77, 88}}); 729 auto expected = Literal::CreateR1<uint64>({{77, 88}}); 730 EXPECT_EQ(output, *expected); 731 } 732 733 TEST_F(LiteralUtilTest, PopulateR1C64) { 734 Literal output(ShapeUtil::MakeShape(C64, {1})); 735 output.PopulateR1<complex64>({{77, 88}}); 736 auto expected = Literal::CreateR1<complex64>({{77, 88}}); 737 EXPECT_EQ(output, *expected); 738 } 739 740 TEST_F(LiteralUtilTest, PopulateR2C64) { 741 Literal output(ShapeUtil::MakeShape(C64, {2, 2})); 742 output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); 743 auto expected = 744 Literal::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); 745 EXPECT_EQ(output, *expected); 746 } 747 748 TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { 749 Literal output(ShapeUtil::MakeShape(BF16, {})); 750 bfloat16 h(0.25f); 751 output.PopulateWithValue<bfloat16>(h); 752 auto expected = Literal::CreateR0<bfloat16>(h); 753 EXPECT_EQ(output, *expected); 754 } 755 756 TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { 757 Literal output(ShapeUtil::MakeShape(BF16, {3})); 758 bfloat16 h(0.5f); 759 output.PopulateWithValue<bfloat16>(h); 760 auto expected = Literal::CreateR1<bfloat16>({h, h, h}); 761 EXPECT_EQ(output, *expected); 762 } 763 764 TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { 765 Literal output(ShapeUtil::MakeShape(BF16, {2, 2})); 766 bfloat16 h(2.0f); 767 output.PopulateWithValue<bfloat16>(h); 768 auto expected = Literal::CreateR2<bfloat16>({{h, h}, {h, h}}); 769 EXPECT_EQ(output, *expected); 770 } 771 772 TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { 773 Literal output(ShapeUtil::MakeShape(F32, {})); 774 output.PopulateWithValue<float>(2.5f); 775 auto expected = Literal::CreateR0<float>(2.5f); 776 EXPECT_EQ(output, *expected); 777 } 778 779 TEST_F(LiteralUtilTest, PopulateWithValueR1S64) { 780 Literal output(ShapeUtil::MakeShape(S64, {3})); 781 output.PopulateWithValue<int64>(-7); 782 auto expected = Literal::CreateR1<int64>({-7, -7, -7}); 783 EXPECT_EQ(output, *expected); 784 } 785 786 TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { 787 Literal output(ShapeUtil::MakeShape(U64, {2, 2})); 788 output.PopulateWithValue<uint64>(42); 789 auto expected = Literal::CreateR2<uint64>({{42, 42}, {42, 42}}); 790 EXPECT_EQ(output, *expected); 791 } 792 793 TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { 794 Literal output(ShapeUtil::MakeShape(C64, {2, 2})); 795 output.PopulateWithValue<complex64>({4, 2}); 796 auto expected = 797 Literal::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); 798 EXPECT_EQ(output, *expected); 799 } 800 801 TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { 802 Literal output(ShapeUtil::MakeShape(F16, {})); 803 half h(0.25f); 804 output.PopulateWithValue<half>(h); 805 auto expected = Literal::CreateR0<half>(h); 806 EXPECT_EQ(output, *expected); 807 } 808 809 TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { 810 Literal output(ShapeUtil::MakeShape(F16, {3})); 811 half h(0.5f); 812 output.PopulateWithValue<half>(h); 813 auto expected = Literal::CreateR1<half>({h, h, h}); 814 EXPECT_EQ(output, *expected); 815 } 816 817 TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { 818 Literal output(ShapeUtil::MakeShape(F16, {2, 2})); 819 half h(2.0f); 820 output.PopulateWithValue<half>(h); 821 auto expected = Literal::CreateR2<half>({{h, h}, {h, h}}); 822 EXPECT_EQ(output, *expected); 823 } 824 825 TEST_F(LiteralUtilTest, ReplicateR2U32) { 826 auto input = 827 Literal::CreateR2<uint32>({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); 828 auto output = input->Replicate<uint32>(3); 829 auto expected = Literal::CreateR3<uint32>( 830 {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, 831 {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, 832 {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}}); 833 EXPECT_EQ(*output, *expected); 834 } 835 836 TEST_F(LiteralUtilTest, CopySliceFrom) { 837 const int64 dimensions[] = {17, 15, 34, 21}; 838 const int64 layouts[][4] = { 839 {3, 2, 1, 0}, {0, 2, 1, 3}, {0, 1, 2, 3}, {2, 0, 3, 1}, {1, 3, 0, 2}}; 840 for (const auto& layout : layouts) { 841 Shape shape = ShapeUtil::MakeShapeWithLayout( 842 primitive_util::NativeToPrimitiveType<uint32>(), dimensions, layout); 843 844 auto source = Literal::CreateFromShape(shape); 845 const int64 zero_base[] = {0, 0, 0, 0}; 846 const int64 step[] = {1, 1, 1, 1}; 847 uint32 seqnr = 0; 848 auto init_proc = [&](const std::vector<int64>& indexes) { 849 source->Set(indexes, ++seqnr); 850 return true; 851 }; 852 ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step, 853 init_proc); 854 855 auto blank = Literal::CreateFromShape(shape); 856 const int64 src_base[] = {3, 1, 5, 7}; 857 const int64 dest_base[] = {6, 4, 12, 2}; 858 const int64 copy_size[] = {7, 8, 11, 9}; 859 TF_EXPECT_OK(blank->CopySliceFrom(*source, src_base, dest_base, copy_size)); 860 861 std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0); 862 std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0); 863 bool matched = true; 864 auto check_proc = [&](const std::vector<int64>& indexes) { 865 std::copy(indexes.begin(), indexes.end(), source_indexes.begin()); 866 std::transform(source_indexes.begin(), source_indexes.end(), src_base, 867 source_indexes.begin(), std::plus<int64>()); 868 std::copy(indexes.begin(), indexes.end(), blank_indexes.begin()); 869 std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base, 870 blank_indexes.begin(), std::plus<int64>()); 871 auto bval = blank->Get<uint32>(blank_indexes); 872 matched = (bval != 0 && bval == source->Get<uint32>(source_indexes)); 873 return matched; 874 }; 875 876 ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step, 877 check_proc); 878 EXPECT_TRUE(matched); 879 } 880 } 881 882 TEST_F(LiteralUtilTest, CopyFromScalars) { 883 auto zero = Literal::CreateR0<uint32>(0); 884 auto nine = Literal::CreateR0<uint32>(9); 885 TF_EXPECT_OK(zero->CopyFrom(*nine)); 886 EXPECT_EQ(*zero, *nine); 887 888 auto vect = Literal::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21}); 889 TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {})); 890 EXPECT_EQ(zero->Get<uint32>({}), 17); 891 TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {})); 892 EXPECT_EQ(vect->Get<uint32>({4}), 17); 893 } 894 895 TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { 896 const Shape empty_r1_shape = ShapeUtil::MakeShape(F32, {0}); 897 const auto const_nine = Literal::CreateR1<float>({9}); 898 const auto const_empty = Literal::CreateFromShape(empty_r1_shape); 899 900 { 901 // Source contains dimension with zero elements. 902 const auto empty = Literal::CreateFromShape(empty_r1_shape); 903 auto nine = Literal::CreateR1<float>({9}); 904 905 TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0})); 906 EXPECT_EQ(*nine, *const_nine); 907 } 908 909 { 910 // Copy 0 element to destination with zero elements. 911 const auto empty = Literal::CreateFromShape(empty_r1_shape); 912 auto nine = Literal::CreateR1<float>({9}); 913 914 TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0})); 915 EXPECT_EQ(*empty, *const_empty); 916 } 917 } 918 919 TEST_F(LiteralUtilTest, CopyFromNilShape) { 920 Literal nil_literal0(ShapeUtil::MakeNil()); 921 Literal nil_literal1(ShapeUtil::MakeNil()); 922 // This doesn't actually do any copying, but it should succeed. 923 TF_ASSERT_OK(nil_literal0.CopyFrom(nil_literal1)); 924 } 925 926 TEST_F(LiteralUtilTest, CopyFromArrays) { 927 auto scalar_42 = Literal::CreateR0<float>(42.0); 928 auto scalar_123 = Literal::CreateR0<float>(123.0); 929 EXPECT_NE(*scalar_42, *scalar_123); 930 TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{}, 931 /*src_shape_index=*/{})); 932 EXPECT_EQ(*scalar_42, *scalar_123); 933 EXPECT_EQ(scalar_42->Get<float>({}), 123.0f); 934 935 auto matrix_1234 = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); 936 auto matrix_5678 = Literal::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}}); 937 EXPECT_NE(*matrix_1234, *matrix_5678); 938 EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 1.0f); 939 TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{}, 940 /*src_shape_index=*/{})); 941 EXPECT_EQ(*matrix_1234, *matrix_5678); 942 EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 5.0f); 943 } 944 945 TEST_F(LiteralUtilTest, CopyFromTuples) { 946 auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); 947 Literal nil_literal(ShapeUtil::MakeNil()); 948 auto nested_tuple = Literal::MakeTuple( 949 {matrix.get(), 950 Literal::MakeTuple({Literal::CreateR0<int32>(42).get(), 951 Literal::CreateR1<double>({23.0, 44.0}).get(), 952 &nil_literal}) 953 .get()}); 954 // Create a tuple the same shape as the inner tuple of nested_tuple but with 955 // different values.. 956 auto tuple = Literal::MakeTuple({Literal::CreateR0<int32>(-5).get(), 957 Literal::CreateR1<double>({2.0, 4.0}).get(), 958 &nil_literal}); 959 960 EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); 961 EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), 42); 962 EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 23.0); 963 EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 44.0); 964 965 // Overwrite the inner tuple element of nested_tuple with the contents of 966 // 'tuple'. 967 TF_ASSERT_OK(nested_tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1}, 968 /*src_shape_index=*/{})); 969 970 // The matrix element should be unchanged. 971 EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); 972 973 // The tuple element should have been copied from 'tuple'. 974 EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), -5); 975 EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 2.0); 976 EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 4.0); 977 } 978 TEST_F(LiteralUtilTest, CopyBetweenSameTuple) { 979 auto tuple = Literal::MakeTuple( 980 {Literal::CreateR0<int32>(-2).get(), Literal::CreateR0<int32>(4).get()}); 981 982 EXPECT_EQ(tuple->Get<int32>({}, {0}), -2); 983 EXPECT_EQ(tuple->Get<int32>({}, {1}), 4); 984 985 // Copy from one element to the other. 986 TF_ASSERT_OK(tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1}, 987 /*src_shape_index=*/{0})); 988 989 EXPECT_EQ(tuple->Get<int32>({}, {0}), -2); 990 EXPECT_EQ(tuple->Get<int32>({}, {1}), -2); 991 } 992 993 TEST_F(LiteralUtilTest, CopyFromDifferentShapes) { 994 auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); 995 auto vector = Literal::CreateR1<float>({5.0, 7.0}); 996 Status status = matrix->CopyFrom(*vector); 997 ASSERT_FALSE(status.ok()); 998 ASSERT_THAT(status.error_message(), 999 HasSubstr("Destination subshape incompatible")); 1000 } 1001 1002 TEST_F(LiteralUtilTest, F16) { 1003 // Verify that the internal data views are consistent and that they 1004 // are in little endian format 1005 // TODO - modify if we make the data format machine endianess dependent 1006 auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); 1007 Literal* l1 = m1.get(); 1008 const char* d1 = reinterpret_cast<const char*>(l1->data<half>().data()); 1009 EXPECT_EQ(d1[0], 0); 1010 EXPECT_EQ(d1[1], 0); 1011 EXPECT_EQ(d1[2], 0); 1012 EXPECT_EQ(d1[3], 0); 1013 EXPECT_EQ(d1[4], 0); 1014 EXPECT_EQ(d1[5], 0); 1015 EXPECT_EQ(d1[6], 0); 1016 EXPECT_EQ(d1[7], 0); 1017 1018 half h1(1.0f); 1019 half h2(2.0f); 1020 auto m2 = Literal::CreateR2<half>({{h1, h2}, {h2, h1}}); 1021 Literal* l2 = m2.get(); 1022 const char* d2 = reinterpret_cast<const char*>(l2->data<half>().data()); 1023 EXPECT_EQ(d2[0], 0); 1024 EXPECT_EQ(d2[1], 0x3C); 1025 EXPECT_EQ(d2[2], 0); 1026 EXPECT_EQ(d2[3], 0x40); 1027 EXPECT_EQ(d2[4], 0); 1028 EXPECT_EQ(d2[5], 0x40); 1029 EXPECT_EQ(d2[6], 0); 1030 EXPECT_EQ(d2[7], 0x3C); 1031 } 1032 1033 TEST_F(LiteralUtilTest, Populate) { 1034 struct PopulateData { 1035 std::vector<int64> dimensions; 1036 std::vector<int64> layout; 1037 } populate_data[] = { 1038 {{}, {}}, 1039 {{0}, {0}}, 1040 {{16}, {0}}, 1041 {{2, 0}, {1, 0}}, 1042 {{4, 16}, {1, 0}}, 1043 {{21, 12}, {0, 1}}, 1044 {{6, 11, 17}, {2, 0, 1}}, 1045 {{6, 11, 5, 17}, {3, 2, 0, 1}}, 1046 }; 1047 for (const auto& data : populate_data) { 1048 Shape shape = ShapeUtil::MakeShapeWithLayout( 1049 primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions, 1050 data.layout); 1051 auto literal = Literal::CreateFromShape(shape); 1052 auto generator = [&](tensorflow::gtl::ArraySlice<int64> indexes) -> uint32 { 1053 // Offsets from linear index just to avoid R0 literals to be initialized 1054 // with zero. 1055 return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), 1056 indexes) + 1057 17; 1058 }; 1059 TF_EXPECT_OK(literal->Populate<uint32>(generator)); 1060 1061 std::vector<int64> zero_base(data.dimensions.size(), 0); 1062 std::vector<int64> step(data.dimensions.size(), 1); 1063 bool matched = true; 1064 auto check_function = [&](const std::vector<int64>& indexes) { 1065 auto value = literal->Get<uint32>(indexes); 1066 matched = matched && (value == generator(indexes)); 1067 return matched; 1068 }; 1069 ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step, 1070 check_function); 1071 EXPECT_TRUE(matched); 1072 } 1073 } 1074 1075 TEST_F(LiteralUtilTest, ConvertR4) { 1076 // clang-format off 1077 auto original = Literal::CreateR4WithLayout<int8>({{ 1078 {{10, 11, 12, 13}, {14, 15, 16, 17}}, 1079 {{18, 19, 20, 21}, {22, 23, 24, 25}}, 1080 {{26, 27, 28, 29}, {30, 31, 32, 33}}, 1081 }}, layout_r4_dim0major_); 1082 auto expected = Literal::CreateR4WithLayout<uint32>({{ 1083 {{10, 11, 12, 13}, {14, 15, 16, 17}}, 1084 {{18, 19, 20, 21}, {22, 23, 24, 25}}, 1085 {{26, 27, 28, 29}, {30, 31, 32, 33}}, 1086 }}, layout_r4_dim0major_); 1087 // clang-format on 1088 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted, 1089 original->Convert(U32)); 1090 1091 EXPECT_EQ(*expected, *converted); 1092 } 1093 1094 TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { 1095 // clang-format off 1096 auto s8 = Literal::CreateR4WithLayout<int8>({{ 1097 {{10, 0, 12, 0}, {0, 15, 0, 17}}, 1098 {{0, 19, 0, 21}, {22, 0, 24, 0}}, 1099 {{26, 0, 28, 0}, {0, 31, 0, 33}}, 1100 }}, layout_r4_dim0major_); 1101 auto s32 = Literal::CreateR4WithLayout<int32>({{ 1102 {{10, 0, 12, 0}, {0, 15, 0, 17}}, 1103 {{0, 19, 0, 21}, {22, 0, 24, 0}}, 1104 {{26, 0, 28, 0}, {0, 31, 0, 33}}, 1105 }}, layout_r4_dim0major_); 1106 auto u32 = Literal::CreateR4WithLayout<uint32>({{ 1107 {{10, 0, 12, 0}, {0, 15, 0, 17}}, 1108 {{0, 19, 0, 21}, {22, 0, 24, 0}}, 1109 {{26, 0, 28, 0}, {0, 31, 0, 33}}, 1110 }}, layout_r4_dim0major_); 1111 auto s64 = Literal::CreateR4WithLayout<int64>({{ 1112 {{10, 0, 12, 0}, {0, 15, 0, 17}}, 1113 {{0, 19, 0, 21}, {22, 0, 24, 0}}, 1114 {{26, 0, 28, 0}, {0, 31, 0, 33}}, 1115 }}, layout_r4_dim0major_); 1116 auto u64 = Literal::CreateR4WithLayout<uint64>({{ 1117 {{10, 0, 12, 0}, {0, 15, 0, 17}}, 1118 {{0, 19, 0, 21}, {22, 0, 24, 0}}, 1119 {{26, 0, 28, 0}, {0, 31, 0, 33}}, 1120 }}, layout_r4_dim0major_); 1121 auto pred = Literal::CreateR4WithLayout<bool>({{ 1122 {{true, false, true, false}, {false, true, false, true}}, 1123 {{false, true, false, true}, {true, false, true, false}}, 1124 {{true, false, true, false}, {false, true, false, true}}, 1125 }}, layout_r4_dim0major_); 1126 auto int32_pred = Literal::CreateR4WithLayout<int32>({{ 1127 {{1, 0, 1, 0}, {0, 1, 0, 1}}, 1128 {{0, 1, 0, 1}, {1, 0, 1, 0}}, 1129 {{1, 0, 1, 0}, {0, 1, 0, 1}}, 1130 }}, layout_r4_dim0major_); 1131 auto f16 = Literal::CreateR4WithLayout<half>({{ 1132 {{half(10.0), half(0.0), half(12.0), half(0.0)}, 1133 {half(0.0), half(15.0), half(0.0), half(17.0)}}, 1134 {{half(0.0), half(19.0), half(0.0), half(21.0)}, 1135 {half(22.0), half(0.0), half(24.0), half(0.0)}}, 1136 {{half(26.0), half(0.0), half(28.0), half(0.0)}, 1137 {half(0.0), half(31.0), half(0.0), half(33.0)}}, 1138 }}, layout_r4_dim0major_); 1139 auto bf16 = Literal::CreateR4WithLayout<bfloat16>({{ 1140 {{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)}, 1141 {bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}}, 1142 {{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)}, 1143 {bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}}, 1144 {{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)}, 1145 {bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}}, 1146 }}, layout_r4_dim0major_); 1147 auto f32 = Literal::CreateR4WithLayout<float>({{ 1148 {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, 1149 {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, 1150 {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, 1151 }}, layout_r4_dim0major_); 1152 auto f64 = Literal::CreateR4WithLayout<double>({{ 1153 {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}}, 1154 {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}}, 1155 {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}}, 1156 }}, layout_r4_dim0major_); 1157 auto c64 = Literal::CreateR4WithLayout<complex64>({{ 1158 {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, 1159 {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, 1160 {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, 1161 }}, layout_r4_dim0major_); 1162 // clang-format on 1163 std::unique_ptr<Literal> conv; 1164 1165 conv = s8->Convert(U32).ConsumeValueOrDie(); 1166 EXPECT_EQ(*conv, *u32); 1167 1168 conv = s8->Convert(S32).ConsumeValueOrDie(); 1169 EXPECT_EQ(*conv, *s32); 1170 1171 conv = s8->Convert(U64).ConsumeValueOrDie(); 1172 EXPECT_EQ(*conv, *u64); 1173 1174 conv = s8->Convert(S64).ConsumeValueOrDie(); 1175 EXPECT_EQ(*conv, *s64); 1176 1177 conv = s8->Convert(PRED).ConsumeValueOrDie(); 1178 EXPECT_EQ(*conv, *pred); 1179 1180 conv = bf16->Convert(S32).ConsumeValueOrDie(); 1181 EXPECT_EQ(*conv, *s32); 1182 1183 conv = bf16->Convert(F32).ConsumeValueOrDie(); 1184 EXPECT_EQ(*conv, *f32); 1185 1186 conv = pred->Convert(S32).ConsumeValueOrDie(); 1187 EXPECT_EQ(*conv, *int32_pred); 1188 1189 conv = f32->Convert(S32).ConsumeValueOrDie(); 1190 EXPECT_EQ(*conv, *s32); 1191 1192 conv = f64->Convert(S32).ConsumeValueOrDie(); 1193 EXPECT_EQ(*conv, *s32); 1194 1195 conv = s32->Convert(F32).ConsumeValueOrDie(); 1196 EXPECT_EQ(*conv, *f32); 1197 1198 conv = f32->Convert(F16).ConsumeValueOrDie(); 1199 EXPECT_EQ(*conv, *f16); 1200 1201 conv = f64->Convert(F16).ConsumeValueOrDie(); 1202 EXPECT_EQ(*conv, *f16); 1203 1204 conv = s32->Convert(F16).ConsumeValueOrDie(); 1205 EXPECT_EQ(*conv, *f16); 1206 1207 conv = u32->Convert(F16).ConsumeValueOrDie(); 1208 EXPECT_EQ(*conv, *f16); 1209 1210 conv = s32->Convert(C64).ConsumeValueOrDie(); 1211 EXPECT_EQ(*conv, *c64); 1212 1213 conv = f16->Convert(C64).ConsumeValueOrDie(); 1214 EXPECT_EQ(*conv, *c64); 1215 1216 EXPECT_EQ(s32->Convert(TUPLE).status().code(), 1217 tensorflow::error::INVALID_ARGUMENT); 1218 EXPECT_EQ(s32->Convert(S16).status().code(), 1219 tensorflow::error::INVALID_ARGUMENT); 1220 EXPECT_EQ(s32->Convert(U16).status().code(), 1221 tensorflow::error::INVALID_ARGUMENT); 1222 EXPECT_EQ(c64->Convert(F32).status().code(), 1223 tensorflow::error::INVALID_ARGUMENT); 1224 EXPECT_EQ(c64->Convert(S32).status().code(), 1225 tensorflow::error::INVALID_ARGUMENT); 1226 } 1227 1228 TEST_F(LiteralUtilTest, CopyFromProto_Bool) { 1229 LiteralProto p; 1230 p.mutable_shape()->set_element_type(PRED); 1231 for (int len = 0; len < 25; ++len) { 1232 p.mutable_shape()->clear_dimensions(); 1233 p.mutable_shape()->add_dimensions(len); 1234 LayoutUtil::SetToDefaultLayout(p.mutable_shape()); 1235 p.clear_preds(); 1236 for (int i = 0; i < len; ++i) { 1237 p.add_preds((i % 2) == (len % 2)); 1238 } 1239 1240 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal, 1241 Literal::CreateFromProto(p)); 1242 ASSERT_EQ(len, literal->data<bool>().size()); 1243 int i = 0; 1244 for (bool value : literal->data<bool>()) { 1245 EXPECT_EQ((i % 2) == (len % 2), value); 1246 ++i; 1247 } 1248 } 1249 } 1250 1251 // Note that f16 is currently stored in a byte array in little endian byte order 1252 TEST_F(LiteralUtilTest, ToProto_f16) { 1253 half h1(1.0f); 1254 half h2(2.0f); 1255 1256 auto m = Literal::CreateR2<half>({{h1, h2}, {h2, h1}}); 1257 Literal* l = m.get(); 1258 EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape())); 1259 EXPECT_EQ(4, l->data<half>().size()); 1260 1261 LiteralProto p = l->ToProto(); 1262 EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape())); 1263 EXPECT_EQ(8, p.f16s().size()); 1264 const char* d = p.f16s().data(); 1265 EXPECT_EQ(d[0], 0); 1266 EXPECT_EQ(d[1], 0x3C); 1267 EXPECT_EQ(d[2], 0); 1268 EXPECT_EQ(d[3], 0x40); 1269 EXPECT_EQ(d[4], 0); 1270 EXPECT_EQ(d[5], 0x40); 1271 EXPECT_EQ(d[6], 0); 1272 EXPECT_EQ(d[7], 0x3C); 1273 } 1274 1275 // Note that f16 is currently stored in a byte array in little endian byte order 1276 TEST_F(LiteralUtilTest, CopyFromProto_f16) { 1277 half h1(1.0f); 1278 half h2(2.0f); 1279 1280 const char half_vals[8] = {0x00, 0x3C, 0x00, 0x40, 0x00, 0x40, 0x00, 0x3C}; 1281 LiteralProto p; 1282 p.mutable_shape()->set_element_type(F16); 1283 p.mutable_shape()->clear_dimensions(); 1284 p.mutable_shape()->add_dimensions(4); 1285 LayoutUtil::SetToDefaultLayout(p.mutable_shape()); 1286 p.clear_f16s(); 1287 p.set_f16s(half_vals, 8); 1288 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal, 1289 Literal::CreateFromProto(p)); 1290 auto r = literal->data<half>(); 1291 ASSERT_EQ(4, r.size()); 1292 ASSERT_EQ(h1, r[0]); 1293 ASSERT_EQ(h2, r[1]); 1294 ASSERT_EQ(h2, r[2]); 1295 ASSERT_EQ(h1, r[3]); 1296 } 1297 1298 TEST_F(LiteralUtilTest, LiteralViewTest) { 1299 auto scalar = Literal::CreateR0<float>(1.0); 1300 auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); 1301 auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); 1302 auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); 1303 Literal nil(ShapeUtil::MakeNil()); 1304 1305 EXPECT_EQ(LiteralView::Create(*scalar, {}), *scalar); 1306 EXPECT_EQ(LiteralView::Create(*matrix, {}), *matrix); 1307 EXPECT_EQ(LiteralView::Create(*tuple, {}), *tuple); 1308 EXPECT_EQ(LiteralView::Create(*nested_tuple, {}), *nested_tuple); 1309 EXPECT_EQ(LiteralView::Create(nil, {}), nil); 1310 1311 EXPECT_EQ(LiteralView::Create(*tuple, {0}), *scalar); 1312 EXPECT_EQ(LiteralView::Create(*tuple, {1}), *matrix); 1313 1314 EXPECT_EQ(LiteralView::Create(*nested_tuple, {0}), *tuple); 1315 EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 0}), *scalar); 1316 EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 1}), *matrix); 1317 EXPECT_EQ(LiteralView::Create(*nested_tuple, {1}), *scalar); 1318 } 1319 1320 TEST_F(LiteralUtilTest, MutatingLiteralView) { 1321 auto scalar = Literal::CreateR0<float>(1.0); 1322 auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); 1323 auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); 1324 auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); 1325 // Verify that changing the underlying data beneath the view changes the 1326 // data of the view itself. 1327 const auto nested_tuple_view = LiteralView::Create(*nested_tuple); 1328 EXPECT_EQ( 1329 nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 1330 1.0f); 1331 EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{}, 1332 /*shape_index=*/{0, 0}), 1333 1.0f); 1334 nested_tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f); 1335 EXPECT_EQ( 1336 nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 1337 555.0f); 1338 EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{}, 1339 /*shape_index=*/{0, 0}), 1340 555.0f); 1341 } 1342 1343 TEST_F(LiteralUtilTest, LiteralViewOfALiteralView) { 1344 auto scalar = Literal::CreateR0<float>(1.0); 1345 auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); 1346 auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); 1347 auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); 1348 1349 const auto nested_tuple_view = LiteralView::Create(*nested_tuple); 1350 const auto tuple_view = 1351 LiteralView::Create(nested_tuple_view, /*view_root=*/{0}); 1352 const auto matrix_view = LiteralView::Create(tuple_view, /*view_root=*/{1}); 1353 EXPECT_EQ(matrix_view, *Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})); 1354 } 1355 1356 TEST_F(LiteralUtilTest, LiteralMove) { 1357 std::unique_ptr<Literal> matrix = 1358 Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); 1359 Literal literal(std::move(*matrix)); 1360 1361 EXPECT_TRUE( 1362 ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape())); 1363 EXPECT_EQ(literal.Get<float>({0, 0}), 1.0); 1364 EXPECT_EQ(literal.Get<float>({0, 1}), 2.0); 1365 EXPECT_EQ(literal.Get<float>({1, 0}), 3.0); 1366 EXPECT_EQ(literal.Get<float>({1, 1}), 4.0); 1367 } 1368 1369 TEST_F(LiteralUtilTest, DecomposeTuple) { 1370 Literal nil_literal(ShapeUtil::MakeNil()); 1371 auto nested_tuple = Literal::MakeTuple( 1372 {Literal::CreateR2<int32>({{1, 2}, {3, 4}}).get(), 1373 Literal::MakeTuple({Literal::CreateR0<int32>(42).get(), 1374 Literal::CreateR1<double>({23.0, 44.0}).get(), 1375 &nil_literal}) 1376 .get(), 1377 &nil_literal}); 1378 1379 EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple->shape())); 1380 std::vector<Literal> elements = nested_tuple->DecomposeTuple(); 1381 EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple->shape())); 1382 1383 ASSERT_EQ(elements.size(), 3); 1384 1385 EXPECT_TRUE(ShapeUtil::Compatible(elements[0].shape(), 1386 ShapeUtil::MakeShape(S32, {2, 2}))); 1387 EXPECT_EQ(elements[0].Get<int32>({0, 0}), 1); 1388 EXPECT_EQ(elements[0].Get<int32>({0, 1}), 2); 1389 EXPECT_EQ(elements[0].Get<int32>({1, 0}), 3); 1390 EXPECT_EQ(elements[0].Get<int32>({1, 1}), 4); 1391 1392 EXPECT_TRUE(ShapeUtil::Compatible( 1393 elements[1].shape(), 1394 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}), 1395 ShapeUtil::MakeShape(F64, {2}), 1396 ShapeUtil::MakeNil()}))); 1397 EXPECT_EQ(elements[1].Get<int32>({}, /*shape_index=*/{0}), 42); 1398 EXPECT_EQ(elements[1].Get<double>({0}, /*shape_index=*/{1}), 23.0); 1399 EXPECT_EQ(elements[1].Get<double>({1}, /*shape_index=*/{1}), 44.0); 1400 1401 EXPECT_TRUE(ShapeUtil::Compatible(elements[2].shape(), ShapeUtil::MakeNil())); 1402 } 1403 1404 TEST_F(LiteralUtilTest, DecomposeEmptyTuple) { 1405 Literal nil_literal(ShapeUtil::MakeNil()); 1406 std::vector<Literal> elements = nil_literal.DecomposeTuple(); 1407 EXPECT_EQ(elements.size(), 0); 1408 } 1409 1410 TEST_F(LiteralUtilTest, MoveIntoTuple) { 1411 std::vector<Literal> elements; 1412 elements.push_back(std::move(*Literal::CreateR0<float>(1.0))); 1413 elements.push_back(std::move(*Literal::CreateR1<int32>({4, 8}))); 1414 elements.push_back(std::move( 1415 *Literal::MakeTuple({Literal::CreateR0<int32>(42).get(), 1416 Literal::CreateR1<double>({23.0, 44.0}).get()}) 1417 1418 )); 1419 1420 Literal literal = Literal::MoveIntoTuple(&elements); 1421 ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); 1422 ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3); 1423 1424 EXPECT_EQ(literal.Get<float>({}, /*shape_index=*/{0}), 1.0); 1425 EXPECT_EQ(literal.Get<int32>({0}, /*shape_index=*/{1}), 4); 1426 EXPECT_EQ(literal.Get<int32>({1}, /*shape_index=*/{1}), 8); 1427 EXPECT_EQ(literal.Get<int32>({}, /*shape_index=*/{2, 0}), 42); 1428 EXPECT_EQ(literal.Get<double>({0}, /*shape_index=*/{2, 1}), 23.0); 1429 EXPECT_EQ(literal.Get<double>({1}, /*shape_index=*/{2, 1}), 44.0); 1430 1431 for (const Literal& element : elements) { 1432 EXPECT_TRUE(ShapeUtil::IsNil(element.shape())); 1433 } 1434 } 1435 1436 TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) { 1437 Literal literal = Literal::MoveIntoTuple({}); 1438 ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); 1439 ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0); 1440 } 1441 1442 TEST_F(LiteralUtilTest, LiteralMoveAssignment) { 1443 Literal literal; 1444 EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape())); 1445 1446 std::unique_ptr<Literal> matrix = 1447 Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); 1448 literal = std::move(*matrix); 1449 1450 EXPECT_TRUE( 1451 ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape())); 1452 EXPECT_EQ(literal.Get<float>({0, 0}), 1.0); 1453 EXPECT_EQ(literal.Get<float>({0, 1}), 2.0); 1454 EXPECT_EQ(literal.Get<float>({1, 0}), 3.0); 1455 EXPECT_EQ(literal.Get<float>({1, 1}), 4.0); 1456 } 1457 1458 TEST_F(LiteralUtilTest, LiteralViewCopy) { 1459 std::unique_ptr<Literal> matrix = 1460 Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); 1461 const auto matrix_view = LiteralView::Create(*matrix); 1462 LiteralView matrix_view_copy(matrix_view); 1463 1464 EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0); 1465 EXPECT_EQ(matrix_view_copy.Get<float>({0, 1}), 2.0); 1466 EXPECT_EQ(matrix_view_copy.Get<float>({1, 0}), 3.0); 1467 EXPECT_EQ(matrix_view_copy.Get<float>({1, 1}), 4.0); 1468 } 1469 1470 TEST_F(LiteralUtilTest, GetSetTuple) { 1471 auto tuple = Literal::MakeTuple( 1472 {Literal::CreateR0<float>(42.0).get(), 1473 Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get()}); 1474 EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0); 1475 tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0); 1476 EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0); 1477 1478 EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 1479 3.0); 1480 tuple->Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0); 1481 EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 1482 -4.0); 1483 } 1484 1485 TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { 1486 // Literals constructed using CreateFromShape should be zero initialized. 1487 std::unique_ptr<Literal> scalar_f32 = 1488 Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {})); 1489 EXPECT_EQ(scalar_f32->Get<float>({}), 0.0); 1490 EXPECT_TRUE(scalar_f32->IsAll(0)); 1491 1492 std::unique_ptr<Literal> vector_s32 = 1493 Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3})); 1494 EXPECT_EQ(vector_s32->Get<int32>({0}), 0); 1495 EXPECT_EQ(vector_s32->Get<int32>({1}), 0); 1496 EXPECT_EQ(vector_s32->Get<int32>({2}), 0); 1497 EXPECT_TRUE(vector_s32->IsAll(0)); 1498 1499 std::unique_ptr<Literal> tuple = 1500 Literal::CreateFromShape(ShapeUtil::MakeTupleShape( 1501 {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}), 1502 ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})})); 1503 1504 EXPECT_EQ(tuple->Get<double>({}, {0}), 0.0); 1505 EXPECT_EQ(tuple->Get<bool>({0}, {1}), false); 1506 EXPECT_EQ(tuple->Get<bool>({1}, {1}), false); 1507 EXPECT_EQ(tuple->Get<uint64>({0, 0}, {2}), 0); 1508 EXPECT_EQ(tuple->Get<uint64>({1, 0}, {2}), 0); 1509 EXPECT_EQ(tuple->Get<complex64>({}, {3}), complex64(0.0f, 0.0f)); 1510 } 1511 1512 TEST_F(LiteralUtilTest, ProtoRoundTrip) { 1513 // Test serializing then deserializing a Literal through a proto. 1514 auto one_f32 = Literal::CreateR0<float>(1.0); 1515 auto two_f32 = Literal::CreateR0<float>(2.0); 1516 auto vector_int8 = Literal::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127}); 1517 auto vector_c64 = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}}); 1518 auto vector_bfloat16 = Literal::CreateR1<bfloat16>( 1519 {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); 1520 auto vector_half = 1521 Literal::CreateR1<half>({half{10.0}, half{20.0}, half{-30.0}}); 1522 auto matrix_pred = 1523 Literal::CreateR2<bool>({{true, false, true}, {false, false, true}}); 1524 auto tuple = Literal::MakeTuple( 1525 {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()}); 1526 Literal nil_literal(ShapeUtil::MakeNil()); 1527 auto nested_tuple = Literal::MakeTuple( 1528 {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal}); 1529 1530 auto to_from_proto = [](const Literal& literal) -> Literal { 1531 return std::move(*Literal::CreateFromProto(literal.ToProto()).ValueOrDie()); 1532 }; 1533 1534 EXPECT_EQ(*one_f32, to_from_proto(*one_f32)); 1535 EXPECT_EQ(*vector_c64, to_from_proto(*vector_c64)); 1536 EXPECT_EQ(*vector_bfloat16, to_from_proto(*vector_bfloat16)); 1537 EXPECT_EQ(*matrix_pred, to_from_proto(*matrix_pred)); 1538 EXPECT_EQ(*tuple, to_from_proto(*tuple)); 1539 EXPECT_EQ(*nested_tuple, to_from_proto(*nested_tuple)); 1540 EXPECT_EQ(nil_literal, to_from_proto(nil_literal)); 1541 1542 EXPECT_NE(*one_f32, *two_f32); 1543 EXPECT_NE(*one_f32, to_from_proto(*two_f32)); 1544 } 1545 1546 TEST_F(LiteralUtilTest, InvalidProtoNoValues) { 1547 // Proto contains a shape, but no values. 1548 LiteralProto proto; 1549 *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}); 1550 Status status = Literal::CreateFromProto(proto).status(); 1551 ASSERT_FALSE(status.ok()); 1552 ASSERT_THAT(status.error_message(), 1553 HasSubstr("Expected 3 elements in LiteralProto")); 1554 } 1555 1556 TEST_F(LiteralUtilTest, InvalidProtoNoShape) { 1557 // Proto contains values, but no shape. 1558 LiteralProto proto; 1559 proto.add_preds(false); 1560 proto.add_preds(true); 1561 proto.add_preds(false); 1562 Status status = Literal::CreateFromProto(proto).status(); 1563 ASSERT_FALSE(status.ok()); 1564 ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape")); 1565 } 1566 1567 TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) { 1568 // Proto contains values in wrong container. 1569 LiteralProto proto; 1570 *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}); 1571 proto.add_preds(false); 1572 proto.add_preds(true); 1573 proto.add_preds(false); 1574 Status status = Literal::CreateFromProto(proto).status(); 1575 ASSERT_FALSE(status.ok()); 1576 ASSERT_THAT(status.error_message(), 1577 HasSubstr("Expected 3 elements in LiteralProto")); 1578 } 1579 1580 TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) { 1581 // Proto contains too few values. 1582 LiteralProto proto; 1583 *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2}); 1584 proto.add_f32s(1.0); 1585 proto.add_f32s(2.0); 1586 proto.add_f32s(3.0); 1587 Status status = Literal::CreateFromProto(proto).status(); 1588 ASSERT_FALSE(status.ok()); 1589 ASSERT_THAT(status.error_message(), 1590 HasSubstr("Expected 84 elements in LiteralProto")); 1591 } 1592 1593 TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) { 1594 // Proto contains too many values. 1595 LiteralProto proto; 1596 *proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2}); 1597 proto.add_s32s(42); 1598 proto.add_s32s(-10); 1599 proto.add_s32s(100); 1600 Status status = Literal::CreateFromProto(proto).status(); 1601 ASSERT_FALSE(status.ok()); 1602 ASSERT_THAT(status.error_message(), 1603 HasSubstr("Expected 2 elements in LiteralProto")); 1604 } 1605 1606 TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) { 1607 // Proto shape missing layout. 1608 LiteralProto proto; 1609 *proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2}); 1610 LayoutUtil::ClearLayout(proto.mutable_shape()); 1611 proto.add_preds(true); 1612 proto.add_preds(false); 1613 proto.add_preds(true); 1614 proto.add_preds(false); 1615 Status status = Literal::CreateFromProto(proto).status(); 1616 ASSERT_FALSE(status.ok()); 1617 ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout")); 1618 } 1619 1620 TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) { 1621 // Proto has the too few tuple elements. 1622 LiteralProto proto; 1623 *proto.mutable_shape() = ShapeUtil::MakeTupleShape( 1624 {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}); 1625 LiteralProto* element0 = proto.add_tuple_literals(); 1626 *element0->mutable_shape() = 1627 ShapeUtil::GetTupleElementShape(proto.shape(), 0); 1628 element0->add_preds(false); 1629 element0->add_preds(true); 1630 1631 Status status = Literal::CreateFromProto(proto).status(); 1632 ASSERT_FALSE(status.ok()); 1633 ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); 1634 } 1635 1636 TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { 1637 // Proto has the too many tuple elements. 1638 LiteralProto proto; 1639 *proto.mutable_shape() = ShapeUtil::MakeTupleShape( 1640 {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}); 1641 LiteralProto* element0 = proto.add_tuple_literals(); 1642 *element0->mutable_shape() = 1643 ShapeUtil::GetTupleElementShape(proto.shape(), 0); 1644 element0->add_preds(false); 1645 element0->add_preds(true); 1646 LiteralProto* element1 = proto.add_tuple_literals(); 1647 *element1->mutable_shape() = 1648 ShapeUtil::GetTupleElementShape(proto.shape(), 1); 1649 element1->add_f32s(42.0); 1650 LiteralProto* element2 = proto.add_tuple_literals(); 1651 *element2->mutable_shape() = ShapeUtil::MakeShape(F32, {}); 1652 element2->add_f32s(123.0); 1653 1654 Status status = Literal::CreateFromProto(proto).status(); 1655 ASSERT_FALSE(status.ok()); 1656 ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); 1657 } 1658 1659 TEST_F(LiteralUtilTest, SortSparseElements) { 1660 auto literal = 1661 Literal::CreateSparse<float>({10, 10, 10}, SparseIndexArray(10, 3), {}); 1662 literal->AppendSparseElement<float>({2, 3, 4}, 2.0); 1663 literal->AppendSparseElement<float>({3, 4, 5}, 3.0); 1664 literal->AppendSparseElement<float>({1, 2, 3}, 1.0); 1665 literal->SortSparseElements(); 1666 ASSERT_EQ(literal->ToString(false), 1667 "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}"); 1668 } 1669 1670 TEST_F(LiteralUtilTest, GetSparseElementAsString) { 1671 std::vector<int64> dimensions = {10, 10, 10}; 1672 SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}}); 1673 1674 ASSERT_EQ( 1675 Literal::CreateSparse<bool>(dimensions, indices, {true, false, true}) 1676 ->GetSparseElementAsString(1), 1677 "false"); 1678 ASSERT_EQ(Literal::CreateSparse<int64>(dimensions, indices, {1, 2, 3}) 1679 ->GetSparseElementAsString(1), 1680 tensorflow::strings::StrCat(int64{2})); 1681 ASSERT_EQ(Literal::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0}) 1682 ->GetSparseElementAsString(1), 1683 tensorflow::strings::StrCat(double{2.0})); 1684 ASSERT_EQ(Literal::CreateSparse<half>(dimensions, indices, 1685 {half{1.0}, half{2.0}, half{3.0}}) 1686 ->GetSparseElementAsString(1), 1687 tensorflow::strings::StrCat(half{2.0})); 1688 ASSERT_EQ( 1689 Literal::CreateSparse<complex64>( 1690 dimensions, indices, 1691 std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) 1692 ->GetSparseElementAsString(1), 1693 tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); 1694 } 1695 1696 } // namespace 1697 } // namespace xla 1698