1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" 17 18 #include <random> 19 #include <vector> 20 21 #include "tensorflow/core/framework/tensor_testutil.h" 22 #include "tensorflow/core/framework/types.pb.h" 23 #include "tensorflow/core/framework/variant.h" 24 #include "tensorflow/core/framework/variant_op_registry.h" 25 #include "tensorflow/core/framework/versions.pb.h" 26 #include "tensorflow/core/lib/core/status_test_util.h" 27 #include "tensorflow/core/lib/io/path.h" 28 #include "tensorflow/core/lib/io/table_builder.h" 29 #include "tensorflow/core/lib/strings/strcat.h" 30 #include "tensorflow/core/platform/test.h" 31 #include "tensorflow/core/platform/test_benchmark.h" 32 33 namespace tensorflow { 34 35 namespace { 36 37 string Prefix(const string& prefix) { 38 return strings::StrCat(testing::TmpDir(), "/", prefix); 39 } 40 41 template <typename T> 42 Tensor Constant(T v, TensorShape shape) { 43 Tensor ret(DataTypeToEnum<T>::value, shape); 44 ret.flat<T>().setConstant(v); 45 return ret; 46 } 47 48 template <typename T> 49 Tensor Constant_2x3(T v) { 50 return Constant(v, TensorShape({2, 3})); 51 } 52 53 template <typename T> 54 void Expect(BundleReader* reader, const string& key, 55 const Tensor& expected_val) { 56 // Tests for Contains(). 57 EXPECT_TRUE(reader->Contains(key)); 58 // Tests for LookupDtypeAndShape(). 59 DataType dtype; 60 TensorShape shape; 61 TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape)); 62 EXPECT_EQ(expected_val.dtype(), dtype); 63 EXPECT_EQ(expected_val.shape(), shape); 64 // Tests for Lookup(), checking tensor contents. 65 Tensor val(expected_val.dtype(), shape); 66 TF_ASSERT_OK(reader->Lookup(key, &val)); 67 test::ExpectTensorEqual<T>(val, expected_val); 68 } 69 70 template <class T> 71 void ExpectVariant(BundleReader* reader, const string& key, 72 const Tensor& expected_t) { 73 // Tests for Contains(). 74 EXPECT_TRUE(reader->Contains(key)); 75 // Tests for LookupDtypeAndShape(). 76 DataType dtype; 77 TensorShape shape; 78 TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape)); 79 // Tests for Lookup(), checking tensor contents. 80 EXPECT_EQ(expected_t.dtype(), dtype); 81 EXPECT_EQ(expected_t.shape(), shape); 82 Tensor actual_t(dtype, shape); 83 TF_ASSERT_OK(reader->Lookup(key, &actual_t)); 84 for (int i = 0; i < expected_t.NumElements(); i++) { 85 Variant actual_var = actual_t.flat<Variant>()(i); 86 Variant expected_var = expected_t.flat<Variant>()(i); 87 EXPECT_EQ(actual_var.TypeName(), expected_var.TypeName()); 88 auto* actual_val = actual_var.get<T>(); 89 auto* expected_val = expected_var.get<T>(); 90 EXPECT_EQ(*expected_val, *actual_val); 91 } 92 } 93 94 template <typename T> 95 void ExpectNext(BundleReader* reader, const Tensor& expected_val) { 96 EXPECT_TRUE(reader->Valid()); 97 reader->Next(); 98 TF_ASSERT_OK(reader->status()); 99 Tensor val; 100 TF_ASSERT_OK(reader->ReadCurrent(&val)); 101 test::ExpectTensorEqual<T>(val, expected_val); 102 } 103 104 std::vector<string> AllTensorKeys(BundleReader* reader) { 105 std::vector<string> ret; 106 reader->Seek(kHeaderEntryKey); 107 reader->Next(); 108 for (; reader->Valid(); reader->Next()) { 109 ret.push_back(reader->key().ToString()); 110 } 111 return ret; 112 } 113 114 // Writes out the metadata file of a bundle again, with the endianness marker 115 // bit flipped. 116 Status FlipEndiannessBit(const string& prefix) { 117 Env* env = Env::Default(); 118 const string metadata_tmp_path = Prefix("some_tmp_path"); 119 std::unique_ptr<WritableFile> file; 120 TF_RETURN_IF_ERROR(env->NewWritableFile(metadata_tmp_path, &file)); 121 table::TableBuilder builder(table::Options(), file.get()); 122 123 // Reads the existing metadata file, and fills the builder. 124 { 125 const string filename = MetaFilename(prefix); 126 uint64 file_size; 127 TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size)); 128 std::unique_ptr<RandomAccessFile> file; 129 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file)); 130 131 table::Table* table = nullptr; 132 TF_RETURN_IF_ERROR( 133 table::Table::Open(table::Options(), file.get(), file_size, &table)); 134 std::unique_ptr<table::Table> table_deleter(table); 135 std::unique_ptr<table::Iterator> iter(table->NewIterator()); 136 137 // Reads the header entry. 138 iter->Seek(kHeaderEntryKey); 139 CHECK(iter->Valid()); 140 BundleHeaderProto header; 141 CHECK(header.ParseFromArray(iter->value().data(), iter->value().size())); 142 // Flips the endianness. 143 if (header.endianness() == BundleHeaderProto::LITTLE) { 144 header.set_endianness(BundleHeaderProto::BIG); 145 } else { 146 header.set_endianness(BundleHeaderProto::LITTLE); 147 } 148 builder.Add(iter->key(), header.SerializeAsString()); 149 iter->Next(); 150 151 // Adds the non-header entries unmodified. 152 for (; iter->Valid(); iter->Next()) builder.Add(iter->key(), iter->value()); 153 } 154 TF_RETURN_IF_ERROR(builder.Finish()); 155 TF_RETURN_IF_ERROR(env->RenameFile(metadata_tmp_path, MetaFilename(prefix))); 156 return file->Close(); 157 } 158 159 template <typename T> 160 void TestBasic() { 161 { 162 BundleWriter writer(Env::Default(), Prefix("foo")); 163 TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3<T>(3))); 164 TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3<T>(0))); 165 TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3<T>(2))); 166 TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3<T>(1))); 167 TF_ASSERT_OK(writer.Finish()); 168 } 169 { 170 BundleReader reader(Env::Default(), Prefix("foo")); 171 TF_ASSERT_OK(reader.status()); 172 EXPECT_EQ( 173 AllTensorKeys(&reader), 174 std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"})); 175 Expect<T>(&reader, "foo_000", Constant_2x3<T>(0)); 176 Expect<T>(&reader, "foo_001", Constant_2x3<T>(1)); 177 Expect<T>(&reader, "foo_002", Constant_2x3<T>(2)); 178 Expect<T>(&reader, "foo_003", Constant_2x3<T>(3)); 179 } 180 { 181 BundleReader reader(Env::Default(), Prefix("foo")); 182 TF_ASSERT_OK(reader.status()); 183 ExpectNext<T>(&reader, Constant_2x3<T>(0)); 184 ExpectNext<T>(&reader, Constant_2x3<T>(1)); 185 ExpectNext<T>(&reader, Constant_2x3<T>(2)); 186 ExpectNext<T>(&reader, Constant_2x3<T>(3)); 187 EXPECT_TRUE(reader.Valid()); 188 reader.Next(); 189 EXPECT_FALSE(reader.Valid()); 190 } 191 { 192 BundleWriter writer(Env::Default(), Prefix("bar")); 193 TF_EXPECT_OK(writer.Add("bar_003", Constant_2x3<T>(3))); 194 TF_EXPECT_OK(writer.Add("bar_000", Constant_2x3<T>(0))); 195 TF_EXPECT_OK(writer.Add("bar_002", Constant_2x3<T>(2))); 196 TF_EXPECT_OK(writer.Add("bar_001", Constant_2x3<T>(1))); 197 TF_ASSERT_OK(writer.Finish()); 198 } 199 { 200 BundleReader reader(Env::Default(), Prefix("bar")); 201 TF_ASSERT_OK(reader.status()); 202 EXPECT_EQ( 203 AllTensorKeys(&reader), 204 std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003"})); 205 Expect<T>(&reader, "bar_003", Constant_2x3<T>(3)); 206 Expect<T>(&reader, "bar_002", Constant_2x3<T>(2)); 207 Expect<T>(&reader, "bar_001", Constant_2x3<T>(1)); 208 Expect<T>(&reader, "bar_000", Constant_2x3<T>(0)); 209 } 210 { 211 BundleReader reader(Env::Default(), Prefix("bar")); 212 TF_ASSERT_OK(reader.status()); 213 ExpectNext<T>(&reader, Constant_2x3<T>(0)); 214 ExpectNext<T>(&reader, Constant_2x3<T>(1)); 215 ExpectNext<T>(&reader, Constant_2x3<T>(2)); 216 ExpectNext<T>(&reader, Constant_2x3<T>(3)); 217 EXPECT_TRUE(reader.Valid()); 218 reader.Next(); 219 EXPECT_FALSE(reader.Valid()); 220 } 221 TF_ASSERT_OK(MergeBundles(Env::Default(), {Prefix("foo"), Prefix("bar")}, 222 Prefix("merged"))); 223 { 224 BundleReader reader(Env::Default(), Prefix("merged")); 225 TF_ASSERT_OK(reader.status()); 226 EXPECT_EQ( 227 AllTensorKeys(&reader), 228 std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003", 229 "foo_000", "foo_001", "foo_002", "foo_003"})); 230 Expect<T>(&reader, "bar_000", Constant_2x3<T>(0)); 231 Expect<T>(&reader, "bar_001", Constant_2x3<T>(1)); 232 Expect<T>(&reader, "bar_002", Constant_2x3<T>(2)); 233 Expect<T>(&reader, "bar_003", Constant_2x3<T>(3)); 234 Expect<T>(&reader, "foo_000", Constant_2x3<T>(0)); 235 Expect<T>(&reader, "foo_001", Constant_2x3<T>(1)); 236 Expect<T>(&reader, "foo_002", Constant_2x3<T>(2)); 237 Expect<T>(&reader, "foo_003", Constant_2x3<T>(3)); 238 } 239 { 240 BundleReader reader(Env::Default(), Prefix("merged")); 241 TF_ASSERT_OK(reader.status()); 242 ExpectNext<T>(&reader, Constant_2x3<T>(0)); 243 ExpectNext<T>(&reader, Constant_2x3<T>(1)); 244 ExpectNext<T>(&reader, Constant_2x3<T>(2)); 245 ExpectNext<T>(&reader, Constant_2x3<T>(3)); 246 ExpectNext<T>(&reader, Constant_2x3<T>(0)); 247 ExpectNext<T>(&reader, Constant_2x3<T>(1)); 248 ExpectNext<T>(&reader, Constant_2x3<T>(2)); 249 ExpectNext<T>(&reader, Constant_2x3<T>(3)); 250 EXPECT_TRUE(reader.Valid()); 251 reader.Next(); 252 EXPECT_FALSE(reader.Valid()); 253 } 254 } 255 256 template <typename T> 257 void TestNonStandardShapes() { 258 { 259 BundleWriter writer(Env::Default(), Prefix("nonstandard")); 260 TF_EXPECT_OK(writer.Add("scalar", Constant<T>(0, TensorShape()))); 261 TF_EXPECT_OK( 262 writer.Add("non_standard0", Constant<T>(0, TensorShape({0, 1618})))); 263 TF_EXPECT_OK( 264 writer.Add("non_standard1", Constant<T>(0, TensorShape({16, 0, 18})))); 265 TF_ASSERT_OK(writer.Finish()); 266 } 267 { 268 BundleReader reader(Env::Default(), Prefix("nonstandard")); 269 TF_ASSERT_OK(reader.status()); 270 Expect<T>(&reader, "scalar", Constant<T>(0, TensorShape())); 271 Expect<T>(&reader, "non_standard0", Constant<T>(0, TensorShape({0, 1618}))); 272 Expect<T>(&reader, "non_standard1", 273 Constant<T>(0, TensorShape({16, 0, 18}))); 274 } 275 } 276 277 // Writes a bundle to disk with a bad "version"; checks for "expected_error". 278 void VersionTest(const VersionDef& version, StringPiece expected_error) { 279 const string path = Prefix("version_test"); 280 { 281 // Prepare an empty bundle with the given version information. 282 BundleHeaderProto header; 283 *header.mutable_version() = version; 284 285 // Write the metadata file to disk. 286 std::unique_ptr<WritableFile> file; 287 TF_ASSERT_OK(Env::Default()->NewWritableFile(MetaFilename(path), &file)); 288 table::TableBuilder builder(table::Options(), file.get()); 289 builder.Add(kHeaderEntryKey, header.SerializeAsString()); 290 TF_ASSERT_OK(builder.Finish()); 291 } 292 // Read it back in and verify that we get the expected error. 293 BundleReader reader(Env::Default(), path); 294 EXPECT_TRUE(errors::IsInvalidArgument(reader.status())); 295 EXPECT_TRUE( 296 StringPiece(reader.status().error_message()).starts_with(expected_error)); 297 } 298 299 } // namespace 300 301 TEST(TensorBundleTest, Basic) { 302 TestBasic<float>(); 303 TestBasic<double>(); 304 TestBasic<int32>(); 305 TestBasic<uint8>(); 306 TestBasic<int16>(); 307 TestBasic<int8>(); 308 TestBasic<complex64>(); 309 TestBasic<complex128>(); 310 TestBasic<int64>(); 311 TestBasic<bool>(); 312 TestBasic<qint32>(); 313 TestBasic<quint8>(); 314 TestBasic<qint8>(); 315 } 316 317 TEST(TensorBundleTest, PartitionedVariables) { 318 const TensorShape kFullShape({5, 10}); 319 // Adds two slices. 320 // First slice: column 0, all zeros. 321 // Second slice: column 1 to rest, all ones. 322 TensorSlice slice1 = TensorSlice::ParseOrDie("-:0,1"); 323 TensorSlice slice2 = TensorSlice::ParseOrDie("-:1,9"); 324 { 325 BundleWriter writer(Env::Default(), Prefix("foo")); 326 327 TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice1, 328 Constant<float>(0., TensorShape({5, 1})))); 329 TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice2, 330 Constant<float>(1., TensorShape({5, 9})))); 331 TF_ASSERT_OK(writer.Finish()); 332 } 333 // Reads in full. 334 { 335 BundleReader reader(Env::Default(), Prefix("foo")); 336 TF_ASSERT_OK(reader.status()); 337 338 Tensor expected_val(DT_FLOAT, kFullShape); 339 test::FillFn<float>(&expected_val, [](int offset) -> float { 340 if (offset % 10 == 0) { 341 return 0; // First column zeros. 342 } 343 return 1; // Other columns ones. 344 }); 345 346 Tensor val(DT_FLOAT, kFullShape); 347 TF_ASSERT_OK(reader.Lookup("foo", &val)); 348 test::ExpectTensorEqual<float>(val, expected_val); 349 } 350 // Reads all slices. 351 { 352 BundleReader reader(Env::Default(), Prefix("foo")); 353 TF_ASSERT_OK(reader.status()); 354 355 std::vector<TensorSlice> slices; 356 TF_ASSERT_OK(reader.LookupTensorSlices("foo", &slices)); 357 358 EXPECT_EQ(2, slices.size()); 359 EXPECT_EQ(slice1.DebugString(), slices[0].DebugString()); 360 EXPECT_EQ(slice2.DebugString(), slices[1].DebugString()); 361 } 362 // Reads a slice consisting of first two columns, "cutting" both slices. 363 { 364 BundleReader reader(Env::Default(), Prefix("foo")); 365 TF_ASSERT_OK(reader.status()); 366 367 // First two columns, "cutting" both slices. 368 const TensorSlice distinct_slice = TensorSlice::ParseOrDie("-:0,2"); 369 Tensor expected_val(DT_FLOAT, TensorShape({5, 2})); 370 test::FillFn<float>(&expected_val, [](int offset) -> float { 371 if (offset % 2 == 0) { 372 return 0; // First column zeros. 373 } 374 return 1; // Other columns ones. 375 }); 376 377 Tensor val(DT_FLOAT, TensorShape({5, 2})); 378 TF_ASSERT_OK(reader.LookupSlice("foo", distinct_slice, &val)); 379 test::ExpectTensorEqual<float>(val, expected_val); 380 } 381 // Reads a slice consisting of columns 2-4, "cutting" the second slice only. 382 { 383 BundleReader reader(Env::Default(), Prefix("foo")); 384 TF_ASSERT_OK(reader.status()); 385 386 const TensorSlice distinct_slice = TensorSlice::ParseOrDie("-:2,2"); 387 Tensor val(DT_FLOAT, TensorShape({5, 2})); 388 TF_ASSERT_OK(reader.LookupSlice("foo", distinct_slice, &val)); 389 test::ExpectTensorEqual<float>(val, 390 Constant<float>(1., TensorShape({5, 2}))); 391 } 392 } 393 394 TEST(TensorBundleTest, EquivalentSliceTest) { 395 const TensorShape kFullShape({5, 10}); 396 const Tensor kExpected(Constant<float>(1., kFullShape)); 397 { 398 BundleWriter writer(Env::Default(), Prefix("foo")); 399 TF_ASSERT_OK(writer.AddSlice("no_extents", kFullShape, 400 TensorSlice::ParseOrDie("-:-"), kExpected)); 401 TF_ASSERT_OK(writer.AddSlice("both_extents", kFullShape, 402 TensorSlice::ParseOrDie("0,5:0,10"), 403 kExpected)); 404 TF_ASSERT_OK(writer.Finish()); 405 } 406 // Slices match exactly and are fully abbreviated. 407 { 408 BundleReader reader(Env::Default(), Prefix("foo")); 409 TF_ASSERT_OK(reader.status()); 410 const TensorSlice slice = TensorSlice::ParseOrDie("-:-"); 411 Tensor val(DT_FLOAT, TensorShape(kFullShape)); 412 TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val)); 413 test::ExpectTensorEqual<float>(val, kExpected); 414 } 415 // Slice match exactly and are fully specified. 416 { 417 BundleReader reader(Env::Default(), Prefix("foo")); 418 TF_ASSERT_OK(reader.status()); 419 const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10"); 420 Tensor val(DT_FLOAT, TensorShape(kFullShape)); 421 TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val)); 422 test::ExpectTensorEqual<float>(val, kExpected); 423 } 424 // Stored slice has no extents, spec has extents. 425 { 426 BundleReader reader(Env::Default(), Prefix("foo")); 427 TF_ASSERT_OK(reader.status()); 428 const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10"); 429 Tensor val(DT_FLOAT, TensorShape(kFullShape)); 430 TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val)); 431 test::ExpectTensorEqual<float>(val, kExpected); 432 } 433 // Stored slice has both extents, spec has no extents. 434 { 435 BundleReader reader(Env::Default(), Prefix("foo")); 436 TF_ASSERT_OK(reader.status()); 437 const TensorSlice slice = TensorSlice::ParseOrDie("-:-"); 438 Tensor val(DT_FLOAT, TensorShape(kFullShape)); 439 TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val)); 440 test::ExpectTensorEqual<float>(val, kExpected); 441 } 442 } 443 444 TEST(TensorBundleTest, NonStandardShapes) { 445 TestNonStandardShapes<float>(); 446 TestNonStandardShapes<double>(); 447 TestNonStandardShapes<int32>(); 448 TestNonStandardShapes<uint8>(); 449 TestNonStandardShapes<int16>(); 450 TestNonStandardShapes<int8>(); 451 TestNonStandardShapes<complex64>(); 452 TestNonStandardShapes<complex128>(); 453 TestNonStandardShapes<int64>(); 454 TestNonStandardShapes<bool>(); 455 TestNonStandardShapes<qint32>(); 456 TestNonStandardShapes<quint8>(); 457 TestNonStandardShapes<qint8>(); 458 } 459 460 TEST(TensorBundleTest, StringTensors) { 461 { 462 BundleWriter writer(Env::Default(), Prefix("foo")); 463 TF_EXPECT_OK(writer.Add("string_tensor", 464 Tensor(DT_STRING, TensorShape({1})))); // Empty. 465 TF_EXPECT_OK(writer.Add("scalar", test::AsTensor<string>({"hello"}))); 466 TF_EXPECT_OK(writer.Add( 467 "strs", 468 test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')}))); 469 // Mixes in some floats. 470 TF_EXPECT_OK(writer.Add("floats", Constant_2x3<float>(16.18))); 471 TF_ASSERT_OK(writer.Finish()); 472 } 473 { 474 BundleReader reader(Env::Default(), Prefix("foo")); 475 TF_ASSERT_OK(reader.status()); 476 EXPECT_EQ( 477 AllTensorKeys(&reader), 478 std::vector<string>({"floats", "scalar", "string_tensor", "strs"})); 479 480 Expect<string>(&reader, "string_tensor", 481 Tensor(DT_STRING, TensorShape({1}))); 482 Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"})); 483 Expect<string>( 484 &reader, "strs", 485 test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')})); 486 Expect<float>(&reader, "floats", Constant_2x3<float>(16.18)); 487 } 488 } 489 490 class VariantObject { 491 public: 492 VariantObject() {} 493 VariantObject(const string& metadata, int64 value) 494 : metadata_(metadata), value_(value) {} 495 496 string TypeName() const { return "TEST VariantObject"; } 497 void Encode(VariantTensorData* data) const { 498 data->set_type_name(TypeName()); 499 data->set_metadata(metadata_); 500 Tensor val_t = Tensor(DT_INT64, TensorShape({})); 501 val_t.scalar<int64>()() = value_; 502 *(data->add_tensors()) = val_t; 503 } 504 bool Decode(const VariantTensorData& data) { 505 EXPECT_EQ(data.type_name(), TypeName()); 506 data.get_metadata(&metadata_); 507 EXPECT_EQ(data.tensors_size(), 1); 508 value_ = data.tensors(0).scalar<int64>()(); 509 return true; 510 } 511 bool operator==(const VariantObject other) const { 512 return metadata_ == other.metadata_ && value_ == other.value_; 513 } 514 string metadata_; 515 int64 value_; 516 }; 517 518 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantObject, "TEST VariantObject"); 519 520 TEST(TensorBundleTest, VariantTensors) { 521 { 522 BundleWriter writer(Env::Default(), Prefix("foo")); 523 TF_EXPECT_OK( 524 writer.Add("variant_tensor", 525 test::AsTensor<Variant>({VariantObject("test", 10), 526 VariantObject("test1", 20)}))); 527 TF_ASSERT_OK(writer.Finish()); 528 } 529 { 530 BundleReader reader(Env::Default(), Prefix("foo")); 531 TF_ASSERT_OK(reader.status()); 532 ExpectVariant<VariantObject>( 533 &reader, "variant_tensor", 534 test::AsTensor<Variant>( 535 {VariantObject("test", 10), VariantObject("test1", 20)})); 536 } 537 } 538 539 TEST(TensorBundleTest, DirectoryStructure) { 540 Env* env = Env::Default(); 541 // Writes two bundles. 542 const std::vector<string> kBundlePrefixes = {Prefix("worker0"), 543 Prefix("worker1")}; 544 for (int i = 0; i < 2; ++i) { 545 BundleWriter writer(env, kBundlePrefixes[i]); 546 TF_EXPECT_OK( 547 writer.Add(strings::StrCat("tensor", i), Constant_2x3<float>(0.))); 548 TF_ASSERT_OK(writer.Finish()); 549 } 550 551 // Ensures we have the expected files. 552 auto CheckDirFiles = [env](const string& bundle_prefix, 553 gtl::ArraySlice<string> expected_files) { 554 StringPiece dir = io::Dirname(bundle_prefix); 555 for (const string& expected_file : expected_files) { 556 TF_EXPECT_OK(env->FileExists(io::JoinPath(dir, expected_file))); 557 } 558 }; 559 560 // Check we have: 561 // worker<i>.index 562 // worker<i>.data-00000-of-00001 563 CheckDirFiles(kBundlePrefixes[0], 564 {"worker0.index", "worker0.data-00000-of-00001"}); 565 CheckDirFiles(kBundlePrefixes[1], 566 {"worker1.index", "worker1.data-00000-of-00001"}); 567 568 // Trivially "merge" one bundle to some other location (i.e., a renaming). 569 const string kAnotherPrefix = Prefix("another"); 570 TF_ASSERT_OK(MergeBundles(env, {kBundlePrefixes[0]}, kAnotherPrefix)); 571 CheckDirFiles(kAnotherPrefix, 572 {"another.index", "another.data-00000-of-00001"}); 573 574 // Performs actual merge of the two bundles. Check we have: 575 // merged.index 576 // merged.data-00000-of-00002 577 // merged.data-00001-of-00002 578 const string kMerged = Prefix("merged"); 579 TF_ASSERT_OK( 580 MergeBundles(env, {kAnotherPrefix, kBundlePrefixes[1]}, kMerged)); 581 CheckDirFiles(kMerged, {"merged.index", "merged.data-00000-of-00002", 582 "merged.data-00001-of-00002"}); 583 } 584 585 TEST(TensorBundleTest, Error) { 586 { // Dup keys. 587 BundleWriter writer(Env::Default(), Prefix("dup")); 588 TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f))); 589 EXPECT_FALSE(writer.Add("foo", Constant_2x3(2.f)).ok()); 590 EXPECT_TRUE( 591 StringPiece(writer.status().ToString()).contains("duplicate key")); 592 EXPECT_FALSE(writer.Finish().ok()); 593 } 594 { // Double finish 595 BundleWriter writer(Env::Default(), Prefix("bad")); 596 EXPECT_TRUE(writer.Finish().ok()); 597 EXPECT_FALSE(writer.Finish().ok()); 598 } 599 { // Not found. 600 BundleReader reader(Env::Default(), Prefix("nonexist")); 601 EXPECT_TRUE(StringPiece(reader.status().ToString()).contains("Not found")); 602 } 603 } 604 605 TEST(TensorBundleTest, Checksum) { 606 // Randomly flips a byte in [pos_lhs, end of data file), or exactly byte 607 // pos_lhs if exact_pos == True. 608 auto FlipByte = [](const string& prefix, int pos_lhs, 609 bool exact_pos = false) { 610 DCHECK_GE(pos_lhs, 0); 611 const string& datafile = DataFilename(Prefix(prefix), 0, 1); 612 string data; 613 TF_ASSERT_OK(ReadFileToString(Env::Default(), datafile, &data)); 614 615 int byte_pos = 0; 616 if (!exact_pos) { 617 std::mt19937 rng; 618 std::uniform_int_distribution<int> dist(pos_lhs, data.size() - 1); 619 byte_pos = dist(rng); 620 } else { 621 byte_pos = pos_lhs; 622 } 623 data[byte_pos] = ~data[byte_pos]; 624 TF_ASSERT_OK(WriteStringToFile(Env::Default(), datafile, data)); 625 }; 626 // The lookup should fail with a checksum-related message. 627 auto ExpectLookupFails = [](const string& prefix, const string& key, 628 const string& expected_msg, Tensor& val) { 629 BundleReader reader(Env::Default(), Prefix(prefix)); 630 Status status = reader.Lookup(key, &val); 631 EXPECT_TRUE(errors::IsDataLoss(status)); 632 EXPECT_TRUE(StringPiece(status.ToString()).contains(expected_msg)); 633 }; 634 635 // Corrupts a float tensor. 636 { 637 BundleWriter writer(Env::Default(), Prefix("singleton")); 638 TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f))); 639 TF_ASSERT_OK(writer.Finish()); 640 641 FlipByte("singleton", 0 /* corrupts any byte */); 642 Tensor val(DT_FLOAT, TensorShape({2, 3})); 643 ExpectLookupFails("singleton", "foo", 644 "Checksum does not match" /* expected fail msg */, val); 645 } 646 // Corrupts a string tensor. 647 { 648 auto WriteStrings = []() { 649 BundleWriter writer(Env::Default(), Prefix("strings")); 650 TF_EXPECT_OK( 651 writer.Add("foo", test::AsTensor<string>({"hello", "world"}))); 652 TF_ASSERT_OK(writer.Finish()); 653 }; 654 // Corrupts the first two bytes, which are the varint32-encoded lengths 655 // of the two string elements. Should hit mismatch on length cksum. 656 for (int i = 0; i < 2; ++i) { 657 WriteStrings(); 658 FlipByte("strings", i, true /* corrupts exactly byte i */); 659 Tensor val(DT_STRING, TensorShape({2})); 660 ExpectLookupFails( 661 "strings", "foo", 662 "length checksum does not match" /* expected fail msg */, val); 663 } 664 // Corrupts the string bytes, should hit an overall cksum mismatch. 665 WriteStrings(); 666 FlipByte("strings", 2 /* corrupts starting from byte 2 */); 667 Tensor val(DT_STRING, TensorShape({2})); 668 ExpectLookupFails("strings", "foo", 669 "Checksum does not match" /* expected fail msg */, val); 670 } 671 } 672 673 TEST(TensorBundleTest, Endianness) { 674 BundleWriter writer(Env::Default(), Prefix("end")); 675 TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0))); 676 TF_ASSERT_OK(writer.Finish()); 677 678 // Flips the endianness bit. 679 TF_ASSERT_OK(FlipEndiannessBit(Prefix("end"))); 680 681 BundleReader reader(Env::Default(), Prefix("end")); 682 EXPECT_TRUE(errors::IsUnimplemented(reader.status())); 683 EXPECT_TRUE(StringPiece(reader.status().ToString()) 684 .contains("different endianness from the reader")); 685 } 686 687 TEST(TensorBundleTest, TruncatedTensorContents) { 688 Env* env = Env::Default(); 689 BundleWriter writer(env, Prefix("end")); 690 TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0))); 691 TF_ASSERT_OK(writer.Finish()); 692 693 // Truncates the data file by one byte, so that we hit EOF. 694 const string datafile = DataFilename(Prefix("end"), 0, 1); 695 string data; 696 TF_ASSERT_OK(ReadFileToString(env, datafile, &data)); 697 ASSERT_TRUE(!data.empty()); 698 TF_ASSERT_OK(WriteStringToFile(env, datafile, 699 StringPiece(data.data(), data.size() - 1))); 700 701 BundleReader reader(env, Prefix("end")); 702 TF_ASSERT_OK(reader.status()); 703 Tensor val(DT_FLOAT, TensorShape({2, 3})); 704 EXPECT_TRUE(errors::IsOutOfRange(reader.Lookup("key", &val))); 705 } 706 707 TEST(TensorBundleTest, HeaderEntry) { 708 { 709 BundleWriter writer(Env::Default(), Prefix("b")); 710 TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0))); 711 TF_ASSERT_OK(writer.Finish()); 712 } 713 714 // Extracts out the header. 715 BundleHeaderProto header; 716 { 717 BundleReader reader(Env::Default(), Prefix("b")); 718 TF_ASSERT_OK(reader.status()); 719 reader.Seek(kHeaderEntryKey); 720 ASSERT_TRUE(reader.Valid()); 721 ASSERT_TRUE(ParseProtoUnlimited(&header, reader.value().data(), 722 reader.value().size())); 723 } 724 725 // num_shards 726 EXPECT_EQ(1, header.num_shards()); 727 // endianness 728 if (port::kLittleEndian) { 729 EXPECT_EQ(BundleHeaderProto::LITTLE, header.endianness()); 730 } else { 731 EXPECT_EQ(BundleHeaderProto::BIG, header.endianness()); 732 } 733 // version 734 EXPECT_GT(kTensorBundleVersion, 0); 735 EXPECT_EQ(kTensorBundleVersion, header.version().producer()); 736 EXPECT_EQ(kTensorBundleMinConsumer, header.version().min_consumer()); 737 } 738 739 TEST(TensorBundleTest, VersionTest) { 740 // Min consumer. 741 { 742 VersionDef versions; 743 versions.set_producer(kTensorBundleVersion + 1); 744 versions.set_min_consumer(kTensorBundleVersion + 1); 745 VersionTest( 746 versions, 747 strings::StrCat("Checkpoint min consumer version ", 748 kTensorBundleVersion + 1, " above current version ", 749 kTensorBundleVersion, " for TensorFlow")); 750 } 751 // Min producer. 752 { 753 VersionDef versions; 754 versions.set_producer(kTensorBundleMinProducer - 1); 755 VersionTest( 756 versions, 757 strings::StrCat("Checkpoint producer version ", 758 kTensorBundleMinProducer - 1, " below min producer ", 759 kTensorBundleMinProducer, " supported by TensorFlow")); 760 } 761 // Bad consumer. 762 { 763 VersionDef versions; 764 versions.set_producer(kTensorBundleVersion + 1); 765 versions.add_bad_consumers(kTensorBundleVersion); 766 VersionTest( 767 versions, 768 strings::StrCat( 769 "Checkpoint disallows consumer version ", kTensorBundleVersion, 770 ". Please upgrade TensorFlow: this version is likely buggy.")); 771 } 772 } 773 774 class TensorBundleAlignmentTest : public ::testing::Test { 775 protected: 776 template <typename T> 777 void ExpectAlignment(BundleReader* reader, const string& key, int alignment) { 778 BundleEntryProto full_tensor_entry; 779 TF_ASSERT_OK(reader->GetBundleEntryProto(key, &full_tensor_entry)); 780 EXPECT_EQ(0, full_tensor_entry.offset() % alignment); 781 } 782 }; 783 784 TEST_F(TensorBundleAlignmentTest, AlignmentTest) { 785 { 786 BundleWriter::Options opts; 787 opts.data_alignment = 42; 788 BundleWriter writer(Env::Default(), Prefix("foo"), opts); 789 TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3<float>(3))); 790 TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3<float>(0))); 791 TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3<float>(2))); 792 TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3<float>(1))); 793 TF_ASSERT_OK(writer.Finish()); 794 } 795 { 796 BundleReader reader(Env::Default(), Prefix("foo")); 797 TF_ASSERT_OK(reader.status()); 798 EXPECT_EQ( 799 AllTensorKeys(&reader), 800 std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"})); 801 Expect<float>(&reader, "foo_000", Constant_2x3<float>(0)); 802 Expect<float>(&reader, "foo_001", Constant_2x3<float>(1)); 803 Expect<float>(&reader, "foo_002", Constant_2x3<float>(2)); 804 Expect<float>(&reader, "foo_003", Constant_2x3<float>(3)); 805 } 806 { 807 BundleReader reader(Env::Default(), Prefix("foo")); 808 TF_ASSERT_OK(reader.status()); 809 ExpectNext<float>(&reader, Constant_2x3<float>(0)); 810 ExpectNext<float>(&reader, Constant_2x3<float>(1)); 811 ExpectNext<float>(&reader, Constant_2x3<float>(2)); 812 ExpectNext<float>(&reader, Constant_2x3<float>(3)); 813 EXPECT_TRUE(reader.Valid()); 814 reader.Next(); 815 EXPECT_FALSE(reader.Valid()); 816 } 817 { 818 BundleReader reader(Env::Default(), Prefix("foo")); 819 TF_ASSERT_OK(reader.status()); 820 ExpectAlignment<float>(&reader, "foo_000", 42); 821 ExpectAlignment<float>(&reader, "foo_001", 42); 822 ExpectAlignment<float>(&reader, "foo_002", 42); 823 ExpectAlignment<float>(&reader, "foo_003", 42); 824 } 825 } 826 827 static void BM_BundleAlignmentByteOff(int iters, int alignment, 828 int tensor_size) { 829 testing::StopTiming(); 830 { 831 BundleWriter::Options opts; 832 opts.data_alignment = alignment; 833 BundleWriter writer(Env::Default(), Prefix("foo"), opts); 834 TF_CHECK_OK(writer.Add("small", Constant(true, TensorShape({1})))); 835 TF_CHECK_OK(writer.Add("big", Constant(32.1, TensorShape({tensor_size})))); 836 TF_CHECK_OK(writer.Finish()); 837 } 838 BundleReader reader(Env::Default(), Prefix("foo")); 839 TF_CHECK_OK(reader.status()); 840 testing::StartTiming(); 841 for (int i = 0; i < iters; ++i) { 842 Tensor t; 843 TF_CHECK_OK(reader.Lookup("big", &t)); 844 } 845 testing::StopTiming(); 846 } 847 848 #define BM_BundleAlignment(ALIGN, SIZE) \ 849 static void BM_BundleAlignment_##ALIGN##_##SIZE(int iters) { \ 850 BM_BundleAlignmentByteOff(iters, ALIGN, SIZE); \ 851 } \ 852 BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE) 853 854 BM_BundleAlignment(1, 512); 855 BM_BundleAlignment(1, 4096); 856 BM_BundleAlignment(1, 1048576); 857 BM_BundleAlignment(4096, 512); 858 BM_BundleAlignment(4096, 4096); 859 BM_BundleAlignment(4096, 1048576); 860 861 } // namespace tensorflow 862