Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Python wrappers for reader Datasets."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.data.ops import dataset_ops
     21 from tensorflow.python.data.util import nest
     22 from tensorflow.python.framework import dtypes
     23 from tensorflow.python.framework import ops
     24 from tensorflow.python.framework import tensor_shape
     25 from tensorflow.python.ops import gen_dataset_ops
     26 from tensorflow.python.ops import parsing_ops
     27 from tensorflow.python.platform import gfile
     28 
     29 
     30 def read_batch_features(file_pattern,
     31                         batch_size,
     32                         features,
     33                         reader,
     34                         reader_args=None,
     35                         randomize_input=True,
     36                         num_epochs=None,
     37                         capacity=10000):
     38   """Reads batches of Examples.
     39 
     40   Example:
     41 
     42   ```
     43   serialized_examples = [
     44     features {
     45       feature { key: "age" value { int64_list { value: [ 0 ] } } }
     46       feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
     47       feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } }
     48     },
     49     features {
     50       feature { key: "age" value { int64_list { value: [] } } }
     51       feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
     52       feature { key: "kws" value { bytes_list { value: [ "sports" ] } } }
     53     }
     54   ]
     55   ```
     56 
     57   We can use arguments:
     58 
     59   ```
     60   features: {
     61     "age": FixedLenFeature([], dtype=tf.int64, default_value=-1),
     62     "gender": FixedLenFeature([], dtype=tf.string),
     63     "kws": VarLenFeature(dtype=tf.string),
     64   }
     65   ```
     66 
     67   And the expected output is:
     68 
     69   ```python
     70   {
     71     "age": [[0], [-1]],
     72     "gender": [["f"], ["f"]],
     73     "kws": SparseTensor(
     74       indices=[[0, 0], [0, 1], [1, 0]],
     75       values=["code", "art", "sports"]
     76       dense_shape=[2, 2]),
     77   }
     78   ```
     79 
     80   Args:
     81     file_pattern: List of files or patterns of file paths containing
     82       `Example` records. See `tf.gfile.Glob` for pattern rules.
     83     batch_size: An int representing the number of consecutive elements of this
     84       dataset to combine in a single batch.
     85     features: A `dict` mapping feature keys to `FixedLenFeature` or
     86       `VarLenFeature` values. See `tf.parse_example`.
     87     reader: A function or class that can be called with a `filenames` tensor
     88       and (optional) `reader_args` and returns a `Dataset` of Examples.
     89     reader_args: Additional arguments to pass to the reader class.
     90     randomize_input: Whether the input should be randomized.
     91     num_epochs: Integer specifying the number of times to read through the
     92       dataset. If None, cycles through the dataset forever.
     93     capacity: Capacity of the ShuffleDataset. A large capacity ensures better
     94       shuffling but would increase memory usage and startup time.
     95 
     96   Returns:
     97     A dict from keys in features to `Tensor` or `SparseTensor` objects.
     98   """
     99   filenames = _get_file_names(file_pattern, randomize_input)
    100   if reader_args:
    101     dataset = reader(filenames, *reader_args)
    102   else:
    103     dataset = reader(filenames)
    104   if dataset.output_types == (dtypes.string, dtypes.string):
    105     dataset = dataset.map(lambda _, v: v)
    106   if num_epochs != 1:
    107     dataset = dataset.repeat(num_epochs)
    108   if randomize_input:
    109     dataset = dataset.shuffle(capacity)
    110   dataset = dataset.batch(batch_size)
    111   dataset = dataset.map(lambda x: parsing_ops.parse_example(x, features))
    112   dataset = dataset.prefetch(1)
    113   iterator = dataset.make_one_shot_iterator()
    114   outputs = iterator.get_next()
    115   return outputs
    116 
    117 
    118 def _get_file_names(file_pattern, randomize_input):
    119   """Parse list of file names from pattern, optionally shuffled.
    120 
    121   Args:
    122     file_pattern: File glob pattern, or list of glob patterns.
    123     randomize_input: Whether to shuffle the order of file names.
    124 
    125   Returns:
    126     List of file names matching `file_pattern`.
    127 
    128   Raises:
    129     ValueError: If `file_pattern` is empty, or pattern matches no files.
    130   """
    131   if isinstance(file_pattern, list):
    132     if not file_pattern:
    133       raise ValueError("File pattern is empty.")
    134     file_names = []
    135     for entry in file_pattern:
    136       file_names.extend(gfile.Glob(entry))
    137   else:
    138     file_names = list(gfile.Glob(file_pattern))
    139 
    140   if not file_names:
    141     raise ValueError("No files match %s." % file_pattern)
    142 
    143   # Sort files so it will be deterministic for unit tests.
    144   if not randomize_input:
    145     file_names = sorted(file_names)
    146   return file_names
    147 
    148 
    149 class SqlDataset(dataset_ops.Dataset):
    150   """A `Dataset` consisting of the results from a SQL query."""
    151 
    152   def __init__(self, driver_name, data_source_name, query, output_types):
    153     """Creates a `SqlDataset`.
    154 
    155     `SqlDataset` allows a user to read data from the result set of a SQL query.
    156     For example:
    157 
    158     ```python
    159     dataset = tf.contrib.data.SqlDataset("sqlite", "/foo/bar.sqlite3",
    160                                          "SELECT name, age FROM people",
    161                                          (tf.string, tf.int32))
    162     iterator = dataset.make_one_shot_iterator()
    163     next_element = iterator.get_next()
    164     # Prints the rows of the result set of the above query.
    165     while True:
    166       try:
    167         print(sess.run(next_element))
    168       except tf.errors.OutOfRangeError:
    169         break
    170     ```
    171 
    172     Args:
    173       driver_name: A 0-D `tf.string` tensor containing the database type.
    174         Currently, the only supported value is 'sqlite'.
    175       data_source_name: A 0-D `tf.string` tensor containing a connection string
    176         to connect to the database.
    177       query: A 0-D `tf.string` tensor containing the SQL query to execute.
    178       output_types: A tuple of `tf.DType` objects representing the types of the
    179         columns returned by `query`.
    180     """
    181     super(SqlDataset, self).__init__()
    182     self._driver_name = ops.convert_to_tensor(
    183         driver_name, dtype=dtypes.string, name="driver_name")
    184     self._data_source_name = ops.convert_to_tensor(
    185         data_source_name, dtype=dtypes.string, name="data_source_name")
    186     self._query = ops.convert_to_tensor(
    187         query, dtype=dtypes.string, name="query")
    188     self._output_types = output_types
    189 
    190   def _as_variant_tensor(self):
    191     return gen_dataset_ops.sql_dataset(self._driver_name,
    192                                        self._data_source_name, self._query,
    193                                        nest.flatten(self.output_types),
    194                                        nest.flatten(self.output_shapes))
    195 
    196   @property
    197   def output_classes(self):
    198     return nest.map_structure(lambda _: ops.Tensor, self._output_types)
    199 
    200   @property
    201   def output_shapes(self):
    202     return nest.map_structure(lambda _: tensor_shape.TensorShape([]),
    203                               self._output_types)
    204 
    205   @property
    206   def output_types(self):
    207     return self._output_types
    208