Home | History | Annotate | Download | only in tensor_bundle
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7 http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
     17 
     18 #include <random>
     19 #include <vector>
     20 
     21 #include "tensorflow/core/framework/tensor_testutil.h"
     22 #include "tensorflow/core/framework/types.pb.h"
     23 #include "tensorflow/core/framework/variant.h"
     24 #include "tensorflow/core/framework/variant_op_registry.h"
     25 #include "tensorflow/core/framework/versions.pb.h"
     26 #include "tensorflow/core/lib/core/status_test_util.h"
     27 #include "tensorflow/core/lib/io/path.h"
     28 #include "tensorflow/core/lib/io/table_builder.h"
     29 #include "tensorflow/core/lib/strings/strcat.h"
     30 #include "tensorflow/core/platform/test.h"
     31 #include "tensorflow/core/platform/test_benchmark.h"
     32 
     33 namespace tensorflow {
     34 
     35 namespace {
     36 
     37 string Prefix(const string& prefix) {
     38   return strings::StrCat(testing::TmpDir(), "/", prefix);
     39 }
     40 
     41 template <typename T>
     42 Tensor Constant(T v, TensorShape shape) {
     43   Tensor ret(DataTypeToEnum<T>::value, shape);
     44   ret.flat<T>().setConstant(v);
     45   return ret;
     46 }
     47 
     48 template <typename T>
     49 Tensor Constant_2x3(T v) {
     50   return Constant(v, TensorShape({2, 3}));
     51 }
     52 
     53 template <typename T>
     54 void Expect(BundleReader* reader, const string& key,
     55             const Tensor& expected_val) {
     56   // Tests for Contains().
     57   EXPECT_TRUE(reader->Contains(key));
     58   // Tests for LookupDtypeAndShape().
     59   DataType dtype;
     60   TensorShape shape;
     61   TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape));
     62   EXPECT_EQ(expected_val.dtype(), dtype);
     63   EXPECT_EQ(expected_val.shape(), shape);
     64   // Tests for Lookup(), checking tensor contents.
     65   Tensor val(expected_val.dtype(), shape);
     66   TF_ASSERT_OK(reader->Lookup(key, &val));
     67   test::ExpectTensorEqual<T>(val, expected_val);
     68 }
     69 
     70 template <class T>
     71 void ExpectVariant(BundleReader* reader, const string& key,
     72                    const Tensor& expected_t) {
     73   // Tests for Contains().
     74   EXPECT_TRUE(reader->Contains(key));
     75   // Tests for LookupDtypeAndShape().
     76   DataType dtype;
     77   TensorShape shape;
     78   TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape));
     79   // Tests for Lookup(), checking tensor contents.
     80   EXPECT_EQ(expected_t.dtype(), dtype);
     81   EXPECT_EQ(expected_t.shape(), shape);
     82   Tensor actual_t(dtype, shape);
     83   TF_ASSERT_OK(reader->Lookup(key, &actual_t));
     84   for (int i = 0; i < expected_t.NumElements(); i++) {
     85     Variant actual_var = actual_t.flat<Variant>()(i);
     86     Variant expected_var = expected_t.flat<Variant>()(i);
     87     EXPECT_EQ(actual_var.TypeName(), expected_var.TypeName());
     88     auto* actual_val = actual_var.get<T>();
     89     auto* expected_val = expected_var.get<T>();
     90     EXPECT_EQ(*expected_val, *actual_val);
     91   }
     92 }
     93 
     94 template <typename T>
     95 void ExpectNext(BundleReader* reader, const Tensor& expected_val) {
     96   EXPECT_TRUE(reader->Valid());
     97   reader->Next();
     98   TF_ASSERT_OK(reader->status());
     99   Tensor val;
    100   TF_ASSERT_OK(reader->ReadCurrent(&val));
    101   test::ExpectTensorEqual<T>(val, expected_val);
    102 }
    103 
    104 std::vector<string> AllTensorKeys(BundleReader* reader) {
    105   std::vector<string> ret;
    106   reader->Seek(kHeaderEntryKey);
    107   reader->Next();
    108   for (; reader->Valid(); reader->Next()) {
    109     ret.push_back(reader->key().ToString());
    110   }
    111   return ret;
    112 }
    113 
    114 // Writes out the metadata file of a bundle again, with the endianness marker
    115 // bit flipped.
    116 Status FlipEndiannessBit(const string& prefix) {
    117   Env* env = Env::Default();
    118   const string metadata_tmp_path = Prefix("some_tmp_path");
    119   std::unique_ptr<WritableFile> file;
    120   TF_RETURN_IF_ERROR(env->NewWritableFile(metadata_tmp_path, &file));
    121   table::TableBuilder builder(table::Options(), file.get());
    122 
    123   // Reads the existing metadata file, and fills the builder.
    124   {
    125     const string filename = MetaFilename(prefix);
    126     uint64 file_size;
    127     TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
    128     std::unique_ptr<RandomAccessFile> file;
    129     TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
    130 
    131     table::Table* table = nullptr;
    132     TF_RETURN_IF_ERROR(
    133         table::Table::Open(table::Options(), file.get(), file_size, &table));
    134     std::unique_ptr<table::Table> table_deleter(table);
    135     std::unique_ptr<table::Iterator> iter(table->NewIterator());
    136 
    137     // Reads the header entry.
    138     iter->Seek(kHeaderEntryKey);
    139     CHECK(iter->Valid());
    140     BundleHeaderProto header;
    141     CHECK(header.ParseFromArray(iter->value().data(), iter->value().size()));
    142     // Flips the endianness.
    143     if (header.endianness() == BundleHeaderProto::LITTLE) {
    144       header.set_endianness(BundleHeaderProto::BIG);
    145     } else {
    146       header.set_endianness(BundleHeaderProto::LITTLE);
    147     }
    148     builder.Add(iter->key(), header.SerializeAsString());
    149     iter->Next();
    150 
    151     // Adds the non-header entries unmodified.
    152     for (; iter->Valid(); iter->Next()) builder.Add(iter->key(), iter->value());
    153   }
    154   TF_RETURN_IF_ERROR(builder.Finish());
    155   TF_RETURN_IF_ERROR(env->RenameFile(metadata_tmp_path, MetaFilename(prefix)));
    156   return file->Close();
    157 }
    158 
    159 template <typename T>
    160 void TestBasic() {
    161   {
    162     BundleWriter writer(Env::Default(), Prefix("foo"));
    163     TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3<T>(3)));
    164     TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3<T>(0)));
    165     TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3<T>(2)));
    166     TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3<T>(1)));
    167     TF_ASSERT_OK(writer.Finish());
    168   }
    169   {
    170     BundleReader reader(Env::Default(), Prefix("foo"));
    171     TF_ASSERT_OK(reader.status());
    172     EXPECT_EQ(
    173         AllTensorKeys(&reader),
    174         std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
    175     Expect<T>(&reader, "foo_000", Constant_2x3<T>(0));
    176     Expect<T>(&reader, "foo_001", Constant_2x3<T>(1));
    177     Expect<T>(&reader, "foo_002", Constant_2x3<T>(2));
    178     Expect<T>(&reader, "foo_003", Constant_2x3<T>(3));
    179   }
    180   {
    181     BundleReader reader(Env::Default(), Prefix("foo"));
    182     TF_ASSERT_OK(reader.status());
    183     ExpectNext<T>(&reader, Constant_2x3<T>(0));
    184     ExpectNext<T>(&reader, Constant_2x3<T>(1));
    185     ExpectNext<T>(&reader, Constant_2x3<T>(2));
    186     ExpectNext<T>(&reader, Constant_2x3<T>(3));
    187     EXPECT_TRUE(reader.Valid());
    188     reader.Next();
    189     EXPECT_FALSE(reader.Valid());
    190   }
    191   {
    192     BundleWriter writer(Env::Default(), Prefix("bar"));
    193     TF_EXPECT_OK(writer.Add("bar_003", Constant_2x3<T>(3)));
    194     TF_EXPECT_OK(writer.Add("bar_000", Constant_2x3<T>(0)));
    195     TF_EXPECT_OK(writer.Add("bar_002", Constant_2x3<T>(2)));
    196     TF_EXPECT_OK(writer.Add("bar_001", Constant_2x3<T>(1)));
    197     TF_ASSERT_OK(writer.Finish());
    198   }
    199   {
    200     BundleReader reader(Env::Default(), Prefix("bar"));
    201     TF_ASSERT_OK(reader.status());
    202     EXPECT_EQ(
    203         AllTensorKeys(&reader),
    204         std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003"}));
    205     Expect<T>(&reader, "bar_003", Constant_2x3<T>(3));
    206     Expect<T>(&reader, "bar_002", Constant_2x3<T>(2));
    207     Expect<T>(&reader, "bar_001", Constant_2x3<T>(1));
    208     Expect<T>(&reader, "bar_000", Constant_2x3<T>(0));
    209   }
    210   {
    211     BundleReader reader(Env::Default(), Prefix("bar"));
    212     TF_ASSERT_OK(reader.status());
    213     ExpectNext<T>(&reader, Constant_2x3<T>(0));
    214     ExpectNext<T>(&reader, Constant_2x3<T>(1));
    215     ExpectNext<T>(&reader, Constant_2x3<T>(2));
    216     ExpectNext<T>(&reader, Constant_2x3<T>(3));
    217     EXPECT_TRUE(reader.Valid());
    218     reader.Next();
    219     EXPECT_FALSE(reader.Valid());
    220   }
    221   TF_ASSERT_OK(MergeBundles(Env::Default(), {Prefix("foo"), Prefix("bar")},
    222                             Prefix("merged")));
    223   {
    224     BundleReader reader(Env::Default(), Prefix("merged"));
    225     TF_ASSERT_OK(reader.status());
    226     EXPECT_EQ(
    227         AllTensorKeys(&reader),
    228         std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003",
    229                              "foo_000", "foo_001", "foo_002", "foo_003"}));
    230     Expect<T>(&reader, "bar_000", Constant_2x3<T>(0));
    231     Expect<T>(&reader, "bar_001", Constant_2x3<T>(1));
    232     Expect<T>(&reader, "bar_002", Constant_2x3<T>(2));
    233     Expect<T>(&reader, "bar_003", Constant_2x3<T>(3));
    234     Expect<T>(&reader, "foo_000", Constant_2x3<T>(0));
    235     Expect<T>(&reader, "foo_001", Constant_2x3<T>(1));
    236     Expect<T>(&reader, "foo_002", Constant_2x3<T>(2));
    237     Expect<T>(&reader, "foo_003", Constant_2x3<T>(3));
    238   }
    239   {
    240     BundleReader reader(Env::Default(), Prefix("merged"));
    241     TF_ASSERT_OK(reader.status());
    242     ExpectNext<T>(&reader, Constant_2x3<T>(0));
    243     ExpectNext<T>(&reader, Constant_2x3<T>(1));
    244     ExpectNext<T>(&reader, Constant_2x3<T>(2));
    245     ExpectNext<T>(&reader, Constant_2x3<T>(3));
    246     ExpectNext<T>(&reader, Constant_2x3<T>(0));
    247     ExpectNext<T>(&reader, Constant_2x3<T>(1));
    248     ExpectNext<T>(&reader, Constant_2x3<T>(2));
    249     ExpectNext<T>(&reader, Constant_2x3<T>(3));
    250     EXPECT_TRUE(reader.Valid());
    251     reader.Next();
    252     EXPECT_FALSE(reader.Valid());
    253   }
    254 }
    255 
    256 template <typename T>
    257 void TestNonStandardShapes() {
    258   {
    259     BundleWriter writer(Env::Default(), Prefix("nonstandard"));
    260     TF_EXPECT_OK(writer.Add("scalar", Constant<T>(0, TensorShape())));
    261     TF_EXPECT_OK(
    262         writer.Add("non_standard0", Constant<T>(0, TensorShape({0, 1618}))));
    263     TF_EXPECT_OK(
    264         writer.Add("non_standard1", Constant<T>(0, TensorShape({16, 0, 18}))));
    265     TF_ASSERT_OK(writer.Finish());
    266   }
    267   {
    268     BundleReader reader(Env::Default(), Prefix("nonstandard"));
    269     TF_ASSERT_OK(reader.status());
    270     Expect<T>(&reader, "scalar", Constant<T>(0, TensorShape()));
    271     Expect<T>(&reader, "non_standard0", Constant<T>(0, TensorShape({0, 1618})));
    272     Expect<T>(&reader, "non_standard1",
    273               Constant<T>(0, TensorShape({16, 0, 18})));
    274   }
    275 }
    276 
    277 // Writes a bundle to disk with a bad "version"; checks for "expected_error".
    278 void VersionTest(const VersionDef& version, StringPiece expected_error) {
    279   const string path = Prefix("version_test");
    280   {
    281     // Prepare an empty bundle with the given version information.
    282     BundleHeaderProto header;
    283     *header.mutable_version() = version;
    284 
    285     // Write the metadata file to disk.
    286     std::unique_ptr<WritableFile> file;
    287     TF_ASSERT_OK(Env::Default()->NewWritableFile(MetaFilename(path), &file));
    288     table::TableBuilder builder(table::Options(), file.get());
    289     builder.Add(kHeaderEntryKey, header.SerializeAsString());
    290     TF_ASSERT_OK(builder.Finish());
    291   }
    292   // Read it back in and verify that we get the expected error.
    293   BundleReader reader(Env::Default(), path);
    294   EXPECT_TRUE(errors::IsInvalidArgument(reader.status()));
    295   EXPECT_TRUE(
    296       StringPiece(reader.status().error_message()).starts_with(expected_error));
    297 }
    298 
    299 }  // namespace
    300 
    301 TEST(TensorBundleTest, Basic) {
    302   TestBasic<float>();
    303   TestBasic<double>();
    304   TestBasic<int32>();
    305   TestBasic<uint8>();
    306   TestBasic<int16>();
    307   TestBasic<int8>();
    308   TestBasic<complex64>();
    309   TestBasic<complex128>();
    310   TestBasic<int64>();
    311   TestBasic<bool>();
    312   TestBasic<qint32>();
    313   TestBasic<quint8>();
    314   TestBasic<qint8>();
    315 }
    316 
    317 TEST(TensorBundleTest, PartitionedVariables) {
    318   const TensorShape kFullShape({5, 10});
    319   // Adds two slices.
    320   // First slice: column 0, all zeros.
    321   // Second slice: column 1 to rest, all ones.
    322   TensorSlice slice1 = TensorSlice::ParseOrDie("-:0,1");
    323   TensorSlice slice2 = TensorSlice::ParseOrDie("-:1,9");
    324   {
    325     BundleWriter writer(Env::Default(), Prefix("foo"));
    326 
    327     TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice1,
    328                                  Constant<float>(0., TensorShape({5, 1}))));
    329     TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice2,
    330                                  Constant<float>(1., TensorShape({5, 9}))));
    331     TF_ASSERT_OK(writer.Finish());
    332   }
    333   // Reads in full.
    334   {
    335     BundleReader reader(Env::Default(), Prefix("foo"));
    336     TF_ASSERT_OK(reader.status());
    337 
    338     Tensor expected_val(DT_FLOAT, kFullShape);
    339     test::FillFn<float>(&expected_val, [](int offset) -> float {
    340       if (offset % 10 == 0) {
    341         return 0;  // First column zeros.
    342       }
    343       return 1;  // Other columns ones.
    344     });
    345 
    346     Tensor val(DT_FLOAT, kFullShape);
    347     TF_ASSERT_OK(reader.Lookup("foo", &val));
    348     test::ExpectTensorEqual<float>(val, expected_val);
    349   }
    350   // Reads all slices.
    351   {
    352     BundleReader reader(Env::Default(), Prefix("foo"));
    353     TF_ASSERT_OK(reader.status());
    354 
    355     std::vector<TensorSlice> slices;
    356     TF_ASSERT_OK(reader.LookupTensorSlices("foo", &slices));
    357 
    358     EXPECT_EQ(2, slices.size());
    359     EXPECT_EQ(slice1.DebugString(), slices[0].DebugString());
    360     EXPECT_EQ(slice2.DebugString(), slices[1].DebugString());
    361   }
    362   // Reads a slice consisting of first two columns, "cutting" both slices.
    363   {
    364     BundleReader reader(Env::Default(), Prefix("foo"));
    365     TF_ASSERT_OK(reader.status());
    366 
    367     // First two columns, "cutting" both slices.
    368     const TensorSlice distinct_slice = TensorSlice::ParseOrDie("-:0,2");
    369     Tensor expected_val(DT_FLOAT, TensorShape({5, 2}));
    370     test::FillFn<float>(&expected_val, [](int offset) -> float {
    371       if (offset % 2 == 0) {
    372         return 0;  // First column zeros.
    373       }
    374       return 1;  // Other columns ones.
    375     });
    376 
    377     Tensor val(DT_FLOAT, TensorShape({5, 2}));
    378     TF_ASSERT_OK(reader.LookupSlice("foo", distinct_slice, &val));
    379     test::ExpectTensorEqual<float>(val, expected_val);
    380   }
    381   // Reads a slice consisting of columns 2-4, "cutting" the second slice only.
    382   {
    383     BundleReader reader(Env::Default(), Prefix("foo"));
    384     TF_ASSERT_OK(reader.status());
    385 
    386     const TensorSlice distinct_slice = TensorSlice::ParseOrDie("-:2,2");
    387     Tensor val(DT_FLOAT, TensorShape({5, 2}));
    388     TF_ASSERT_OK(reader.LookupSlice("foo", distinct_slice, &val));
    389     test::ExpectTensorEqual<float>(val,
    390                                    Constant<float>(1., TensorShape({5, 2})));
    391   }
    392 }
    393 
    394 TEST(TensorBundleTest, EquivalentSliceTest) {
    395   const TensorShape kFullShape({5, 10});
    396   const Tensor kExpected(Constant<float>(1., kFullShape));
    397   {
    398     BundleWriter writer(Env::Default(), Prefix("foo"));
    399     TF_ASSERT_OK(writer.AddSlice("no_extents", kFullShape,
    400                                  TensorSlice::ParseOrDie("-:-"), kExpected));
    401     TF_ASSERT_OK(writer.AddSlice("both_extents", kFullShape,
    402                                  TensorSlice::ParseOrDie("0,5:0,10"),
    403                                  kExpected));
    404     TF_ASSERT_OK(writer.Finish());
    405   }
    406   // Slices match exactly and are fully abbreviated.
    407   {
    408     BundleReader reader(Env::Default(), Prefix("foo"));
    409     TF_ASSERT_OK(reader.status());
    410     const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
    411     Tensor val(DT_FLOAT, TensorShape(kFullShape));
    412     TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
    413     test::ExpectTensorEqual<float>(val, kExpected);
    414   }
    415   // Slice match exactly and are fully specified.
    416   {
    417     BundleReader reader(Env::Default(), Prefix("foo"));
    418     TF_ASSERT_OK(reader.status());
    419     const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
    420     Tensor val(DT_FLOAT, TensorShape(kFullShape));
    421     TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
    422     test::ExpectTensorEqual<float>(val, kExpected);
    423   }
    424   // Stored slice has no extents, spec has extents.
    425   {
    426     BundleReader reader(Env::Default(), Prefix("foo"));
    427     TF_ASSERT_OK(reader.status());
    428     const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
    429     Tensor val(DT_FLOAT, TensorShape(kFullShape));
    430     TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
    431     test::ExpectTensorEqual<float>(val, kExpected);
    432   }
    433   // Stored slice has both extents, spec has no extents.
    434   {
    435     BundleReader reader(Env::Default(), Prefix("foo"));
    436     TF_ASSERT_OK(reader.status());
    437     const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
    438     Tensor val(DT_FLOAT, TensorShape(kFullShape));
    439     TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
    440     test::ExpectTensorEqual<float>(val, kExpected);
    441   }
    442 }
    443 
    444 TEST(TensorBundleTest, NonStandardShapes) {
    445   TestNonStandardShapes<float>();
    446   TestNonStandardShapes<double>();
    447   TestNonStandardShapes<int32>();
    448   TestNonStandardShapes<uint8>();
    449   TestNonStandardShapes<int16>();
    450   TestNonStandardShapes<int8>();
    451   TestNonStandardShapes<complex64>();
    452   TestNonStandardShapes<complex128>();
    453   TestNonStandardShapes<int64>();
    454   TestNonStandardShapes<bool>();
    455   TestNonStandardShapes<qint32>();
    456   TestNonStandardShapes<quint8>();
    457   TestNonStandardShapes<qint8>();
    458 }
    459 
    460 TEST(TensorBundleTest, StringTensors) {
    461   {
    462     BundleWriter writer(Env::Default(), Prefix("foo"));
    463     TF_EXPECT_OK(writer.Add("string_tensor",
    464                             Tensor(DT_STRING, TensorShape({1}))));  // Empty.
    465     TF_EXPECT_OK(writer.Add("scalar", test::AsTensor<string>({"hello"})));
    466     TF_EXPECT_OK(writer.Add(
    467         "strs",
    468         test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')})));
    469     // Mixes in some floats.
    470     TF_EXPECT_OK(writer.Add("floats", Constant_2x3<float>(16.18)));
    471     TF_ASSERT_OK(writer.Finish());
    472   }
    473   {
    474     BundleReader reader(Env::Default(), Prefix("foo"));
    475     TF_ASSERT_OK(reader.status());
    476     EXPECT_EQ(
    477         AllTensorKeys(&reader),
    478         std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
    479 
    480     Expect<string>(&reader, "string_tensor",
    481                    Tensor(DT_STRING, TensorShape({1})));
    482     Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"}));
    483     Expect<string>(
    484         &reader, "strs",
    485         test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')}));
    486     Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
    487   }
    488 }
    489 
    490 class VariantObject {
    491  public:
    492   VariantObject() {}
    493   VariantObject(const string& metadata, int64 value)
    494       : metadata_(metadata), value_(value) {}
    495 
    496   string TypeName() const { return "TEST VariantObject"; }
    497   void Encode(VariantTensorData* data) const {
    498     data->set_type_name(TypeName());
    499     data->set_metadata(metadata_);
    500     Tensor val_t = Tensor(DT_INT64, TensorShape({}));
    501     val_t.scalar<int64>()() = value_;
    502     *(data->add_tensors()) = val_t;
    503   }
    504   bool Decode(const VariantTensorData& data) {
    505     EXPECT_EQ(data.type_name(), TypeName());
    506     data.get_metadata(&metadata_);
    507     EXPECT_EQ(data.tensors_size(), 1);
    508     value_ = data.tensors(0).scalar<int64>()();
    509     return true;
    510   }
    511   bool operator==(const VariantObject other) const {
    512     return metadata_ == other.metadata_ && value_ == other.value_;
    513   }
    514   string metadata_;
    515   int64 value_;
    516 };
    517 
    518 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantObject, "TEST VariantObject");
    519 
    520 TEST(TensorBundleTest, VariantTensors) {
    521   {
    522     BundleWriter writer(Env::Default(), Prefix("foo"));
    523     TF_EXPECT_OK(
    524         writer.Add("variant_tensor",
    525                    test::AsTensor<Variant>({VariantObject("test", 10),
    526                                             VariantObject("test1", 20)})));
    527     TF_ASSERT_OK(writer.Finish());
    528   }
    529   {
    530     BundleReader reader(Env::Default(), Prefix("foo"));
    531     TF_ASSERT_OK(reader.status());
    532     ExpectVariant<VariantObject>(
    533         &reader, "variant_tensor",
    534         test::AsTensor<Variant>(
    535             {VariantObject("test", 10), VariantObject("test1", 20)}));
    536   }
    537 }
    538 
    539 TEST(TensorBundleTest, DirectoryStructure) {
    540   Env* env = Env::Default();
    541   // Writes two bundles.
    542   const std::vector<string> kBundlePrefixes = {Prefix("worker0"),
    543                                                Prefix("worker1")};
    544   for (int i = 0; i < 2; ++i) {
    545     BundleWriter writer(env, kBundlePrefixes[i]);
    546     TF_EXPECT_OK(
    547         writer.Add(strings::StrCat("tensor", i), Constant_2x3<float>(0.)));
    548     TF_ASSERT_OK(writer.Finish());
    549   }
    550 
    551   // Ensures we have the expected files.
    552   auto CheckDirFiles = [env](const string& bundle_prefix,
    553                              gtl::ArraySlice<string> expected_files) {
    554     StringPiece dir = io::Dirname(bundle_prefix);
    555     for (const string& expected_file : expected_files) {
    556       TF_EXPECT_OK(env->FileExists(io::JoinPath(dir, expected_file)));
    557     }
    558   };
    559 
    560   // Check we have:
    561   //   worker<i>.index
    562   //   worker<i>.data-00000-of-00001
    563   CheckDirFiles(kBundlePrefixes[0],
    564                 {"worker0.index", "worker0.data-00000-of-00001"});
    565   CheckDirFiles(kBundlePrefixes[1],
    566                 {"worker1.index", "worker1.data-00000-of-00001"});
    567 
    568   // Trivially "merge" one bundle to some other location (i.e., a renaming).
    569   const string kAnotherPrefix = Prefix("another");
    570   TF_ASSERT_OK(MergeBundles(env, {kBundlePrefixes[0]}, kAnotherPrefix));
    571   CheckDirFiles(kAnotherPrefix,
    572                 {"another.index", "another.data-00000-of-00001"});
    573 
    574   // Performs actual merge of the two bundles.  Check we have:
    575   //   merged.index
    576   //   merged.data-00000-of-00002
    577   //   merged.data-00001-of-00002
    578   const string kMerged = Prefix("merged");
    579   TF_ASSERT_OK(
    580       MergeBundles(env, {kAnotherPrefix, kBundlePrefixes[1]}, kMerged));
    581   CheckDirFiles(kMerged, {"merged.index", "merged.data-00000-of-00002",
    582                           "merged.data-00001-of-00002"});
    583 }
    584 
    585 TEST(TensorBundleTest, Error) {
    586   {  // Dup keys.
    587     BundleWriter writer(Env::Default(), Prefix("dup"));
    588     TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f)));
    589     EXPECT_FALSE(writer.Add("foo", Constant_2x3(2.f)).ok());
    590     EXPECT_TRUE(
    591         StringPiece(writer.status().ToString()).contains("duplicate key"));
    592     EXPECT_FALSE(writer.Finish().ok());
    593   }
    594   {  // Double finish
    595     BundleWriter writer(Env::Default(), Prefix("bad"));
    596     EXPECT_TRUE(writer.Finish().ok());
    597     EXPECT_FALSE(writer.Finish().ok());
    598   }
    599   {  // Not found.
    600     BundleReader reader(Env::Default(), Prefix("nonexist"));
    601     EXPECT_TRUE(StringPiece(reader.status().ToString()).contains("Not found"));
    602   }
    603 }
    604 
    605 TEST(TensorBundleTest, Checksum) {
    606   // Randomly flips a byte in [pos_lhs, end of data file), or exactly byte
    607   // pos_lhs if exact_pos == True.
    608   auto FlipByte = [](const string& prefix, int pos_lhs,
    609                      bool exact_pos = false) {
    610     DCHECK_GE(pos_lhs, 0);
    611     const string& datafile = DataFilename(Prefix(prefix), 0, 1);
    612     string data;
    613     TF_ASSERT_OK(ReadFileToString(Env::Default(), datafile, &data));
    614 
    615     int byte_pos = 0;
    616     if (!exact_pos) {
    617       std::mt19937 rng;
    618       std::uniform_int_distribution<int> dist(pos_lhs, data.size() - 1);
    619       byte_pos = dist(rng);
    620     } else {
    621       byte_pos = pos_lhs;
    622     }
    623     data[byte_pos] = ~data[byte_pos];
    624     TF_ASSERT_OK(WriteStringToFile(Env::Default(), datafile, data));
    625   };
    626   // The lookup should fail with a checksum-related message.
    627   auto ExpectLookupFails = [](const string& prefix, const string& key,
    628                               const string& expected_msg, Tensor& val) {
    629     BundleReader reader(Env::Default(), Prefix(prefix));
    630     Status status = reader.Lookup(key, &val);
    631     EXPECT_TRUE(errors::IsDataLoss(status));
    632     EXPECT_TRUE(StringPiece(status.ToString()).contains(expected_msg));
    633   };
    634 
    635   // Corrupts a float tensor.
    636   {
    637     BundleWriter writer(Env::Default(), Prefix("singleton"));
    638     TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f)));
    639     TF_ASSERT_OK(writer.Finish());
    640 
    641     FlipByte("singleton", 0 /* corrupts any byte */);
    642     Tensor val(DT_FLOAT, TensorShape({2, 3}));
    643     ExpectLookupFails("singleton", "foo",
    644                       "Checksum does not match" /* expected fail msg */, val);
    645   }
    646   // Corrupts a string tensor.
    647   {
    648     auto WriteStrings = []() {
    649       BundleWriter writer(Env::Default(), Prefix("strings"));
    650       TF_EXPECT_OK(
    651           writer.Add("foo", test::AsTensor<string>({"hello", "world"})));
    652       TF_ASSERT_OK(writer.Finish());
    653     };
    654     // Corrupts the first two bytes, which are the varint32-encoded lengths
    655     // of the two string elements.  Should hit mismatch on length cksum.
    656     for (int i = 0; i < 2; ++i) {
    657       WriteStrings();
    658       FlipByte("strings", i, true /* corrupts exactly byte i */);
    659       Tensor val(DT_STRING, TensorShape({2}));
    660       ExpectLookupFails(
    661           "strings", "foo",
    662           "length checksum does not match" /* expected fail msg */, val);
    663     }
    664     // Corrupts the string bytes, should hit an overall cksum mismatch.
    665     WriteStrings();
    666     FlipByte("strings", 2 /* corrupts starting from byte 2 */);
    667     Tensor val(DT_STRING, TensorShape({2}));
    668     ExpectLookupFails("strings", "foo",
    669                       "Checksum does not match" /* expected fail msg */, val);
    670   }
    671 }
    672 
    673 TEST(TensorBundleTest, Endianness) {
    674   BundleWriter writer(Env::Default(), Prefix("end"));
    675   TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0)));
    676   TF_ASSERT_OK(writer.Finish());
    677 
    678   // Flips the endianness bit.
    679   TF_ASSERT_OK(FlipEndiannessBit(Prefix("end")));
    680 
    681   BundleReader reader(Env::Default(), Prefix("end"));
    682   EXPECT_TRUE(errors::IsUnimplemented(reader.status()));
    683   EXPECT_TRUE(StringPiece(reader.status().ToString())
    684                   .contains("different endianness from the reader"));
    685 }
    686 
    687 TEST(TensorBundleTest, TruncatedTensorContents) {
    688   Env* env = Env::Default();
    689   BundleWriter writer(env, Prefix("end"));
    690   TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0)));
    691   TF_ASSERT_OK(writer.Finish());
    692 
    693   // Truncates the data file by one byte, so that we hit EOF.
    694   const string datafile = DataFilename(Prefix("end"), 0, 1);
    695   string data;
    696   TF_ASSERT_OK(ReadFileToString(env, datafile, &data));
    697   ASSERT_TRUE(!data.empty());
    698   TF_ASSERT_OK(WriteStringToFile(env, datafile,
    699                                  StringPiece(data.data(), data.size() - 1)));
    700 
    701   BundleReader reader(env, Prefix("end"));
    702   TF_ASSERT_OK(reader.status());
    703   Tensor val(DT_FLOAT, TensorShape({2, 3}));
    704   EXPECT_TRUE(errors::IsOutOfRange(reader.Lookup("key", &val)));
    705 }
    706 
    707 TEST(TensorBundleTest, HeaderEntry) {
    708   {
    709     BundleWriter writer(Env::Default(), Prefix("b"));
    710     TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0)));
    711     TF_ASSERT_OK(writer.Finish());
    712   }
    713 
    714   // Extracts out the header.
    715   BundleHeaderProto header;
    716   {
    717     BundleReader reader(Env::Default(), Prefix("b"));
    718     TF_ASSERT_OK(reader.status());
    719     reader.Seek(kHeaderEntryKey);
    720     ASSERT_TRUE(reader.Valid());
    721     ASSERT_TRUE(ParseProtoUnlimited(&header, reader.value().data(),
    722                                     reader.value().size()));
    723   }
    724 
    725   // num_shards
    726   EXPECT_EQ(1, header.num_shards());
    727   // endianness
    728   if (port::kLittleEndian) {
    729     EXPECT_EQ(BundleHeaderProto::LITTLE, header.endianness());
    730   } else {
    731     EXPECT_EQ(BundleHeaderProto::BIG, header.endianness());
    732   }
    733   // version
    734   EXPECT_GT(kTensorBundleVersion, 0);
    735   EXPECT_EQ(kTensorBundleVersion, header.version().producer());
    736   EXPECT_EQ(kTensorBundleMinConsumer, header.version().min_consumer());
    737 }
    738 
    739 TEST(TensorBundleTest, VersionTest) {
    740   // Min consumer.
    741   {
    742     VersionDef versions;
    743     versions.set_producer(kTensorBundleVersion + 1);
    744     versions.set_min_consumer(kTensorBundleVersion + 1);
    745     VersionTest(
    746         versions,
    747         strings::StrCat("Checkpoint min consumer version ",
    748                         kTensorBundleVersion + 1, " above current version ",
    749                         kTensorBundleVersion, " for TensorFlow"));
    750   }
    751   // Min producer.
    752   {
    753     VersionDef versions;
    754     versions.set_producer(kTensorBundleMinProducer - 1);
    755     VersionTest(
    756         versions,
    757         strings::StrCat("Checkpoint producer version ",
    758                         kTensorBundleMinProducer - 1, " below min producer ",
    759                         kTensorBundleMinProducer, " supported by TensorFlow"));
    760   }
    761   // Bad consumer.
    762   {
    763     VersionDef versions;
    764     versions.set_producer(kTensorBundleVersion + 1);
    765     versions.add_bad_consumers(kTensorBundleVersion);
    766     VersionTest(
    767         versions,
    768         strings::StrCat(
    769             "Checkpoint disallows consumer version ", kTensorBundleVersion,
    770             ".  Please upgrade TensorFlow: this version is likely buggy."));
    771   }
    772 }
    773 
    774 class TensorBundleAlignmentTest : public ::testing::Test {
    775  protected:
    776   template <typename T>
    777   void ExpectAlignment(BundleReader* reader, const string& key, int alignment) {
    778     BundleEntryProto full_tensor_entry;
    779     TF_ASSERT_OK(reader->GetBundleEntryProto(key, &full_tensor_entry));
    780     EXPECT_EQ(0, full_tensor_entry.offset() % alignment);
    781   }
    782 };
    783 
    784 TEST_F(TensorBundleAlignmentTest, AlignmentTest) {
    785   {
    786     BundleWriter::Options opts;
    787     opts.data_alignment = 42;
    788     BundleWriter writer(Env::Default(), Prefix("foo"), opts);
    789     TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3<float>(3)));
    790     TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3<float>(0)));
    791     TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3<float>(2)));
    792     TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3<float>(1)));
    793     TF_ASSERT_OK(writer.Finish());
    794   }
    795   {
    796     BundleReader reader(Env::Default(), Prefix("foo"));
    797     TF_ASSERT_OK(reader.status());
    798     EXPECT_EQ(
    799         AllTensorKeys(&reader),
    800         std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
    801     Expect<float>(&reader, "foo_000", Constant_2x3<float>(0));
    802     Expect<float>(&reader, "foo_001", Constant_2x3<float>(1));
    803     Expect<float>(&reader, "foo_002", Constant_2x3<float>(2));
    804     Expect<float>(&reader, "foo_003", Constant_2x3<float>(3));
    805   }
    806   {
    807     BundleReader reader(Env::Default(), Prefix("foo"));
    808     TF_ASSERT_OK(reader.status());
    809     ExpectNext<float>(&reader, Constant_2x3<float>(0));
    810     ExpectNext<float>(&reader, Constant_2x3<float>(1));
    811     ExpectNext<float>(&reader, Constant_2x3<float>(2));
    812     ExpectNext<float>(&reader, Constant_2x3<float>(3));
    813     EXPECT_TRUE(reader.Valid());
    814     reader.Next();
    815     EXPECT_FALSE(reader.Valid());
    816   }
    817   {
    818     BundleReader reader(Env::Default(), Prefix("foo"));
    819     TF_ASSERT_OK(reader.status());
    820     ExpectAlignment<float>(&reader, "foo_000", 42);
    821     ExpectAlignment<float>(&reader, "foo_001", 42);
    822     ExpectAlignment<float>(&reader, "foo_002", 42);
    823     ExpectAlignment<float>(&reader, "foo_003", 42);
    824   }
    825 }
    826 
    827 static void BM_BundleAlignmentByteOff(int iters, int alignment,
    828                                       int tensor_size) {
    829   testing::StopTiming();
    830   {
    831     BundleWriter::Options opts;
    832     opts.data_alignment = alignment;
    833     BundleWriter writer(Env::Default(), Prefix("foo"), opts);
    834     TF_CHECK_OK(writer.Add("small", Constant(true, TensorShape({1}))));
    835     TF_CHECK_OK(writer.Add("big", Constant(32.1, TensorShape({tensor_size}))));
    836     TF_CHECK_OK(writer.Finish());
    837   }
    838   BundleReader reader(Env::Default(), Prefix("foo"));
    839   TF_CHECK_OK(reader.status());
    840   testing::StartTiming();
    841   for (int i = 0; i < iters; ++i) {
    842     Tensor t;
    843     TF_CHECK_OK(reader.Lookup("big", &t));
    844   }
    845   testing::StopTiming();
    846 }
    847 
    848 #define BM_BundleAlignment(ALIGN, SIZE)                        \
    849   static void BM_BundleAlignment_##ALIGN##_##SIZE(int iters) { \
    850     BM_BundleAlignmentByteOff(iters, ALIGN, SIZE);             \
    851   }                                                            \
    852   BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE)
    853 
    854 BM_BundleAlignment(1, 512);
    855 BM_BundleAlignment(1, 4096);
    856 BM_BundleAlignment(1, 1048576);
    857 BM_BundleAlignment(4096, 512);
    858 BM_BundleAlignment(4096, 4096);
    859 BM_BundleAlignment(4096, 1048576);
    860 
    861 }  // namespace tensorflow
    862