1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Experimental shuffle ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.python.data.ops import dataset_ops 21 from tensorflow.python.data.util import nest 22 from tensorflow.python.data.util import sparse 23 from tensorflow.python.framework import constant_op 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.framework import ops 26 from tensorflow.python.framework import random_seed 27 from tensorflow.python.ops import gen_dataset_ops 28 29 30 class _ShuffleAndRepeatDataset(dataset_ops.Dataset): 31 """A `Dataset` that fuses `shuffle` and `repeat`.""" 32 33 def __init__(self, 34 input_dataset, 35 buffer_size, 36 count=None, 37 seed=None): 38 """See `Dataset.map()` for details.""" 39 super(_ShuffleAndRepeatDataset, self).__init__() 40 self._input_dataset = input_dataset 41 self._buffer_size = ops.convert_to_tensor( 42 buffer_size, dtype=dtypes.int64, name="buffer_size") 43 if count is None: 44 self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") 45 else: 46 self._count = ops.convert_to_tensor( 47 count, dtype=dtypes.int64, name="count") 48 49 seed, seed2 = random_seed.get_seed(seed) 50 if seed is None: 51 self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") 52 else: 53 self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") 54 if seed2 is None: 55 self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") 56 else: 57 self._seed2 = ops.convert_to_tensor( 58 seed2, dtype=dtypes.int64, name="seed2") 59 60 def _as_variant_tensor(self): 61 # pylint: disable=protected-access 62 input_resource = self._input_dataset._as_variant_tensor() 63 return gen_dataset_ops.shuffle_and_repeat_dataset( 64 input_resource, 65 buffer_size=self._buffer_size, 66 count=self._count, 67 seed=self._seed, 68 seed2=self._seed2, 69 output_types=nest.flatten( 70 sparse.as_dense_types(self.output_types, self.output_classes)), 71 output_shapes=nest.flatten( 72 sparse.as_dense_shapes(self.output_shapes, self.output_classes))) 73 # pylint: enable=protected-access 74 75 @property 76 def output_classes(self): 77 return self._input_dataset.output_classes 78 79 @property 80 def output_shapes(self): 81 return self._input_dataset.output_shapes 82 83 @property 84 def output_types(self): 85 return self._input_dataset.output_types 86 87 88 def shuffle_and_repeat(buffer_size, count=None, seed=None): 89 """Shuffles and repeats a Dataset returning a new permutation for each epoch. 90 91 `dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size, count))` 92 93 is equivalent to 94 95 `dataset.shuffle(buffer_size, reshuffle_each_iteration=True).repeat(count)` 96 97 The difference is that the latter dataset is not serializable. So, 98 if you need to checkpoint an input pipeline with reshuffling you must use 99 this implementation. 100 101 Args: 102 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the 103 maximum number elements that will be buffered when prefetching. 104 count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 105 number of times the dataset should be repeated. The default behavior 106 (if `count` is `None` or `-1`) is for the dataset be repeated 107 indefinitely. 108 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 109 random seed that will be used to create the distribution. See 110 @{tf.set_random_seed} for behavior. 111 112 Returns: 113 A `Dataset` transformation function, which can be passed to 114 @{tf.data.Dataset.apply}. 115 """ 116 117 def _apply_fn(dataset): # pylint: disable=missing-docstring 118 return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed) 119 120 return _apply_fn 121