1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Non-deterministic dataset transformations.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.python.data.ops import dataset_ops 21 from tensorflow.python.data.util import convert 22 from tensorflow.python.data.util import nest 23 from tensorflow.python.data.util import sparse 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.framework import function 26 from tensorflow.python.framework import ops 27 from tensorflow.python.ops import gen_dataset_ops 28 from tensorflow.python.util import deprecation 29 30 31 class ParallelInterleaveDataset(dataset_ops.Dataset): 32 """A `Dataset` that maps a function over its input and flattens the result.""" 33 34 def __init__(self, input_dataset, map_func, cycle_length, block_length, 35 sloppy, buffer_output_elements, prefetch_input_elements): 36 """See `tf.contrib.data.parallel_interleave()` for details.""" 37 super(ParallelInterleaveDataset, self).__init__() 38 self._input_dataset = input_dataset 39 40 @function.Defun(*nest.flatten( 41 sparse.as_dense_types(input_dataset.output_types, 42 input_dataset.output_classes))) 43 def tf_map_func(*args): 44 """A wrapper for Defun that facilitates shape inference.""" 45 # Pass in shape information from the input_dataset. 46 dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, 47 input_dataset.output_classes) 48 for arg, shape in zip(args, nest.flatten(dense_shapes)): 49 arg.set_shape(shape) 50 51 nested_args = nest.pack_sequence_as(input_dataset.output_types, args) 52 nested_args = sparse.deserialize_sparse_tensors( 53 nested_args, input_dataset.output_types, input_dataset.output_shapes, 54 input_dataset.output_classes) 55 if dataset_ops._should_unpack_args(nested_args): # pylint: disable=protected-access 56 dataset = map_func(*nested_args) 57 else: 58 dataset = map_func(nested_args) 59 60 if not isinstance(dataset, dataset_ops.Dataset): 61 raise TypeError("`map_func` must return a `Dataset` object.") 62 63 self._output_classes = dataset.output_classes 64 self._output_types = dataset.output_types 65 self._output_shapes = dataset.output_shapes 66 67 return dataset._as_variant_tensor() # pylint: disable=protected-access 68 69 self._map_func = tf_map_func 70 self._map_func.add_to_graph(ops.get_default_graph()) 71 72 self._cycle_length = ops.convert_to_tensor( 73 cycle_length, dtype=dtypes.int64, name="cycle_length") 74 self._block_length = ops.convert_to_tensor( 75 block_length, dtype=dtypes.int64, name="block_length") 76 self._sloppy = ops.convert_to_tensor( 77 sloppy, dtype=dtypes.bool, name="sloppy") 78 self._buffer_output_elements = convert.optional_param_to_tensor( 79 "buffer_output_elements", 80 buffer_output_elements, 81 argument_default=2 * block_length) 82 self._prefetch_input_elements = convert.optional_param_to_tensor( 83 "prefetch_input_elements", 84 prefetch_input_elements, 85 argument_default=2 * cycle_length) 86 87 def _as_variant_tensor(self): 88 return gen_dataset_ops.parallel_interleave_dataset( 89 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 90 self._map_func.captured_inputs, 91 self._cycle_length, 92 self._block_length, 93 self._sloppy, 94 self._buffer_output_elements, 95 self._prefetch_input_elements, 96 f=self._map_func, 97 output_types=nest.flatten( 98 sparse.as_dense_types(self.output_types, self.output_classes)), 99 output_shapes=nest.flatten( 100 sparse.as_dense_shapes(self.output_shapes, self.output_classes))) 101 102 @property 103 def output_classes(self): 104 return self._output_classes 105 106 @property 107 def output_shapes(self): 108 return self._output_shapes 109 110 @property 111 def output_types(self): 112 return self._output_types 113 114 115 def parallel_interleave(map_func, 116 cycle_length, 117 block_length=1, 118 sloppy=False, 119 buffer_output_elements=None, 120 prefetch_input_elements=None): 121 """A parallel version of the `Dataset.interleave()` transformation. 122 123 `parallel_interleave()` maps `map_func` across its input to produce nested 124 datasets, and outputs their elements interleaved. Unlike 125 @{tf.data.Dataset.interleave}, it gets elements from `cycle_length` nested 126 datasets in parallel, which increases the throughput, especially in the 127 presence of stragglers. Furthermore, the `sloppy` argument can be used to 128 improve performance, by relaxing the requirement that the outputs are produced 129 in a deterministic order, and allowing the implementation to skip over nested 130 datasets whose elements are not readily available when requested. 131 132 Example usage: 133 134 ```python 135 # Preprocess 4 files concurrently. 136 filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords") 137 dataset = filenames.apply( 138 tf.contrib.data.parallel_interleave( 139 lambda filename: tf.data.TFRecordDataset(filename), 140 cycle_length=4)) 141 ``` 142 143 WARNING: If `sloppy` is `True`, the order of produced elements is not 144 deterministic. 145 146 Args: 147 map_func: A function mapping a nested structure of tensors to a `Dataset`. 148 cycle_length: The number of input `Dataset`s to interleave from in parallel. 149 block_length: The number of consecutive elements to pull from an input 150 `Dataset` before advancing to the next input `Dataset`. 151 sloppy: If false, elements are produced in deterministic order. Otherwise, 152 the implementation is allowed, for the sake of expediency, to produce 153 elements in a non-deterministic order. 154 buffer_output_elements: The number of elements each iterator being 155 interleaved should buffer (similar to the `.prefetch()` transformation for 156 each interleaved iterator). 157 prefetch_input_elements: The number of input elements to transform to 158 iterators before they are needed for interleaving. 159 160 Returns: 161 A `Dataset` transformation function, which can be passed to 162 @{tf.data.Dataset.apply}. 163 """ 164 def _apply_fn(dataset): 165 return ParallelInterleaveDataset( 166 dataset, map_func, cycle_length, block_length, sloppy, 167 buffer_output_elements, prefetch_input_elements) 168 169 return _apply_fn 170 171 172 @deprecation.deprecated( 173 None, "Use `tf.contrib.data.parallel_interleave(..., sloppy=True)`.") 174 def sloppy_interleave(map_func, cycle_length, block_length=1): 175 """A non-deterministic version of the `Dataset.interleave()` transformation. 176 177 `sloppy_interleave()` maps `map_func` across `dataset`, and 178 non-deterministically interleaves the results. 179 180 The resulting dataset is almost identical to `interleave`. The key 181 difference is that if retrieving a value from a given output iterator would 182 cause `get_next` to block, that iterator will be skipped, and consumed 183 when next available. If consuming from all iterators would cause the 184 `get_next` call to block, the `get_next` call blocks until the first value is 185 available. 186 187 If the underlying datasets produce elements as fast as they are consumed, the 188 `sloppy_interleave` transformation behaves identically to `interleave`. 189 However, if an underlying dataset would block the consumer, 190 `sloppy_interleave` can violate the round-robin order (that `interleave` 191 strictly obeys), producing an element from a different underlying 192 dataset instead. 193 194 Example usage: 195 196 ```python 197 # Preprocess 4 files concurrently. 198 filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords") 199 dataset = filenames.apply( 200 tf.contrib.data.sloppy_interleave( 201 lambda filename: tf.data.TFRecordDataset(filename), 202 cycle_length=4)) 203 ``` 204 205 WARNING: The order of elements in the resulting dataset is not 206 deterministic. Use `Dataset.interleave()` if you want the elements to have a 207 deterministic order. 208 209 Args: 210 map_func: A function mapping a nested structure of tensors (having shapes 211 and types defined by `self.output_shapes` and `self.output_types`) to a 212 `Dataset`. 213 cycle_length: The number of input `Dataset`s to interleave from in parallel. 214 block_length: The number of consecutive elements to pull from an input 215 `Dataset` before advancing to the next input `Dataset`. Note: 216 `sloppy_interleave` will skip the remainder of elements in the 217 `block_length` in order to avoid blocking. 218 219 Returns: 220 A `Dataset` transformation function, which can be passed to 221 @{tf.data.Dataset.apply}. 222 """ 223 def _apply_fn(dataset): 224 return ParallelInterleaveDataset( 225 dataset, 226 map_func, 227 cycle_length, 228 block_length, 229 sloppy=True, 230 buffer_output_elements=None, 231 prefetch_input_elements=None) 232 233 return _apply_fn 234