1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Base class for testing reader datasets.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import gzip 22 import os 23 import zlib 24 25 from tensorflow.core.example import example_pb2 26 from tensorflow.core.example import feature_pb2 27 from tensorflow.python.data.experimental.ops import readers 28 from tensorflow.python.data.kernel_tests import test_base 29 from tensorflow.python.data.ops import readers as core_readers 30 from tensorflow.python.framework import dtypes 31 from tensorflow.python.lib.io import python_io 32 from tensorflow.python.ops import parsing_ops 33 from tensorflow.python.util import compat 34 35 36 class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase): 37 """Base class for setting up and testing FixedLengthRecordDataset.""" 38 39 def setUp(self): 40 super(FixedLengthRecordDatasetTestBase, self).setUp() 41 self._num_files = 2 42 self._num_records = 7 43 self._header_bytes = 5 44 self._record_bytes = 3 45 self._footer_bytes = 2 46 47 def _record(self, f, r): 48 return compat.as_bytes(str(f * 2 + r) * self._record_bytes) 49 50 def _createFiles(self): 51 filenames = [] 52 for i in range(self._num_files): 53 fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) 54 filenames.append(fn) 55 with open(fn, "wb") as f: 56 f.write(b"H" * self._header_bytes) 57 for j in range(self._num_records): 58 f.write(self._record(i, j)) 59 f.write(b"F" * self._footer_bytes) 60 return filenames 61 62 63 class MakeBatchedFeaturesDatasetTestBase(test_base.DatasetTestBase): 64 """Base class for setting up and testing `make_batched_features_dataset`.""" 65 66 def setUp(self): 67 super(MakeBatchedFeaturesDatasetTestBase, self).setUp() 68 self._num_files = 2 69 self._num_records = 7 70 self.test_filenames = self._createFiles() 71 72 def make_batch_feature(self, 73 filenames, 74 num_epochs, 75 batch_size, 76 label_key=None, 77 reader_num_threads=1, 78 parser_num_threads=1, 79 shuffle=False, 80 shuffle_seed=None, 81 drop_final_batch=False): 82 self.filenames = filenames 83 self.num_epochs = num_epochs 84 self.batch_size = batch_size 85 86 return readers.make_batched_features_dataset( 87 file_pattern=self.filenames, 88 batch_size=self.batch_size, 89 features={ 90 "file": parsing_ops.FixedLenFeature([], dtypes.int64), 91 "record": parsing_ops.FixedLenFeature([], dtypes.int64), 92 "keywords": parsing_ops.VarLenFeature(dtypes.string), 93 "label": parsing_ops.FixedLenFeature([], dtypes.string), 94 }, 95 label_key=label_key, 96 reader=core_readers.TFRecordDataset, 97 num_epochs=self.num_epochs, 98 shuffle=shuffle, 99 shuffle_seed=shuffle_seed, 100 reader_num_threads=reader_num_threads, 101 parser_num_threads=parser_num_threads, 102 drop_final_batch=drop_final_batch) 103 104 def _record(self, f, r, l): 105 example = example_pb2.Example( 106 features=feature_pb2.Features( 107 feature={ 108 "file": 109 feature_pb2.Feature( 110 int64_list=feature_pb2.Int64List(value=[f])), 111 "record": 112 feature_pb2.Feature( 113 int64_list=feature_pb2.Int64List(value=[r])), 114 "keywords": 115 feature_pb2.Feature( 116 bytes_list=feature_pb2.BytesList( 117 value=self._get_keywords(f, r))), 118 "label": 119 feature_pb2.Feature( 120 bytes_list=feature_pb2.BytesList( 121 value=[compat.as_bytes(l)])) 122 })) 123 return example.SerializeToString() 124 125 def _get_keywords(self, f, r): 126 num_keywords = 1 + (f + r) % 2 127 keywords = [] 128 for index in range(num_keywords): 129 keywords.append(compat.as_bytes("keyword%d" % index)) 130 return keywords 131 132 def _sum_keywords(self, num_files): 133 sum_keywords = 0 134 for i in range(num_files): 135 for j in range(self._num_records): 136 sum_keywords += 1 + (i + j) % 2 137 return sum_keywords 138 139 def _createFiles(self): 140 filenames = [] 141 for i in range(self._num_files): 142 fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) 143 filenames.append(fn) 144 writer = python_io.TFRecordWriter(fn) 145 for j in range(self._num_records): 146 writer.write(self._record(i, j, "fake-label")) 147 writer.close() 148 return filenames 149 150 def _run_actual_batch(self, outputs, label_key_provided=False): 151 if label_key_provided: 152 # outputs would be a tuple of (feature dict, label) 153 features, label = self.evaluate(outputs()) 154 else: 155 features = self.evaluate(outputs()) 156 label = features["label"] 157 file_out = features["file"] 158 keywords_indices = features["keywords"].indices 159 keywords_values = features["keywords"].values 160 keywords_dense_shape = features["keywords"].dense_shape 161 record = features["record"] 162 return ([ 163 file_out, keywords_indices, keywords_values, keywords_dense_shape, 164 record, label 165 ]) 166 167 def _next_actual_batch(self, label_key_provided=False): 168 return self._run_actual_batch(self.outputs, label_key_provided) 169 170 def _interleave(self, iterators, cycle_length): 171 pending_iterators = iterators 172 open_iterators = [] 173 num_open = 0 174 for i in range(cycle_length): 175 if pending_iterators: 176 open_iterators.append(pending_iterators.pop(0)) 177 num_open += 1 178 179 while num_open: 180 for i in range(min(cycle_length, len(open_iterators))): 181 if open_iterators[i] is None: 182 continue 183 try: 184 yield next(open_iterators[i]) 185 except StopIteration: 186 if pending_iterators: 187 open_iterators[i] = pending_iterators.pop(0) 188 else: 189 open_iterators[i] = None 190 num_open -= 1 191 192 def _next_expected_batch(self, 193 file_indices, 194 batch_size, 195 num_epochs, 196 cycle_length=1): 197 198 def _next_record(file_indices): 199 for j in file_indices: 200 for i in range(self._num_records): 201 yield j, i, compat.as_bytes("fake-label") 202 203 def _next_record_interleaved(file_indices, cycle_length): 204 return self._interleave([_next_record([i]) for i in file_indices], 205 cycle_length) 206 207 file_batch = [] 208 keywords_batch_indices = [] 209 keywords_batch_values = [] 210 keywords_batch_max_len = 0 211 record_batch = [] 212 batch_index = 0 213 label_batch = [] 214 for _ in range(num_epochs): 215 if cycle_length == 1: 216 next_records = _next_record(file_indices) 217 else: 218 next_records = _next_record_interleaved(file_indices, cycle_length) 219 for record in next_records: 220 f = record[0] 221 r = record[1] 222 label_batch.append(record[2]) 223 file_batch.append(f) 224 record_batch.append(r) 225 keywords = self._get_keywords(f, r) 226 keywords_batch_values.extend(keywords) 227 keywords_batch_indices.extend( 228 [[batch_index, i] for i in range(len(keywords))]) 229 batch_index += 1 230 keywords_batch_max_len = max(keywords_batch_max_len, len(keywords)) 231 if len(file_batch) == batch_size: 232 yield [ 233 file_batch, keywords_batch_indices, keywords_batch_values, 234 [batch_size, keywords_batch_max_len], record_batch, label_batch 235 ] 236 file_batch = [] 237 keywords_batch_indices = [] 238 keywords_batch_values = [] 239 keywords_batch_max_len = 0 240 record_batch = [] 241 batch_index = 0 242 label_batch = [] 243 if file_batch: 244 yield [ 245 file_batch, keywords_batch_indices, keywords_batch_values, 246 [len(file_batch), keywords_batch_max_len], record_batch, label_batch 247 ] 248 249 def verify_records(self, 250 batch_size, 251 file_index=None, 252 num_epochs=1, 253 label_key_provided=False, 254 interleave_cycle_length=1): 255 if file_index is not None: 256 file_indices = [file_index] 257 else: 258 file_indices = range(self._num_files) 259 260 for expected_batch in self._next_expected_batch( 261 file_indices, 262 batch_size, 263 num_epochs, 264 cycle_length=interleave_cycle_length): 265 actual_batch = self._next_actual_batch( 266 label_key_provided=label_key_provided) 267 for i in range(len(expected_batch)): 268 self.assertAllEqual(expected_batch[i], actual_batch[i]) 269 270 271 class TextLineDatasetTestBase(test_base.DatasetTestBase): 272 """Base class for setting up and testing TextLineDataset.""" 273 274 def _lineText(self, f, l): 275 return compat.as_bytes("%d: %d" % (f, l)) 276 277 def _createFiles(self, 278 num_files, 279 num_lines, 280 crlf=False, 281 compression_type=None): 282 filenames = [] 283 for i in range(num_files): 284 fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) 285 filenames.append(fn) 286 contents = [] 287 for j in range(num_lines): 288 contents.append(self._lineText(i, j)) 289 # Always include a newline after the record unless it is 290 # at the end of the file, in which case we include it 291 if j + 1 != num_lines or i == 0: 292 contents.append(b"\r\n" if crlf else b"\n") 293 contents = b"".join(contents) 294 295 if not compression_type: 296 with open(fn, "wb") as f: 297 f.write(contents) 298 elif compression_type == "GZIP": 299 with gzip.GzipFile(fn, "wb") as f: 300 f.write(contents) 301 elif compression_type == "ZLIB": 302 contents = zlib.compress(contents) 303 with open(fn, "wb") as f: 304 f.write(contents) 305 else: 306 raise ValueError("Unsupported compression_type", compression_type) 307 308 return filenames 309 310 311 class TFRecordDatasetTestBase(test_base.DatasetTestBase): 312 """Base class for setting up and testing TFRecordDataset.""" 313 314 def _interleave(self, iterators, cycle_length): 315 pending_iterators = iterators 316 open_iterators = [] 317 num_open = 0 318 for i in range(cycle_length): 319 if pending_iterators: 320 open_iterators.append(pending_iterators.pop(0)) 321 num_open += 1 322 323 while num_open: 324 for i in range(min(cycle_length, len(open_iterators))): 325 if open_iterators[i] is None: 326 continue 327 try: 328 yield next(open_iterators[i]) 329 except StopIteration: 330 if pending_iterators: 331 open_iterators[i] = pending_iterators.pop(0) 332 else: 333 open_iterators[i] = None 334 num_open -= 1 335 336 def _next_expected_batch(self, file_indices, batch_size, num_epochs, 337 cycle_length, drop_final_batch, use_parser_fn): 338 339 def _next_record(file_indices): 340 for j in file_indices: 341 for i in range(self._num_records): 342 yield j, i 343 344 def _next_record_interleaved(file_indices, cycle_length): 345 return self._interleave([_next_record([i]) for i in file_indices], 346 cycle_length) 347 348 record_batch = [] 349 batch_index = 0 350 for _ in range(num_epochs): 351 if cycle_length == 1: 352 next_records = _next_record(file_indices) 353 else: 354 next_records = _next_record_interleaved(file_indices, cycle_length) 355 for f, r in next_records: 356 record = self._record(f, r) 357 if use_parser_fn: 358 record = record[1:] 359 record_batch.append(record) 360 batch_index += 1 361 if len(record_batch) == batch_size: 362 yield record_batch 363 record_batch = [] 364 batch_index = 0 365 if record_batch and not drop_final_batch: 366 yield record_batch 367 368 def _verify_records(self, outputs, batch_size, file_index, num_epochs, 369 interleave_cycle_length, drop_final_batch, use_parser_fn): 370 if file_index is not None: 371 if isinstance(file_index, list): 372 file_indices = file_index 373 else: 374 file_indices = [file_index] 375 else: 376 file_indices = range(self._num_files) 377 378 for expected_batch in self._next_expected_batch( 379 file_indices, batch_size, num_epochs, interleave_cycle_length, 380 drop_final_batch, use_parser_fn): 381 actual_batch = self.evaluate(outputs()) 382 self.assertAllEqual(expected_batch, actual_batch) 383 384 def setUp(self): 385 super(TFRecordDatasetTestBase, self).setUp() 386 self._num_files = 2 387 self._num_records = 7 388 389 self.test_filenames = self._createFiles() 390 391 def _record(self, f, r): 392 return compat.as_bytes("Record %d of file %d" % (r, f)) 393 394 def _createFiles(self): 395 filenames = [] 396 for i in range(self._num_files): 397 fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) 398 filenames.append(fn) 399 writer = python_io.TFRecordWriter(fn) 400 for j in range(self._num_records): 401 writer.write(self._record(i, j)) 402 writer.close() 403 return filenames 404