1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import gzip 21 import os 22 import zlib 23 24 from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base 25 from tensorflow.contrib.data.python.ops import readers 26 from tensorflow.core.example import example_pb2 27 from tensorflow.core.example import feature_pb2 28 from tensorflow.python.data.ops import iterator_ops 29 from tensorflow.python.data.ops import readers as core_readers 30 from tensorflow.python.framework import constant_op 31 from tensorflow.python.framework import dtypes 32 from tensorflow.python.framework import errors 33 from tensorflow.python.framework import ops 34 from tensorflow.python.lib.io import python_io 35 from tensorflow.python.ops import array_ops 36 from tensorflow.python.ops import parsing_ops 37 from tensorflow.python.platform import test 38 from tensorflow.python.util import compat 39 40 41 class TextLineDatasetTestBase(test.TestCase): 42 43 def _lineText(self, f, l): 44 return compat.as_bytes("%d: %d" % (f, l)) 45 46 def _createFiles(self, 47 num_files, 48 num_lines, 49 crlf=False, 50 compression_type=None): 51 filenames = [] 52 for i in range(num_files): 53 fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) 54 filenames.append(fn) 55 contents = [] 56 for j in range(num_lines): 57 contents.append(self._lineText(i, j)) 58 # Always include a newline after the record unless it is 59 # at the end of the file, in which case we include it 60 if j + 1 != num_lines or i == 0: 61 contents.append(b"\r\n" if crlf else b"\n") 62 contents = b"".join(contents) 63 64 if not compression_type: 65 with open(fn, "wb") as f: 66 f.write(contents) 67 elif compression_type == "GZIP": 68 with gzip.GzipFile(fn, "wb") as f: 69 f.write(contents) 70 elif compression_type == "ZLIB": 71 contents = zlib.compress(contents) 72 with open(fn, "wb") as f: 73 f.write(contents) 74 else: 75 raise ValueError("Unsupported compression_type", compression_type) 76 77 return filenames 78 79 80 class TextLineDatasetSerializationTest( 81 TextLineDatasetTestBase, 82 dataset_serialization_test_base.DatasetSerializationTestBase): 83 84 def _build_iterator_graph(self, test_filenames, compression_type=None): 85 return core_readers.TextLineDataset( 86 test_filenames, compression_type=compression_type, buffer_size=10) 87 88 def testTextLineCore(self): 89 compression_types = [None, "GZIP", "ZLIB"] 90 num_files = 5 91 lines_per_file = 5 92 num_outputs = num_files * lines_per_file 93 for compression_type in compression_types: 94 test_filenames = self._createFiles( 95 num_files, 96 lines_per_file, 97 crlf=True, 98 compression_type=compression_type) 99 # pylint: disable=cell-var-from-loop 100 self.run_core_tests( 101 lambda: self._build_iterator_graph(test_filenames, compression_type), 102 lambda: self._build_iterator_graph(test_filenames), num_outputs) 103 # pylint: enable=cell-var-from-loop 104 105 106 class FixedLengthRecordReaderTestBase(test.TestCase): 107 108 def setUp(self): 109 super(FixedLengthRecordReaderTestBase, self).setUp() 110 self._num_files = 2 111 self._num_records = 7 112 self._header_bytes = 5 113 self._record_bytes = 3 114 self._footer_bytes = 2 115 116 def _record(self, f, r): 117 return compat.as_bytes(str(f * 2 + r) * self._record_bytes) 118 119 def _createFiles(self): 120 filenames = [] 121 for i in range(self._num_files): 122 fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) 123 filenames.append(fn) 124 with open(fn, "wb") as f: 125 f.write(b"H" * self._header_bytes) 126 for j in range(self._num_records): 127 f.write(self._record(i, j)) 128 f.write(b"F" * self._footer_bytes) 129 return filenames 130 131 132 class FixedLengthRecordDatasetSerializationTest( 133 FixedLengthRecordReaderTestBase, 134 dataset_serialization_test_base.DatasetSerializationTestBase): 135 136 def _build_iterator_graph(self, num_epochs, compression_type=None): 137 filenames = self._createFiles() 138 return core_readers.FixedLengthRecordDataset( 139 filenames, self._record_bytes, self._header_bytes, 140 self._footer_bytes).repeat(num_epochs) 141 142 def testFixedLengthRecordCore(self): 143 num_epochs = 5 144 num_outputs = num_epochs * self._num_files * self._num_records 145 self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), 146 lambda: self._build_iterator_graph(num_epochs * 2), 147 num_outputs) 148 149 150 class TFRecordDatasetTestBase(test.TestCase): 151 152 def setUp(self): 153 super(TFRecordDatasetTestBase, self).setUp() 154 self._num_files = 2 155 self._num_records = 7 156 157 self.test_filenames = self._createFiles() 158 159 self.filenames = array_ops.placeholder(dtypes.string, shape=[None]) 160 self.num_epochs = array_ops.placeholder_with_default( 161 constant_op.constant(1, dtypes.int64), shape=[]) 162 self.compression_type = array_ops.placeholder_with_default("", shape=[]) 163 self.batch_size = array_ops.placeholder(dtypes.int64, shape=[]) 164 165 repeat_dataset = core_readers.TFRecordDataset( 166 self.filenames, self.compression_type).repeat(self.num_epochs) 167 batch_dataset = repeat_dataset.batch(self.batch_size) 168 169 iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) 170 self.init_op = iterator.make_initializer(repeat_dataset) 171 self.init_batch_op = iterator.make_initializer(batch_dataset) 172 self.get_next = iterator.get_next() 173 174 def _record(self, f, r): 175 return compat.as_bytes("Record %d of file %d" % (r, f)) 176 177 def _createFiles(self): 178 filenames = [] 179 for i in range(self._num_files): 180 fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) 181 filenames.append(fn) 182 writer = python_io.TFRecordWriter(fn) 183 for j in range(self._num_records): 184 writer.write(self._record(i, j)) 185 writer.close() 186 return filenames 187 188 189 class TFRecordDatasetSerializationTest( 190 TFRecordDatasetTestBase, 191 dataset_serialization_test_base.DatasetSerializationTestBase): 192 193 def _build_iterator_graph(self, 194 num_epochs, 195 batch_size=1, 196 compression_type=None, 197 buffer_size=None): 198 filenames = self._createFiles() 199 if compression_type is "ZLIB": 200 zlib_files = [] 201 for i, fn in enumerate(filenames): 202 with open(fn, "rb") as f: 203 cdata = zlib.compress(f.read()) 204 zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) 205 with open(zfn, "wb") as f: 206 f.write(cdata) 207 zlib_files.append(zfn) 208 filenames = zlib_files 209 210 elif compression_type is "GZIP": 211 gzip_files = [] 212 for i, fn in enumerate(self.test_filenames): 213 with open(fn, "rb") as f: 214 gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) 215 with gzip.GzipFile(gzfn, "wb") as gzf: 216 gzf.write(f.read()) 217 gzip_files.append(gzfn) 218 filenames = gzip_files 219 220 return core_readers.TFRecordDataset( 221 filenames, compression_type, 222 buffer_size=buffer_size).repeat(num_epochs).batch(batch_size) 223 224 def testTFRecordWithoutBufferCore(self): 225 num_epochs = 5 226 batch_size = num_epochs 227 num_outputs = num_epochs * self._num_files * self._num_records // batch_size 228 # pylint: disable=g-long-lambda 229 self.run_core_tests( 230 lambda: self._build_iterator_graph(num_epochs, batch_size, 231 buffer_size=0), 232 lambda: self._build_iterator_graph(num_epochs * 2, batch_size), 233 num_outputs) 234 self.run_core_tests( 235 lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None, 236 num_outputs * batch_size) 237 # pylint: enable=g-long-lambda 238 239 def testTFRecordWithBufferCore(self): 240 num_epochs = 5 241 num_outputs = num_epochs * self._num_files * self._num_records 242 self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), 243 lambda: self._build_iterator_graph(num_epochs * 2), 244 num_outputs) 245 246 def testTFRecordWithCompressionCore(self): 247 num_epochs = 5 248 num_outputs = num_epochs * self._num_files * self._num_records 249 self.run_core_tests( 250 lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"), 251 lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) 252 self.run_core_tests( 253 lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"), 254 lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) 255 256 257 class ReadBatchFeaturesTest(test.TestCase): 258 259 def setUp(self): 260 super(ReadBatchFeaturesTest, self).setUp() 261 self._num_files = 2 262 self._num_records = 7 263 self.test_filenames = self._createFiles() 264 265 def _read_batch_features(self, filenames, num_epochs, batch_size): 266 self.filenames = filenames 267 self.num_epochs = num_epochs 268 self.batch_size = batch_size 269 270 return readers.read_batch_features( 271 file_pattern=self.filenames, 272 batch_size=self.batch_size, 273 features={ 274 "file": parsing_ops.FixedLenFeature([], dtypes.int64), 275 "record": parsing_ops.FixedLenFeature([], dtypes.int64), 276 "keywords": parsing_ops.VarLenFeature(dtypes.string) 277 }, 278 reader=core_readers.TFRecordDataset, 279 randomize_input=False, 280 num_epochs=self.num_epochs) 281 282 def _record(self, f, r): 283 example = example_pb2.Example(features=feature_pb2.Features( 284 feature={ 285 "file": 286 feature_pb2.Feature(int64_list=feature_pb2.Int64List( 287 value=[f])), 288 "record": 289 feature_pb2.Feature(int64_list=feature_pb2.Int64List( 290 value=[r])), 291 "keywords": 292 feature_pb2.Feature(bytes_list=feature_pb2.BytesList( 293 value=self._get_keywords(f, r))) 294 })) 295 return example.SerializeToString() 296 297 def _get_keywords(self, f, r): 298 num_keywords = 1 + (f + r) % 2 299 keywords = [] 300 for index in range(num_keywords): 301 keywords.append(compat.as_bytes("keyword%d" % index)) 302 return keywords 303 304 def _createFiles(self): 305 filenames = [] 306 for i in range(self._num_files): 307 fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) 308 filenames.append(fn) 309 writer = python_io.TFRecordWriter(fn) 310 for j in range(self._num_records): 311 writer.write(self._record(i, j)) 312 writer.close() 313 return filenames 314 315 def _next_actual_batch(self, sess): 316 file_op = self.outputs["file"] 317 keywords_indices_op = self.outputs["keywords"].indices 318 keywords_values_op = self.outputs["keywords"].values 319 keywords_dense_shape_op = self.outputs["keywords"].dense_shape 320 record_op = self.outputs["record"] 321 return sess.run([ 322 file_op, keywords_indices_op, keywords_values_op, 323 keywords_dense_shape_op, record_op 324 ]) 325 326 def _next_expected_batch(self, file_indices, batch_size, num_epochs): 327 328 def _next_record(file_indices): 329 for j in file_indices: 330 for i in range(self._num_records): 331 yield j, i 332 333 file_batch = [] 334 keywords_batch_indices = [] 335 keywords_batch_values = [] 336 keywords_batch_max_len = 0 337 record_batch = [] 338 batch_index = 0 339 for _ in range(num_epochs): 340 for record in _next_record(file_indices): 341 f = record[0] 342 r = record[1] 343 file_batch.append(f) 344 record_batch.append(r) 345 keywords = self._get_keywords(f, r) 346 keywords_batch_values.extend(keywords) 347 keywords_batch_indices.extend([[batch_index, i] 348 for i in range(len(keywords))]) 349 batch_index += 1 350 keywords_batch_max_len = max(keywords_batch_max_len, len(keywords)) 351 if len(file_batch) == batch_size: 352 yield [ 353 file_batch, keywords_batch_indices, keywords_batch_values, 354 [batch_size, keywords_batch_max_len], record_batch 355 ] 356 file_batch = [] 357 keywords_batch_indices = [] 358 keywords_batch_values = [] 359 keywords_batch_max_len = 0 360 record_batch = [] 361 batch_index = 0 362 if file_batch: 363 yield [ 364 file_batch, keywords_batch_indices, keywords_batch_values, 365 [len(file_batch), keywords_batch_max_len], record_batch 366 ] 367 368 def _verify_records(self, sess, batch_size, file_index=None, num_epochs=1): 369 if file_index is not None: 370 file_indices = [file_index] 371 else: 372 file_indices = range(self._num_files) 373 374 for expected_batch in self._next_expected_batch(file_indices, batch_size, 375 num_epochs): 376 actual_batch = self._next_actual_batch(sess) 377 for i in range(len(expected_batch)): 378 self.assertAllEqual(expected_batch[i], actual_batch[i]) 379 380 def testRead(self): 381 for batch_size in [1, 2]: 382 for num_epochs in [1, 10]: 383 with ops.Graph().as_default() as g: 384 with self.test_session(graph=g) as sess: 385 # Basic test: read from file 0. 386 self.outputs = self._read_batch_features( 387 filenames=self.test_filenames[0], 388 num_epochs=num_epochs, 389 batch_size=batch_size) 390 self._verify_records(sess, batch_size, 0, num_epochs=num_epochs) 391 with self.assertRaises(errors.OutOfRangeError): 392 self._next_actual_batch(sess) 393 394 with ops.Graph().as_default() as g: 395 with self.test_session(graph=g) as sess: 396 # Basic test: read from file 1. 397 self.outputs = self._read_batch_features( 398 filenames=self.test_filenames[1], 399 num_epochs=num_epochs, 400 batch_size=batch_size) 401 self._verify_records(sess, batch_size, 1, num_epochs=num_epochs) 402 with self.assertRaises(errors.OutOfRangeError): 403 self._next_actual_batch(sess) 404 405 with ops.Graph().as_default() as g: 406 with self.test_session(graph=g) as sess: 407 # Basic test: read from both files. 408 self.outputs = self._read_batch_features( 409 filenames=self.test_filenames, 410 num_epochs=num_epochs, 411 batch_size=batch_size) 412 self._verify_records(sess, batch_size, num_epochs=num_epochs) 413 with self.assertRaises(errors.OutOfRangeError): 414 self._next_actual_batch(sess) 415 416 def testReadWithEquivalentDataset(self): 417 features = { 418 "file": parsing_ops.FixedLenFeature([], dtypes.int64), 419 "record": parsing_ops.FixedLenFeature([], dtypes.int64), 420 } 421 dataset = (core_readers.TFRecordDataset(self.test_filenames) 422 .map(lambda x: parsing_ops.parse_single_example(x, features)) 423 .repeat(10).batch(2)) 424 iterator = dataset.make_initializable_iterator() 425 init_op = iterator.initializer 426 next_element = iterator.get_next() 427 428 with self.test_session() as sess: 429 sess.run(init_op) 430 for file_batch, _, _, _, record_batch in self._next_expected_batch( 431 range(self._num_files), 2, 10): 432 actual_batch = sess.run(next_element) 433 self.assertAllEqual(file_batch, actual_batch["file"]) 434 self.assertAllEqual(record_batch, actual_batch["record"]) 435 with self.assertRaises(errors.OutOfRangeError): 436 sess.run(next_element) 437 438 439 if __name__ == "__main__": 440 test.main() 441