Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Batching dataset transformations."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.data.ops import dataset_ops
     21 from tensorflow.python.data.util import nest
     22 from tensorflow.python.data.util import sparse
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.framework import sparse_tensor
     26 from tensorflow.python.framework import tensor_shape
     27 from tensorflow.python.framework import tensor_util
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import gen_dataset_ops
     30 from tensorflow.python.ops import math_ops
     31 
     32 
     33 def dense_to_sparse_batch(batch_size, row_shape):
     34   """A transformation that batches ragged elements into `tf.SparseTensor`s.
     35 
     36   Like `Dataset.padded_batch()`, this transformation combines multiple
     37   consecutive elements of the dataset, which might have different
     38   shapes, into a single element. The resulting element has three
     39   components (`indices`, `values`, and `dense_shape`), which
     40   comprise a `tf.SparseTensor` that represents the same data. The
     41   `row_shape` represents the dense shape of each row in the
     42   resulting `tf.SparseTensor`, to which the effective batch size is
     43   prepended. For example:
     44 
     45   ```python
     46   # NOTE: The following examples use `{ ... }` to represent the
     47   # contents of a dataset.
     48   a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
     49 
     50   a.apply(tf.contrib.data.dense_to_sparse_batch(batch_size=2, row_shape=[6])) ==
     51   {
     52       ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]],  # indices
     53        ['a', 'b', 'c', 'a', 'b'],                 # values
     54        [2, 6]),                                   # dense_shape
     55       ([[0, 0], [0, 1], [0, 2], [0, 3]],
     56        ['a', 'b', 'c', 'd'],
     57        [1, 6])
     58   }
     59   ```
     60 
     61   Args:
     62     batch_size: A `tf.int64` scalar `tf.Tensor`, representing the
     63       number of consecutive elements of this dataset to combine in a
     64       single batch.
     65     row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like
     66       object representing the equivalent dense shape of a row in the
     67       resulting `tf.SparseTensor`. Each element of this dataset must
     68       have the same rank as `row_shape`, and must have size less
     69       than or equal to `row_shape` in each dimension.
     70 
     71   Returns:
     72     A `Dataset` transformation function, which can be passed to
     73     @{tf.data.Dataset.apply}.
     74   """
     75 
     76   def _apply_fn(dataset):
     77     return DenseToSparseBatchDataset(dataset, batch_size, row_shape)
     78 
     79   return _apply_fn
     80 
     81 
     82 def unbatch():
     83   """A Transformation which splits the elements of a dataset.
     84 
     85   For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
     86   where `B` may vary from element to element, then for each element in
     87   the dataset, the unbatched dataset will contain `B` consecutive elements
     88   of shape `[a0, a1, ...]`.
     89 
     90   Returns:
     91     A `Dataset` transformation function, which can be passed to
     92     @{tf.data.Dataset.apply}.
     93   """
     94 
     95   def _apply_fn(dataset):
     96 
     97     def unbatch_map(arg, *rest):
     98       if rest:
     99         return dataset_ops.Dataset.from_tensor_slices((arg,) + rest)
    100       else:
    101         return dataset_ops.Dataset.from_tensor_slices(arg)
    102 
    103     return dataset.flat_map(map_func=unbatch_map)
    104 
    105   return _apply_fn
    106 
    107 
    108 def filter_irregular_batches(batch_size):
    109   """Transformation that filters out batches that are not of size batch_size."""
    110 
    111   def _apply_fn(dataset):
    112     """Function from `Dataset` to `Dataset` that applies the transformation."""
    113     tensor_batch_size = ops.convert_to_tensor(
    114         batch_size, dtype=dtypes.int64, name="batch_size")
    115 
    116     flattened = _RestructuredDataset(
    117         dataset,
    118         tuple(nest.flatten(dataset.output_types)),
    119         output_classes=tuple(nest.flatten(dataset.output_classes)))
    120 
    121     def _predicate(*xs):
    122       """Return `True` if this element is a full batch."""
    123       # Extract the dynamic batch size from the first component of the flattened
    124       # batched element.
    125       first_component = xs[0]
    126       first_component_batch_size = array_ops.shape(
    127           first_component, out_type=dtypes.int64)[0]
    128 
    129       return math_ops.equal(first_component_batch_size, tensor_batch_size)
    130 
    131     filtered = flattened.filter(_predicate)
    132 
    133     maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size)
    134 
    135     def _set_first_dimension(shape):
    136       return shape.merge_with(
    137           tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:]))
    138 
    139     known_shapes = nest.map_structure(_set_first_dimension,
    140                                       dataset.output_shapes)
    141     return _RestructuredDataset(
    142         filtered,
    143         dataset.output_types,
    144         known_shapes,
    145         output_classes=dataset.output_classes)
    146 
    147   return _apply_fn
    148 
    149 
    150 def batch_and_drop_remainder(batch_size):
    151   """A batching transformation that omits the final small batch (if present).
    152 
    153   Like @{tf.data.Dataset.batch}, this transformation combines
    154   consecutive elements of this dataset into batches. However, if the batch
    155   size does not evenly divide the input dataset size, this transformation will
    156   drop the final smaller element.
    157 
    158   The following example illustrates the difference between this
    159   transformation and `Dataset.batch()`:
    160 
    161   ```python
    162   dataset = tf.data.Dataset.range(200)
    163   batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(128))
    164   print(batched.output_shapes)  # ==> "(128,)" (the batch dimension is known)
    165   ```
    166 
    167   By contrast, `dataset.batch(128)` would yield a two-element dataset with
    168   shapes `(128,)` and `(72,)`, so the batch dimension would not be statically
    169   known.
    170 
    171   Args:
    172     batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
    173         consecutive elements of this dataset to combine in a single batch.
    174 
    175   Returns:
    176     A `Dataset` transformation function, which can be passed to
    177     @{tf.data.Dataset.apply}
    178   """
    179 
    180   def _apply_fn(dataset):
    181     """Function from `Dataset` to `Dataset` that applies the transformation."""
    182     batched = dataset.batch(batch_size)
    183     return filter_irregular_batches(batch_size)(batched)
    184 
    185   return _apply_fn
    186 
    187 
    188 def padded_batch_and_drop_remainder(batch_size,
    189                                     padded_shapes,
    190                                     padding_values=None):
    191   """A batching and padding transformation that omits the final small batch.
    192 
    193   Like @{tf.data.Dataset.padded_batch}, this transformation combines
    194   consecutive elements of this dataset into batches. However, if the batch
    195   size does not evenly divide the input dataset size, this transformation will
    196   drop the final smaller element.
    197 
    198   See `@{tf.contrib.data.batch_and_drop_remainder}` for more details.
    199 
    200   Args:
    201     batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
    202       consecutive elements of this dataset to combine in a single batch.
    203     padded_shapes: A nested structure of `tf.TensorShape` or
    204       `tf.int64` vector tensor-like objects. See
    205       @{tf.data.Dataset.padded_batch} for details.
    206     padding_values: (Optional.) A nested structure of scalar-shaped
    207       `tf.Tensor`. See @{tf.data.Dataset.padded_batch} for details.
    208 
    209   Returns:
    210     A `Dataset` transformation function, which can be passed to
    211     @{tf.data.Dataset.apply}
    212   """
    213 
    214   def _apply_fn(dataset):
    215     """Function from `Dataset` to `Dataset` that applies the transformation."""
    216     batched = dataset.padded_batch(
    217         batch_size, padded_shapes=padded_shapes, padding_values=padding_values)
    218     return filter_irregular_batches(batch_size)(batched)
    219 
    220   return _apply_fn
    221 
    222 
    223 class DenseToSparseBatchDataset(dataset_ops.Dataset):
    224   """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s."""
    225 
    226   def __init__(self, input_dataset, batch_size, row_shape):
    227     """See `Dataset.dense_to_sparse_batch()` for more details."""
    228     super(DenseToSparseBatchDataset, self).__init__()
    229     if not isinstance(input_dataset.output_types, dtypes.DType):
    230       raise TypeError("DenseToSparseDataset requires an input whose elements "
    231                       "have a single component, whereas the input has %r." %
    232                       input_dataset.output_types)
    233     self._input_dataset = input_dataset
    234     self._batch_size = batch_size
    235     self._row_shape = row_shape
    236 
    237   def _as_variant_tensor(self):
    238     return gen_dataset_ops.dense_to_sparse_batch_dataset(
    239         self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
    240         self._batch_size,
    241         row_shape=dataset_ops._partial_shape_to_tensor(self._row_shape),  # pylint: disable=protected-access
    242         output_shapes=nest.flatten(
    243             sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
    244         output_types=nest.flatten(
    245             sparse.as_dense_types(self.output_types, self.output_classes)))
    246 
    247   @property
    248   def output_classes(self):
    249     return sparse_tensor.SparseTensor
    250 
    251   @property
    252   def output_shapes(self):
    253     return tensor_shape.vector(None).concatenate(self._row_shape)
    254 
    255   @property
    256   def output_types(self):
    257     return self._input_dataset.output_types
    258 
    259 
    260 class _RestructuredDataset(dataset_ops.Dataset):
    261   """An internal helper for changing the structure and shape of a dataset."""
    262 
    263   def __init__(self,
    264                dataset,
    265                output_types,
    266                output_shapes=None,
    267                output_classes=None):
    268     """Creates a new dataset with the given output types and shapes.
    269 
    270     The given `dataset` must have a structure that is convertible:
    271     * `dataset.output_types` must be the same as `output_types` module nesting.
    272     * Each shape in `dataset.output_shapes` must be compatible with each shape
    273       in `output_shapes` (if given).
    274 
    275     Note: This helper permits "unsafe casts" for shapes, equivalent to using
    276     `tf.Tensor.set_shape()` where domain-specific knowledge is available.
    277 
    278     Args:
    279       dataset: A `Dataset` object.
    280       output_types: A nested structure of `tf.DType` objects.
    281       output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
    282         If omitted, the shapes will be inherited from `dataset`.
    283       output_classes: (Optional.) A nested structure of class types.
    284         If omitted, the class types will be inherited from `dataset`.
    285 
    286     Raises:
    287       ValueError: If either `output_types` or `output_shapes` is not compatible
    288         with the structure of `dataset`.
    289     """
    290     super(_RestructuredDataset, self).__init__()
    291     self._dataset = dataset
    292 
    293     # Validate that the types are compatible.
    294     output_types = nest.map_structure(dtypes.as_dtype, output_types)
    295     flat_original_types = nest.flatten(dataset.output_types)
    296     flat_new_types = nest.flatten(output_types)
    297     if flat_original_types != flat_new_types:
    298       raise ValueError(
    299           "Dataset with output types %r cannot be restructured to have output "
    300           "types %r" % (dataset.output_types, output_types))
    301 
    302     self._output_types = output_types
    303 
    304     if output_shapes is None:
    305       # Inherit shapes from the original `dataset`.
    306       self._output_shapes = nest.pack_sequence_as(output_types,
    307                                                   nest.flatten(
    308                                                       dataset.output_shapes))
    309     else:
    310       # Validate that the shapes are compatible.
    311       nest.assert_same_structure(output_types, output_shapes)
    312       flat_original_shapes = nest.flatten(dataset.output_shapes)
    313       flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
    314 
    315       for original_shape, new_shape in zip(flat_original_shapes,
    316                                            flat_new_shapes):
    317         if not original_shape.is_compatible_with(new_shape):
    318           raise ValueError(
    319               "Dataset with output shapes %r cannot be restructured to have "
    320               "incompatible output shapes %r" % (dataset.output_shapes,
    321                                                  output_shapes))
    322       self._output_shapes = nest.map_structure_up_to(
    323           output_types, tensor_shape.as_shape, output_shapes)
    324     if output_classes is None:
    325       # Inherit class types from the original `dataset`.
    326       self._output_classes = nest.pack_sequence_as(output_types,
    327                                                    nest.flatten(
    328                                                        dataset.output_classes))
    329     else:
    330       self._output_classes = output_classes
    331 
    332   def _as_variant_tensor(self):
    333     return self._dataset._as_variant_tensor()  # pylint: disable=protected-access
    334 
    335   @property
    336   def output_classes(self):
    337     return self._output_classes
    338 
    339   @property
    340   def output_types(self):
    341     return self._output_types
    342 
    343   @property
    344   def output_shapes(self):
    345     return self._output_shapes
    346 
    347 
    348 class _MapAndBatchDataset(dataset_ops.MapDataset):
    349   """A `Dataset` that maps a function over a batch of elements."""
    350 
    351   def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches):
    352     """See `Dataset.map()` for details."""
    353     super(_MapAndBatchDataset, self).__init__(input_dataset, map_func)
    354     self._batch_size = ops.convert_to_tensor(
    355         batch_size, dtype=dtypes.int64, name="batch_size")
    356     self._num_parallel_batches = ops.convert_to_tensor(
    357         num_parallel_batches, dtype=dtypes.int64, name="num_parallel_batches")
    358 
    359   def _as_variant_tensor(self):
    360     # pylint: disable=protected-access
    361     input_resource = self._input_dataset._as_variant_tensor()
    362     return gen_dataset_ops.map_and_batch_dataset(
    363         input_resource,
    364         self._map_func.captured_inputs,
    365         f=self._map_func,
    366         batch_size=self._batch_size,
    367         num_parallel_batches=self._num_parallel_batches,
    368         output_types=nest.flatten(
    369             sparse.as_dense_types(self.output_types, self.output_classes)),
    370         output_shapes=nest.flatten(
    371             sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
    372     # pylint: enable=protected-access
    373 
    374   @property
    375   def output_shapes(self):
    376     return nest.pack_sequence_as(self._output_shapes, [
    377         tensor_shape.vector(tensor_util.constant_value(
    378             self._batch_size)).concatenate(s)
    379         for s in nest.flatten(self._output_shapes)
    380     ])
    381 
    382   @property
    383   def output_types(self):
    384     return self._output_types
    385 
    386 
    387 def map_and_batch(map_func, batch_size, num_parallel_batches=1):
    388   """Fused implementation of `map` and `batch`.
    389 
    390   Maps `map_func` across `batch_size` consecutive elements of this dataset
    391   and then combines them into a batch. Functionally, it is equivalent to `map`
    392   followed by `batch`. However, by fusing the two transformations together, the
    393   implementation can be more efficient. Surfacing this transformation in the API
    394   is temporary. Once automatic input pipeline optimization is implemented,
    395   the fusing of `map` and `batch` will happen automatically and this API will be
    396   deprecated.
    397 
    398   Args:
    399     map_func: A function mapping a nested structure of tensors to another
    400       nested structure of tensors.
    401     batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
    402       consecutive elements of this dataset to combine in a single batch.
    403     num_parallel_batches: A `tf.int64` scalar `tf.Tensor`, representing the
    404       number of batches to create in parallel. On one hand, higher values can
    405       help mitigate the effect of stragglers. On the other hand, higher values
    406       can increase contention if CPU is scarce.
    407 
    408   Returns:
    409     A `Dataset` transformation function, which can be passed to
    410     @{tf.data.Dataset.apply}.
    411   """
    412 
    413   def _apply_fn(dataset):
    414     return _MapAndBatchDataset(dataset, map_func, batch_size,
    415                                num_parallel_batches)
    416 
    417   return _apply_fn
    418