Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Experimental shuffle ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.data.ops import dataset_ops
     21 from tensorflow.python.data.util import nest
     22 from tensorflow.python.data.util import sparse
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.framework import random_seed
     27 from tensorflow.python.ops import gen_dataset_ops
     28 
     29 
     30 class _ShuffleAndRepeatDataset(dataset_ops.Dataset):
     31   """A `Dataset` that fuses `shuffle` and `repeat`."""
     32 
     33   def __init__(self,
     34                input_dataset,
     35                buffer_size,
     36                count=None,
     37                seed=None):
     38     """See `Dataset.map()` for details."""
     39     super(_ShuffleAndRepeatDataset, self).__init__()
     40     self._input_dataset = input_dataset
     41     self._buffer_size = ops.convert_to_tensor(
     42         buffer_size, dtype=dtypes.int64, name="buffer_size")
     43     if count is None:
     44       self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
     45     else:
     46       self._count = ops.convert_to_tensor(
     47           count, dtype=dtypes.int64, name="count")
     48 
     49     seed, seed2 = random_seed.get_seed(seed)
     50     if seed is None:
     51       self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed")
     52     else:
     53       self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed")
     54     if seed2 is None:
     55       self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2")
     56     else:
     57       self._seed2 = ops.convert_to_tensor(
     58           seed2, dtype=dtypes.int64, name="seed2")
     59 
     60   def _as_variant_tensor(self):
     61     # pylint: disable=protected-access
     62     input_resource = self._input_dataset._as_variant_tensor()
     63     return gen_dataset_ops.shuffle_and_repeat_dataset(
     64         input_resource,
     65         buffer_size=self._buffer_size,
     66         count=self._count,
     67         seed=self._seed,
     68         seed2=self._seed2,
     69         output_types=nest.flatten(
     70             sparse.as_dense_types(self.output_types, self.output_classes)),
     71         output_shapes=nest.flatten(
     72             sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
     73     # pylint: enable=protected-access
     74 
     75   @property
     76   def output_classes(self):
     77     return self._input_dataset.output_classes
     78 
     79   @property
     80   def output_shapes(self):
     81     return self._input_dataset.output_shapes
     82 
     83   @property
     84   def output_types(self):
     85     return self._input_dataset.output_types
     86 
     87 
     88 def shuffle_and_repeat(buffer_size, count=None, seed=None):
     89   """Shuffles and repeats a Dataset returning a new permutation for each epoch.
     90 
     91   `dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size, count))`
     92 
     93   is equivalent to
     94 
     95   `dataset.shuffle(buffer_size, reshuffle_each_iteration=True).repeat(count)`
     96 
     97   The difference is that the latter dataset is not serializable. So,
     98   if you need to checkpoint an input pipeline with reshuffling you must use
     99   this implementation.
    100 
    101   Args:
    102     buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
    103       maximum number elements that will be buffered when prefetching.
    104     count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
    105       number of times the dataset should be repeated. The default behavior
    106       (if `count` is `None` or `-1`) is for the dataset be repeated
    107       indefinitely.
    108     seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
    109       random seed that will be used to create the distribution. See
    110       @{tf.set_random_seed} for behavior.
    111 
    112   Returns:
    113     A `Dataset` transformation function, which can be passed to
    114     @{tf.data.Dataset.apply}.
    115   """
    116 
    117   def _apply_fn(dataset):  # pylint: disable=missing-docstring
    118     return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed)
    119 
    120   return _apply_fn
    121