1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/lib/core/coding.h" 17 #include "tensorflow/core/lib/core/errors.h" 18 #include "tensorflow/core/lib/core/status_test_util.h" 19 #include "tensorflow/core/lib/hash/crc32c.h" 20 #include "tensorflow/core/lib/io/record_reader.h" 21 #include "tensorflow/core/lib/io/record_writer.h" 22 #include "tensorflow/core/lib/random/simple_philox.h" 23 #include "tensorflow/core/platform/env.h" 24 #include "tensorflow/core/platform/test.h" 25 26 namespace tensorflow { 27 namespace io { 28 29 // Construct a string of the specified length made out of the supplied 30 // partial string. 31 static string BigString(const string& partial_string, size_t n) { 32 string result; 33 while (result.size() < n) { 34 result.append(partial_string); 35 } 36 result.resize(n); 37 return result; 38 } 39 40 // Construct a string from a number 41 static string NumberString(int n) { 42 char buf[50]; 43 snprintf(buf, sizeof(buf), "%d.", n); 44 return string(buf); 45 } 46 47 // Return a skewed potentially long string 48 static string RandomSkewedString(int i, random::SimplePhilox* rnd) { 49 return BigString(NumberString(i), rnd->Skewed(17)); 50 } 51 52 class RecordioTest : public ::testing::Test { 53 private: 54 class StringDest : public WritableFile { 55 public: 56 string contents_; 57 58 Status Close() override { return Status::OK(); } 59 Status Flush() override { return Status::OK(); } 60 Status Sync() override { return Status::OK(); } 61 Status Append(const StringPiece& slice) override { 62 contents_.append(slice.data(), slice.size()); 63 return Status::OK(); 64 } 65 }; 66 67 class StringSource : public RandomAccessFile { 68 public: 69 StringPiece contents_; 70 mutable bool force_error_; 71 mutable bool returned_partial_; 72 StringSource() : force_error_(false), returned_partial_(false) {} 73 74 Status Read(uint64 offset, size_t n, StringPiece* result, 75 char* scratch) const override { 76 EXPECT_FALSE(returned_partial_) << "must not Read() after eof/error"; 77 78 if (force_error_) { 79 force_error_ = false; 80 returned_partial_ = true; 81 return errors::DataLoss("read error"); 82 } 83 84 if (offset >= contents_.size()) { 85 return errors::OutOfRange("end of file"); 86 } 87 88 if (contents_.size() < offset + n) { 89 n = contents_.size() - offset; 90 returned_partial_ = true; 91 } 92 *result = StringPiece(contents_.data() + offset, n); 93 return Status::OK(); 94 } 95 }; 96 97 StringDest dest_; 98 StringSource source_; 99 bool reading_; 100 uint64 readpos_; 101 RecordWriter* writer_; 102 RecordReader* reader_; 103 104 public: 105 RecordioTest() 106 : reading_(false), 107 readpos_(0), 108 writer_(new RecordWriter(&dest_)), 109 reader_(new RecordReader(&source_)) {} 110 111 ~RecordioTest() override { 112 delete writer_; 113 delete reader_; 114 } 115 116 void Write(const string& msg) { 117 ASSERT_TRUE(!reading_) << "Write() after starting to read"; 118 TF_ASSERT_OK(writer_->WriteRecord(StringPiece(msg))); 119 } 120 121 size_t WrittenBytes() const { return dest_.contents_.size(); } 122 123 string Read() { 124 if (!reading_) { 125 reading_ = true; 126 source_.contents_ = StringPiece(dest_.contents_); 127 } 128 string record; 129 Status s = reader_->ReadRecord(&readpos_, &record); 130 if (s.ok()) { 131 return record; 132 } else if (errors::IsOutOfRange(s)) { 133 return "EOF"; 134 } else { 135 return s.ToString(); 136 } 137 } 138 139 void IncrementByte(int offset, int delta) { 140 dest_.contents_[offset] += delta; 141 } 142 143 void SetByte(int offset, char new_byte) { 144 dest_.contents_[offset] = new_byte; 145 } 146 147 void ShrinkSize(int bytes) { 148 dest_.contents_.resize(dest_.contents_.size() - bytes); 149 } 150 151 void FixChecksum(int header_offset, int len) { 152 // Compute crc of type/len/data 153 uint32_t crc = crc32c::Value(&dest_.contents_[header_offset + 6], 1 + len); 154 crc = crc32c::Mask(crc); 155 core::EncodeFixed32(&dest_.contents_[header_offset], crc); 156 } 157 158 void ForceError() { source_.force_error_ = true; } 159 160 void StartReadingAt(uint64_t initial_offset) { readpos_ = initial_offset; } 161 162 void CheckOffsetPastEndReturnsNoRecords(uint64_t offset_past_end) { 163 Write("foo"); 164 Write("bar"); 165 Write(BigString("x", 10000)); 166 reading_ = true; 167 source_.contents_ = StringPiece(dest_.contents_); 168 uint64 offset = WrittenBytes() + offset_past_end; 169 string record; 170 Status s = reader_->ReadRecord(&offset, &record); 171 ASSERT_TRUE(errors::IsOutOfRange(s)) << s; 172 } 173 }; 174 175 TEST_F(RecordioTest, Empty) { ASSERT_EQ("EOF", Read()); } 176 177 TEST_F(RecordioTest, ReadWrite) { 178 Write("foo"); 179 Write("bar"); 180 Write(""); 181 Write("xxxx"); 182 ASSERT_EQ("foo", Read()); 183 ASSERT_EQ("bar", Read()); 184 ASSERT_EQ("", Read()); 185 ASSERT_EQ("xxxx", Read()); 186 ASSERT_EQ("EOF", Read()); 187 ASSERT_EQ("EOF", Read()); // Make sure reads at eof work 188 } 189 190 TEST_F(RecordioTest, ManyRecords) { 191 for (int i = 0; i < 100000; i++) { 192 Write(NumberString(i)); 193 } 194 for (int i = 0; i < 100000; i++) { 195 ASSERT_EQ(NumberString(i), Read()); 196 } 197 ASSERT_EQ("EOF", Read()); 198 } 199 200 TEST_F(RecordioTest, RandomRead) { 201 const int N = 500; 202 { 203 random::PhiloxRandom philox(301, 17); 204 random::SimplePhilox rnd(&philox); 205 for (int i = 0; i < N; i++) { 206 Write(RandomSkewedString(i, &rnd)); 207 } 208 } 209 { 210 random::PhiloxRandom philox(301, 17); 211 random::SimplePhilox rnd(&philox); 212 for (int i = 0; i < N; i++) { 213 ASSERT_EQ(RandomSkewedString(i, &rnd), Read()); 214 } 215 } 216 ASSERT_EQ("EOF", Read()); 217 } 218 219 // Tests of all the error paths in log_reader.cc follow: 220 static void AssertHasSubstr(StringPiece s, StringPiece expected) { 221 EXPECT_TRUE(StringPiece(s).contains(expected)) 222 << s << " does not contain " << expected; 223 } 224 225 TEST_F(RecordioTest, ReadError) { 226 Write("foo"); 227 ForceError(); 228 AssertHasSubstr(Read(), "Data loss"); 229 } 230 231 TEST_F(RecordioTest, CorruptLength) { 232 Write("foo"); 233 IncrementByte(6, 100); 234 AssertHasSubstr(Read(), "Data loss"); 235 } 236 237 TEST_F(RecordioTest, CorruptLengthCrc) { 238 Write("foo"); 239 IncrementByte(10, 100); 240 AssertHasSubstr(Read(), "Data loss"); 241 } 242 243 TEST_F(RecordioTest, CorruptData) { 244 Write("foo"); 245 IncrementByte(14, 10); 246 AssertHasSubstr(Read(), "Data loss"); 247 } 248 249 TEST_F(RecordioTest, CorruptDataCrc) { 250 Write("foo"); 251 IncrementByte(WrittenBytes() - 1, 10); 252 AssertHasSubstr(Read(), "Data loss"); 253 } 254 255 TEST_F(RecordioTest, ReadEnd) { CheckOffsetPastEndReturnsNoRecords(0); } 256 257 TEST_F(RecordioTest, ReadPastEnd) { CheckOffsetPastEndReturnsNoRecords(5); } 258 259 } // namespace io 260 } // namespace tensorflow 261