1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Python wrappers for reader Datasets.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.python.data.ops import dataset_ops 21 from tensorflow.python.data.util import nest 22 from tensorflow.python.framework import dtypes 23 from tensorflow.python.framework import ops 24 from tensorflow.python.framework import tensor_shape 25 from tensorflow.python.ops import gen_dataset_ops 26 from tensorflow.python.ops import parsing_ops 27 from tensorflow.python.platform import gfile 28 29 30 def read_batch_features(file_pattern, 31 batch_size, 32 features, 33 reader, 34 reader_args=None, 35 randomize_input=True, 36 num_epochs=None, 37 capacity=10000): 38 """Reads batches of Examples. 39 40 Example: 41 42 ``` 43 serialized_examples = [ 44 features { 45 feature { key: "age" value { int64_list { value: [ 0 ] } } } 46 feature { key: "gender" value { bytes_list { value: [ "f" ] } } } 47 feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } } 48 }, 49 features { 50 feature { key: "age" value { int64_list { value: [] } } } 51 feature { key: "gender" value { bytes_list { value: [ "f" ] } } } 52 feature { key: "kws" value { bytes_list { value: [ "sports" ] } } } 53 } 54 ] 55 ``` 56 57 We can use arguments: 58 59 ``` 60 features: { 61 "age": FixedLenFeature([], dtype=tf.int64, default_value=-1), 62 "gender": FixedLenFeature([], dtype=tf.string), 63 "kws": VarLenFeature(dtype=tf.string), 64 } 65 ``` 66 67 And the expected output is: 68 69 ```python 70 { 71 "age": [[0], [-1]], 72 "gender": [["f"], ["f"]], 73 "kws": SparseTensor( 74 indices=[[0, 0], [0, 1], [1, 0]], 75 values=["code", "art", "sports"] 76 dense_shape=[2, 2]), 77 } 78 ``` 79 80 Args: 81 file_pattern: List of files or patterns of file paths containing 82 `Example` records. See `tf.gfile.Glob` for pattern rules. 83 batch_size: An int representing the number of consecutive elements of this 84 dataset to combine in a single batch. 85 features: A `dict` mapping feature keys to `FixedLenFeature` or 86 `VarLenFeature` values. See `tf.parse_example`. 87 reader: A function or class that can be called with a `filenames` tensor 88 and (optional) `reader_args` and returns a `Dataset` of Examples. 89 reader_args: Additional arguments to pass to the reader class. 90 randomize_input: Whether the input should be randomized. 91 num_epochs: Integer specifying the number of times to read through the 92 dataset. If None, cycles through the dataset forever. 93 capacity: Capacity of the ShuffleDataset. A large capacity ensures better 94 shuffling but would increase memory usage and startup time. 95 96 Returns: 97 A dict from keys in features to `Tensor` or `SparseTensor` objects. 98 """ 99 filenames = _get_file_names(file_pattern, randomize_input) 100 if reader_args: 101 dataset = reader(filenames, *reader_args) 102 else: 103 dataset = reader(filenames) 104 if dataset.output_types == (dtypes.string, dtypes.string): 105 dataset = dataset.map(lambda _, v: v) 106 if num_epochs != 1: 107 dataset = dataset.repeat(num_epochs) 108 if randomize_input: 109 dataset = dataset.shuffle(capacity) 110 dataset = dataset.batch(batch_size) 111 dataset = dataset.map(lambda x: parsing_ops.parse_example(x, features)) 112 dataset = dataset.prefetch(1) 113 iterator = dataset.make_one_shot_iterator() 114 outputs = iterator.get_next() 115 return outputs 116 117 118 def _get_file_names(file_pattern, randomize_input): 119 """Parse list of file names from pattern, optionally shuffled. 120 121 Args: 122 file_pattern: File glob pattern, or list of glob patterns. 123 randomize_input: Whether to shuffle the order of file names. 124 125 Returns: 126 List of file names matching `file_pattern`. 127 128 Raises: 129 ValueError: If `file_pattern` is empty, or pattern matches no files. 130 """ 131 if isinstance(file_pattern, list): 132 if not file_pattern: 133 raise ValueError("File pattern is empty.") 134 file_names = [] 135 for entry in file_pattern: 136 file_names.extend(gfile.Glob(entry)) 137 else: 138 file_names = list(gfile.Glob(file_pattern)) 139 140 if not file_names: 141 raise ValueError("No files match %s." % file_pattern) 142 143 # Sort files so it will be deterministic for unit tests. 144 if not randomize_input: 145 file_names = sorted(file_names) 146 return file_names 147 148 149 class SqlDataset(dataset_ops.Dataset): 150 """A `Dataset` consisting of the results from a SQL query.""" 151 152 def __init__(self, driver_name, data_source_name, query, output_types): 153 """Creates a `SqlDataset`. 154 155 `SqlDataset` allows a user to read data from the result set of a SQL query. 156 For example: 157 158 ```python 159 dataset = tf.contrib.data.SqlDataset("sqlite", "/foo/bar.sqlite3", 160 "SELECT name, age FROM people", 161 (tf.string, tf.int32)) 162 iterator = dataset.make_one_shot_iterator() 163 next_element = iterator.get_next() 164 # Prints the rows of the result set of the above query. 165 while True: 166 try: 167 print(sess.run(next_element)) 168 except tf.errors.OutOfRangeError: 169 break 170 ``` 171 172 Args: 173 driver_name: A 0-D `tf.string` tensor containing the database type. 174 Currently, the only supported value is 'sqlite'. 175 data_source_name: A 0-D `tf.string` tensor containing a connection string 176 to connect to the database. 177 query: A 0-D `tf.string` tensor containing the SQL query to execute. 178 output_types: A tuple of `tf.DType` objects representing the types of the 179 columns returned by `query`. 180 """ 181 super(SqlDataset, self).__init__() 182 self._driver_name = ops.convert_to_tensor( 183 driver_name, dtype=dtypes.string, name="driver_name") 184 self._data_source_name = ops.convert_to_tensor( 185 data_source_name, dtype=dtypes.string, name="data_source_name") 186 self._query = ops.convert_to_tensor( 187 query, dtype=dtypes.string, name="query") 188 self._output_types = output_types 189 190 def _as_variant_tensor(self): 191 return gen_dataset_ops.sql_dataset(self._driver_name, 192 self._data_source_name, self._query, 193 nest.flatten(self.output_types), 194 nest.flatten(self.output_shapes)) 195 196 @property 197 def output_classes(self): 198 return nest.map_structure(lambda _: ops.Tensor, self._output_types) 199 200 @property 201 def output_shapes(self): 202 return nest.map_structure(lambda _: tensor_shape.TensorShape([]), 203 self._output_types) 204 205 @property 206 def output_types(self): 207 return self._output_types 208