Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Non-deterministic dataset transformations."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.data.ops import dataset_ops
     21 from tensorflow.python.data.util import convert
     22 from tensorflow.python.data.util import nest
     23 from tensorflow.python.data.util import sparse
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.framework import function
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.ops import gen_dataset_ops
     28 from tensorflow.python.util import deprecation
     29 
     30 
     31 class ParallelInterleaveDataset(dataset_ops.Dataset):
     32   """A `Dataset` that maps a function over its input and flattens the result."""
     33 
     34   def __init__(self, input_dataset, map_func, cycle_length, block_length,
     35                sloppy, buffer_output_elements, prefetch_input_elements):
     36     """See `tf.contrib.data.parallel_interleave()` for details."""
     37     super(ParallelInterleaveDataset, self).__init__()
     38     self._input_dataset = input_dataset
     39 
     40     @function.Defun(*nest.flatten(
     41         sparse.as_dense_types(input_dataset.output_types,
     42                               input_dataset.output_classes)))
     43     def tf_map_func(*args):
     44       """A wrapper for Defun that facilitates shape inference."""
     45       # Pass in shape information from the input_dataset.
     46       dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
     47                                             input_dataset.output_classes)
     48       for arg, shape in zip(args, nest.flatten(dense_shapes)):
     49         arg.set_shape(shape)
     50 
     51       nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
     52       nested_args = sparse.deserialize_sparse_tensors(
     53           nested_args, input_dataset.output_types, input_dataset.output_shapes,
     54           input_dataset.output_classes)
     55       if dataset_ops._should_unpack_args(nested_args):  # pylint: disable=protected-access
     56         dataset = map_func(*nested_args)
     57       else:
     58         dataset = map_func(nested_args)
     59 
     60       if not isinstance(dataset, dataset_ops.Dataset):
     61         raise TypeError("`map_func` must return a `Dataset` object.")
     62 
     63       self._output_classes = dataset.output_classes
     64       self._output_types = dataset.output_types
     65       self._output_shapes = dataset.output_shapes
     66 
     67       return dataset._as_variant_tensor()  # pylint: disable=protected-access
     68 
     69     self._map_func = tf_map_func
     70     self._map_func.add_to_graph(ops.get_default_graph())
     71 
     72     self._cycle_length = ops.convert_to_tensor(
     73         cycle_length, dtype=dtypes.int64, name="cycle_length")
     74     self._block_length = ops.convert_to_tensor(
     75         block_length, dtype=dtypes.int64, name="block_length")
     76     self._sloppy = ops.convert_to_tensor(
     77         sloppy, dtype=dtypes.bool, name="sloppy")
     78     self._buffer_output_elements = convert.optional_param_to_tensor(
     79         "buffer_output_elements",
     80         buffer_output_elements,
     81         argument_default=2 * block_length)
     82     self._prefetch_input_elements = convert.optional_param_to_tensor(
     83         "prefetch_input_elements",
     84         prefetch_input_elements,
     85         argument_default=2 * cycle_length)
     86 
     87   def _as_variant_tensor(self):
     88     return gen_dataset_ops.parallel_interleave_dataset(
     89         self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
     90         self._map_func.captured_inputs,
     91         self._cycle_length,
     92         self._block_length,
     93         self._sloppy,
     94         self._buffer_output_elements,
     95         self._prefetch_input_elements,
     96         f=self._map_func,
     97         output_types=nest.flatten(
     98             sparse.as_dense_types(self.output_types, self.output_classes)),
     99         output_shapes=nest.flatten(
    100             sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
    101 
    102   @property
    103   def output_classes(self):
    104     return self._output_classes
    105 
    106   @property
    107   def output_shapes(self):
    108     return self._output_shapes
    109 
    110   @property
    111   def output_types(self):
    112     return self._output_types
    113 
    114 
    115 def parallel_interleave(map_func,
    116                         cycle_length,
    117                         block_length=1,
    118                         sloppy=False,
    119                         buffer_output_elements=None,
    120                         prefetch_input_elements=None):
    121   """A parallel version of the `Dataset.interleave()` transformation.
    122 
    123   `parallel_interleave()` maps `map_func` across its input to produce nested
    124   datasets, and outputs their elements interleaved. Unlike
    125   @{tf.data.Dataset.interleave}, it gets elements from `cycle_length` nested
    126   datasets in parallel, which increases the throughput, especially in the
    127   presence of stragglers. Furthermore, the `sloppy` argument can be used to
    128   improve performance, by relaxing the requirement that the outputs are produced
    129   in a deterministic order, and allowing the implementation to skip over nested
    130   datasets whose elements are not readily available when requested.
    131 
    132   Example usage:
    133 
    134   ```python
    135   # Preprocess 4 files concurrently.
    136   filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
    137   dataset = filenames.apply(
    138       tf.contrib.data.parallel_interleave(
    139           lambda filename: tf.data.TFRecordDataset(filename),
    140           cycle_length=4))
    141   ```
    142 
    143   WARNING: If `sloppy` is `True`, the order of produced elements is not
    144   deterministic.
    145 
    146   Args:
    147     map_func: A function mapping a nested structure of tensors to a `Dataset`.
    148     cycle_length: The number of input `Dataset`s to interleave from in parallel.
    149     block_length: The number of consecutive elements to pull from an input
    150       `Dataset` before advancing to the next input `Dataset`.
    151     sloppy: If false, elements are produced in deterministic order. Otherwise,
    152       the implementation is allowed, for the sake of expediency, to produce
    153       elements in a non-deterministic order.
    154     buffer_output_elements: The number of elements each iterator being
    155       interleaved should buffer (similar to the `.prefetch()` transformation for
    156       each interleaved iterator).
    157     prefetch_input_elements: The number of input elements to transform to
    158       iterators before they are needed for interleaving.
    159 
    160   Returns:
    161     A `Dataset` transformation function, which can be passed to
    162     @{tf.data.Dataset.apply}.
    163   """
    164   def _apply_fn(dataset):
    165     return ParallelInterleaveDataset(
    166         dataset, map_func, cycle_length, block_length, sloppy,
    167         buffer_output_elements, prefetch_input_elements)
    168 
    169   return _apply_fn
    170 
    171 
    172 @deprecation.deprecated(
    173     None, "Use `tf.contrib.data.parallel_interleave(..., sloppy=True)`.")
    174 def sloppy_interleave(map_func, cycle_length, block_length=1):
    175   """A non-deterministic version of the `Dataset.interleave()` transformation.
    176 
    177   `sloppy_interleave()` maps `map_func` across `dataset`, and
    178   non-deterministically interleaves the results.
    179 
    180   The resulting dataset is almost identical to `interleave`. The key
    181   difference is that if retrieving a value from a given output iterator would
    182   cause `get_next` to block, that iterator will be skipped, and consumed
    183   when next available. If consuming from all iterators would cause the
    184   `get_next` call to block, the `get_next` call blocks until the first value is
    185   available.
    186 
    187   If the underlying datasets produce elements as fast as they are consumed, the
    188   `sloppy_interleave` transformation behaves identically to `interleave`.
    189   However, if an underlying dataset would block the consumer,
    190   `sloppy_interleave` can violate the round-robin order (that `interleave`
    191   strictly obeys), producing an element from a different underlying
    192   dataset instead.
    193 
    194   Example usage:
    195 
    196   ```python
    197   # Preprocess 4 files concurrently.
    198   filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
    199   dataset = filenames.apply(
    200       tf.contrib.data.sloppy_interleave(
    201           lambda filename: tf.data.TFRecordDataset(filename),
    202           cycle_length=4))
    203   ```
    204 
    205   WARNING: The order of elements in the resulting dataset is not
    206   deterministic. Use `Dataset.interleave()` if you want the elements to have a
    207   deterministic order.
    208 
    209   Args:
    210     map_func: A function mapping a nested structure of tensors (having shapes
    211       and types defined by `self.output_shapes` and `self.output_types`) to a
    212       `Dataset`.
    213     cycle_length: The number of input `Dataset`s to interleave from in parallel.
    214     block_length: The number of consecutive elements to pull from an input
    215       `Dataset` before advancing to the next input `Dataset`. Note:
    216       `sloppy_interleave` will skip the remainder of elements in the
    217       `block_length` in order to avoid blocking.
    218 
    219   Returns:
    220     A `Dataset` transformation function, which can be passed to
    221     @{tf.data.Dataset.apply}.
    222   """
    223   def _apply_fn(dataset):
    224     return ParallelInterleaveDataset(
    225         dataset,
    226         map_func,
    227         cycle_length,
    228         block_length,
    229         sloppy=True,
    230         buffer_output_elements=None,
    231         prefetch_input_elements=None)
    232 
    233   return _apply_fn
    234