Home | History | Annotate | Download | only in training
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Python wrappers for Datasets and Iterators."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.data.ops import dataset_ops
     21 from tensorflow.python.data.util import nest
     22 from tensorflow.python.data.util import sparse
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.framework import tensor_shape
     26 from tensorflow.python.framework import tensor_util
     27 from tensorflow.python.ops import gen_dataset_ops
     28 from tensorflow.python.util import nest as tf_nest
     29 
     30 
     31 class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset):
     32   """A `Dataset` that prepends a queue to another `Dataset`.
     33 
     34   A vector of handles to the queue is returned as the first component of
     35   the associated iterator.  This vector can be passed to
     36   `enqueue_in_queue_dataset` to add new elements to the queue.
     37   """
     38 
     39   def __init__(self, input_dataset, batch_size, padded_shapes, padding_values):
     40     """Initialize `PrependFromQueueAndPaddedBatchDataset`."""
     41     super(_PrependFromQueueAndPaddedBatchDataset, self).__init__()
     42     if sparse.any_sparse(input_dataset.output_classes):
     43       raise TypeError(
     44           "Batching of padded sparse tensors is not currently supported")
     45     self._input_dataset = input_dataset
     46     self._batch_size = ops.convert_to_tensor(
     47         batch_size, dtype=dtypes.int64, name="batch_size")
     48     # pylint: disable=protected-access
     49     if padded_shapes is None:
     50       self._padded_shapes = nest.map_structure(
     51           dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes)
     52     else:
     53       self._padded_shapes = nest.map_structure_up_to(
     54           input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor,
     55           padded_shapes)
     56     padding_values = (
     57         padding_values if padding_values is not None else
     58         dataset_ops._default_padding(input_dataset))
     59     self._padding_values = nest.map_structure_up_to(
     60         input_dataset.output_shapes, dataset_ops._padding_value_to_tensor,
     61         padding_values, input_dataset.output_types)
     62     # pylint: enable=protected-access
     63 
     64   def _as_variant_tensor(self):
     65     # pylint: disable=protected-access
     66     return gen_dataset_ops.prepend_from_queue_and_padded_batch_dataset(
     67         self._input_dataset._as_variant_tensor(),
     68         batch_size=self._batch_size,
     69         padded_shapes=[
     70             ops.convert_to_tensor(s, dtype=dtypes.int64)
     71             for s in nest.flatten(self._padded_shapes)
     72         ],
     73         padding_values=nest.flatten(self._padding_values),
     74         output_shapes=nest.flatten(
     75             sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
     76     # pylint: enable=protected-access
     77 
     78   @property
     79   def output_classes(self):
     80     return (ops.Tensor, self._input_dataset.output_classes)
     81 
     82   def _as_batch_shape(self, shape_like):
     83     return tensor_shape.vector(None).concatenate(
     84         tensor_util.constant_value_as_shape(shape_like))
     85 
     86   @property
     87   def output_shapes(self):
     88     # First output is a variant representing the Queue
     89     return (tensor_shape.vector(None),
     90             nest.map_structure(self._as_batch_shape, self._padded_shapes))
     91 
     92   @property
     93   def output_types(self):
     94     # First output is a variant representing the Queue
     95     return (dtypes.variant, self._input_dataset.output_types)
     96 
     97 
     98 def prepend_from_queue_and_padded_batch_dataset(batch_size,
     99                                                 padding_values=None,
    100                                                 padded_shapes=None):
    101   """A transformation that prepends a queue to a `Dataset` and batches results.
    102 
    103   A vector of handles to the queue is returned as the first component of the
    104   associated iterator.  This vector can be passed to `enqueue_in_queue_dataset`
    105   to add new elements to the queue.
    106 
    107   Below is an example of how this dataset might be used to split incoming
    108   variable-length sequences into "head" and "rest" parts, where "rest" parts
    109   are re-enqueued back into the dataset.  A more realistic example would
    110   perform some calculation on the "head" and modify some components of "rest"
    111   with the result (before re-enqueueing).
    112 
    113   ```python
    114   dataset = tf.data.Dataset.from_tensor_slices([2*x for x in range(10)])
    115   # Make a dataset of variable-length vectors and their lengths.
    116   dataset = dataset.map(lambda count: (count, tf.ones((count,))))
    117   # Emit a queue we can prepend to, and counts/values as padded batch.
    118   dataset = dataset.apply(
    119       tf.contrib.training.prepend_from_queue_and_padded_batch_dataset(
    120         batch_size=10))
    121   dataset = dataset.prefetch(1)
    122 
    123   iterator = dataset.make_one_shot_iterator()
    124   queue, (count, padded_value) = iterator.get_next()
    125 
    126   # Split the padded_value into two pieces: head and rest
    127   rest_indices = tf.squeeze(tf.where(count > 3), axis=1)
    128   bound = tf.minimum(3, tf.reduce_max(count))
    129   value_head = padded_value[:, :bound]
    130   count_rest = tf.gather(count - 3, rest_indices)
    131   value_rest = tf.gather(padded_value[:, bound:], rest_indices)
    132   queue_rest = tf.gather(queue, rest_indices)
    133   enqueue_rest_op = tf.contrib.training.enqueue_in_queue_dataset(
    134     queue_rest, (count_rest, value_rest))
    135   with tf.control_dependencies([enqueue_rest_op]):
    136     calculation = fn(value_head)
    137 
    138   while True:  # Will raise OutOfRange when finished with all pieces.
    139     session.run(calculation)
    140   ```
    141 
    142   Args:
    143     batch_size: `int64` scalar tensor.  The batch size to use when performing
    144       padded batching.
    145     padding_values: (optional) Nested tuple of scalar tensors.  If provided,
    146       the structure and dtypes of padding_values should match that of
    147       incoming dataset's `output_types`.
    148     padded_shapes: (optional) Nested tuple of `int64` vector tensors.
    149       If provided, the structure must match that of the incoming dataset's
    150       `output_types`.  If not provided, the incoming dataset's `output_shapes`
    151       is used.  Any unknown (`None` or `-1`) dimensions in the shapes are
    152       treated as being unique per-batch: for each batch time, an unknown
    153       dimension is replaced with the maximum given value of this dimension
    154       across all tensors for the given component in the batch.
    155 
    156   Returns:
    157     A `Dataset` transformation function, which can be passed to
    158     @{tf.data.Dataset.apply}.
    159   """
    160 
    161   def _apply_fn(dataset):
    162     return _PrependFromQueueAndPaddedBatchDataset(
    163         dataset,
    164         batch_size=batch_size,
    165         padding_values=padding_values,
    166         padded_shapes=padded_shapes)
    167 
    168   return _apply_fn
    169 
    170 
    171 def enqueue_in_queue_dataset(queue, components):
    172   """Enqueue components into queue from `PrependFromQueueAndPaddedBatchDataset`.
    173 
    174   The components' dtypes and shapes must be compatible with the `output_shapes`
    175   attribute of the `dataset` created by
    176   `prepend_from_queue_and_padded_batch_dataset`.  This operation supports both
    177   non-batched and batched modes.
    178 
    179   For more details, see the example in the docstring for
    180   `prepend_from_queue_and_padded_batch_dataset`.
    181 
    182   Args:
    183     queue: `variant` scalar or vector tensor.
    184       The tensor emitted by the first component of the iterator associated with
    185       `prepend_from_queue_and_padded_batch_dataset`.  If this is a scalar,
    186       then the `components` input tensors should not have a prepended batch
    187       dimension.
    188     components: Nested tuple of tensors, each with a leading batch dimension
    189       if `queue` is a vector.  The structure, dtypes, and shapes
    190       (excluding batch dimension) must match the nested tuples
    191       `dataset.output_types[1]` and `dataset.output_shapes[1]` (the non-queue
    192       output types and shapes) of the `dataset` emitted by
    193       the original `prepend_from_queue_and_padded_batch_dataset` call.
    194 
    195   Returns:
    196     An `Operation` that enqueues `components` into the dataset(s) associated
    197     with entries of `queue`.
    198   """
    199   return gen_dataset_ops.enqueue_in_queue_dataset(
    200       queue=queue, components=tf_nest.flatten(components))
    201