1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Python wrappers for reader Datasets.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.python.data.ops.dataset_ops import Dataset 21 from tensorflow.python.data.util import convert 22 from tensorflow.python.framework import dtypes 23 from tensorflow.python.framework import ops 24 from tensorflow.python.framework import tensor_shape 25 from tensorflow.python.ops import gen_dataset_ops 26 from tensorflow.python.util.tf_export import tf_export 27 28 29 # TODO(b/64974358): Increase default buffer size to 256 MB. 30 _DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB 31 32 33 @tf_export("data.TextLineDataset") 34 class TextLineDataset(Dataset): 35 """A `Dataset` comprising lines from one or more text files.""" 36 37 def __init__(self, filenames, compression_type=None, buffer_size=None): 38 """Creates a `TextLineDataset`. 39 40 Args: 41 filenames: A `tf.string` tensor containing one or more filenames. 42 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 43 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 44 buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes 45 to buffer. A value of 0 results in the default buffering values chosen 46 based on the compression type. 47 """ 48 super(TextLineDataset, self).__init__() 49 self._filenames = ops.convert_to_tensor( 50 filenames, dtype=dtypes.string, name="filenames") 51 self._compression_type = convert.optional_param_to_tensor( 52 "compression_type", 53 compression_type, 54 argument_default="", 55 argument_dtype=dtypes.string) 56 self._buffer_size = convert.optional_param_to_tensor( 57 "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) 58 59 def _as_variant_tensor(self): 60 return gen_dataset_ops.text_line_dataset( 61 self._filenames, self._compression_type, self._buffer_size) 62 63 @property 64 def output_classes(self): 65 return ops.Tensor 66 67 @property 68 def output_shapes(self): 69 return tensor_shape.scalar() 70 71 @property 72 def output_types(self): 73 return dtypes.string 74 75 76 @tf_export("data.TFRecordDataset") 77 class TFRecordDataset(Dataset): 78 """A `Dataset` comprising records from one or more TFRecord files.""" 79 80 def __init__(self, filenames, compression_type=None, buffer_size=None): 81 """Creates a `TFRecordDataset`. 82 83 Args: 84 filenames: A `tf.string` tensor containing one or more filenames. 85 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 86 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 87 buffer_size: (Optional.) A `tf.int64` scalar representing the number of 88 bytes in the read buffer. 0 means no buffering. 89 """ 90 super(TFRecordDataset, self).__init__() 91 # Force the type to string even if filenames is an empty list. 92 self._filenames = ops.convert_to_tensor( 93 filenames, dtypes.string, name="filenames") 94 self._compression_type = convert.optional_param_to_tensor( 95 "compression_type", 96 compression_type, 97 argument_default="", 98 argument_dtype=dtypes.string) 99 self._buffer_size = convert.optional_param_to_tensor( 100 "buffer_size", 101 buffer_size, 102 argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES) 103 104 def _as_variant_tensor(self): 105 return gen_dataset_ops.tf_record_dataset( 106 self._filenames, self._compression_type, self._buffer_size) 107 108 @property 109 def output_classes(self): 110 return ops.Tensor 111 112 @property 113 def output_shapes(self): 114 return tensor_shape.TensorShape([]) 115 116 @property 117 def output_types(self): 118 return dtypes.string 119 120 121 @tf_export("data.FixedLengthRecordDataset") 122 class FixedLengthRecordDataset(Dataset): 123 """A `Dataset` of fixed-length records from one or more binary files.""" 124 125 def __init__(self, 126 filenames, 127 record_bytes, 128 header_bytes=None, 129 footer_bytes=None, 130 buffer_size=None): 131 """Creates a `FixedLengthRecordDataset`. 132 133 Args: 134 filenames: A `tf.string` tensor containing one or more filenames. 135 record_bytes: A `tf.int64` scalar representing the number of bytes in 136 each record. 137 header_bytes: (Optional.) A `tf.int64` scalar representing the number of 138 bytes to skip at the start of a file. 139 footer_bytes: (Optional.) A `tf.int64` scalar representing the number of 140 bytes to ignore at the end of a file. 141 buffer_size: (Optional.) A `tf.int64` scalar representing the number of 142 bytes to buffer when reading. 143 """ 144 super(FixedLengthRecordDataset, self).__init__() 145 self._filenames = ops.convert_to_tensor( 146 filenames, dtype=dtypes.string, name="filenames") 147 self._record_bytes = ops.convert_to_tensor( 148 record_bytes, dtype=dtypes.int64, name="record_bytes") 149 150 self._header_bytes = convert.optional_param_to_tensor( 151 "header_bytes", header_bytes) 152 self._footer_bytes = convert.optional_param_to_tensor( 153 "footer_bytes", footer_bytes) 154 self._buffer_size = convert.optional_param_to_tensor( 155 "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) 156 157 def _as_variant_tensor(self): 158 return gen_dataset_ops.fixed_length_record_dataset( 159 self._filenames, self._header_bytes, self._record_bytes, 160 self._footer_bytes, self._buffer_size) 161 162 @property 163 def output_classes(self): 164 return ops.Tensor 165 166 @property 167 def output_shapes(self): 168 return tensor_shape.scalar() 169 170 @property 171 def output_types(self): 172 return dtypes.string 173