Home | History | Annotate | Download | only in io
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/lib/core/coding.h"
     17 #include "tensorflow/core/lib/core/errors.h"
     18 #include "tensorflow/core/lib/core/status_test_util.h"
     19 #include "tensorflow/core/lib/hash/crc32c.h"
     20 #include "tensorflow/core/lib/io/record_reader.h"
     21 #include "tensorflow/core/lib/io/record_writer.h"
     22 #include "tensorflow/core/lib/random/simple_philox.h"
     23 #include "tensorflow/core/platform/env.h"
     24 #include "tensorflow/core/platform/test.h"
     25 
     26 namespace tensorflow {
     27 namespace io {
     28 
     29 // Construct a string of the specified length made out of the supplied
     30 // partial string.
     31 static string BigString(const string& partial_string, size_t n) {
     32   string result;
     33   while (result.size() < n) {
     34     result.append(partial_string);
     35   }
     36   result.resize(n);
     37   return result;
     38 }
     39 
     40 // Construct a string from a number
     41 static string NumberString(int n) {
     42   char buf[50];
     43   snprintf(buf, sizeof(buf), "%d.", n);
     44   return string(buf);
     45 }
     46 
     47 // Return a skewed potentially long string
     48 static string RandomSkewedString(int i, random::SimplePhilox* rnd) {
     49   return BigString(NumberString(i), rnd->Skewed(17));
     50 }
     51 
     52 class RecordioTest : public ::testing::Test {
     53  private:
     54   class StringDest : public WritableFile {
     55    public:
     56     string contents_;
     57 
     58     Status Close() override { return Status::OK(); }
     59     Status Flush() override { return Status::OK(); }
     60     Status Sync() override { return Status::OK(); }
     61     Status Append(const StringPiece& slice) override {
     62       contents_.append(slice.data(), slice.size());
     63       return Status::OK();
     64     }
     65   };
     66 
     67   class StringSource : public RandomAccessFile {
     68    public:
     69     StringPiece contents_;
     70     mutable bool force_error_;
     71     mutable bool returned_partial_;
     72     StringSource() : force_error_(false), returned_partial_(false) {}
     73 
     74     Status Read(uint64 offset, size_t n, StringPiece* result,
     75                 char* scratch) const override {
     76       EXPECT_FALSE(returned_partial_) << "must not Read() after eof/error";
     77 
     78       if (force_error_) {
     79         force_error_ = false;
     80         returned_partial_ = true;
     81         return errors::DataLoss("read error");
     82       }
     83 
     84       if (offset >= contents_.size()) {
     85         return errors::OutOfRange("end of file");
     86       }
     87 
     88       if (contents_.size() < offset + n) {
     89         n = contents_.size() - offset;
     90         returned_partial_ = true;
     91       }
     92       *result = StringPiece(contents_.data() + offset, n);
     93       return Status::OK();
     94     }
     95   };
     96 
     97   StringDest dest_;
     98   StringSource source_;
     99   bool reading_;
    100   uint64 readpos_;
    101   RecordWriter* writer_;
    102   RecordReader* reader_;
    103 
    104  public:
    105   RecordioTest()
    106       : reading_(false),
    107         readpos_(0),
    108         writer_(new RecordWriter(&dest_)),
    109         reader_(new RecordReader(&source_)) {}
    110 
    111   ~RecordioTest() override {
    112     delete writer_;
    113     delete reader_;
    114   }
    115 
    116   void Write(const string& msg) {
    117     ASSERT_TRUE(!reading_) << "Write() after starting to read";
    118     TF_ASSERT_OK(writer_->WriteRecord(StringPiece(msg)));
    119   }
    120 
    121   size_t WrittenBytes() const { return dest_.contents_.size(); }
    122 
    123   string Read() {
    124     if (!reading_) {
    125       reading_ = true;
    126       source_.contents_ = StringPiece(dest_.contents_);
    127     }
    128     string record;
    129     Status s = reader_->ReadRecord(&readpos_, &record);
    130     if (s.ok()) {
    131       return record;
    132     } else if (errors::IsOutOfRange(s)) {
    133       return "EOF";
    134     } else {
    135       return s.ToString();
    136     }
    137   }
    138 
    139   void IncrementByte(int offset, int delta) {
    140     dest_.contents_[offset] += delta;
    141   }
    142 
    143   void SetByte(int offset, char new_byte) {
    144     dest_.contents_[offset] = new_byte;
    145   }
    146 
    147   void ShrinkSize(int bytes) {
    148     dest_.contents_.resize(dest_.contents_.size() - bytes);
    149   }
    150 
    151   void FixChecksum(int header_offset, int len) {
    152     // Compute crc of type/len/data
    153     uint32_t crc = crc32c::Value(&dest_.contents_[header_offset + 6], 1 + len);
    154     crc = crc32c::Mask(crc);
    155     core::EncodeFixed32(&dest_.contents_[header_offset], crc);
    156   }
    157 
    158   void ForceError() { source_.force_error_ = true; }
    159 
    160   void StartReadingAt(uint64_t initial_offset) { readpos_ = initial_offset; }
    161 
    162   void CheckOffsetPastEndReturnsNoRecords(uint64_t offset_past_end) {
    163     Write("foo");
    164     Write("bar");
    165     Write(BigString("x", 10000));
    166     reading_ = true;
    167     source_.contents_ = StringPiece(dest_.contents_);
    168     uint64 offset = WrittenBytes() + offset_past_end;
    169     string record;
    170     Status s = reader_->ReadRecord(&offset, &record);
    171     ASSERT_TRUE(errors::IsOutOfRange(s)) << s;
    172   }
    173 };
    174 
    175 TEST_F(RecordioTest, Empty) { ASSERT_EQ("EOF", Read()); }
    176 
    177 TEST_F(RecordioTest, ReadWrite) {
    178   Write("foo");
    179   Write("bar");
    180   Write("");
    181   Write("xxxx");
    182   ASSERT_EQ("foo", Read());
    183   ASSERT_EQ("bar", Read());
    184   ASSERT_EQ("", Read());
    185   ASSERT_EQ("xxxx", Read());
    186   ASSERT_EQ("EOF", Read());
    187   ASSERT_EQ("EOF", Read());  // Make sure reads at eof work
    188 }
    189 
    190 TEST_F(RecordioTest, ManyRecords) {
    191   for (int i = 0; i < 100000; i++) {
    192     Write(NumberString(i));
    193   }
    194   for (int i = 0; i < 100000; i++) {
    195     ASSERT_EQ(NumberString(i), Read());
    196   }
    197   ASSERT_EQ("EOF", Read());
    198 }
    199 
    200 TEST_F(RecordioTest, RandomRead) {
    201   const int N = 500;
    202   {
    203     random::PhiloxRandom philox(301, 17);
    204     random::SimplePhilox rnd(&philox);
    205     for (int i = 0; i < N; i++) {
    206       Write(RandomSkewedString(i, &rnd));
    207     }
    208   }
    209   {
    210     random::PhiloxRandom philox(301, 17);
    211     random::SimplePhilox rnd(&philox);
    212     for (int i = 0; i < N; i++) {
    213       ASSERT_EQ(RandomSkewedString(i, &rnd), Read());
    214     }
    215   }
    216   ASSERT_EQ("EOF", Read());
    217 }
    218 
    219 // Tests of all the error paths in log_reader.cc follow:
    220 static void AssertHasSubstr(StringPiece s, StringPiece expected) {
    221   EXPECT_TRUE(StringPiece(s).contains(expected))
    222       << s << " does not contain " << expected;
    223 }
    224 
    225 TEST_F(RecordioTest, ReadError) {
    226   Write("foo");
    227   ForceError();
    228   AssertHasSubstr(Read(), "Data loss");
    229 }
    230 
    231 TEST_F(RecordioTest, CorruptLength) {
    232   Write("foo");
    233   IncrementByte(6, 100);
    234   AssertHasSubstr(Read(), "Data loss");
    235 }
    236 
    237 TEST_F(RecordioTest, CorruptLengthCrc) {
    238   Write("foo");
    239   IncrementByte(10, 100);
    240   AssertHasSubstr(Read(), "Data loss");
    241 }
    242 
    243 TEST_F(RecordioTest, CorruptData) {
    244   Write("foo");
    245   IncrementByte(14, 10);
    246   AssertHasSubstr(Read(), "Data loss");
    247 }
    248 
    249 TEST_F(RecordioTest, CorruptDataCrc) {
    250   Write("foo");
    251   IncrementByte(WrittenBytes() - 1, 10);
    252   AssertHasSubstr(Read(), "Data loss");
    253 }
    254 
    255 TEST_F(RecordioTest, ReadEnd) { CheckOffsetPastEndReturnsNoRecords(0); }
    256 
    257 TEST_F(RecordioTest, ReadPastEnd) { CheckOffsetPastEndReturnsNoRecords(5); }
    258 
    259 }  // namespace io
    260 }  // namespace tensorflow
    261