1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Python wrappers for Datasets.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import abc 21 import functools 22 import threading 23 import warnings 24 25 import numpy as np 26 import six 27 from six.moves import queue as Queue # pylint: disable=redefined-builtin 28 29 30 from tensorflow.python.compat import compat 31 from tensorflow.python.data.experimental.ops import optimization_options 32 from tensorflow.python.data.experimental.ops import stats_options 33 from tensorflow.python.data.experimental.ops import threading_options 34 from tensorflow.python.data.ops import iterator_ops 35 from tensorflow.python.data.util import nest 36 from tensorflow.python.data.util import options as options_lib 37 from tensorflow.python.data.util import random_seed 38 from tensorflow.python.data.util import sparse 39 from tensorflow.python.data.util import structure as structure_lib 40 from tensorflow.python.data.util import traverse 41 from tensorflow.python.eager import context 42 from tensorflow.python.eager import function as eager_function 43 from tensorflow.python.framework import constant_op 44 from tensorflow.python.framework import dtypes 45 from tensorflow.python.framework import function 46 from tensorflow.python.framework import ops 47 from tensorflow.python.framework import random_seed as core_random_seed 48 from tensorflow.python.framework import smart_cond 49 from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib 50 from tensorflow.python.framework import tensor_shape 51 from tensorflow.python.framework import tensor_spec 52 from tensorflow.python.framework import tensor_util 53 from tensorflow.python.ops import array_ops 54 from tensorflow.python.ops import control_flow_ops 55 from tensorflow.python.ops import gen_dataset_ops 56 from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 57 from tensorflow.python.ops import gen_io_ops 58 from tensorflow.python.ops import math_ops 59 from tensorflow.python.ops import script_ops 60 from tensorflow.python.ops import string_ops 61 from tensorflow.python.platform import tf_logging as logging 62 from tensorflow.python.training.tracking import tracking 63 from tensorflow.python.util import deprecation 64 from tensorflow.python.util import function_utils 65 from tensorflow.python.util.tf_export import tf_export 66 67 68 ops.NotDifferentiable("ReduceDataset") 69 70 71 @tf_export("data.Dataset", v1=[]) 72 @six.add_metaclass(abc.ABCMeta) 73 class DatasetV2(object): 74 """Represents a potentially large set of elements. 75 76 A `Dataset` can be used to represent an input pipeline as a 77 collection of elements (nested structures of tensors) and a "logical 78 plan" of transformations that act on those elements. 79 """ 80 81 def __init__(self, variant_tensor): 82 """Creates a DatasetV2 object. 83 84 This is a difference between DatasetV1 and DatasetV2. DatasetV1 does not 85 take anything in its constructor whereas in the DatasetV2, we expect 86 subclasses to create a variant_tensor and pass it in to the super() call. 87 88 Args: 89 variant_tensor: A DT_VARIANT tensor that represents the dataset. 90 """ 91 self._variant_tensor_attr = variant_tensor 92 self._graph_attr = ops.get_default_graph() 93 94 @property 95 def _variant_tensor(self): 96 return self._variant_tensor_attr 97 98 @_variant_tensor.setter 99 def _variant_tensor(self, _): 100 raise ValueError("The _variant_tensor property is read-only") 101 102 def _as_serialized_graph(self): 103 """Produces serialized graph representation of the dataset. 104 105 Returns: 106 A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a 107 serialized graph. 108 """ 109 return gen_dataset_ops.dataset_to_graph(self._variant_tensor) 110 111 @abc.abstractmethod 112 def _inputs(self): 113 """Returns a list of the input datasets of the dataset.""" 114 115 raise NotImplementedError("Dataset._inputs") 116 117 @property 118 def _graph(self): 119 return self._graph_attr 120 121 @_graph.setter 122 def _graph(self, _): 123 raise ValueError("The _graph property is read-only") 124 125 def _has_captured_ref(self): 126 """Whether this dataset uses a function that captures ref variables. 127 128 Returns: 129 A boolean, which if true indicates that the dataset or one of its inputs 130 uses a function that captures ref variables. 131 """ 132 if context.executing_eagerly(): 133 # RefVariables are not supported in eager mode 134 return False 135 136 def is_tensor_or_parent_ref(tensor): 137 if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access 138 return True 139 return any([is_tensor_or_parent_ref(x) for x in tensor.op.inputs]) 140 141 for fn in self._functions(): 142 if any([is_tensor_or_parent_ref(t) for t in fn.function.captured_inputs]): 143 return True 144 145 return any( 146 [input_dataset._has_captured_ref() for input_dataset in self._inputs()]) # pylint: disable=protected-access 147 148 # TODO(jsimsa): Change this to be the transitive closure of functions used 149 # by this dataset and its inputs. 150 def _functions(self): 151 """Returns a list of functions associated with this dataset. 152 153 Returns: 154 A list of `StructuredFunctionWrapper` objects. 155 """ 156 return [] 157 158 def options(self): 159 """Returns the options for this dataset and its inputs. 160 161 Returns: 162 A `tf.data.Options` object representing the dataset options. 163 """ 164 options = Options() 165 for input_dataset in self._inputs(): 166 input_options = input_dataset.options() 167 if input_options is not None: 168 options = options.merge(input_options) 169 return options 170 171 def _apply_options(self): 172 """Apply options, such as optimization configuration, to the dataset.""" 173 174 dataset = self 175 options = self.options() 176 if options.experimental_threading is not None: 177 t_options = options.experimental_threading 178 if t_options.max_intra_op_parallelism is not None: 179 dataset = _MaxIntraOpParallelismDataset( 180 dataset, t_options.max_intra_op_parallelism) 181 if t_options.private_threadpool_size is not None: 182 dataset = _PrivateThreadPoolDataset(dataset, 183 t_options.private_threadpool_size) 184 static_optimizations = options._static_optimizations() # pylint: disable=protected-access 185 if static_optimizations: 186 if self._has_captured_ref(): 187 warnings.warn( 188 "tf.data static optimizations are not compatible with tf.Variable. " 189 "The following optimizations will be disabled: %s. To enable " 190 "optimizations, use resource variables instead by calling " 191 "`tf.enable_resource_variables()` at the start of the program." % 192 ", ".join(static_optimizations)) 193 else: 194 dataset = _OptimizeDataset(dataset, static_optimizations) 195 196 autotune = True 197 cpu_budget = 0 # Indicates that all CPU cores should be used. 198 if options.experimental_optimization is not None: 199 if options.experimental_optimization.autotune is False: # pylint: disable=g-bool-id-comparison 200 autotune = False 201 if options.experimental_optimization.autotune_cpu_budget is not None: 202 cpu_budget = options.experimental_optimization.autotune_cpu_budget 203 204 if autotune: 205 dataset = _ModelDataset(dataset, cpu_budget) 206 207 if options.experimental_stats and options.experimental_stats.aggregator: # pylint: disable=line-too-long 208 dataset = _SetStatsAggregatorDataset( # pylint: disable=protected-access 209 dataset, options.experimental_stats.aggregator, 210 options.experimental_stats.prefix, 211 options.experimental_stats.counter_prefix) 212 return dataset 213 214 def __iter__(self): 215 """Creates an `Iterator` for enumerating the elements of this dataset. 216 217 The returned iterator implements the Python iterator protocol and therefore 218 can only be used in eager mode. 219 220 Returns: 221 An `Iterator` over the elements of this dataset. 222 223 Raises: 224 RuntimeError: If eager execution is not enabled. 225 """ 226 if context.executing_eagerly(): 227 return iterator_ops.EagerIterator(self) 228 else: 229 raise RuntimeError("dataset.__iter__() is only supported when eager " 230 "execution is enabled.") 231 232 @abc.abstractproperty 233 def _element_structure(self): 234 """The structure of an element of this dataset. 235 236 Returns: 237 A `Structure` object representing the structure of an element of this 238 dataset. 239 """ 240 raise NotImplementedError("Dataset._element_structure") 241 242 def __repr__(self): 243 output_shapes = nest.map_structure(str, get_legacy_output_shapes(self)) 244 output_shapes = str(output_shapes).replace("'", "") 245 output_types = nest.map_structure(repr, get_legacy_output_types(self)) 246 output_types = str(output_types).replace("'", "") 247 return ("<%s shapes: %s, types: %s>" % (type(self).__name__, output_shapes, 248 output_types)) 249 250 @staticmethod 251 def from_tensors(tensors): 252 """Creates a `Dataset` with a single element, comprising the given tensors. 253 254 Note that if `tensors` contains a NumPy array, and eager execution is not 255 enabled, the values will be embedded in the graph as one or more 256 `tf.constant` operations. For large datasets (> 1 GB), this can waste 257 memory and run into byte limits of graph serialization. If `tensors` 258 contains one or more large NumPy arrays, consider the alternative described 259 in [this 260 guide](https://tensorflow.org/guide/datasets#consuming_numpy_arrays). 261 262 Args: 263 tensors: A nested structure of tensors. 264 265 Returns: 266 Dataset: A `Dataset`. 267 """ 268 return TensorDataset(tensors) 269 270 @staticmethod 271 def from_tensor_slices(tensors): 272 """Creates a `Dataset` whose elements are slices of the given tensors. 273 274 Note that if `tensors` contains a NumPy array, and eager execution is not 275 enabled, the values will be embedded in the graph as one or more 276 `tf.constant` operations. For large datasets (> 1 GB), this can waste 277 memory and run into byte limits of graph serialization. If `tensors` 278 contains one or more large NumPy arrays, consider the alternative described 279 in [this guide]( 280 https://tensorflow.org/guide/datasets#consuming_numpy_arrays). 281 282 Args: 283 tensors: A nested structure of tensors, each having the same size in the 284 0th dimension. 285 286 Returns: 287 Dataset: A `Dataset`. 288 """ 289 return TensorSliceDataset(tensors) 290 291 class _GeneratorState(object): 292 """Stores outstanding iterators created from a Python generator. 293 294 This class keeps track of potentially multiple iterators that may have 295 been created from a generator, e.g. in the case that the dataset is 296 repeated, or nested within a parallel computation. 297 """ 298 299 def __init__(self, generator): 300 self._generator = generator 301 self._lock = threading.Lock() 302 self._next_id = 0 # GUARDED_BY(self._lock) 303 self._args = {} 304 self._iterators = {} 305 306 def get_next_id(self, *args): 307 with self._lock: 308 ret = self._next_id 309 self._next_id += 1 310 self._args[ret] = args 311 # NOTE(mrry): Explicitly create an array of `np.int64` because implicit 312 # casting in `py_func()` will create an array of `np.int32` on Windows, 313 # leading to a runtime error. 314 return np.array(ret, dtype=np.int64) 315 316 def get_iterator(self, iterator_id): 317 try: 318 return self._iterators[iterator_id] 319 except KeyError: 320 iterator = iter(self._generator(*self._args.pop(iterator_id))) 321 self._iterators[iterator_id] = iterator 322 return iterator 323 324 def iterator_completed(self, iterator_id): 325 del self._iterators[iterator_id] 326 327 @staticmethod 328 def from_generator(generator, output_types, output_shapes=None, args=None): 329 """Creates a `Dataset` whose elements are generated by `generator`. 330 331 The `generator` argument must be a callable object that returns 332 an object that support the `iter()` protocol (e.g. a generator function). 333 The elements generated by `generator` must be compatible with the given 334 `output_types` and (optional) `output_shapes` arguments. 335 336 For example: 337 338 ```python 339 import itertools 340 tf.enable_eager_execution() 341 342 def gen(): 343 for i in itertools.count(1): 344 yield (i, [1] * i) 345 346 ds = tf.data.Dataset.from_generator( 347 gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None]))) 348 349 for value in ds.take(2): 350 print value 351 # (1, array([1])) 352 # (2, array([1, 1])) 353 ``` 354 355 NOTE: The current implementation of `Dataset.from_generator()` uses 356 `tf.py_func` and inherits the same constraints. In particular, it 357 requires the `Dataset`- and `Iterator`-related operations to be placed 358 on a device in the same process as the Python program that called 359 `Dataset.from_generator()`. The body of `generator` will not be 360 serialized in a `GraphDef`, and you should not use this method if you 361 need to serialize your model and restore it in a different environment. 362 363 NOTE: If `generator` depends on mutable global variables or other external 364 state, be aware that the runtime may invoke `generator` multiple times 365 (in order to support repeating the `Dataset`) and at any time 366 between the call to `Dataset.from_generator()` and the production of the 367 first element from the generator. Mutating global variables or external 368 state can cause undefined behavior, and we recommend that you explicitly 369 cache any external state in `generator` before calling 370 `Dataset.from_generator()`. 371 372 Args: 373 generator: A callable object that returns an object that supports the 374 `iter()` protocol. If `args` is not specified, `generator` must take 375 no arguments; otherwise it must take as many arguments as there are 376 values in `args`. 377 output_types: A nested structure of `tf.DType` objects corresponding to 378 each component of an element yielded by `generator`. 379 output_shapes: (Optional.) A nested structure of `tf.TensorShape` 380 objects corresponding to each component of an element yielded by 381 `generator`. 382 args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated 383 and passed to `generator` as NumPy-array arguments. 384 385 Returns: 386 Dataset: A `Dataset`. 387 """ 388 if not callable(generator): 389 raise TypeError("`generator` must be callable.") 390 if output_shapes is None: 391 output_shapes = nest.map_structure( 392 lambda _: tensor_shape.TensorShape(None), output_types) 393 else: 394 output_shapes = nest.map_structure_up_to( 395 output_types, tensor_shape.as_shape, output_shapes) 396 if args is None: 397 args = () 398 else: 399 args = tuple(ops.convert_n_to_tensor(args, name="args")) 400 401 flattened_types = [dtypes.as_dtype(dt) for dt in nest.flatten(output_types)] 402 flattened_shapes = nest.flatten(output_shapes) 403 404 generator_state = DatasetV2._GeneratorState(generator) 405 406 def get_iterator_id_fn(unused_dummy): 407 """Creates a unique `iterator_id` for each pass over the dataset. 408 409 The returned `iterator_id` disambiguates between multiple concurrently 410 existing iterators. 411 412 Args: 413 unused_dummy: Ignored value. 414 415 Returns: 416 A `tf.int64` tensor whose value uniquely identifies an iterator in 417 `generator_state`. 418 """ 419 return script_ops.py_func( 420 generator_state.get_next_id, args, dtypes.int64, stateful=True) 421 422 def generator_next_fn(iterator_id_t): 423 """Generates the next element from iterator with ID `iterator_id_t`. 424 425 We map this function across an infinite repetition of the 426 `iterator_id_t`, and raise `StopIteration` to terminate the iteration. 427 428 Args: 429 iterator_id_t: A `tf.int64` tensor whose value uniquely identifies 430 the iterator in `generator_state` from which to generate an element. 431 432 Returns: 433 A nested structure of tensors representing an element from the iterator. 434 """ 435 436 def generator_py_func(iterator_id): 437 """A `py_func` that will be called to invoke the iterator.""" 438 # `next()` raises `StopIteration` when there are no more 439 # elements remaining to be generated. 440 values = next(generator_state.get_iterator(iterator_id)) 441 442 # Use the same _convert function from the py_func() implementation to 443 # convert the returned values to arrays early, so that we can inspect 444 # their values. 445 try: 446 flattened_values = nest.flatten_up_to(output_types, values) 447 except (TypeError, ValueError): 448 raise TypeError( 449 "`generator` yielded an element that did not match the expected " 450 "structure. The expected structure was %s, but the yielded " 451 "element was %s." % (output_types, values)) 452 ret_arrays = [] 453 for ret, dtype in zip(flattened_values, flattened_types): 454 try: 455 ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access 456 ret, dtype=dtype.as_numpy_dtype)) 457 except (TypeError, ValueError): 458 raise TypeError( 459 "`generator` yielded an element that could not be converted to " 460 "the expected type. The expected type was %s, but the yielded " 461 "element was %s." % (dtype.name, ret)) 462 463 # Additional type and shape checking to ensure that the components 464 # of the generated element match the `output_types` and `output_shapes` 465 # arguments. 466 for (ret_array, expected_dtype, expected_shape) in zip( 467 ret_arrays, flattened_types, flattened_shapes): 468 if ret_array.dtype != expected_dtype.as_numpy_dtype: 469 raise TypeError( 470 "`generator` yielded an element of type %s where an element " 471 "of type %s was expected." % (ret_array.dtype, 472 expected_dtype.as_numpy_dtype)) 473 if not expected_shape.is_compatible_with(ret_array.shape): 474 raise ValueError( 475 "`generator` yielded an element of shape %s where an element " 476 "of shape %s was expected." % (ret_array.shape, expected_shape)) 477 478 return ret_arrays 479 480 flat_values = script_ops.py_func( 481 generator_py_func, [iterator_id_t], flattened_types, stateful=True) 482 483 # The `py_func()` op drops the inferred shapes, so we add them back in 484 # here. 485 if output_shapes is not None: 486 for ret_t, shape in zip(flat_values, flattened_shapes): 487 ret_t.set_shape(shape) 488 489 return nest.pack_sequence_as(output_types, flat_values) 490 491 def finalize_fn(iterator_id_t): 492 """Releases host-side state for the iterator with ID `iterator_id_t`.""" 493 494 def finalize_py_func(iterator_id): 495 generator_state.iterator_completed(iterator_id) 496 # We return a dummy value so that the `finalize_fn` has a valid 497 # signature. 498 # NOTE(mrry): Explicitly create an array of `np.int64` because implicit 499 # casting in `py_func()` will create an array of `np.int32` on Windows, 500 # leading to a runtime error. 501 return np.array(0, dtype=np.int64) 502 503 return script_ops.py_func( 504 finalize_py_func, [iterator_id_t], dtypes.int64, stateful=True) 505 506 # This function associates each traversal of `generator` with a unique 507 # iterator ID. 508 def flat_map_fn(dummy_arg): 509 # The `get_iterator_id_fn` gets a unique ID for the current instance of 510 # of the generator. 511 # The `generator_next_fn` gets the next element from the iterator with the 512 # given ID, and raises StopIteration when that iterator contains no 513 # more elements. 514 return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn, 515 finalize_fn) 516 517 # A single-element dataset that, each time it is evaluated, contains a 518 # freshly-generated and unique (for the returned dataset) int64 519 # ID that will be used to identify the appropriate Python state, which 520 # is encapsulated in `generator_state`, and captured in 521 # `get_iterator_id_map_fn`. 522 dummy = 0 523 id_dataset = Dataset.from_tensors(dummy) 524 525 # A dataset that contains all of the elements generated by a 526 # single iterator created from `generator`, identified by the 527 # iterator ID contained in `id_dataset`. Lifting the iteration 528 # into a flat_map here enables multiple repetitions and/or nested 529 # versions of the returned dataset to be created, because it forces 530 # the generation of a new ID for each version. 531 return id_dataset.flat_map(flat_map_fn) 532 533 @staticmethod 534 def range(*args): 535 """Creates a `Dataset` of a step-separated range of values. 536 537 For example: 538 539 ```python 540 Dataset.range(5) == [0, 1, 2, 3, 4] 541 Dataset.range(2, 5) == [2, 3, 4] 542 Dataset.range(1, 5, 2) == [1, 3] 543 Dataset.range(1, 5, -2) == [] 544 Dataset.range(5, 1) == [] 545 Dataset.range(5, 1, -2) == [5, 3] 546 ``` 547 548 Args: 549 *args: follows the same semantics as python's xrange. 550 len(args) == 1 -> start = 0, stop = args[0], step = 1 551 len(args) == 2 -> start = args[0], stop = args[1], step = 1 552 len(args) == 3 -> start = args[0], stop = args[1, stop = args[2] 553 554 Returns: 555 Dataset: A `RangeDataset`. 556 557 Raises: 558 ValueError: if len(args) == 0. 559 """ 560 return RangeDataset(*args) 561 562 @staticmethod 563 def zip(datasets): 564 """Creates a `Dataset` by zipping together the given datasets. 565 566 This method has similar semantics to the built-in `zip()` function 567 in Python, with the main difference being that the `datasets` 568 argument can be an arbitrary nested structure of `Dataset` objects. 569 For example: 570 571 ```python 572 # NOTE: The following examples use `{ ... }` to represent the 573 # contents of a dataset. 574 a = { 1, 2, 3 } 575 b = { 4, 5, 6 } 576 c = { (7, 8), (9, 10), (11, 12) } 577 d = { 13, 14 } 578 579 # The nested structure of the `datasets` argument determines the 580 # structure of elements in the resulting dataset. 581 Dataset.zip((a, b)) == { (1, 4), (2, 5), (3, 6) } 582 Dataset.zip((b, a)) == { (4, 1), (5, 2), (6, 3) } 583 584 # The `datasets` argument may contain an arbitrary number of 585 # datasets. 586 Dataset.zip((a, b, c)) == { (1, 4, (7, 8)), 587 (2, 5, (9, 10)), 588 (3, 6, (11, 12)) } 589 590 # The number of elements in the resulting dataset is the same as 591 # the size of the smallest dataset in `datasets`. 592 Dataset.zip((a, d)) == { (1, 13), (2, 14) } 593 ``` 594 595 Args: 596 datasets: A nested structure of datasets. 597 598 Returns: 599 Dataset: A `Dataset`. 600 """ 601 return ZipDataset(datasets) 602 603 def concatenate(self, dataset): 604 """Creates a `Dataset` by concatenating given dataset with this dataset. 605 606 ```python 607 # NOTE: The following examples use `{ ... }` to represent the 608 # contents of a dataset. 609 a = { 1, 2, 3 } 610 b = { 4, 5, 6, 7 } 611 612 # Input dataset and dataset to be concatenated should have same 613 # nested structures and output types. 614 # c = { (8, 9), (10, 11), (12, 13) } 615 # d = { 14.0, 15.0, 16.0 } 616 # a.concatenate(c) and a.concatenate(d) would result in error. 617 618 a.concatenate(b) == { 1, 2, 3, 4, 5, 6, 7 } 619 ``` 620 621 Args: 622 dataset: `Dataset` to be concatenated. 623 624 Returns: 625 Dataset: A `Dataset`. 626 """ 627 return ConcatenateDataset(self, dataset) 628 629 def prefetch(self, buffer_size): 630 """Creates a `Dataset` that prefetches elements from this dataset. 631 632 Args: 633 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the 634 maximum number of elements that will be buffered when prefetching. 635 636 Returns: 637 Dataset: A `Dataset`. 638 """ 639 return PrefetchDataset(self, buffer_size) 640 641 @staticmethod 642 def list_files(file_pattern, shuffle=None, seed=None): 643 """A dataset of all files matching one or more glob patterns. 644 645 NOTE: The default behavior of this method is to return filenames in 646 a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False` 647 to get results in a deterministic order. 648 649 Example: 650 If we had the following files on our filesystem: 651 - /path/to/dir/a.txt 652 - /path/to/dir/b.py 653 - /path/to/dir/c.py 654 If we pass "/path/to/dir/*.py" as the directory, the dataset would 655 produce: 656 - /path/to/dir/b.py 657 - /path/to/dir/c.py 658 659 Args: 660 file_pattern: A string, a list of strings, or a `tf.Tensor` of string type 661 (scalar or vector), representing the filename glob (i.e. shell wildcard) 662 pattern(s) that will be matched. 663 shuffle: (Optional.) If `True`, the file names will be shuffled randomly. 664 Defaults to `True`. 665 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 666 seed that will be used to create the distribution. See 667 `tf.set_random_seed` for behavior. 668 669 Returns: 670 Dataset: A `Dataset` of strings corresponding to file names. 671 """ 672 with ops.name_scope("list_files"): 673 if shuffle is None: 674 shuffle = True 675 file_pattern = ops.convert_to_tensor( 676 file_pattern, dtype=dtypes.string, name="file_pattern") 677 matching_files = gen_io_ops.matching_files(file_pattern) 678 679 # Raise an exception if `file_pattern` does not match any files. 680 condition = math_ops.greater(array_ops.shape(matching_files)[0], 0, 681 name="match_not_empty") 682 683 message = math_ops.add( 684 "No files matched pattern: ", 685 string_ops.reduce_join(file_pattern, separator=", "), name="message") 686 687 assert_not_empty = control_flow_ops.Assert( 688 condition, [message], summarize=1, name="assert_not_empty") 689 with ops.control_dependencies([assert_not_empty]): 690 matching_files = array_ops.identity(matching_files) 691 692 dataset = Dataset.from_tensor_slices(matching_files) 693 if shuffle: 694 # NOTE(mrry): The shuffle buffer size must be greater than zero, but the 695 # list of files might be empty. 696 buffer_size = math_ops.maximum( 697 array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1) 698 dataset = dataset.shuffle(buffer_size, seed=seed) 699 return dataset 700 701 def repeat(self, count=None): 702 """Repeats this dataset `count` times. 703 704 NOTE: If this dataset is a function of global state (e.g. a random number 705 generator), then different repetitions may produce different elements. 706 707 Args: 708 count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 709 number of times the dataset should be repeated. The default behavior 710 (if `count` is `None` or `-1`) is for the dataset be repeated 711 indefinitely. 712 713 Returns: 714 Dataset: A `Dataset`. 715 """ 716 return RepeatDataset(self, count) 717 718 def _enumerate(self, start=0): 719 720 max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max 721 return Dataset.zip((Dataset.range(start, max_value), self)) 722 723 def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None): 724 """Randomly shuffles the elements of this dataset. 725 726 This dataset fills a buffer with `buffer_size` elements, then randomly 727 samples elements from this buffer, replacing the selected elements with new 728 elements. For perfect shuffling, a buffer size greater than or equal to the 729 full size of the dataset is required. 730 731 For instance, if your dataset contains 10,000 elements but `buffer_size` is 732 set to 1,000, then `shuffle` will initially select a random element from 733 only the first 1,000 elements in the buffer. Once an element is selected, 734 its space in the buffer is replaced by the next (i.e. 1,001-st) element, 735 maintaining the 1,000 element buffer. 736 737 Args: 738 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the 739 number of elements from this dataset from which the new 740 dataset will sample. 741 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 742 random seed that will be used to create the distribution. See 743 `tf.set_random_seed` for behavior. 744 reshuffle_each_iteration: (Optional.) A boolean, which if true indicates 745 that the dataset should be pseudorandomly reshuffled each time it is 746 iterated over. (Defaults to `True`.) 747 748 Returns: 749 Dataset: A `Dataset`. 750 """ 751 return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration) 752 753 def cache(self, filename=""): 754 """Caches the elements in this dataset. 755 756 Args: 757 filename: A `tf.string` scalar `tf.Tensor`, representing the name of a 758 directory on the filesystem to use for caching tensors in this Dataset. 759 If a filename is not provided, the dataset will be cached in memory. 760 761 Returns: 762 Dataset: A `Dataset`. 763 """ 764 return CacheDataset(self, filename) 765 766 def take(self, count): 767 """Creates a `Dataset` with at most `count` elements from this dataset. 768 769 Args: 770 count: A `tf.int64` scalar `tf.Tensor`, representing the number of 771 elements of this dataset that should be taken to form the new dataset. 772 If `count` is -1, or if `count` is greater than the size of this 773 dataset, the new dataset will contain all elements of this dataset. 774 775 Returns: 776 Dataset: A `Dataset`. 777 """ 778 return TakeDataset(self, count) 779 780 def skip(self, count): 781 """Creates a `Dataset` that skips `count` elements from this dataset. 782 783 Args: 784 count: A `tf.int64` scalar `tf.Tensor`, representing the number 785 of elements of this dataset that should be skipped to form the 786 new dataset. If `count` is greater than the size of this 787 dataset, the new dataset will contain no elements. If `count` 788 is -1, skips the entire dataset. 789 790 Returns: 791 Dataset: A `Dataset`. 792 """ 793 return SkipDataset(self, count) 794 795 def shard(self, num_shards, index): 796 """Creates a `Dataset` that includes only 1/`num_shards` of this dataset. 797 798 This dataset operator is very useful when running distributed training, as 799 it allows each worker to read a unique subset. 800 801 When reading a single input file, you can skip elements as follows: 802 803 ```python 804 d = tf.data.TFRecordDataset(input_file) 805 d = d.shard(num_workers, worker_index) 806 d = d.repeat(num_epochs) 807 d = d.shuffle(shuffle_buffer_size) 808 d = d.map(parser_fn, num_parallel_calls=num_map_threads) 809 ``` 810 811 Important caveats: 812 813 - Be sure to shard before you use any randomizing operator (such as 814 shuffle). 815 - Generally it is best if the shard operator is used early in the dataset 816 pipeline. For example, when reading from a set of TFRecord files, shard 817 before converting the dataset to input samples. This avoids reading every 818 file on every worker. The following is an example of an efficient 819 sharding strategy within a complete pipeline: 820 821 ```python 822 d = Dataset.list_files(pattern) 823 d = d.shard(num_workers, worker_index) 824 d = d.repeat(num_epochs) 825 d = d.shuffle(shuffle_buffer_size) 826 d = d.interleave(tf.data.TFRecordDataset, 827 cycle_length=num_readers, block_length=1) 828 d = d.map(parser_fn, num_parallel_calls=num_map_threads) 829 ``` 830 831 Args: 832 num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of 833 shards operating in parallel. 834 index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. 835 836 Returns: 837 Dataset: A `Dataset`. 838 839 Raises: 840 InvalidArgumentError: if `num_shards` or `index` are illegal values. 841 Note: error checking is done on a best-effort basis, and errors aren't 842 guaranteed to be caught upon dataset creation. (e.g. providing in a 843 placeholder tensor bypasses the early checking, and will instead result 844 in an error during a session.run call.) 845 """ 846 return ShardDataset(self, num_shards, index) 847 848 def batch(self, batch_size, drop_remainder=False): 849 """Combines consecutive elements of this dataset into batches. 850 851 The tensors in the resulting element will have an additional outer 852 dimension, which will be `batch_size` (or `N % batch_size` for the last 853 element if `batch_size` does not divide the number of input elements `N` 854 evenly and `drop_remainder` is `False`). If your program depends on the 855 batches having the same outer dimension, you should set the `drop_remainder` 856 argument to `True` to prevent the smaller batch from being produced. 857 858 Args: 859 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 860 consecutive elements of this dataset to combine in a single batch. 861 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 862 whether the last batch should be dropped in the case it has fewer than 863 `batch_size` elements; the default behavior is not to drop the smaller 864 batch. 865 866 Returns: 867 Dataset: A `Dataset`. 868 """ 869 return BatchDataset(self, batch_size, drop_remainder) 870 871 def padded_batch(self, 872 batch_size, 873 padded_shapes, 874 padding_values=None, 875 drop_remainder=False): 876 """Combines consecutive elements of this dataset into padded batches. 877 878 This transformation combines multiple consecutive elements of the input 879 dataset into a single element. 880 881 Like `tf.data.Dataset.batch`, the tensors in the resulting element will 882 have an additional outer dimension, which will be `batch_size` (or 883 `N % batch_size` for the last element if `batch_size` does not divide the 884 number of input elements `N` evenly and `drop_remainder` is `False`). If 885 your program depends on the batches having the same outer dimension, you 886 should set the `drop_remainder` argument to `True` to prevent the smaller 887 batch from being produced. 888 889 Unlike `tf.data.Dataset.batch`, the input elements to be batched may have 890 different shapes, and this transformation will pad each component to the 891 respective shape in `padding_shapes`. The `padding_shapes` argument 892 determines the resulting shape for each dimension of each component in an 893 output element: 894 895 * If the dimension is a constant (e.g. `tf.Dimension(37)`), the component 896 will be padded out to that length in that dimension. 897 * If the dimension is unknown (e.g. `tf.Dimension(None)`), the component 898 will be padded out to the maximum length of all elements in that 899 dimension. 900 901 See also `tf.data.experimental.dense_to_sparse_batch`, which combines 902 elements that may have different shapes into a `tf.SparseTensor`. 903 904 Args: 905 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 906 consecutive elements of this dataset to combine in a single batch. 907 padded_shapes: A nested structure of `tf.TensorShape` or 908 `tf.int64` vector tensor-like objects representing the shape 909 to which the respective component of each input element should 910 be padded prior to batching. Any unknown dimensions 911 (e.g. `tf.Dimension(None)` in a `tf.TensorShape` or `-1` in a 912 tensor-like object) will be padded to the maximum size of that 913 dimension in each batch. 914 padding_values: (Optional.) A nested structure of scalar-shaped 915 `tf.Tensor`, representing the padding values to use for the 916 respective components. Defaults are `0` for numeric types and 917 the empty string for string types. 918 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 919 whether the last batch should be dropped in the case it has fewer than 920 `batch_size` elements; the default behavior is not to drop the smaller 921 batch. 922 923 Returns: 924 Dataset: A `Dataset`. 925 """ 926 return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values, 927 drop_remainder) 928 929 def map(self, map_func, num_parallel_calls=None): 930 """Maps `map_func` across the elements of this dataset. 931 932 This transformation applies `map_func` to each element of this dataset, and 933 returns a new dataset containing the transformed elements, in the same 934 order as they appeared in the input. 935 936 For example: 937 938 ```python 939 # NOTE: The following examples use `{ ... }` to represent the 940 # contents of a dataset. 941 a = { 1, 2, 3, 4, 5 } 942 943 a.map(lambda x: x + 1) = { 2, 3, 4, 5, 6 } 944 ``` 945 946 The input signature of `map_func` is determined by the structure of each 947 element in this dataset. For example: 948 949 ```python 950 # Each element is a `tf.Tensor` object. 951 a = { 1, 2, 3, 4, 5 } 952 # `map_func` takes a single argument of type `tf.Tensor` with the same 953 # shape and dtype. 954 result = a.map(lambda x: ...) 955 956 # Each element is a tuple containing two `tf.Tensor` objects. 957 b = { (1, "foo"), (2, "bar"), (3, "baz") } 958 # `map_func` takes two arguments of type `tf.Tensor`. 959 result = b.map(lambda x_int, y_str: ...) 960 961 # Each element is a dictionary mapping strings to `tf.Tensor` objects. 962 c = { {"a": 1, "b": "foo"}, {"a": 2, "b": "bar"}, {"a": 3, "b": "baz"} } 963 # `map_func` takes a single argument of type `dict` with the same keys as 964 # the elements. 965 result = c.map(lambda d: ...) 966 ``` 967 968 The value or values returned by `map_func` determine the structure of each 969 element in the returned dataset. 970 971 ```python 972 # `map_func` returns a scalar `tf.Tensor` of type `tf.float32`. 973 def f(...): 974 return tf.constant(37.0) 975 result = dataset.map(f) 976 result.output_classes == tf.Tensor 977 result.output_types == tf.float32 978 result.output_shapes == [] # scalar 979 980 # `map_func` returns two `tf.Tensor` objects. 981 def g(...): 982 return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"]) 983 result = dataset.map(g) 984 result.output_classes == (tf.Tensor, tf.Tensor) 985 result.output_types == (tf.float32, tf.string) 986 result.output_shapes == ([], [3]) 987 988 # Python primitives, lists, and NumPy arrays are implicitly converted to 989 # `tf.Tensor`. 990 def h(...): 991 return 37.0, ["Foo", "Bar", "Baz"], np.array([1.0, 2.0] dtype=np.float64) 992 result = dataset.map(h) 993 result.output_classes == (tf.Tensor, tf.Tensor, tf.Tensor) 994 result.output_types == (tf.float32, tf.string, tf.float64) 995 result.output_shapes == ([], [3], [2]) 996 997 # `map_func` can return nested structures. 998 def i(...): 999 return {"a": 37.0, "b": [42, 16]}, "foo" 1000 result.output_classes == ({"a": tf.Tensor, "b": tf.Tensor}, tf.Tensor) 1001 result.output_types == ({"a": tf.float32, "b": tf.int32}, tf.string) 1002 result.output_shapes == ({"a": [], "b": [2]}, []) 1003 ``` 1004 1005 In addition to `tf.Tensor` objects, `map_func` can accept as arguments and 1006 return `tf.SparseTensor` objects. 1007 1008 Args: 1009 map_func: A function mapping a nested structure of tensors (having 1010 shapes and types defined by `self.output_shapes` and 1011 `self.output_types`) to another nested structure of tensors. 1012 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 1013 representing the number elements to process asynchronously in parallel. 1014 If not specified, elements will be processed sequentially. If the value 1015 `tf.data.experimental.AUTOTUNE` is used, then the number of parallel 1016 calls is set dynamically based on available CPU. 1017 1018 Returns: 1019 Dataset: A `Dataset`. 1020 """ 1021 if num_parallel_calls is None: 1022 return MapDataset(self, map_func, preserve_cardinality=True) 1023 else: 1024 return ParallelMapDataset( 1025 self, map_func, num_parallel_calls, preserve_cardinality=True) 1026 1027 def flat_map(self, map_func): 1028 """Maps `map_func` across this dataset and flattens the result. 1029 1030 Use `flat_map` if you want to make sure that the order of your dataset 1031 stays the same. For example, to flatten a dataset of batches into a 1032 dataset of their elements: 1033 1034 ```python 1035 # NOTE: The following examples use `{ ... }` to represent the 1036 # contents of a dataset. '[...]' represents a tensor. 1037 a = {[1,2,3,4,5], [6,7,8,9], [10]} 1038 1039 a.flat_map(lambda x: Dataset.from_tensor_slices(x)) == 1040 {[1,2,3,4,5,6,7,8,9,10]} 1041 ``` 1042 1043 `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since 1044 `flat_map` produces the same output as 1045 `tf.data.Dataset.interleave(cycle_length=1)` 1046 1047 Args: 1048 map_func: A function mapping a nested structure of tensors (having shapes 1049 and types defined by `self.output_shapes` and `self.output_types`) to a 1050 `Dataset`. 1051 1052 Returns: 1053 Dataset: A `Dataset`. 1054 """ 1055 return FlatMapDataset(self, map_func) 1056 1057 def interleave(self, 1058 map_func, 1059 cycle_length, 1060 block_length=1, 1061 num_parallel_calls=None): 1062 """Maps `map_func` across this dataset, and interleaves the results. 1063 1064 For example, you can use `Dataset.interleave()` to process many input files 1065 concurrently: 1066 1067 ```python 1068 # Preprocess 4 files concurrently, and interleave blocks of 16 records from 1069 # each file. 1070 filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...] 1071 dataset = (Dataset.from_tensor_slices(filenames) 1072 .interleave(lambda x: 1073 TextLineDataset(x).map(parse_fn, num_parallel_calls=1), 1074 cycle_length=4, block_length=16)) 1075 ``` 1076 1077 The `cycle_length` and `block_length` arguments control the order in which 1078 elements are produced. `cycle_length` controls the number of input elements 1079 that are processed concurrently. If you set `cycle_length` to 1, this 1080 transformation will handle one input element at a time, and will produce 1081 identical results to `tf.data.Dataset.flat_map`. In general, 1082 this transformation will apply `map_func` to `cycle_length` input elements, 1083 open iterators on the returned `Dataset` objects, and cycle through them 1084 producing `block_length` consecutive elements from each iterator, and 1085 consuming the next input element each time it reaches the end of an 1086 iterator. 1087 1088 For example: 1089 1090 ```python 1091 # NOTE: The following examples use `{ ... }` to represent the 1092 # contents of a dataset. 1093 a = { 1, 2, 3, 4, 5 } 1094 1095 # NOTE: New lines indicate "block" boundaries. 1096 a.interleave(lambda x: Dataset.from_tensors(x).repeat(6), 1097 cycle_length=2, block_length=4) == { 1098 1, 1, 1, 1, 1099 2, 2, 2, 2, 1100 1, 1, 1101 2, 2, 1102 3, 3, 3, 3, 1103 4, 4, 4, 4, 1104 3, 3, 1105 4, 4, 1106 5, 5, 5, 5, 1107 5, 5, 1108 } 1109 ``` 1110 1111 NOTE: The order of elements yielded by this transformation is 1112 deterministic, as long as `map_func` is a pure function. If 1113 `map_func` contains any stateful operations, the order in which 1114 that state is accessed is undefined. 1115 1116 Args: 1117 map_func: A function mapping a nested structure of tensors (having shapes 1118 and types defined by `self.output_shapes` and `self.output_types`) to a 1119 `Dataset`. 1120 cycle_length: The number of elements from this dataset that will be 1121 processed concurrently. 1122 block_length: The number of consecutive elements to produce from each 1123 input element before cycling to another input element. 1124 num_parallel_calls: (Optional.) If specified, the implementation creates 1125 a threadpool, which is used to fetch inputs from cycle elements 1126 asynchronously and in parallel. The default behavior is to fetch inputs 1127 from cycle elements synchronously with no parallelism. If the value 1128 `tf.data.experimental.AUTOTUNE` is used, then the number of parallel 1129 calls is set dynamically based on available CPU. 1130 1131 Returns: 1132 Dataset: A `Dataset`. 1133 """ 1134 if num_parallel_calls is None: 1135 return InterleaveDataset(self, map_func, cycle_length, block_length) 1136 else: 1137 return ParallelInterleaveDataset(self, map_func, cycle_length, 1138 block_length, num_parallel_calls) 1139 1140 def filter(self, predicate): 1141 """Filters this dataset according to `predicate`. 1142 1143 ```python 1144 d = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 1145 1146 d = d.filter(lambda x: x < 3) # [1, 2] 1147 1148 # `tf.math.equal(x, y)` is required for equality comparison 1149 def filter_fn(x): 1150 return tf.math.equal(x, 1) 1151 1152 d = d.filter(filter_fn) # [1] 1153 ``` 1154 1155 Args: 1156 predicate: A function mapping a nested structure of tensors (having shapes 1157 and types defined by `self.output_shapes` and `self.output_types`) to a 1158 scalar `tf.bool` tensor. 1159 1160 Returns: 1161 Dataset: The `Dataset` containing the elements of this dataset for which 1162 `predicate` is `True`. 1163 """ 1164 return FilterDataset(self, predicate) 1165 1166 def apply(self, transformation_func): 1167 """Applies a transformation function to this dataset. 1168 1169 `apply` enables chaining of custom `Dataset` transformations, which are 1170 represented as functions that take one `Dataset` argument and return a 1171 transformed `Dataset`. 1172 1173 For example: 1174 1175 ``` 1176 dataset = (dataset.map(lambda x: x ** 2) 1177 .apply(group_by_window(key_func, reduce_func, window_size)) 1178 .map(lambda x: x ** 3)) 1179 ``` 1180 1181 Args: 1182 transformation_func: A function that takes one `Dataset` argument and 1183 returns a `Dataset`. 1184 1185 Returns: 1186 Dataset: The `Dataset` returned by applying `transformation_func` to this 1187 dataset. 1188 """ 1189 dataset = transformation_func(self) 1190 if not isinstance(dataset, DatasetV2): 1191 raise TypeError( 1192 "`transformation_func` must return a Dataset. Got {}.".format( 1193 dataset)) 1194 dataset._input_datasets = [self] # pylint: disable=protected-access 1195 return dataset 1196 1197 def window(self, size, shift=None, stride=1, drop_remainder=False): 1198 """Combines input elements into a dataset of windows. 1199 1200 Each window is a dataset itself and contains `size` elements (or 1201 possibly fewer if there are not enough input elements to fill the window 1202 and `drop_remainder` evaluates to false). 1203 1204 The `stride` argument determines the stride of the input elements, 1205 and the `shift` argument determines the shift of the window. 1206 1207 For example: 1208 - `tf.data.Dataset.range(7).window(2)` produces 1209 `{{0, 1}, {2, 3}, {4, 5}, {6}}` 1210 - `tf.data.Dataset.range(7).window(3, 2, 1, True)` produces 1211 `{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}}` 1212 - `tf.data.Dataset.range(7).window(3, 1, 2, True)` produces 1213 `{{0, 2, 4}, {1, 3, 5}, {2, 4, 6}}` 1214 1215 Args: 1216 size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements 1217 of the input dataset to combine into a window. 1218 shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 1219 forward shift of the sliding window in each iteration. Defaults to 1220 `size`. 1221 stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 1222 stride of the input elements in the sliding window. 1223 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 1224 whether a window should be dropped in case its size is smaller than 1225 `window_size`. 1226 1227 Returns: 1228 Dataset: A `Dataset` of windows, each of which is a nested `Dataset` with 1229 the same structure as this dataset, but a finite subsequence of its 1230 elements. 1231 """ 1232 if shift is None: 1233 shift = size 1234 return WindowDataset(self, size, shift, stride, drop_remainder) 1235 1236 def reduce(self, initial_state, reduce_func): 1237 """Reduces the input dataset to a single element. 1238 1239 The transformation calls `reduce_func` successively on every element of 1240 the input dataset until the dataset is exhausted, aggregating information in 1241 its internal state. The `initial_state` argument is used for the initial 1242 state and the final state is returned as the result. 1243 1244 For example: 1245 - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1)` 1246 produces `5` 1247 - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y)` 1248 produces `10` 1249 1250 Args: 1251 initial_state: A nested structure of tensors, representing the initial 1252 state of the transformation. 1253 reduce_func: A function that maps `(old_state, input_element)` to 1254 `new_state`. It must take two arguments and return a nested structure 1255 of tensors. The structure of `new_state` must match the structure of 1256 `initial_state`. 1257 1258 Returns: 1259 A nested structure of `tf.Tensor` objects, corresponding to the final 1260 state of the transformation. 1261 1262 """ 1263 1264 with ops.name_scope("initial_state"): 1265 # Convert any `SparseTensorValue`s to `SparseTensor`s and all other 1266 # values to tensors. 1267 initial_state = nest.pack_sequence_as(initial_state, [ 1268 sparse_tensor_lib.SparseTensor.from_value(t) 1269 if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor( 1270 t, name="component_%d" % i) 1271 for i, t in enumerate(nest.flatten(initial_state)) 1272 ]) 1273 1274 # Compute initial values for the state classes, shapes and types based on 1275 # the initial state. 1276 state_structure = structure_lib.Structure.from_value(initial_state) 1277 1278 # Iteratively rerun the reduce function until reaching a fixed point on 1279 # `state_structure`. 1280 need_to_rerun = True 1281 while need_to_rerun: 1282 1283 wrapped_func = StructuredFunctionWrapper( 1284 reduce_func, 1285 "reduce()", 1286 input_structure=structure_lib.NestedStructure( 1287 (state_structure, self._element_structure)), 1288 add_to_graph=False) 1289 1290 # Extract and validate class information from the returned values. 1291 output_classes = wrapped_func.output_classes 1292 state_classes = state_structure._to_legacy_output_classes() # pylint: disable=protected-access 1293 for new_state_class, state_class in zip( 1294 nest.flatten(output_classes), nest.flatten(state_classes)): 1295 if not issubclass(new_state_class, state_class): 1296 raise TypeError( 1297 "The element classes for the new state must match the initial " 1298 "state. Expected %s; got %s." % (state_classes, 1299 wrapped_func.output_classes)) 1300 1301 # Extract and validate type information from the returned values. 1302 output_types = wrapped_func.output_types 1303 state_types = state_structure._to_legacy_output_types() # pylint: disable=protected-access 1304 for new_state_type, state_type in zip( 1305 nest.flatten(output_types), nest.flatten(state_types)): 1306 if new_state_type != state_type: 1307 raise TypeError( 1308 "The element types for the new state must match the initial " 1309 "state. Expected %s; got %s." % (state_types, 1310 wrapped_func.output_types)) 1311 1312 # Extract shape information from the returned values. 1313 output_shapes = wrapped_func.output_shapes 1314 state_shapes = state_structure._to_legacy_output_shapes() # pylint: disable=protected-access 1315 flat_state_shapes = nest.flatten(state_shapes) 1316 flat_new_state_shapes = nest.flatten(output_shapes) 1317 weakened_state_shapes = [ 1318 original.most_specific_compatible_shape(new) 1319 for original, new in zip(flat_state_shapes, flat_new_state_shapes) 1320 ] 1321 1322 need_to_rerun = False 1323 for original_shape, weakened_shape in zip(flat_state_shapes, 1324 weakened_state_shapes): 1325 if original_shape.ndims is not None and ( 1326 weakened_shape.ndims is None or 1327 original_shape.as_list() != weakened_shape.as_list()): 1328 need_to_rerun = True 1329 break 1330 1331 if need_to_rerun: 1332 # TODO(b/110122868): Support a "most specific compatible structure" 1333 # method for combining structures, to avoid using legacy structures 1334 # here. 1335 state_structure = structure_lib.convert_legacy_structure( 1336 state_types, 1337 nest.pack_sequence_as(state_shapes, weakened_state_shapes), 1338 state_classes) 1339 1340 reduce_func = wrapped_func.function 1341 reduce_func.add_to_graph(ops.get_default_graph()) 1342 1343 # pylint: disable=protected-access 1344 return state_structure._from_compatible_tensor_list( 1345 gen_dataset_ops.reduce_dataset( 1346 self._variant_tensor, 1347 state_structure._to_tensor_list(initial_state), 1348 reduce_func.captured_inputs, 1349 f=reduce_func, 1350 output_shapes=state_structure._flat_shapes, 1351 output_types=state_structure._flat_types)) 1352 1353 def with_options(self, options): 1354 """Returns a new `tf.data.Dataset` with the given options set. 1355 1356 The options are "global" in the sense they apply to the entire dataset. 1357 If options are set multiple times, they are merged as long as different 1358 options do not use different non-default values. 1359 1360 Args: 1361 options: A `tf.data.Options` that identifies the options the use. 1362 1363 Returns: 1364 Dataset: A `Dataset` with the given options. 1365 1366 Raises: 1367 ValueError: when an option is set more than once to a non-default value 1368 """ 1369 return _OptionsDataset(self, options) 1370 1371 1372 @tf_export(v1=["data.Dataset"]) 1373 class DatasetV1(DatasetV2): 1374 """Represents a potentially large set of elements. 1375 1376 A `Dataset` can be used to represent an input pipeline as a 1377 collection of elements (nested structures of tensors) and a "logical 1378 plan" of transformations that act on those elements. 1379 """ 1380 1381 def __init__(self): 1382 try: 1383 variant_tensor = self._as_variant_tensor() 1384 except AttributeError as e: 1385 if "_as_variant_tensor" in str(e): 1386 raise AttributeError("Please use _variant_tensor instead of " 1387 "_as_variant_tensor() to obtain the variant " 1388 "associated with a dataset") 1389 raise AttributeError("A likely cause of this error is that the super " 1390 "call for this dataset is not the last line of the " 1391 "__init__ method. The base class causes the " 1392 "_as_variant_tensor call in its constructor and " 1393 "if that uses attributes defined in the __init__ " 1394 "method, those attrs need to be defined before the " 1395 "super call.") 1396 super(DatasetV1, self).__init__(variant_tensor) 1397 1398 @abc.abstractmethod 1399 def _as_variant_tensor(self): 1400 """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset. 1401 1402 Returns: 1403 A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset. 1404 """ 1405 raise NotImplementedError("Dataset._as_variant_tensor") 1406 1407 @deprecation.deprecated( 1408 None, "Use `for ... in dataset:` to iterate over a dataset. If using " 1409 "`tf.estimator`, return the `Dataset` object directly from your input " 1410 "function. As a last resort, you can use " 1411 "`tf.compat.v1.data.make_one_shot_iterator(dataset)`.") 1412 def make_one_shot_iterator(self): 1413 """Creates an `Iterator` for enumerating the elements of this dataset. 1414 1415 Note: The returned iterator will be initialized automatically. 1416 A "one-shot" iterator does not currently support re-initialization. 1417 1418 Returns: 1419 An `Iterator` over the elements of this dataset. 1420 """ 1421 return self._make_one_shot_iterator() 1422 1423 def _make_one_shot_iterator(self): # pylint: disable=missing-docstring 1424 if context.executing_eagerly(): 1425 return iterator_ops.EagerIterator(self) 1426 1427 _ensure_same_dataset_graph(self) 1428 # Now that we create datasets at python object creation time, the capture 1429 # by value _make_dataset() function would try to capture these variant 1430 # tensor dataset inputs, which are marked as stateful ops and would throw 1431 # an error if we try and capture them. We therefore traverse the graph 1432 # to find all these ops and whitelist them so that the capturing 1433 # logic instead of throwing an error recreates these ops which is what was 1434 # happening before. 1435 all_ds_ops = traverse.obtain_all_variant_tensor_ops(self) 1436 graph_level_seed, op_level_seed = core_random_seed.get_seed(None) 1437 1438 # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is 1439 # a 0-argument function. 1440 @function.Defun(capture_by_value=True, whitelisted_stateful_ops=all_ds_ops) 1441 def _make_dataset(): 1442 """Factory function for a dataset.""" 1443 # NOTE(mrry): `Defun` does not capture the graph-level seed from the 1444 # enclosing graph, so if a graph-level seed is present we set the local 1445 # graph seed based on a combination of the graph- and op-level seeds. 1446 if graph_level_seed is not None: 1447 assert op_level_seed is not None 1448 core_random_seed.set_random_seed( 1449 (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1)) 1450 1451 dataset = self._apply_options() 1452 return dataset._variant_tensor # pylint: disable=protected-access 1453 1454 try: 1455 _make_dataset.add_to_graph(ops.get_default_graph()) 1456 except ValueError as err: 1457 if "Cannot capture a stateful node" in str(err): 1458 raise ValueError( 1459 "Failed to create a one-shot iterator for a dataset. " 1460 "`Dataset.make_one_shot_iterator()` does not support datasets that " 1461 "capture stateful objects, such as a `Variable` or `LookupTable`. " 1462 "In these cases, use `Dataset.make_initializable_iterator()`. " 1463 "(Original error: %s)" % err) 1464 else: 1465 six.reraise(ValueError, err) 1466 1467 # pylint: disable=protected-access 1468 return iterator_ops.Iterator( 1469 gen_dataset_ops.one_shot_iterator( 1470 dataset_factory=_make_dataset, **flat_structure(self)), 1471 None, get_legacy_output_types(self), get_legacy_output_shapes(self), 1472 get_legacy_output_classes(self)) 1473 1474 @deprecation.deprecated( 1475 None, "Use `for ... in dataset:` to iterate over a dataset. If using " 1476 "`tf.estimator`, return the `Dataset` object directly from your input " 1477 "function. As a last resort, you can use " 1478 "`tf.compat.v1.data.make_initializable_iterator(dataset)`.") 1479 def make_initializable_iterator(self, shared_name=None): 1480 """Creates an `Iterator` for enumerating the elements of this dataset. 1481 1482 Note: The returned iterator will be in an uninitialized state, 1483 and you must run the `iterator.initializer` operation before using it: 1484 1485 ```python 1486 dataset = ... 1487 iterator = dataset.make_initializable_iterator() 1488 # ... 1489 sess.run(iterator.initializer) 1490 ``` 1491 1492 Args: 1493 shared_name: (Optional.) If non-empty, the returned iterator will be 1494 shared under the given name across multiple sessions that share the 1495 same devices (e.g. when using a remote server). 1496 1497 Returns: 1498 An `Iterator` over the elements of this dataset. 1499 1500 Raises: 1501 RuntimeError: If eager execution is enabled. 1502 """ 1503 1504 return self._make_initializable_iterator(shared_name) 1505 1506 def _make_initializable_iterator(self, shared_name=None): # pylint: disable=missing-docstring 1507 if context.executing_eagerly(): 1508 raise RuntimeError( 1509 "dataset.make_initializable_iterator is not supported when eager " 1510 "execution is enabled.") 1511 _ensure_same_dataset_graph(self) 1512 dataset = self._apply_options() 1513 if shared_name is None: 1514 shared_name = "" 1515 if compat.forward_compatible(2018, 8, 3): 1516 iterator_resource = gen_dataset_ops.iterator_v2( 1517 container="", shared_name=shared_name, **flat_structure(self)) 1518 else: 1519 iterator_resource = gen_dataset_ops.iterator( 1520 container="", shared_name=shared_name, **flat_structure(self)) 1521 with ops.colocate_with(iterator_resource): 1522 initializer = gen_dataset_ops.make_iterator( 1523 dataset._variant_tensor, # pylint: disable=protected-access 1524 iterator_resource) 1525 # pylint: disable=protected-access 1526 return iterator_ops.Iterator( 1527 iterator_resource, initializer, get_legacy_output_types(dataset), 1528 get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset)) 1529 1530 @property 1531 def output_classes(self): 1532 """Returns the class of each component of an element of this dataset. 1533 1534 The expected values are `tf.Tensor` and `tf.SparseTensor`. 1535 1536 Returns: 1537 A nested structure of Python `type` objects corresponding to each 1538 component of an element of this dataset. 1539 """ 1540 return self._element_structure._to_legacy_output_classes() # pylint: disable=protected-access 1541 1542 @property 1543 def output_shapes(self): 1544 """Returns the shape of each component of an element of this dataset. 1545 1546 Returns: 1547 A nested structure of `tf.TensorShape` objects corresponding to each 1548 component of an element of this dataset. 1549 """ 1550 return self._element_structure._to_legacy_output_shapes() # pylint: disable=protected-access 1551 1552 @property 1553 def output_types(self): 1554 """Returns the type of each component of an element of this dataset. 1555 1556 Returns: 1557 A nested structure of `tf.DType` objects corresponding to each component 1558 of an element of this dataset. 1559 """ 1560 return self._element_structure._to_legacy_output_types() # pylint: disable=protected-access 1561 1562 @property 1563 def _element_structure(self): 1564 # TODO(b/110122868): Remove this override once all `Dataset` instances 1565 # implement `element_structure`. 1566 return structure_lib.convert_legacy_structure( 1567 self.output_types, self.output_shapes, self.output_classes) 1568 1569 @staticmethod 1570 @functools.wraps(DatasetV2.from_tensors) 1571 def from_tensors(tensors): 1572 return DatasetV1Adapter(DatasetV2.from_tensors(tensors)) 1573 1574 @staticmethod 1575 @functools.wraps(DatasetV2.from_tensor_slices) 1576 def from_tensor_slices(tensors): 1577 return DatasetV1Adapter(DatasetV2.from_tensor_slices(tensors)) 1578 1579 @staticmethod 1580 @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.") 1581 def from_sparse_tensor_slices(sparse_tensor): 1582 """Splits each rank-N `tf.SparseTensor` in this dataset row-wise. 1583 1584 Args: 1585 sparse_tensor: A `tf.SparseTensor`. 1586 1587 Returns: 1588 Dataset: A `Dataset` of rank-(N-1) sparse tensors. 1589 """ 1590 return DatasetV1Adapter(SparseTensorSliceDataset(sparse_tensor)) 1591 1592 @staticmethod 1593 @functools.wraps(DatasetV2.from_generator) 1594 def from_generator(generator, output_types, output_shapes=None, args=None): 1595 return DatasetV1Adapter(DatasetV2.from_generator( 1596 generator, output_types, output_shapes, args)) 1597 1598 @staticmethod 1599 @functools.wraps(DatasetV2.range) 1600 def range(*args): 1601 return DatasetV1Adapter(DatasetV2.range(*args)) 1602 1603 @staticmethod 1604 @functools.wraps(DatasetV2.zip) 1605 def zip(datasets): 1606 return DatasetV1Adapter(DatasetV2.zip(datasets)) 1607 1608 @functools.wraps(DatasetV2.concatenate) 1609 def concatenate(self, dataset): 1610 return DatasetV1Adapter(super(DatasetV1, self).concatenate(dataset)) 1611 1612 @functools.wraps(DatasetV2.prefetch) 1613 def prefetch(self, buffer_size): 1614 return DatasetV1Adapter(super(DatasetV1, self).prefetch(buffer_size)) 1615 1616 @staticmethod 1617 @functools.wraps(DatasetV2.list_files) 1618 def list_files(file_pattern, shuffle=None, seed=None): 1619 return DatasetV1Adapter(DatasetV2.list_files(file_pattern, shuffle, seed)) 1620 1621 @functools.wraps(DatasetV2.repeat) 1622 def repeat(self, count=None): 1623 return DatasetV1Adapter(super(DatasetV1, self).repeat(count)) 1624 1625 @functools.wraps(DatasetV2.shuffle) 1626 def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None): 1627 return DatasetV1Adapter(super(DatasetV1, self).shuffle( 1628 buffer_size, seed, reshuffle_each_iteration)) 1629 1630 @functools.wraps(DatasetV2.cache) 1631 def cache(self, filename=""): 1632 return DatasetV1Adapter(super(DatasetV1, self).cache(filename)) 1633 1634 @functools.wraps(DatasetV2.take) 1635 def take(self, count): 1636 return DatasetV1Adapter(super(DatasetV1, self).take(count)) 1637 1638 @functools.wraps(DatasetV2.skip) 1639 def skip(self, count): 1640 return DatasetV1Adapter(super(DatasetV1, self).skip(count)) 1641 1642 @functools.wraps(DatasetV2.shard) 1643 def shard(self, num_shards, index): 1644 return DatasetV1Adapter(super(DatasetV1, self).shard(num_shards, index)) 1645 1646 @functools.wraps(DatasetV2.batch) 1647 def batch(self, batch_size, drop_remainder=False): 1648 return DatasetV1Adapter(super(DatasetV1, self).batch( 1649 batch_size, drop_remainder)) 1650 1651 @functools.wraps(DatasetV2.padded_batch) 1652 def padded_batch(self, 1653 batch_size, 1654 padded_shapes, 1655 padding_values=None, 1656 drop_remainder=False): 1657 return DatasetV1Adapter(super(DatasetV1, self).padded_batch( 1658 batch_size, padded_shapes, padding_values, drop_remainder)) 1659 1660 @functools.wraps(DatasetV2.map) 1661 def map(self, map_func, num_parallel_calls=None): 1662 if num_parallel_calls is None: 1663 return DatasetV1Adapter( 1664 MapDataset(self, map_func, preserve_cardinality=False)) 1665 else: 1666 return DatasetV1Adapter( 1667 ParallelMapDataset( 1668 self, map_func, num_parallel_calls, preserve_cardinality=False)) 1669 1670 @deprecation.deprecated(None, "Use `tf.data.Dataset.map()") 1671 def map_with_legacy_function(self, map_func, num_parallel_calls=None): 1672 """Maps `map_func` across the elements of this dataset. 1673 1674 NOTE: This is an escape hatch for existing uses of `map` that do not work 1675 with V2 functions. New uses are strongly discouraged and existing uses 1676 should migrate to `map` as this method will be removed in V2. 1677 1678 Args: 1679 map_func: A function mapping a nested structure of tensors (having shapes 1680 and types defined by `self.output_shapes` and `self.output_types`) to 1681 another nested structure of tensors. 1682 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 1683 representing the number elements to process asynchronously in parallel. 1684 If not specified, elements will be processed sequentially. If the value 1685 `tf.data.experimental.AUTOTUNE` is used, then the number of parallel 1686 calls is set dynamically based on available CPU. 1687 1688 Returns: 1689 Dataset: A `Dataset`. 1690 """ 1691 if num_parallel_calls is None: 1692 return DatasetV1Adapter( 1693 MapDataset( 1694 self, 1695 map_func, 1696 preserve_cardinality=False, 1697 use_legacy_function=True)) 1698 else: 1699 return DatasetV1Adapter( 1700 ParallelMapDataset( 1701 self, 1702 map_func, 1703 num_parallel_calls, 1704 preserve_cardinality=False, 1705 use_legacy_function=True)) 1706 1707 @functools.wraps(DatasetV2.flat_map) 1708 def flat_map(self, map_func): 1709 return DatasetV1Adapter(super(DatasetV1, self).flat_map(map_func)) 1710 1711 @functools.wraps(DatasetV2.interleave) 1712 def interleave(self, 1713 map_func, 1714 cycle_length, 1715 block_length=1, 1716 num_parallel_calls=None): 1717 return DatasetV1Adapter(super(DatasetV1, self).interleave( 1718 map_func, cycle_length, block_length, num_parallel_calls)) 1719 1720 @functools.wraps(DatasetV2.filter) 1721 def filter(self, predicate): 1722 return DatasetV1Adapter(super(DatasetV1, self).filter(predicate)) 1723 1724 @deprecation.deprecated(None, "Use `tf.data.Dataset.filter()") 1725 def filter_with_legacy_function(self, predicate): 1726 """Filters this dataset according to `predicate`. 1727 1728 NOTE: This is an escape hatch for existing uses of `filter` that do not work 1729 with V2 functions. New uses are strongly discouraged and existing uses 1730 should migrate to `filter` as this method will be removed in V2. 1731 1732 Args: 1733 predicate: A function mapping a nested structure of tensors (having shapes 1734 and types defined by `self.output_shapes` and `self.output_types`) to a 1735 scalar `tf.bool` tensor. 1736 1737 Returns: 1738 Dataset: The `Dataset` containing the elements of this dataset for which 1739 `predicate` is `True`. 1740 """ 1741 return FilterDataset(self, predicate, use_legacy_function=True) 1742 1743 @functools.wraps(DatasetV2.apply) 1744 def apply(self, transformation_func): 1745 return DatasetV1Adapter(super(DatasetV1, self).apply(transformation_func)) 1746 1747 @functools.wraps(DatasetV2.window) 1748 def window(self, size, shift=None, stride=1, drop_remainder=False): 1749 return DatasetV1Adapter(super(DatasetV1, self).window( 1750 size, shift, stride, drop_remainder)) 1751 1752 @functools.wraps(DatasetV2.with_options) 1753 def with_options(self, options): 1754 return DatasetV1Adapter(super(DatasetV1, self).with_options(options)) 1755 1756 1757 # TODO(b/119044825): Until all `tf.data` unit tests are converted to V2, keep 1758 # this alias in place. 1759 Dataset = DatasetV1 1760 1761 1762 class DatasetV1Adapter(DatasetV1): 1763 """Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API.""" 1764 1765 def __init__(self, dataset): 1766 self._dataset = dataset 1767 super(DatasetV1Adapter, self).__init__() 1768 1769 def _as_variant_tensor(self): 1770 return self._dataset._variant_tensor # pylint: disable=protected-access 1771 1772 def _has_captured_ref(self): 1773 return self._dataset._has_captured_ref() # pylint: disable=protected-access 1774 1775 def _inputs(self): 1776 return self._dataset._inputs() # pylint: disable=protected-access 1777 1778 def options(self): 1779 return self._dataset.options() 1780 1781 @property 1782 def _element_structure(self): 1783 return self._dataset._element_structure # pylint: disable=protected-access 1784 1785 def __iter__(self): 1786 return iter(self._dataset) 1787 1788 1789 def _ensure_same_dataset_graph(dataset): 1790 """Walks the dataset graph to ensure all datasets come from the same graph.""" 1791 current_graph = ops.get_default_graph() 1792 bfs_q = Queue.Queue() 1793 bfs_q.put(dataset) # pylint: disable=protected-access 1794 visited = [] 1795 while not bfs_q.empty(): 1796 ds = bfs_q.get() 1797 visited.append(ds) 1798 ds_graph = ds._graph # pylint: disable=protected-access 1799 if current_graph != ds_graph: 1800 logging.warning("The graph (" + str(current_graph) + ") of the iterator " 1801 "is different from the graph (" + str(ds_graph) + ") " 1802 "the dataset: " + str(ds._variant_tensor) + " was " # pylint: disable=protected-access 1803 "created in. If you are using the Estimator API, " 1804 "make sure that no part of the dataset returned by the " 1805 "`input_fn` function is defined outside the `input_fn` " 1806 "function. Please ensure that all datasets in the " 1807 "pipeline are created in the same graph as the iterator. " 1808 "NOTE: This warning will become an error in future " 1809 "versions of TensorFlow.") 1810 for input_ds in ds._inputs(): # pylint: disable=protected-access 1811 if input_ds not in visited: 1812 bfs_q.put(input_ds) 1813 1814 1815 @tf_export(v1=["data.make_one_shot_iterator"]) 1816 def make_one_shot_iterator(dataset): 1817 """Creates a `tf.data.Iterator` for enumerating the elements of a dataset. 1818 1819 Note: The returned iterator will be initialized automatically. 1820 A "one-shot" iterator does not support re-initialization. 1821 1822 Args: 1823 dataset: A `tf.data.Dataset`. 1824 1825 Returns: 1826 A `tf.data.Iterator` over the elements of this dataset. 1827 """ 1828 try: 1829 # Call the defined `_make_one_shot_iterator()` if there is one, because some 1830 # datasets (e.g. for prefetching) override its behavior. 1831 return dataset._make_one_shot_iterator() # pylint: disable=protected-access 1832 except AttributeError: 1833 return DatasetV1Adapter(dataset)._make_one_shot_iterator() # pylint: disable=protected-access 1834 1835 1836 @tf_export(v1=["data.make_initializable_iterator"]) 1837 def make_initializable_iterator(dataset, shared_name=None): 1838 """Creates a `tf.data.Iterator` for enumerating the elements of a dataset. 1839 1840 Note: The returned iterator will be in an uninitialized state, 1841 and you must run the `iterator.initializer` operation before using it: 1842 1843 ```python 1844 dataset = ... 1845 iterator = tf.data.make_initializable_iterator(dataset) 1846 # ... 1847 sess.run(iterator.initializer) 1848 ``` 1849 1850 Args: 1851 dataset: A `tf.data.Dataset`. 1852 shared_name: (Optional.) If non-empty, the returned iterator will be 1853 shared under the given name across multiple sessions that share the 1854 same devices (e.g. when using a remote server). 1855 1856 Returns: 1857 A `tf.data.Iterator` over the elements of `dataset`. 1858 1859 Raises: 1860 RuntimeError: If eager execution is enabled. 1861 """ 1862 try: 1863 # Call the defined `_make_initializable_iterator()` if there is one, because 1864 # some datasets (e.g. for prefetching) override its behavior. 1865 return dataset._make_initializable_iterator(shared_name) # pylint: disable=protected-access 1866 except AttributeError: 1867 return DatasetV1Adapter(dataset)._make_initializable_iterator(shared_name) # pylint: disable=protected-access 1868 1869 1870 # TODO(b/110122868): Replace this method with a public API for reflecting on 1871 # dataset structure. 1872 def get_structure(dataset_or_iterator): 1873 """Returns the `tf.data.experimental.Structure` of a `Dataset` or `Iterator`. 1874 1875 Args: 1876 dataset_or_iterator: A `tf.data.Dataset`, `tf.data.Iterator`, or 1877 `EagerIterator`. 1878 1879 Returns: 1880 A `tf.data.experimental.Structure` representing the structure of the 1881 elements of `dataset_or_iterator`. 1882 1883 Raises: 1884 TypeError: If `dataset_or_iterator` is not a dataset or iterator object. 1885 """ 1886 try: 1887 ret = dataset_or_iterator._element_structure # pylint: disable=protected-access 1888 if isinstance(ret, structure_lib.Structure): 1889 return ret 1890 except AttributeError: 1891 pass 1892 raise TypeError("`dataset_or_iterator` must be a Dataset or Iterator object, " 1893 "but got %s." % type(dataset_or_iterator)) 1894 1895 1896 # TODO(b/110122868): Remove all uses of this method. 1897 def get_legacy_output_shapes(dataset_or_iterator): 1898 """Returns the output shapes of a `Dataset` or `Iterator`. 1899 1900 This utility method replaces the deprecated-in-V2 1901 `tf.compat.v1.Dataset.output_shapes` property. 1902 1903 Args: 1904 dataset_or_iterator: A `tf.data.Dataset`, `tf.data.Iterator`, or 1905 `EagerIterator`. 1906 1907 Returns: 1908 A nested structure of `tf.TensorShape` objects corresponding to each 1909 component of an element of the given dataset or iterator. 1910 """ 1911 return get_structure(dataset_or_iterator)._to_legacy_output_shapes() # pylint: disable=protected-access 1912 1913 1914 # TODO(b/110122868): Remove all uses of this method. 1915 def get_legacy_output_types(dataset_or_iterator): 1916 """Returns the output shapes of a `Dataset` or `Iterator`. 1917 1918 This utility method replaces the deprecated-in-V2 1919 `tf.compat.v1.Dataset.output_types` property. 1920 1921 Args: 1922 dataset_or_iterator: A `tf.data.Dataset`, `tf.data.Iterator`, or 1923 `EagerIterator`. 1924 1925 Returns: 1926 A nested structure of `tf.DType` objects corresponding to each component 1927 of an element of this dataset. 1928 """ 1929 return get_structure(dataset_or_iterator)._to_legacy_output_types() # pylint: disable=protected-access 1930 1931 1932 # TODO(b/110122868): Remove all uses of this method. 1933 def get_legacy_output_classes(dataset_or_iterator): 1934 """Returns the output classes of a `Dataset` or `Iterator`. 1935 1936 This utility method replaces the deprecated-in-V2 1937 `tf.compat.v1.Dataset.output_classes` property. 1938 1939 Args: 1940 dataset_or_iterator: A `tf.data.Dataset`, `tf.data.Iterator`, or 1941 `EagerIterator`. 1942 1943 Returns: 1944 A nested structure of Python `type` or `tf.data.experimental.Structure` 1945 objects corresponding to each component of an element of this dataset. 1946 """ 1947 return get_structure(dataset_or_iterator)._to_legacy_output_classes() # pylint: disable=protected-access 1948 1949 1950 @tf_export("data.Options") 1951 class Options(options_lib.OptionsBase): 1952 """Represents options for tf.data.Dataset. 1953 1954 An `Options` object can be, for instance, used to control which static 1955 optimizations to apply or whether to use performance modeling to dynamically 1956 tune the parallelism of operations such as `tf.data.Dataset.map` or 1957 `tf.data.Dataset.interleave`. 1958 """ 1959 1960 experimental_deterministic = options_lib.create_option( 1961 name="experimental_deterministic", 1962 ty=bool, 1963 docstring= 1964 "Whether the outputs need to be produced in deterministic order. If None," 1965 " defaults to True.") 1966 1967 experimental_numa_aware = options_lib.create_option( 1968 name="experimental_numa_aware", 1969 ty=bool, 1970 docstring= 1971 "Whether to use NUMA-aware operations. If None, defaults to False.") 1972 1973 experimental_optimization = options_lib.create_option( 1974 name="experimental_optimization", 1975 ty=optimization_options.OptimizationOptions, 1976 docstring= 1977 "The optimization options associated with the dataset. See " 1978 "`tf.data.experimental.OptimizationOptions` for more details.", 1979 default_factory=optimization_options.OptimizationOptions) 1980 1981 experimental_stats = options_lib.create_option( 1982 name="experimental_stats", 1983 ty=stats_options.StatsOptions, 1984 docstring= 1985 "The statistics options associated with the dataset. See " 1986 "`tf.data.experimental.StatsOptions` for more details.", 1987 default_factory=stats_options.StatsOptions) 1988 1989 experimental_threading = options_lib.create_option( 1990 name="experimental_threading", 1991 ty=threading_options.ThreadingOptions, 1992 docstring= 1993 "The threading options associated with the dataset. See " 1994 "`tf.data.experimental.ThreadingOptions` for more details.", 1995 default_factory=threading_options.ThreadingOptions) 1996 1997 def _static_optimizations(self): 1998 """Produces the list of enabled static optimizations.""" 1999 2000 result = [] 2001 result.extend(self.experimental_optimization._static_optimizations()) # pylint: disable=protected-access 2002 2003 if self.experimental_numa_aware: 2004 result.append("make_numa_aware") 2005 if self.experimental_deterministic is False: 2006 result.append("make_sloppy") 2007 exp_stats_options = self.experimental_stats 2008 if exp_stats_options and exp_stats_options.latency_all_edges: 2009 result.append("latency_all_edges") 2010 return result 2011 2012 def merge(self, options): 2013 """Merges itself with the given `tf.data.Options`. 2014 2015 The given `tf.data.Options` can be merged as long as there does not exist an 2016 attribute that is set to different values in `self` and `options`. 2017 2018 Args: 2019 options: a `tf.data.Options` to merge with 2020 2021 Raises: 2022 ValueError: if the given `tf.data.Options` cannot be merged 2023 2024 Returns: 2025 New `tf.data.Options()` object which is the result of merging self with 2026 the input `tf.data.Options`. 2027 """ 2028 return options_lib.merge_options(self, options) 2029 2030 2031 class DatasetSource(DatasetV2): 2032 """Abstract class representing a dataset with no inputs.""" 2033 2034 def _inputs(self): 2035 return [] 2036 2037 2038 class UnaryDataset(DatasetV2): 2039 """Abstract class representing a dataset with one input.""" 2040 2041 def __init__(self, input_dataset, variant_tensor): 2042 self._input_dataset = input_dataset 2043 super(UnaryDataset, self).__init__(variant_tensor) 2044 2045 def _inputs(self): 2046 return [self._input_dataset] 2047 2048 2049 class UnaryUnchangedStructureDataset(UnaryDataset): 2050 """Represents a unary dataset with the same input and output structure.""" 2051 2052 def __init__(self, input_dataset, variant_tensor): 2053 self._input_dataset = input_dataset 2054 super(UnaryUnchangedStructureDataset, self).__init__( 2055 input_dataset, variant_tensor) 2056 2057 @property 2058 def _element_structure(self): 2059 return self._input_dataset._element_structure # pylint: disable=protected-access 2060 2061 2062 class TensorDataset(DatasetSource): 2063 """A `Dataset` with a single element, viz. a nested structure of tensors.""" 2064 2065 def __init__(self, tensors): 2066 """See `Dataset.from_tensors()` for details.""" 2067 with ops.name_scope("tensors"): 2068 tensors = nest.pack_sequence_as(tensors, [ 2069 sparse_tensor_lib.SparseTensor.from_value(t) 2070 if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor( 2071 t, name="component_%d" % i) 2072 for i, t in enumerate(nest.flatten(tensors)) 2073 ]) 2074 self._structure = structure_lib.Structure.from_value(tensors) 2075 self._tensors = self._structure._to_tensor_list(tensors) # pylint: disable=protected-access 2076 2077 variant_tensor = gen_dataset_ops.tensor_dataset( 2078 self._tensors, output_shapes=self._structure._flat_shapes) # pylint: disable=protected-access 2079 super(TensorDataset, self).__init__(variant_tensor) 2080 2081 @property 2082 def _element_structure(self): 2083 return self._structure 2084 2085 2086 class TensorSliceDataset(DatasetSource): 2087 """A `Dataset` of slices from a nested structure of tensors.""" 2088 2089 def __init__(self, tensors): 2090 """See `Dataset.from_tensor_slices()` for details.""" 2091 with ops.name_scope("tensors"): 2092 tensors = nest.pack_sequence_as(tensors, [ 2093 sparse_tensor_lib.SparseTensor.from_value(t) 2094 if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor( 2095 t, name="component_%d" % i) 2096 for i, t in enumerate(nest.flatten(tensors)) 2097 ]) 2098 2099 batched_structure = structure_lib.Structure.from_value(tensors) 2100 # pylint: disable=protected-access 2101 self._tensors = batched_structure._to_batched_tensor_list(tensors) 2102 self._structure = batched_structure._unbatch() 2103 # pylint: enable=protected-access 2104 2105 batch_dim = tensor_shape.Dimension(tensor_shape.dimension_value( 2106 self._tensors[0].get_shape()[0])) 2107 for t in self._tensors[1:]: 2108 batch_dim.assert_is_compatible_with(tensor_shape.Dimension( 2109 tensor_shape.dimension_value(t.get_shape()[0]))) 2110 2111 variant_tensor = gen_dataset_ops.tensor_slice_dataset( 2112 self._tensors, output_shapes=self._structure._flat_shapes) # pylint: disable=protected-access 2113 super(TensorSliceDataset, self).__init__(variant_tensor) 2114 2115 @property 2116 def _element_structure(self): 2117 return self._structure 2118 2119 2120 class SparseTensorSliceDataset(DatasetSource): 2121 """A `Dataset` that splits a rank-N `tf.SparseTensor` into its rows.""" 2122 2123 def __init__(self, sparse_tensor): 2124 """See `Dataset.from_sparse_tensor_slices()` for details.""" 2125 if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor): 2126 raise TypeError( 2127 "`sparse_tensor` must be a `tf.SparseTensor` object. Was {}.".format( 2128 sparse_tensor)) 2129 self._sparse_tensor = sparse_tensor 2130 2131 indices_shape = self._sparse_tensor.indices.get_shape() 2132 shape_shape = self._sparse_tensor.dense_shape.get_shape() 2133 rank = (indices_shape.dims[1] - 1).merge_with(shape_shape.dims[0] - 1) 2134 self._structure = structure_lib.NestedStructure( 2135 (structure_lib.TensorStructure(dtypes.int64, [None, rank]), 2136 structure_lib.TensorStructure(self._sparse_tensor.dtype, [None]), 2137 structure_lib.TensorStructure(dtypes.int64, [rank]))) 2138 2139 variant_tensor = gen_dataset_ops.sparse_tensor_slice_dataset( 2140 self._sparse_tensor.indices, self._sparse_tensor.values, 2141 self._sparse_tensor.dense_shape) 2142 super(SparseTensorSliceDataset, self).__init__(variant_tensor) 2143 2144 @property 2145 def _element_structure(self): 2146 return self._structure 2147 2148 2149 class _VariantDataset(DatasetV2): 2150 """A Dataset wrapper around a `tf.variant`-typed function argument.""" 2151 2152 def __init__(self, dataset_variant, structure): 2153 self._structure = structure 2154 super(_VariantDataset, self).__init__(dataset_variant) 2155 2156 def _inputs(self): 2157 return [] 2158 2159 @property 2160 def _element_structure(self): 2161 return self._structure 2162 2163 2164 @tf_export("data.experimental.DatasetStructure") 2165 class DatasetStructure(structure_lib.Structure): 2166 """Represents a `Dataset` of structured values.""" 2167 2168 def __init__(self, element_structure): 2169 self._element_structure = element_structure 2170 2171 @property 2172 def _flat_shapes(self): 2173 return [tensor_shape.scalar()] 2174 2175 @property 2176 def _flat_types(self): 2177 return [dtypes.variant] 2178 2179 def is_compatible_with(self, other): 2180 # pylint: disable=protected-access 2181 return (isinstance(other, DatasetStructure) and 2182 self._element_structure.is_compatible_with( 2183 other._element_structure)) 2184 2185 def _to_tensor_list(self, value): 2186 return [value._variant_tensor] # pylint: disable=protected-access 2187 2188 def _to_batched_tensor_list(self, value): 2189 raise NotImplementedError("Unbatching for `tf.data.Dataset` objects.") 2190 2191 def _from_tensor_list(self, flat_value): 2192 if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or 2193 not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())): 2194 raise ValueError( 2195 "DatasetStructure corresponds to a single tf.variant scalar.") 2196 return self._from_compatible_tensor_list(flat_value) 2197 2198 def _from_compatible_tensor_list(self, flat_value): 2199 # pylint: disable=protected-access 2200 return _VariantDataset(flat_value[0], self._element_structure) 2201 2202 @staticmethod 2203 def from_value(value): 2204 return DatasetStructure(value._element_structure) # pylint: disable=protected-access 2205 2206 def _to_legacy_output_types(self): 2207 return self 2208 2209 def _to_legacy_output_shapes(self): 2210 return self 2211 2212 def _to_legacy_output_classes(self): 2213 return self 2214 2215 def _batch(self, batch_size): 2216 raise NotImplementedError("Batching for `tf.data.Dataset` objects.") 2217 2218 def _unbatch(self): 2219 raise NotImplementedError("Unbatching for `tf.data.Dataset` objects.") 2220 2221 2222 # pylint: disable=protected-access 2223 structure_lib.Structure._register_custom_converter(DatasetV2, 2224 DatasetStructure.from_value) 2225 # pylint: enable=protected-access 2226 2227 2228 class StructuredFunctionWrapper(object): 2229 """A function wrapper that supports structured arguments and return values.""" 2230 2231 # pylint: disable=protected-access 2232 def __init__(self, 2233 func, 2234 transformation_name, 2235 dataset=None, 2236 input_classes=None, 2237 input_shapes=None, 2238 input_types=None, 2239 input_structure=None, 2240 add_to_graph=True, 2241 use_legacy_function=False, 2242 defun_kwargs=None): 2243 """Creates a new `StructuredFunctionWrapper` for the given function. 2244 2245 Args: 2246 func: A function from a nested structure to another nested structure. 2247 transformation_name: Human-readable name of the transformation in which 2248 this function is being instantiated, for error messages. 2249 dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this 2250 dataset will be assumed as the structure for `func` arguments; otherwise 2251 `input_classes`, `input_shapes`, and `input_types` must be defined. 2252 input_classes: (Optional.) A nested structure of `type`. If given, this 2253 argument defines the Python types for `func` arguments. 2254 input_shapes: (Optional.) A nested structure of `tf.TensorShape`. If 2255 given, this argument defines the shapes and structure for `func` 2256 arguments. 2257 input_types: (Optional.) A nested structure of `tf.DType`. If given, this 2258 argument defines the element types and structure for `func` arguments. 2259 input_structure: (Optional.) A `Structure` object. If given, this argument 2260 defines the element types and structure for `func` arguments. 2261 add_to_graph: (Optional.) If `True`, the function will be added to the 2262 default graph. 2263 use_legacy_function: (Optional.) A boolean that determines whether the 2264 function be created using `tensorflow.python.eager.function.defun` 2265 (default behavior) or `tensorflow.python.framework.function.Defun` 2266 (legacy beheavior). 2267 defun_kwargs: (Optional.) A dictionary mapping string argument names to 2268 values. If supplied, will be passed to `function` as keyword arguments. 2269 2270 Raises: 2271 ValueError: If an invalid combination of `dataset`, `input_classes`, 2272 `input_shapes`, and `input_types` is passed. 2273 """ 2274 if input_structure is None: 2275 if dataset is None: 2276 if input_classes is None or input_shapes is None or input_types is None: 2277 raise ValueError("Either `dataset`, `input_structure` or all of " 2278 "`input_classes`, `input_shapes`, and `input_types` " 2279 "must be specified.") 2280 self._input_structure = structure_lib.convert_legacy_structure( 2281 input_types, input_shapes, input_classes) 2282 else: 2283 if not (input_classes is None and input_shapes is None and 2284 input_types is None): 2285 raise ValueError("Either `dataset`, `input_structure` or all of " 2286 "`input_classes`, `input_shapes`, and `input_types` " 2287 "must be specified.") 2288 self._input_structure = dataset._element_structure 2289 else: 2290 if not (dataset is None and input_classes is None and input_shapes is None 2291 and input_types is None): 2292 raise ValueError("Either `dataset`, `input_structure`, or all of " 2293 "`input_classes`, `input_shapes`, and `input_types` " 2294 "must be specified.") 2295 self._input_structure = input_structure 2296 2297 if defun_kwargs is None: 2298 defun_kwargs = {} 2299 2300 readable_transformation_name = transformation_name.replace( 2301 ".", "_")[:-2] if len(transformation_name) > 2 else "" 2302 2303 func_name = "_".join( 2304 [readable_transformation_name, 2305 function_utils.get_func_name(func)]) 2306 2307 def _warn_if_collections(transformation_name): 2308 """Prints a warning if the given graph uses common graph collections. 2309 2310 NOTE(mrry): Currently a warning is only generated for resources. Any 2311 variables created will be automatically hoisted out to the outermost scope 2312 using `init_scope()`. Some collections (such as for control-flow contexts) 2313 are benign and should not generate a warning. 2314 2315 Args: 2316 transformation_name: A human-readable name for the transformation. 2317 """ 2318 warnings.warn("Creating resources inside a function passed to %s " 2319 "is not supported. Create each resource outside the " 2320 "function, and capture it inside the function to use it." % 2321 transformation_name, stacklevel=5) 2322 2323 def _wrapper_helper(*args): 2324 """Wrapper for passing nested structures to and from tf.data functions.""" 2325 nested_args = self._input_structure._from_compatible_tensor_list(args) 2326 if not _should_unpack_args(nested_args): 2327 nested_args = (nested_args,) 2328 2329 ret = func(*nested_args) 2330 # If `func` returns a list of tensors, `nest.flatten()` and 2331 # `ops.convert_to_tensor()` would conspire to attempt to stack 2332 # those tensors into a single tensor, because the customized 2333 # version of `nest.flatten()` does not recurse into lists. Since 2334 # it is more likely that the list arose from returning the 2335 # result of an operation (such as `tf.py_func()`) that returns a 2336 # list of not-necessarily-stackable tensors, we treat the 2337 # returned value is a `tuple` instead. A user wishing to pack 2338 # the return value into a single tensor can use an explicit 2339 # `tf.stack()` before returning. 2340 if isinstance(ret, list): 2341 ret = tuple(ret) 2342 2343 try: 2344 self._output_structure = structure_lib.Structure.from_value(ret) 2345 except (ValueError, TypeError): 2346 raise TypeError("Unsupported return value from function passed to " 2347 "%s: %s." % (transformation_name, ret)) 2348 return ret 2349 2350 if use_legacy_function: 2351 func_name = func_name + "_" + str(ops.uid()) 2352 2353 @function.Defun( 2354 *self._input_structure._flat_types, 2355 func_name=func_name, 2356 **defun_kwargs) 2357 def wrapper_fn(*args): 2358 ret = _wrapper_helper(*args) 2359 # _warn_if_collections(transformation_name, ops.get_default_graph(), 0) 2360 return self._output_structure._to_tensor_list(ret) 2361 2362 self._function = wrapper_fn 2363 resource_tracker = tracking.ResourceTracker() 2364 with tracking.resource_tracker_scope(resource_tracker): 2365 if add_to_graph: 2366 self._function.add_to_graph(ops.get_default_graph()) 2367 else: 2368 # Use the private method that will execute `wrapper_fn` but delay 2369 # adding it to the graph in case (e.g.) we need to rerun the function. 2370 self._function._create_definition_if_needed() 2371 if resource_tracker.resources: 2372 _warn_if_collections(transformation_name) 2373 2374 else: 2375 defun_kwargs.update({"func_name": func_name}) 2376 2377 # TODO(b/124254153): Enable autograph once the overhead is low enough. 2378 # TODO(mdan): Make sure autograph recurses into _wrapper_helper when on. 2379 @eager_function.defun_with_attributes( 2380 input_signature=[ 2381 tensor_spec.TensorSpec(input_shape, input_type) # pylint: disable=g-complex-comprehension 2382 for input_shape, input_type in zip( 2383 self._input_structure._flat_shapes, 2384 self._input_structure._flat_types) 2385 ], 2386 autograph=False, 2387 attributes=defun_kwargs) 2388 def wrapper_fn(*args): # pylint: disable=missing-docstring 2389 ret = _wrapper_helper(*args) 2390 ret = self._output_structure._to_tensor_list(ret) 2391 return [ops.convert_to_tensor(t) for t in ret] 2392 2393 resource_tracker = tracking.ResourceTracker() 2394 with tracking.resource_tracker_scope(resource_tracker): 2395 self._function = wrapper_fn._get_concrete_function_internal() 2396 if add_to_graph: 2397 self._function.add_to_graph(ops.get_default_graph()) 2398 if resource_tracker.resources: 2399 _warn_if_collections(transformation_name) 2400 2401 outer_graph_seed = ops.get_default_graph().seed 2402 if outer_graph_seed and self._function.graph.seed == outer_graph_seed: 2403 if self._function.graph._seed_used: 2404 warnings.warn( 2405 "Seed %s from outer graph might be getting used by function %s, " 2406 "if the random op has not been provided any seed. Explicitly set " 2407 "the seed in the function if this is not the intended behavior." 2408 %(outer_graph_seed, func_name), stacklevel=4) 2409 # pylint: enable=protected-access 2410 2411 @property 2412 def output_structure(self): 2413 return self._output_structure 2414 2415 @property 2416 def output_classes(self): 2417 return self._output_structure._to_legacy_output_classes() # pylint: disable=protected-access 2418 2419 @property 2420 def output_shapes(self): 2421 return self._output_structure._to_legacy_output_shapes() # pylint: disable=protected-access 2422 2423 @property 2424 def output_types(self): 2425 return self._output_structure._to_legacy_output_types() # pylint: disable=protected-access 2426 2427 @property 2428 def function(self): 2429 return self._function 2430 2431 2432 def flat_structure(dataset): 2433 """Helper for setting `output_shapes` and `output_types` attrs of Dataset ops. 2434 2435 Most Dataset op constructors expect `output_shapes` and `output_types` 2436 arguments that represent the flattened structure of an element. This helper 2437 function generates these attrs as a keyword argument dictionary, allowing 2438 `Dataset._variant_tensor` implementations to pass 2439 `**flat_structure(self)` to the op constructor. 2440 2441 Args: 2442 dataset: A `tf.data.Dataset`. 2443 2444 Returns: 2445 A dictionary of keyword arguments that can be passed to many Dataset op 2446 constructors. 2447 """ 2448 # pylint: disable=protected-access 2449 structure = dataset._element_structure 2450 return { 2451 "output_shapes": structure._flat_shapes, 2452 "output_types": structure._flat_types, 2453 } 2454 2455 2456 class _GeneratorDataset(DatasetSource): 2457 """A `Dataset` that generates elements by invoking a function.""" 2458 2459 def __init__(self, init_args, init_func, next_func, finalize_func): 2460 """Constructs a `_GeneratorDataset`. 2461 2462 Args: 2463 init_args: A nested structure representing the arguments to `init_func`. 2464 init_func: A TensorFlow function that will be called on `init_args` each 2465 time a C++ iterator over this dataset is constructed. Returns a nested 2466 structure representing the "state" of the dataset. 2467 next_func: A TensorFlow function that will be called on the result of 2468 `init_func` to produce each element, and that raises `OutOfRangeError` 2469 to terminate iteration. 2470 finalize_func: A TensorFlow function that will be called on the result of 2471 `init_func` immediately before a C++ iterator over this dataset is 2472 destroyed. The return value is ignored. 2473 """ 2474 self._init_args = init_args 2475 2476 self._init_structure = structure_lib.Structure.from_value(init_args) 2477 2478 self._init_func = StructuredFunctionWrapper( 2479 init_func, 2480 self._transformation_name(), 2481 input_structure=self._init_structure) 2482 2483 self._next_func = StructuredFunctionWrapper( 2484 next_func, 2485 self._transformation_name(), 2486 input_structure=self._init_func.output_structure) 2487 2488 self._finalize_func = StructuredFunctionWrapper( 2489 finalize_func, 2490 self._transformation_name(), 2491 input_structure=self._init_func.output_structure) 2492 variant_tensor = gen_dataset_ops.generator_dataset( 2493 self._init_structure._to_tensor_list(self._init_args) # pylint: disable=protected-access 2494 + self._init_func.function.captured_inputs, 2495 self._next_func.function.captured_inputs, 2496 self._finalize_func.function.captured_inputs, 2497 init_func=self._init_func.function, 2498 next_func=self._next_func.function, 2499 finalize_func=self._finalize_func.function, 2500 **flat_structure(self)) 2501 super(_GeneratorDataset, self).__init__(variant_tensor) 2502 2503 @property 2504 def _element_structure(self): 2505 return self._next_func.output_structure 2506 2507 def _transformation_name(self): 2508 return "Dataset.from_generator()" 2509 2510 2511 class ZipDataset(DatasetV2): 2512 """A `Dataset` that zips its inputs together.""" 2513 2514 def __init__(self, datasets): 2515 """See `Dataset.zip()` for details.""" 2516 for ds in nest.flatten(datasets): 2517 if not isinstance(ds, DatasetV2): 2518 if isinstance(ds, list): 2519 message = ("The argument to `Dataset.zip()` must be a nested " 2520 "structure of `Dataset` objects. Nested structures do not " 2521 "support Python lists; please use a tuple instead.") 2522 else: 2523 message = ("The argument to `Dataset.zip()` must be a nested " 2524 "structure of `Dataset` objects.") 2525 raise TypeError(message) 2526 self._datasets = datasets 2527 self._structure = structure_lib.NestedStructure( 2528 nest.pack_sequence_as( 2529 self._datasets, 2530 [ds._element_structure for ds in nest.flatten(self._datasets)])) # pylint: disable=protected-access 2531 2532 # pylint: disable=protected-access 2533 variant_tensor = gen_dataset_ops.zip_dataset( 2534 [ds._variant_tensor for ds in nest.flatten(self._datasets)], 2535 **flat_structure(self)) 2536 # pylint: enable=protected-access 2537 super(ZipDataset, self).__init__(variant_tensor) 2538 2539 def _inputs(self): 2540 return nest.flatten(self._datasets) 2541 2542 @property 2543 def _element_structure(self): 2544 return self._structure 2545 2546 2547 class ConcatenateDataset(DatasetV2): 2548 """A `Dataset` that concatenates its input with given dataset.""" 2549 2550 def __init__(self, input_dataset, dataset_to_concatenate): 2551 """See `Dataset.concatenate()` for details.""" 2552 self._input_dataset = input_dataset 2553 self._dataset_to_concatenate = dataset_to_concatenate 2554 2555 output_types = get_legacy_output_types(input_dataset) 2556 if output_types != get_legacy_output_types(dataset_to_concatenate): 2557 raise TypeError( 2558 "Two datasets to concatenate have different types %s and %s" % 2559 (output_types, get_legacy_output_types(dataset_to_concatenate))) 2560 2561 output_classes = get_legacy_output_classes(input_dataset) 2562 if output_classes != get_legacy_output_classes(dataset_to_concatenate): 2563 raise TypeError( 2564 "Two datasets to concatenate have different classes %s and %s" % 2565 (output_classes, get_legacy_output_classes(dataset_to_concatenate))) 2566 2567 input_shapes = get_legacy_output_shapes(self._input_dataset) 2568 output_shapes = nest.pack_sequence_as(input_shapes, [ 2569 ts1.most_specific_compatible_shape(ts2) 2570 for (ts1, ts2) in zip( 2571 nest.flatten(input_shapes), 2572 nest.flatten(get_legacy_output_shapes( 2573 self._dataset_to_concatenate))) 2574 ]) 2575 2576 self._structure = structure_lib.convert_legacy_structure( 2577 output_types, output_shapes, output_classes) 2578 2579 self._input_datasets = [input_dataset, dataset_to_concatenate] 2580 # pylint: disable=protected-access 2581 variant_tensor = gen_dataset_ops.concatenate_dataset( 2582 input_dataset._variant_tensor, dataset_to_concatenate._variant_tensor, 2583 **flat_structure(self)) 2584 # pylint: enable=protected-access 2585 super(ConcatenateDataset, self).__init__(variant_tensor) 2586 2587 def _inputs(self): 2588 return self._input_datasets 2589 2590 @property 2591 def _element_structure(self): 2592 return self._structure 2593 2594 2595 class RepeatDataset(UnaryUnchangedStructureDataset): 2596 """A `Dataset` that repeats its input several times.""" 2597 2598 def __init__(self, input_dataset, count): 2599 """See `Dataset.repeat()` for details.""" 2600 self._input_dataset = input_dataset 2601 if count is None: 2602 self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") 2603 else: 2604 self._count = ops.convert_to_tensor( 2605 count, dtype=dtypes.int64, name="count") 2606 variant_tensor = gen_dataset_ops.repeat_dataset( 2607 input_dataset._variant_tensor, # pylint: disable=protected-access 2608 count=self._count, 2609 **flat_structure(self)) 2610 super(RepeatDataset, self).__init__(input_dataset, variant_tensor) 2611 2612 2613 class RangeDataset(DatasetSource): 2614 """A `Dataset` of a step separated range of values.""" 2615 2616 def __init__(self, *args): 2617 """See `Dataset.range()` for details.""" 2618 self._parse_args(*args) 2619 variant_tensor = gen_dataset_ops.range_dataset( 2620 start=self._start, 2621 stop=self._stop, 2622 step=self._step, 2623 **flat_structure(self)) 2624 super(RangeDataset, self).__init__(variant_tensor) 2625 2626 def _parse_args(self, *args): 2627 """Parse arguments according to the same rules as the `range()` builtin.""" 2628 if len(args) == 1: 2629 self._start = self._build_tensor(0, "start") 2630 self._stop = self._build_tensor(args[0], "stop") 2631 self._step = self._build_tensor(1, "step") 2632 elif len(args) == 2: 2633 self._start = self._build_tensor(args[0], "start") 2634 self._stop = self._build_tensor(args[1], "stop") 2635 self._step = self._build_tensor(1, "step") 2636 elif len(args) == 3: 2637 self._start = self._build_tensor(args[0], "start") 2638 self._stop = self._build_tensor(args[1], "stop") 2639 self._step = self._build_tensor(args[2], "step") 2640 else: 2641 raise ValueError("Invalid arguments to RangeDataset: %s" % str(args)) 2642 2643 def _build_tensor(self, int64_value, name): 2644 return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name) 2645 2646 @property 2647 def _element_structure(self): 2648 return structure_lib.TensorStructure(dtypes.int64, []) 2649 2650 2651 class CacheDataset(UnaryUnchangedStructureDataset): 2652 """A `Dataset` that caches elements of its input.""" 2653 2654 def __init__(self, input_dataset, filename): 2655 """See `Dataset.cache()` for details.""" 2656 self._input_dataset = input_dataset 2657 self._filename = ops.convert_to_tensor( 2658 filename, dtype=dtypes.string, name="filename") 2659 variant_tensor = gen_dataset_ops.cache_dataset( 2660 input_dataset._variant_tensor, # pylint: disable=protected-access 2661 filename=self._filename, 2662 **flat_structure(self)) 2663 super(CacheDataset, self).__init__(input_dataset, variant_tensor) 2664 2665 2666 class ShuffleDataset(UnaryUnchangedStructureDataset): 2667 """A `Dataset` that randomly shuffles the elements of its input.""" 2668 2669 def __init__(self, 2670 input_dataset, 2671 buffer_size, 2672 seed=None, 2673 reshuffle_each_iteration=None): 2674 """Randomly shuffles the elements of this dataset. 2675 2676 Args: 2677 input_dataset: The input dataset. 2678 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the 2679 number of elements from this dataset from which the new 2680 dataset will sample. 2681 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 2682 random seed that will be used to create the distribution. See 2683 `tf.set_random_seed` for behavior. 2684 reshuffle_each_iteration: (Optional.) A boolean, which if true indicates 2685 that the dataset should be pseudorandomly reshuffled each time it is 2686 iterated over. (Defaults to `True`.) 2687 2688 Returns: 2689 A `Dataset`. 2690 2691 Raises: 2692 ValueError: if invalid arguments are provided. 2693 """ 2694 self._input_dataset = input_dataset 2695 self._buffer_size = ops.convert_to_tensor( 2696 buffer_size, dtype=dtypes.int64, name="buffer_size") 2697 self._seed, self._seed2 = random_seed.get_seed(seed) 2698 2699 if reshuffle_each_iteration is None: 2700 self._reshuffle_each_iteration = True 2701 else: 2702 self._reshuffle_each_iteration = reshuffle_each_iteration 2703 variant_tensor = gen_dataset_ops.shuffle_dataset( 2704 input_dataset._variant_tensor, # pylint: disable=protected-access 2705 buffer_size=self._buffer_size, 2706 seed=self._seed, 2707 seed2=self._seed2, 2708 reshuffle_each_iteration=self._reshuffle_each_iteration, 2709 **flat_structure(self)) 2710 super(ShuffleDataset, self).__init__(input_dataset, variant_tensor) 2711 2712 2713 class TakeDataset(UnaryUnchangedStructureDataset): 2714 """A `Dataset` containing the first `count` elements from its input.""" 2715 2716 def __init__(self, input_dataset, count): 2717 """See `Dataset.take()` for details.""" 2718 self._input_dataset = input_dataset 2719 self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") 2720 variant_tensor = gen_dataset_ops.take_dataset( 2721 input_dataset._variant_tensor, # pylint: disable=protected-access 2722 count=self._count, 2723 **flat_structure(self)) 2724 super(TakeDataset, self).__init__(input_dataset, variant_tensor) 2725 2726 2727 class SkipDataset(UnaryUnchangedStructureDataset): 2728 """A `Dataset` skipping the first `count` elements from its input.""" 2729 2730 def __init__(self, input_dataset, count): 2731 """See `Dataset.skip()` for details.""" 2732 self._input_dataset = input_dataset 2733 self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") 2734 variant_tensor = gen_dataset_ops.skip_dataset( 2735 input_dataset._variant_tensor, # pylint: disable=protected-access 2736 count=self._count, 2737 **flat_structure(self)) 2738 super(SkipDataset, self).__init__(input_dataset, variant_tensor) 2739 2740 2741 class ShardDataset(UnaryUnchangedStructureDataset): 2742 """A `Dataset` for sharding its input.""" 2743 2744 def __init__(self, input_dataset, num_shards, index): 2745 """See `Dataset.shard()` for details.""" 2746 self._input_dataset = input_dataset 2747 self._num_shards = ops.convert_to_tensor( 2748 num_shards, dtype=dtypes.int64, name="num_shards") 2749 self._index = ops.convert_to_tensor(index, dtype=dtypes.int64, name="index") 2750 variant_tensor = gen_dataset_ops.shard_dataset( 2751 input_dataset._variant_tensor, # pylint: disable=protected-access 2752 num_shards=self._num_shards, 2753 index=self._index, 2754 **flat_structure(self)) 2755 super(ShardDataset, self).__init__(input_dataset, variant_tensor) 2756 2757 2758 class BatchDataset(UnaryDataset): 2759 """A `Dataset` that batches contiguous elements from its input.""" 2760 2761 def __init__(self, input_dataset, batch_size, drop_remainder): 2762 """See `Dataset.batch()` for details.""" 2763 self._input_dataset = input_dataset 2764 self._batch_size = ops.convert_to_tensor( 2765 batch_size, dtype=dtypes.int64, name="batch_size") 2766 self._drop_remainder = ops.convert_to_tensor( 2767 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 2768 2769 constant_drop_remainder = tensor_util.constant_value(self._drop_remainder) 2770 # pylint: disable=protected-access 2771 if constant_drop_remainder: 2772 # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) 2773 # or `False` (explicitly retaining the remainder). 2774 self._structure = input_dataset._element_structure._batch( 2775 tensor_util.constant_value(self._batch_size)) 2776 else: 2777 self._structure = input_dataset._element_structure._batch(None) 2778 variant_tensor = gen_dataset_ops.batch_dataset_v2( 2779 input_dataset._variant_tensor, # pylint: disable=protected-access 2780 batch_size=self._batch_size, 2781 drop_remainder=self._drop_remainder, 2782 **flat_structure(self)) 2783 super(BatchDataset, self).__init__(input_dataset, variant_tensor) 2784 2785 @property 2786 def _element_structure(self): 2787 return self._structure 2788 2789 2790 def _is_padded_shape_compatible_with(padded_shape, input_component_shape): 2791 """Returns `True` if `input_component_shape` can be padded to `padded_shape`. 2792 2793 Args: 2794 padded_shape: A `tf.TensorShape`. 2795 input_component_shape: A `tf.TensorShape`. 2796 2797 Returns: 2798 `True` if `input_component_shape` can be padded to `padded_shape`, otherwise 2799 `False`. 2800 """ 2801 2802 if padded_shape.dims is None or input_component_shape.dims is None: 2803 return True 2804 if len(padded_shape.dims) != len(input_component_shape.dims): 2805 return False 2806 for padded_dim, input_dim in zip( 2807 padded_shape.dims, input_component_shape.dims): 2808 if (padded_dim.value is not None and input_dim.value is not None 2809 and padded_dim.value < input_dim.value): 2810 return False 2811 return True 2812 2813 2814 def _padded_shape_to_tensor(padded_shape, input_component_shape): 2815 """Converts `padded_shape` to a `tf.Tensor` representing that shape. 2816 2817 Args: 2818 padded_shape: A shape-like object, which may be a `tf.TensorShape`, a Python 2819 sequence, or a 1-D `tf.Tensor` of `tf.int64` elements. 2820 input_component_shape: A `tf.TensorShape`, with which `padded_shape` must 2821 be compatible. 2822 2823 Returns: 2824 A 1-D `tf.Tensor` of `tf.int64` elements, representing `padded_shape`. 2825 2826 Raises: 2827 ValueError: If `padded_shape` is not a shape or not compatible with 2828 `input_component_shape`. 2829 TypeError: If `padded_shape` is not convertible to a `tf.int64` tensor. 2830 """ 2831 try: 2832 # Try to convert the `padded_shape` to a `tf.TensorShape` 2833 padded_shape_as_shape = tensor_shape.as_shape(padded_shape) 2834 # We will return the "canonical" tensor representation, which uses 2835 # `-1` in place of `None`. 2836 ret = ops.convert_to_tensor( 2837 [dim if dim is not None else -1 2838 for dim in padded_shape_as_shape.as_list()], dtype=dtypes.int64) 2839 except (TypeError, ValueError): 2840 # The argument was not trivially convertible to a 2841 # `tf.TensorShape`, so fall back on the conversion to tensor 2842 # machinery. 2843 ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64) 2844 if ret.shape.dims is not None and len(ret.shape.dims) != 1: 2845 raise ValueError( 2846 "Padded shape %s must be a 1-D tensor of tf.int64 values, but its " 2847 "shape was %s." % (padded_shape, ret.shape)) 2848 if ret.dtype != dtypes.int64: 2849 raise TypeError( 2850 "Padded shape %s must be a 1-D tensor of tf.int64 values, but its " 2851 "element type was %s." % (padded_shape, ret.dtype.name)) 2852 padded_shape_as_shape = tensor_util.constant_value_as_shape(ret) 2853 2854 if not _is_padded_shape_compatible_with(padded_shape_as_shape, 2855 input_component_shape): 2856 raise ValueError("The padded shape %s is not compatible with the " 2857 "corresponding input component shape %s." 2858 % (padded_shape_as_shape, input_component_shape)) 2859 2860 return ret 2861 2862 2863 def _padding_value_to_tensor(value, output_type): 2864 """Converts the padding value to a tensor. 2865 2866 Args: 2867 value: The padding value. 2868 output_type: Its expected dtype. 2869 2870 Returns: 2871 A scalar `Tensor`. 2872 2873 Raises: 2874 ValueError: if the padding value is not a scalar. 2875 TypeError: if the padding value's type does not match `output_type`. 2876 """ 2877 value = ops.convert_to_tensor(value, name="padding_value") 2878 if not value.shape.is_compatible_with(tensor_shape.scalar()): 2879 raise ValueError("Padding value should be a scalar, but is not: %s" % value) 2880 if value.dtype != output_type: 2881 raise TypeError("Padding value tensor (%s) does not match output type: %s" % 2882 (value, output_type)) 2883 return value 2884 2885 2886 def _default_padding(input_dataset): 2887 """Returns default padding tensors in a structure matching `input_dataset`.""" 2888 def make_zero(t): 2889 if t.base_dtype == dtypes.string: 2890 return "" 2891 elif t.base_dtype == dtypes.variant: 2892 error_msg = ("Unable to create padding for field of type 'variant' " 2893 "because t.base_type == dtypes.variant == " 2894 "{}.".format( 2895 t.base_dtype)) 2896 raise TypeError(error_msg) 2897 else: 2898 return np.zeros_like(t.as_numpy_dtype()) 2899 2900 return nest.map_structure( 2901 make_zero, get_legacy_output_types(input_dataset)) 2902 2903 2904 class PaddedBatchDataset(UnaryDataset): 2905 """A `Dataset` that batches and pads contiguous elements from its input.""" 2906 2907 def __init__(self, input_dataset, batch_size, padded_shapes, padding_values, 2908 drop_remainder): 2909 """See `Dataset.batch()` for details.""" 2910 self._input_dataset = input_dataset 2911 if sparse.any_sparse(get_legacy_output_classes(input_dataset)): 2912 # TODO(b/63669786): support batching of sparse tensors 2913 raise TypeError( 2914 "Batching of padded sparse tensors is not currently supported") 2915 self._input_dataset = input_dataset 2916 self._batch_size = ops.convert_to_tensor( 2917 batch_size, dtype=dtypes.int64, name="batch_size") 2918 padding_values = ( 2919 padding_values 2920 if padding_values is not None else _default_padding(input_dataset)) 2921 2922 input_shapes = get_legacy_output_shapes(input_dataset) 2923 flat_padded_shapes = nest.flatten_up_to(input_shapes, padded_shapes) 2924 2925 flat_padded_shapes_as_tensors = [] 2926 2927 for input_component_shape, padded_shape in zip( 2928 nest.flatten(input_shapes), flat_padded_shapes): 2929 flat_padded_shapes_as_tensors.append( 2930 _padded_shape_to_tensor(padded_shape, input_component_shape)) 2931 2932 self._padded_shapes = nest.pack_sequence_as(input_shapes, 2933 flat_padded_shapes_as_tensors) 2934 2935 self._padding_values = nest.map_structure_up_to( 2936 input_shapes, _padding_value_to_tensor, padding_values, 2937 get_legacy_output_types(input_dataset)) 2938 self._drop_remainder = ops.convert_to_tensor( 2939 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 2940 2941 def _padded_shape_to_batch_shape(s): 2942 return tensor_shape.vector( 2943 tensor_util.constant_value(self._batch_size) if smart_cond. 2944 smart_constant_value(self._drop_remainder) else None).concatenate( 2945 tensor_util.constant_value_as_shape(s)) 2946 2947 output_shapes = nest.map_structure( 2948 _padded_shape_to_batch_shape, self._padded_shapes) 2949 self._structure = structure_lib.convert_legacy_structure( 2950 get_legacy_output_types(self._input_dataset), output_shapes, 2951 get_legacy_output_classes(self._input_dataset)) 2952 2953 # pylint: disable=protected-access 2954 # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018. 2955 if smart_cond.smart_constant_value(self._drop_remainder) is False: 2956 variant_tensor = gen_dataset_ops.padded_batch_dataset( 2957 input_dataset._variant_tensor, # pylint: disable=protected-access 2958 batch_size=self._batch_size, 2959 padded_shapes=[ 2960 ops.convert_to_tensor(s, dtype=dtypes.int64) 2961 for s in nest.flatten(self._padded_shapes) 2962 ], 2963 padding_values=nest.flatten(self._padding_values), 2964 output_shapes=self._structure._flat_shapes) 2965 else: 2966 variant_tensor = gen_dataset_ops.padded_batch_dataset_v2( 2967 input_dataset._variant_tensor, # pylint: disable=protected-access 2968 batch_size=self._batch_size, 2969 padded_shapes=[ 2970 ops.convert_to_tensor(s, dtype=dtypes.int64) 2971 for s in nest.flatten(self._padded_shapes) 2972 ], 2973 padding_values=nest.flatten(self._padding_values), 2974 drop_remainder=self._drop_remainder, 2975 output_shapes=self._structure._flat_shapes) 2976 super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor) 2977 2978 @property 2979 def _element_structure(self): 2980 return self._structure 2981 2982 2983 def _should_unpack_args(args): 2984 """Returns `True` if `args` should be `*args` when passed to a callable.""" 2985 return type(args) is tuple # pylint: disable=unidiomatic-typecheck 2986 2987 2988 class MapDataset(UnaryDataset): 2989 """A `Dataset` that maps a function over elements in its input.""" 2990 2991 def __init__(self, 2992 input_dataset, 2993 map_func, 2994 use_inter_op_parallelism=True, 2995 preserve_cardinality=False, 2996 use_legacy_function=False): 2997 """See `Dataset.map()` for details.""" 2998 self._input_dataset = input_dataset 2999 self._use_inter_op_parallelism = use_inter_op_parallelism 3000 self._preserve_cardinality = preserve_cardinality 3001 self._map_func = StructuredFunctionWrapper( 3002 map_func, 3003 self._transformation_name(), 3004 dataset=input_dataset, 3005 use_legacy_function=use_legacy_function) 3006 variant_tensor = gen_dataset_ops.map_dataset( 3007 input_dataset._variant_tensor, # pylint: disable=protected-access 3008 self._map_func.function.captured_inputs, 3009 f=self._map_func.function, 3010 use_inter_op_parallelism=self._use_inter_op_parallelism, 3011 preserve_cardinality=self._preserve_cardinality, 3012 **flat_structure(self)) 3013 super(MapDataset, self).__init__(input_dataset, variant_tensor) 3014 3015 def _functions(self): 3016 return [self._map_func] 3017 3018 @property 3019 def _element_structure(self): 3020 return self._map_func.output_structure 3021 3022 def _transformation_name(self): 3023 return "Dataset.map()" 3024 3025 3026 class ParallelMapDataset(UnaryDataset): 3027 """A `Dataset` that maps a function over elements in its input in parallel.""" 3028 3029 def __init__(self, 3030 input_dataset, 3031 map_func, 3032 num_parallel_calls, 3033 use_inter_op_parallelism=True, 3034 preserve_cardinality=False, 3035 use_legacy_function=False): 3036 """See `Dataset.map()` for details.""" 3037 self._input_dataset = input_dataset 3038 self._use_inter_op_parallelism = use_inter_op_parallelism 3039 self._map_func = StructuredFunctionWrapper( 3040 map_func, 3041 self._transformation_name(), 3042 dataset=input_dataset, 3043 use_legacy_function=use_legacy_function) 3044 self._num_parallel_calls = ops.convert_to_tensor( 3045 num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls") 3046 self._preserve_cardinality = preserve_cardinality 3047 variant_tensor = gen_dataset_ops.parallel_map_dataset( 3048 input_dataset._variant_tensor, # pylint: disable=protected-access 3049 self._map_func.function.captured_inputs, 3050 f=self._map_func.function, 3051 num_parallel_calls=self._num_parallel_calls, 3052 use_inter_op_parallelism=self._use_inter_op_parallelism, 3053 preserve_cardinality=self._preserve_cardinality, 3054 **flat_structure(self)) 3055 super(ParallelMapDataset, self).__init__(input_dataset, variant_tensor) 3056 3057 def _functions(self): 3058 return [self._map_func] 3059 3060 @property 3061 def _element_structure(self): 3062 return self._map_func.output_structure 3063 3064 def _transformation_name(self): 3065 return "Dataset.map()" 3066 3067 3068 class FlatMapDataset(UnaryDataset): 3069 """A `Dataset` that maps a function over its input and flattens the result.""" 3070 3071 def __init__(self, input_dataset, map_func): 3072 """See `Dataset.flat_map()` for details.""" 3073 self._input_dataset = input_dataset 3074 self._map_func = StructuredFunctionWrapper( 3075 map_func, self._transformation_name(), dataset=input_dataset) 3076 if not isinstance(self._map_func.output_structure, DatasetStructure): 3077 raise TypeError( 3078 "`map_func` must return a `Dataset` object. Got {}".format( 3079 type(self._map_func.output_structure))) 3080 self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access 3081 variant_tensor = gen_dataset_ops.flat_map_dataset( 3082 input_dataset._variant_tensor, # pylint: disable=protected-access 3083 self._map_func.function.captured_inputs, 3084 f=self._map_func.function, 3085 **flat_structure(self)) 3086 super(FlatMapDataset, self).__init__(input_dataset, variant_tensor) 3087 3088 def _functions(self): 3089 return [self._map_func] 3090 3091 @property 3092 def _element_structure(self): 3093 return self._structure 3094 3095 def _transformation_name(self): 3096 return "Dataset.flat_map()" 3097 3098 3099 class InterleaveDataset(UnaryDataset): 3100 """A `Dataset` that maps a function over its input and interleaves the result. 3101 """ 3102 3103 def __init__(self, input_dataset, map_func, cycle_length, block_length): 3104 """See `Dataset.interleave()` for details.""" 3105 self._input_dataset = input_dataset 3106 self._map_func = StructuredFunctionWrapper( 3107 map_func, self._transformation_name(), dataset=input_dataset) 3108 if not isinstance(self._map_func.output_structure, DatasetStructure): 3109 raise TypeError( 3110 "`map_func` must return a `Dataset` object. Got {}".format( 3111 type(self._map_func.output_structure))) 3112 self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access 3113 self._cycle_length = ops.convert_to_tensor( 3114 cycle_length, dtype=dtypes.int64, name="cycle_length") 3115 self._block_length = ops.convert_to_tensor( 3116 block_length, dtype=dtypes.int64, name="block_length") 3117 3118 variant_tensor = gen_dataset_ops.interleave_dataset( 3119 input_dataset._variant_tensor, # pylint: disable=protected-access 3120 self._map_func.function.captured_inputs, # pylint: disable=protected-access 3121 self._cycle_length, 3122 self._block_length, 3123 f=self._map_func.function, 3124 **flat_structure(self)) 3125 super(InterleaveDataset, self).__init__(input_dataset, variant_tensor) 3126 3127 def _functions(self): 3128 return [self._map_func] 3129 3130 @property 3131 def _element_structure(self): 3132 return self._structure 3133 3134 def _transformation_name(self): 3135 return "Dataset.interleave()" 3136 3137 3138 class ParallelInterleaveDataset(UnaryDataset): 3139 """A `Dataset` that maps a function over its input and interleaves the result.""" 3140 3141 def __init__(self, input_dataset, map_func, cycle_length, block_length, 3142 num_parallel_calls): 3143 """See `Dataset.interleave()` for details.""" 3144 self._input_dataset = input_dataset 3145 self._map_func = StructuredFunctionWrapper( 3146 map_func, self._transformation_name(), dataset=input_dataset) 3147 if not isinstance(self._map_func.output_structure, DatasetStructure): 3148 raise TypeError( 3149 "`map_func` must return a `Dataset` object. Got {}".format( 3150 type(self._map_func.output_structure))) 3151 self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access 3152 self._cycle_length = ops.convert_to_tensor( 3153 cycle_length, dtype=dtypes.int64, name="cycle_length") 3154 self._block_length = ops.convert_to_tensor( 3155 block_length, dtype=dtypes.int64, name="block_length") 3156 self._num_parallel_calls = ops.convert_to_tensor( 3157 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") 3158 variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v2( 3159 input_dataset._variant_tensor, # pylint: disable=protected-access 3160 self._map_func.function.captured_inputs, # pylint: disable=protected-access 3161 self._cycle_length, 3162 self._block_length, 3163 self._num_parallel_calls, 3164 f=self._map_func.function, 3165 **flat_structure(self)) 3166 super(ParallelInterleaveDataset, self).__init__(input_dataset, 3167 variant_tensor) 3168 3169 def _functions(self): 3170 return [self._map_func] 3171 3172 @property 3173 def _element_structure(self): 3174 return self._structure 3175 3176 def _transformation_name(self): 3177 return "Dataset.interleave()" 3178 3179 3180 class FilterDataset(UnaryUnchangedStructureDataset): 3181 """A `Dataset` that filters its input according to a predicate function.""" 3182 3183 def __init__(self, input_dataset, predicate, use_legacy_function=False): 3184 """See `Dataset.filter()` for details.""" 3185 self._input_dataset = input_dataset 3186 wrapped_func = StructuredFunctionWrapper( 3187 predicate, 3188 self._transformation_name(), 3189 dataset=input_dataset, 3190 use_legacy_function=use_legacy_function) 3191 if not wrapped_func.output_structure.is_compatible_with( 3192 structure_lib.TensorStructure(dtypes.bool, [])): 3193 error_msg = ("`predicate` return type must be convertible to a scalar " 3194 "boolean tensor. Was {}.").format( 3195 wrapped_func.output_structure) 3196 raise ValueError(error_msg) 3197 self._predicate = wrapped_func 3198 variant_tensor = gen_dataset_ops.filter_dataset( 3199 input_dataset._variant_tensor, # pylint: disable=protected-access 3200 other_arguments=self._predicate.function.captured_inputs, 3201 predicate=self._predicate.function, 3202 **flat_structure(self)) 3203 super(FilterDataset, self).__init__(input_dataset, variant_tensor) 3204 3205 def _functions(self): 3206 return [self._predicate] 3207 3208 def _transformation_name(self): 3209 return "Dataset.filter()" 3210 3211 3212 class PrefetchDataset(UnaryUnchangedStructureDataset): 3213 """A `Dataset` that asynchronously prefetches its input.""" 3214 3215 def __init__(self, input_dataset, buffer_size): 3216 """See `Dataset.prefetch()` for details.""" 3217 self._input_dataset = input_dataset 3218 if buffer_size is None: 3219 buffer_size = -1 # This is the sentinel for auto-tuning. 3220 self._buffer_size = ops.convert_to_tensor( 3221 buffer_size, dtype=dtypes.int64, name="buffer_size") 3222 variant_tensor = gen_dataset_ops.prefetch_dataset( 3223 input_dataset._variant_tensor, # pylint: disable=protected-access 3224 buffer_size=self._buffer_size, 3225 **flat_structure(self)) 3226 super(PrefetchDataset, self).__init__(input_dataset, variant_tensor) 3227 3228 3229 class WindowDataset(UnaryDataset): 3230 """A dataset that creates window datasets from the input elements.""" 3231 3232 def __init__(self, input_dataset, size, shift, stride, drop_remainder): 3233 """See `window_dataset()` for more details.""" 3234 self._input_dataset = input_dataset 3235 self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size") 3236 self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift") 3237 self._stride = ops.convert_to_tensor( 3238 stride, dtype=dtypes.int64, name="stride") 3239 self._drop_remainder = ops.convert_to_tensor( 3240 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 3241 nest_of_structures = nest.pack_sequence_as( 3242 get_legacy_output_classes(input_dataset), 3243 [ 3244 DatasetStructure(structure_lib.convert_legacy_structure( 3245 output_type, output_shape, output_class)) 3246 for output_class, output_shape, output_type in zip( 3247 nest.flatten(get_legacy_output_classes(input_dataset)), 3248 nest.flatten(get_legacy_output_shapes(input_dataset)), 3249 nest.flatten(get_legacy_output_types(input_dataset))) 3250 ]) 3251 self._structure = structure_lib.NestedStructure(nest_of_structures) 3252 variant_tensor = gen_dataset_ops.window_dataset( 3253 input_dataset._variant_tensor, # pylint: disable=protected-access 3254 self._size, 3255 self._shift, 3256 self._stride, 3257 self._drop_remainder, 3258 **flat_structure(self)) 3259 super(WindowDataset, self).__init__(input_dataset, variant_tensor) 3260 3261 @property 3262 def _element_structure(self): 3263 return self._structure 3264 3265 3266 class _OptionsDataset(UnaryUnchangedStructureDataset): 3267 """An identity `Dataset` that stores options.""" 3268 3269 def __init__(self, input_dataset, options): 3270 self._input_dataset = input_dataset 3271 self._options = input_dataset.options() 3272 if self._options: 3273 self._options = self._options.merge(options) 3274 else: 3275 self._options = options 3276 variant_tensor = input_dataset._variant_tensor # pylint: disable=protected-access 3277 super(_OptionsDataset, self).__init__(input_dataset, variant_tensor) 3278 3279 def options(self): 3280 return self._options 3281 3282 3283 class _ModelDataset(UnaryUnchangedStructureDataset): 3284 """A `Dataset` that acts as an identity, and models performance.""" 3285 3286 def __init__(self, input_dataset, cpu_budget): 3287 self._input_dataset = input_dataset 3288 variant_tensor = gen_dataset_ops.model_dataset( 3289 input_dataset._variant_tensor, # pylint: disable=protected-access 3290 cpu_budget=cpu_budget, 3291 **flat_structure(self)) 3292 super(_ModelDataset, self).__init__(input_dataset, variant_tensor) 3293 3294 3295 class _OptimizeDataset(UnaryUnchangedStructureDataset): 3296 """A `Dataset` that acts as an identity, and applies optimizations.""" 3297 3298 def __init__(self, input_dataset, optimizations): 3299 self._input_dataset = input_dataset 3300 if optimizations is None: 3301 optimizations = [] 3302 self._optimizations = ops.convert_to_tensor( 3303 optimizations, dtype=dtypes.string, name="optimizations") 3304 variant_tensor = gen_dataset_ops.optimize_dataset( 3305 input_dataset._variant_tensor, # pylint: disable=protected-access 3306 self._optimizations, 3307 **flat_structure(self)) 3308 super(_OptimizeDataset, self).__init__(input_dataset, variant_tensor) 3309 3310 3311 class _SetStatsAggregatorDataset(UnaryUnchangedStructureDataset): 3312 """A `Dataset` that acts as an identity, and sets a stats aggregator.""" 3313 3314 def __init__(self, input_dataset, aggregator, prefix, counter_prefix): 3315 self._input_dataset = input_dataset 3316 self._stats_aggregator = aggregator 3317 self._prefix = prefix 3318 self._counter_prefix = counter_prefix 3319 variant_tensor = ged_ops.experimental_set_stats_aggregator_dataset( 3320 input_dataset._variant_tensor, # pylint: disable=protected-access 3321 self._stats_aggregator._resource, # pylint: disable=protected-access 3322 self._prefix, 3323 self._counter_prefix, 3324 **flat_structure(self)) 3325 super(_SetStatsAggregatorDataset, self).__init__(input_dataset, 3326 variant_tensor) 3327 3328 3329 class _MaxIntraOpParallelismDataset(UnaryUnchangedStructureDataset): 3330 """A `Dataset` that acts as an identity, overriding intra-op parallelism.""" 3331 3332 def __init__(self, input_dataset, max_intra_op_parallelism): 3333 self._input_dataset = input_dataset 3334 self._max_intra_op_parallelism = ops.convert_to_tensor( 3335 max_intra_op_parallelism, 3336 dtype=dtypes.int64, 3337 name="max_intra_op_parallelism") 3338 variant_tensor = ged_ops.experimental_max_intra_op_parallelism_dataset( 3339 input_dataset._variant_tensor, # pylint: disable=protected-access 3340 self._max_intra_op_parallelism, 3341 **flat_structure(self)) 3342 super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset, 3343 variant_tensor) 3344 3345 3346 class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset): 3347 """A `Dataset` that acts as an identity, setting a private threadpool.""" 3348 3349 def __init__(self, input_dataset, num_threads): 3350 self._input_dataset = input_dataset 3351 self._num_threads = ops.convert_to_tensor( 3352 num_threads, dtype=dtypes.int64, name="num_threads") 3353 variant_tensor = ged_ops.experimental_private_thread_pool_dataset( 3354 input_dataset._variant_tensor, # pylint: disable=protected-access 3355 self._num_threads, 3356 **flat_structure(self)) 3357 super(_PrivateThreadPoolDataset, self).__init__(input_dataset, 3358 variant_tensor) 3359