1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Batching dataset transformations.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.python.data.ops import dataset_ops 21 from tensorflow.python.data.util import nest 22 from tensorflow.python.data.util import sparse 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.framework import ops 25 from tensorflow.python.framework import sparse_tensor 26 from tensorflow.python.framework import tensor_shape 27 from tensorflow.python.framework import tensor_util 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.ops import gen_dataset_ops 30 from tensorflow.python.ops import math_ops 31 32 33 def dense_to_sparse_batch(batch_size, row_shape): 34 """A transformation that batches ragged elements into `tf.SparseTensor`s. 35 36 Like `Dataset.padded_batch()`, this transformation combines multiple 37 consecutive elements of the dataset, which might have different 38 shapes, into a single element. The resulting element has three 39 components (`indices`, `values`, and `dense_shape`), which 40 comprise a `tf.SparseTensor` that represents the same data. The 41 `row_shape` represents the dense shape of each row in the 42 resulting `tf.SparseTensor`, to which the effective batch size is 43 prepended. For example: 44 45 ```python 46 # NOTE: The following examples use `{ ... }` to represent the 47 # contents of a dataset. 48 a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } 49 50 a.apply(tf.contrib.data.dense_to_sparse_batch(batch_size=2, row_shape=[6])) == 51 { 52 ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices 53 ['a', 'b', 'c', 'a', 'b'], # values 54 [2, 6]), # dense_shape 55 ([[0, 0], [0, 1], [0, 2], [0, 3]], 56 ['a', 'b', 'c', 'd'], 57 [1, 6]) 58 } 59 ``` 60 61 Args: 62 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the 63 number of consecutive elements of this dataset to combine in a 64 single batch. 65 row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like 66 object representing the equivalent dense shape of a row in the 67 resulting `tf.SparseTensor`. Each element of this dataset must 68 have the same rank as `row_shape`, and must have size less 69 than or equal to `row_shape` in each dimension. 70 71 Returns: 72 A `Dataset` transformation function, which can be passed to 73 @{tf.data.Dataset.apply}. 74 """ 75 76 def _apply_fn(dataset): 77 return DenseToSparseBatchDataset(dataset, batch_size, row_shape) 78 79 return _apply_fn 80 81 82 def unbatch(): 83 """A Transformation which splits the elements of a dataset. 84 85 For example, if elements of the dataset are shaped `[B, a0, a1, ...]`, 86 where `B` may vary from element to element, then for each element in 87 the dataset, the unbatched dataset will contain `B` consecutive elements 88 of shape `[a0, a1, ...]`. 89 90 Returns: 91 A `Dataset` transformation function, which can be passed to 92 @{tf.data.Dataset.apply}. 93 """ 94 95 def _apply_fn(dataset): 96 97 def unbatch_map(arg, *rest): 98 if rest: 99 return dataset_ops.Dataset.from_tensor_slices((arg,) + rest) 100 else: 101 return dataset_ops.Dataset.from_tensor_slices(arg) 102 103 return dataset.flat_map(map_func=unbatch_map) 104 105 return _apply_fn 106 107 108 def filter_irregular_batches(batch_size): 109 """Transformation that filters out batches that are not of size batch_size.""" 110 111 def _apply_fn(dataset): 112 """Function from `Dataset` to `Dataset` that applies the transformation.""" 113 tensor_batch_size = ops.convert_to_tensor( 114 batch_size, dtype=dtypes.int64, name="batch_size") 115 116 flattened = _RestructuredDataset( 117 dataset, 118 tuple(nest.flatten(dataset.output_types)), 119 output_classes=tuple(nest.flatten(dataset.output_classes))) 120 121 def _predicate(*xs): 122 """Return `True` if this element is a full batch.""" 123 # Extract the dynamic batch size from the first component of the flattened 124 # batched element. 125 first_component = xs[0] 126 first_component_batch_size = array_ops.shape( 127 first_component, out_type=dtypes.int64)[0] 128 129 return math_ops.equal(first_component_batch_size, tensor_batch_size) 130 131 filtered = flattened.filter(_predicate) 132 133 maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size) 134 135 def _set_first_dimension(shape): 136 return shape.merge_with( 137 tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:])) 138 139 known_shapes = nest.map_structure(_set_first_dimension, 140 dataset.output_shapes) 141 return _RestructuredDataset( 142 filtered, 143 dataset.output_types, 144 known_shapes, 145 output_classes=dataset.output_classes) 146 147 return _apply_fn 148 149 150 def batch_and_drop_remainder(batch_size): 151 """A batching transformation that omits the final small batch (if present). 152 153 Like @{tf.data.Dataset.batch}, this transformation combines 154 consecutive elements of this dataset into batches. However, if the batch 155 size does not evenly divide the input dataset size, this transformation will 156 drop the final smaller element. 157 158 The following example illustrates the difference between this 159 transformation and `Dataset.batch()`: 160 161 ```python 162 dataset = tf.data.Dataset.range(200) 163 batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(128)) 164 print(batched.output_shapes) # ==> "(128,)" (the batch dimension is known) 165 ``` 166 167 By contrast, `dataset.batch(128)` would yield a two-element dataset with 168 shapes `(128,)` and `(72,)`, so the batch dimension would not be statically 169 known. 170 171 Args: 172 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 173 consecutive elements of this dataset to combine in a single batch. 174 175 Returns: 176 A `Dataset` transformation function, which can be passed to 177 @{tf.data.Dataset.apply} 178 """ 179 180 def _apply_fn(dataset): 181 """Function from `Dataset` to `Dataset` that applies the transformation.""" 182 batched = dataset.batch(batch_size) 183 return filter_irregular_batches(batch_size)(batched) 184 185 return _apply_fn 186 187 188 def padded_batch_and_drop_remainder(batch_size, 189 padded_shapes, 190 padding_values=None): 191 """A batching and padding transformation that omits the final small batch. 192 193 Like @{tf.data.Dataset.padded_batch}, this transformation combines 194 consecutive elements of this dataset into batches. However, if the batch 195 size does not evenly divide the input dataset size, this transformation will 196 drop the final smaller element. 197 198 See `@{tf.contrib.data.batch_and_drop_remainder}` for more details. 199 200 Args: 201 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 202 consecutive elements of this dataset to combine in a single batch. 203 padded_shapes: A nested structure of `tf.TensorShape` or 204 `tf.int64` vector tensor-like objects. See 205 @{tf.data.Dataset.padded_batch} for details. 206 padding_values: (Optional.) A nested structure of scalar-shaped 207 `tf.Tensor`. See @{tf.data.Dataset.padded_batch} for details. 208 209 Returns: 210 A `Dataset` transformation function, which can be passed to 211 @{tf.data.Dataset.apply} 212 """ 213 214 def _apply_fn(dataset): 215 """Function from `Dataset` to `Dataset` that applies the transformation.""" 216 batched = dataset.padded_batch( 217 batch_size, padded_shapes=padded_shapes, padding_values=padding_values) 218 return filter_irregular_batches(batch_size)(batched) 219 220 return _apply_fn 221 222 223 class DenseToSparseBatchDataset(dataset_ops.Dataset): 224 """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s.""" 225 226 def __init__(self, input_dataset, batch_size, row_shape): 227 """See `Dataset.dense_to_sparse_batch()` for more details.""" 228 super(DenseToSparseBatchDataset, self).__init__() 229 if not isinstance(input_dataset.output_types, dtypes.DType): 230 raise TypeError("DenseToSparseDataset requires an input whose elements " 231 "have a single component, whereas the input has %r." % 232 input_dataset.output_types) 233 self._input_dataset = input_dataset 234 self._batch_size = batch_size 235 self._row_shape = row_shape 236 237 def _as_variant_tensor(self): 238 return gen_dataset_ops.dense_to_sparse_batch_dataset( 239 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 240 self._batch_size, 241 row_shape=dataset_ops._partial_shape_to_tensor(self._row_shape), # pylint: disable=protected-access 242 output_shapes=nest.flatten( 243 sparse.as_dense_shapes(self.output_shapes, self.output_classes)), 244 output_types=nest.flatten( 245 sparse.as_dense_types(self.output_types, self.output_classes))) 246 247 @property 248 def output_classes(self): 249 return sparse_tensor.SparseTensor 250 251 @property 252 def output_shapes(self): 253 return tensor_shape.vector(None).concatenate(self._row_shape) 254 255 @property 256 def output_types(self): 257 return self._input_dataset.output_types 258 259 260 class _RestructuredDataset(dataset_ops.Dataset): 261 """An internal helper for changing the structure and shape of a dataset.""" 262 263 def __init__(self, 264 dataset, 265 output_types, 266 output_shapes=None, 267 output_classes=None): 268 """Creates a new dataset with the given output types and shapes. 269 270 The given `dataset` must have a structure that is convertible: 271 * `dataset.output_types` must be the same as `output_types` module nesting. 272 * Each shape in `dataset.output_shapes` must be compatible with each shape 273 in `output_shapes` (if given). 274 275 Note: This helper permits "unsafe casts" for shapes, equivalent to using 276 `tf.Tensor.set_shape()` where domain-specific knowledge is available. 277 278 Args: 279 dataset: A `Dataset` object. 280 output_types: A nested structure of `tf.DType` objects. 281 output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. 282 If omitted, the shapes will be inherited from `dataset`. 283 output_classes: (Optional.) A nested structure of class types. 284 If omitted, the class types will be inherited from `dataset`. 285 286 Raises: 287 ValueError: If either `output_types` or `output_shapes` is not compatible 288 with the structure of `dataset`. 289 """ 290 super(_RestructuredDataset, self).__init__() 291 self._dataset = dataset 292 293 # Validate that the types are compatible. 294 output_types = nest.map_structure(dtypes.as_dtype, output_types) 295 flat_original_types = nest.flatten(dataset.output_types) 296 flat_new_types = nest.flatten(output_types) 297 if flat_original_types != flat_new_types: 298 raise ValueError( 299 "Dataset with output types %r cannot be restructured to have output " 300 "types %r" % (dataset.output_types, output_types)) 301 302 self._output_types = output_types 303 304 if output_shapes is None: 305 # Inherit shapes from the original `dataset`. 306 self._output_shapes = nest.pack_sequence_as(output_types, 307 nest.flatten( 308 dataset.output_shapes)) 309 else: 310 # Validate that the shapes are compatible. 311 nest.assert_same_structure(output_types, output_shapes) 312 flat_original_shapes = nest.flatten(dataset.output_shapes) 313 flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) 314 315 for original_shape, new_shape in zip(flat_original_shapes, 316 flat_new_shapes): 317 if not original_shape.is_compatible_with(new_shape): 318 raise ValueError( 319 "Dataset with output shapes %r cannot be restructured to have " 320 "incompatible output shapes %r" % (dataset.output_shapes, 321 output_shapes)) 322 self._output_shapes = nest.map_structure_up_to( 323 output_types, tensor_shape.as_shape, output_shapes) 324 if output_classes is None: 325 # Inherit class types from the original `dataset`. 326 self._output_classes = nest.pack_sequence_as(output_types, 327 nest.flatten( 328 dataset.output_classes)) 329 else: 330 self._output_classes = output_classes 331 332 def _as_variant_tensor(self): 333 return self._dataset._as_variant_tensor() # pylint: disable=protected-access 334 335 @property 336 def output_classes(self): 337 return self._output_classes 338 339 @property 340 def output_types(self): 341 return self._output_types 342 343 @property 344 def output_shapes(self): 345 return self._output_shapes 346 347 348 class _MapAndBatchDataset(dataset_ops.MapDataset): 349 """A `Dataset` that maps a function over a batch of elements.""" 350 351 def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches): 352 """See `Dataset.map()` for details.""" 353 super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) 354 self._batch_size = ops.convert_to_tensor( 355 batch_size, dtype=dtypes.int64, name="batch_size") 356 self._num_parallel_batches = ops.convert_to_tensor( 357 num_parallel_batches, dtype=dtypes.int64, name="num_parallel_batches") 358 359 def _as_variant_tensor(self): 360 # pylint: disable=protected-access 361 input_resource = self._input_dataset._as_variant_tensor() 362 return gen_dataset_ops.map_and_batch_dataset( 363 input_resource, 364 self._map_func.captured_inputs, 365 f=self._map_func, 366 batch_size=self._batch_size, 367 num_parallel_batches=self._num_parallel_batches, 368 output_types=nest.flatten( 369 sparse.as_dense_types(self.output_types, self.output_classes)), 370 output_shapes=nest.flatten( 371 sparse.as_dense_shapes(self.output_shapes, self.output_classes))) 372 # pylint: enable=protected-access 373 374 @property 375 def output_shapes(self): 376 return nest.pack_sequence_as(self._output_shapes, [ 377 tensor_shape.vector(tensor_util.constant_value( 378 self._batch_size)).concatenate(s) 379 for s in nest.flatten(self._output_shapes) 380 ]) 381 382 @property 383 def output_types(self): 384 return self._output_types 385 386 387 def map_and_batch(map_func, batch_size, num_parallel_batches=1): 388 """Fused implementation of `map` and `batch`. 389 390 Maps `map_func` across `batch_size` consecutive elements of this dataset 391 and then combines them into a batch. Functionally, it is equivalent to `map` 392 followed by `batch`. However, by fusing the two transformations together, the 393 implementation can be more efficient. Surfacing this transformation in the API 394 is temporary. Once automatic input pipeline optimization is implemented, 395 the fusing of `map` and `batch` will happen automatically and this API will be 396 deprecated. 397 398 Args: 399 map_func: A function mapping a nested structure of tensors to another 400 nested structure of tensors. 401 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 402 consecutive elements of this dataset to combine in a single batch. 403 num_parallel_batches: A `tf.int64` scalar `tf.Tensor`, representing the 404 number of batches to create in parallel. On one hand, higher values can 405 help mitigate the effect of stragglers. On the other hand, higher values 406 can increase contention if CPU is scarce. 407 408 Returns: 409 A `Dataset` transformation function, which can be passed to 410 @{tf.data.Dataset.apply}. 411 """ 412 413 def _apply_fn(dataset): 414 return _MapAndBatchDataset(dataset, map_func, batch_size, 415 num_parallel_batches) 416 417 return _apply_fn 418