Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Python wrappers for Datasets."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import abc
     21 import functools
     22 import threading
     23 import warnings
     24 
     25 import numpy as np
     26 import six
     27 from six.moves import queue as Queue  # pylint: disable=redefined-builtin
     28 
     29 
     30 from tensorflow.python.compat import compat
     31 from tensorflow.python.data.experimental.ops import optimization_options
     32 from tensorflow.python.data.experimental.ops import stats_options
     33 from tensorflow.python.data.experimental.ops import threading_options
     34 from tensorflow.python.data.ops import iterator_ops
     35 from tensorflow.python.data.util import nest
     36 from tensorflow.python.data.util import options as options_lib
     37 from tensorflow.python.data.util import random_seed
     38 from tensorflow.python.data.util import sparse
     39 from tensorflow.python.data.util import structure as structure_lib
     40 from tensorflow.python.data.util import traverse
     41 from tensorflow.python.eager import context
     42 from tensorflow.python.eager import function as eager_function
     43 from tensorflow.python.framework import constant_op
     44 from tensorflow.python.framework import dtypes
     45 from tensorflow.python.framework import function
     46 from tensorflow.python.framework import ops
     47 from tensorflow.python.framework import random_seed as core_random_seed
     48 from tensorflow.python.framework import smart_cond
     49 from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
     50 from tensorflow.python.framework import tensor_shape
     51 from tensorflow.python.framework import tensor_spec
     52 from tensorflow.python.framework import tensor_util
     53 from tensorflow.python.ops import array_ops
     54 from tensorflow.python.ops import control_flow_ops
     55 from tensorflow.python.ops import gen_dataset_ops
     56 from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
     57 from tensorflow.python.ops import gen_io_ops
     58 from tensorflow.python.ops import math_ops
     59 from tensorflow.python.ops import script_ops
     60 from tensorflow.python.ops import string_ops
     61 from tensorflow.python.platform import tf_logging as logging
     62 from tensorflow.python.training.tracking import tracking
     63 from tensorflow.python.util import deprecation
     64 from tensorflow.python.util import function_utils
     65 from tensorflow.python.util.tf_export import tf_export
     66 
     67 
     68 ops.NotDifferentiable("ReduceDataset")
     69 
     70 
     71 @tf_export("data.Dataset", v1=[])
     72 @six.add_metaclass(abc.ABCMeta)
     73 class DatasetV2(object):
     74   """Represents a potentially large set of elements.
     75 
     76   A `Dataset` can be used to represent an input pipeline as a
     77   collection of elements (nested structures of tensors) and a "logical
     78   plan" of transformations that act on those elements.
     79   """
     80 
     81   def __init__(self, variant_tensor):
     82     """Creates a DatasetV2 object.
     83 
     84     This is a difference between DatasetV1 and DatasetV2. DatasetV1 does not
     85     take anything in its constructor whereas in the DatasetV2, we expect
     86     subclasses to create a variant_tensor and pass it in to the super() call.
     87 
     88     Args:
     89       variant_tensor: A DT_VARIANT tensor that represents the dataset.
     90     """
     91     self._variant_tensor_attr = variant_tensor
     92     self._graph_attr = ops.get_default_graph()
     93 
     94   @property
     95   def _variant_tensor(self):
     96     return self._variant_tensor_attr
     97 
     98   @_variant_tensor.setter
     99   def _variant_tensor(self, _):
    100     raise ValueError("The _variant_tensor property is read-only")
    101 
    102   def _as_serialized_graph(self):
    103     """Produces serialized graph representation of the dataset.
    104 
    105     Returns:
    106       A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
    107       serialized graph.
    108     """
    109     return gen_dataset_ops.dataset_to_graph(self._variant_tensor)
    110 
    111   @abc.abstractmethod
    112   def _inputs(self):
    113     """Returns a list of the input datasets of the dataset."""
    114 
    115     raise NotImplementedError("Dataset._inputs")
    116 
    117   @property
    118   def _graph(self):
    119     return self._graph_attr
    120 
    121   @_graph.setter
    122   def _graph(self, _):
    123     raise ValueError("The _graph property is read-only")
    124 
    125   def _has_captured_ref(self):
    126     """Whether this dataset uses a function that captures ref variables.
    127 
    128     Returns:
    129       A boolean, which if true indicates that the dataset or one of its inputs
    130       uses a function that captures ref variables.
    131     """
    132     if context.executing_eagerly():
    133       # RefVariables are not supported in eager mode
    134       return False
    135 
    136     def is_tensor_or_parent_ref(tensor):
    137       if tensor.dtype._is_ref_dtype:  # pylint: disable=protected-access
    138         return True
    139       return any([is_tensor_or_parent_ref(x) for x in tensor.op.inputs])
    140 
    141     for fn in self._functions():
    142       if any([is_tensor_or_parent_ref(t) for t in fn.function.captured_inputs]):
    143         return True
    144 
    145     return any(
    146         [input_dataset._has_captured_ref() for input_dataset in self._inputs()])  # pylint: disable=protected-access
    147 
    148   # TODO(jsimsa): Change this to be the transitive closure of functions used
    149   # by this dataset and its inputs.
    150   def _functions(self):
    151     """Returns a list of functions associated with this dataset.
    152 
    153     Returns:
    154       A list of `StructuredFunctionWrapper` objects.
    155     """
    156     return []
    157 
    158   def options(self):
    159     """Returns the options for this dataset and its inputs.
    160 
    161     Returns:
    162       A `tf.data.Options` object representing the dataset options.
    163     """
    164     options = Options()
    165     for input_dataset in self._inputs():
    166       input_options = input_dataset.options()
    167       if input_options is not None:
    168         options = options.merge(input_options)
    169     return options
    170 
    171   def _apply_options(self):
    172     """Apply options, such as optimization configuration, to the dataset."""
    173 
    174     dataset = self
    175     options = self.options()
    176     if options.experimental_threading is not None:
    177       t_options = options.experimental_threading
    178       if t_options.max_intra_op_parallelism is not None:
    179         dataset = _MaxIntraOpParallelismDataset(
    180             dataset, t_options.max_intra_op_parallelism)
    181       if t_options.private_threadpool_size is not None:
    182         dataset = _PrivateThreadPoolDataset(dataset,
    183                                             t_options.private_threadpool_size)
    184     static_optimizations = options._static_optimizations()  # pylint: disable=protected-access
    185     if static_optimizations:
    186       if self._has_captured_ref():
    187         warnings.warn(
    188             "tf.data static optimizations are not compatible with tf.Variable. "
    189             "The following optimizations will be disabled: %s. To enable "
    190             "optimizations, use resource variables instead by calling "
    191             "`tf.enable_resource_variables()` at the start of the program." %
    192             ", ".join(static_optimizations))
    193       else:
    194         dataset = _OptimizeDataset(dataset, static_optimizations)
    195 
    196     autotune = True
    197     cpu_budget = 0  # Indicates that all CPU cores should be used.
    198     if options.experimental_optimization is not None:
    199       if options.experimental_optimization.autotune is False:  # pylint: disable=g-bool-id-comparison
    200         autotune = False
    201       if options.experimental_optimization.autotune_cpu_budget is not None:
    202         cpu_budget = options.experimental_optimization.autotune_cpu_budget
    203 
    204     if autotune:
    205       dataset = _ModelDataset(dataset, cpu_budget)
    206 
    207     if options.experimental_stats and options.experimental_stats.aggregator:  # pylint: disable=line-too-long
    208       dataset = _SetStatsAggregatorDataset(  # pylint: disable=protected-access
    209           dataset, options.experimental_stats.aggregator,
    210           options.experimental_stats.prefix,
    211           options.experimental_stats.counter_prefix)
    212     return dataset
    213 
    214   def __iter__(self):
    215     """Creates an `Iterator` for enumerating the elements of this dataset.
    216 
    217     The returned iterator implements the Python iterator protocol and therefore
    218     can only be used in eager mode.
    219 
    220     Returns:
    221       An `Iterator` over the elements of this dataset.
    222 
    223     Raises:
    224       RuntimeError: If eager execution is not enabled.
    225     """
    226     if context.executing_eagerly():
    227       return iterator_ops.EagerIterator(self)
    228     else:
    229       raise RuntimeError("dataset.__iter__() is only supported when eager "
    230                          "execution is enabled.")
    231 
    232   @abc.abstractproperty
    233   def _element_structure(self):
    234     """The structure of an element of this dataset.
    235 
    236     Returns:
    237       A `Structure` object representing the structure of an element of this
    238       dataset.
    239     """
    240     raise NotImplementedError("Dataset._element_structure")
    241 
    242   def __repr__(self):
    243     output_shapes = nest.map_structure(str, get_legacy_output_shapes(self))
    244     output_shapes = str(output_shapes).replace("'", "")
    245     output_types = nest.map_structure(repr, get_legacy_output_types(self))
    246     output_types = str(output_types).replace("'", "")
    247     return ("<%s shapes: %s, types: %s>" % (type(self).__name__, output_shapes,
    248                                             output_types))
    249 
    250   @staticmethod
    251   def from_tensors(tensors):
    252     """Creates a `Dataset` with a single element, comprising the given tensors.
    253 
    254     Note that if `tensors` contains a NumPy array, and eager execution is not
    255     enabled, the values will be embedded in the graph as one or more
    256     `tf.constant` operations. For large datasets (> 1 GB), this can waste
    257     memory and run into byte limits of graph serialization. If `tensors`
    258     contains one or more large NumPy arrays, consider the alternative described
    259     in [this
    260     guide](https://tensorflow.org/guide/datasets#consuming_numpy_arrays).
    261 
    262     Args:
    263       tensors: A nested structure of tensors.
    264 
    265     Returns:
    266       Dataset: A `Dataset`.
    267     """
    268     return TensorDataset(tensors)
    269 
    270   @staticmethod
    271   def from_tensor_slices(tensors):
    272     """Creates a `Dataset` whose elements are slices of the given tensors.
    273 
    274     Note that if `tensors` contains a NumPy array, and eager execution is not
    275     enabled, the values will be embedded in the graph as one or more
    276     `tf.constant` operations. For large datasets (> 1 GB), this can waste
    277     memory and run into byte limits of graph serialization. If `tensors`
    278     contains one or more large NumPy arrays, consider the alternative described
    279     in [this guide](
    280     https://tensorflow.org/guide/datasets#consuming_numpy_arrays).
    281 
    282     Args:
    283       tensors: A nested structure of tensors, each having the same size in the
    284         0th dimension.
    285 
    286     Returns:
    287       Dataset: A `Dataset`.
    288     """
    289     return TensorSliceDataset(tensors)
    290 
    291   class _GeneratorState(object):
    292     """Stores outstanding iterators created from a Python generator.
    293 
    294     This class keeps track of potentially multiple iterators that may have
    295     been created from a generator, e.g. in the case that the dataset is
    296     repeated, or nested within a parallel computation.
    297     """
    298 
    299     def __init__(self, generator):
    300       self._generator = generator
    301       self._lock = threading.Lock()
    302       self._next_id = 0  # GUARDED_BY(self._lock)
    303       self._args = {}
    304       self._iterators = {}
    305 
    306     def get_next_id(self, *args):
    307       with self._lock:
    308         ret = self._next_id
    309         self._next_id += 1
    310       self._args[ret] = args
    311       # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
    312       # casting in `py_func()` will create an array of `np.int32` on Windows,
    313       # leading to a runtime error.
    314       return np.array(ret, dtype=np.int64)
    315 
    316     def get_iterator(self, iterator_id):
    317       try:
    318         return self._iterators[iterator_id]
    319       except KeyError:
    320         iterator = iter(self._generator(*self._args.pop(iterator_id)))
    321         self._iterators[iterator_id] = iterator
    322         return iterator
    323 
    324     def iterator_completed(self, iterator_id):
    325       del self._iterators[iterator_id]
    326 
    327   @staticmethod
    328   def from_generator(generator, output_types, output_shapes=None, args=None):
    329     """Creates a `Dataset` whose elements are generated by `generator`.
    330 
    331     The `generator` argument must be a callable object that returns
    332     an object that support the `iter()` protocol (e.g. a generator function).
    333     The elements generated by `generator` must be compatible with the given
    334     `output_types` and (optional) `output_shapes` arguments.
    335 
    336     For example:
    337 
    338     ```python
    339     import itertools
    340     tf.enable_eager_execution()
    341 
    342     def gen():
    343       for i in itertools.count(1):
    344         yield (i, [1] * i)
    345 
    346     ds = tf.data.Dataset.from_generator(
    347         gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None])))
    348 
    349     for value in ds.take(2):
    350       print value
    351     # (1, array([1]))
    352     # (2, array([1, 1]))
    353     ```
    354 
    355     NOTE: The current implementation of `Dataset.from_generator()` uses
    356     `tf.py_func` and inherits the same constraints. In particular, it
    357     requires the `Dataset`- and `Iterator`-related operations to be placed
    358     on a device in the same process as the Python program that called
    359     `Dataset.from_generator()`. The body of `generator` will not be
    360     serialized in a `GraphDef`, and you should not use this method if you
    361     need to serialize your model and restore it in a different environment.
    362 
    363     NOTE: If `generator` depends on mutable global variables or other external
    364     state, be aware that the runtime may invoke `generator` multiple times
    365     (in order to support repeating the `Dataset`) and at any time
    366     between the call to `Dataset.from_generator()` and the production of the
    367     first element from the generator. Mutating global variables or external
    368     state can cause undefined behavior, and we recommend that you explicitly
    369     cache any external state in `generator` before calling
    370     `Dataset.from_generator()`.
    371 
    372     Args:
    373       generator: A callable object that returns an object that supports the
    374         `iter()` protocol. If `args` is not specified, `generator` must take
    375         no arguments; otherwise it must take as many arguments as there are
    376         values in `args`.
    377       output_types: A nested structure of `tf.DType` objects corresponding to
    378         each component of an element yielded by `generator`.
    379       output_shapes: (Optional.) A nested structure of `tf.TensorShape`
    380         objects corresponding to each component of an element yielded by
    381         `generator`.
    382       args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated
    383         and passed to `generator` as NumPy-array arguments.
    384 
    385     Returns:
    386       Dataset: A `Dataset`.
    387     """
    388     if not callable(generator):
    389       raise TypeError("`generator` must be callable.")
    390     if output_shapes is None:
    391       output_shapes = nest.map_structure(
    392           lambda _: tensor_shape.TensorShape(None), output_types)
    393     else:
    394       output_shapes = nest.map_structure_up_to(
    395           output_types, tensor_shape.as_shape, output_shapes)
    396     if args is None:
    397       args = ()
    398     else:
    399       args = tuple(ops.convert_n_to_tensor(args, name="args"))
    400 
    401     flattened_types = [dtypes.as_dtype(dt) for dt in nest.flatten(output_types)]
    402     flattened_shapes = nest.flatten(output_shapes)
    403 
    404     generator_state = DatasetV2._GeneratorState(generator)
    405 
    406     def get_iterator_id_fn(unused_dummy):
    407       """Creates a unique `iterator_id` for each pass over the dataset.
    408 
    409       The returned `iterator_id` disambiguates between multiple concurrently
    410       existing iterators.
    411 
    412       Args:
    413         unused_dummy: Ignored value.
    414 
    415       Returns:
    416         A `tf.int64` tensor whose value uniquely identifies an iterator in
    417         `generator_state`.
    418       """
    419       return script_ops.py_func(
    420           generator_state.get_next_id, args, dtypes.int64, stateful=True)
    421 
    422     def generator_next_fn(iterator_id_t):
    423       """Generates the next element from iterator with ID `iterator_id_t`.
    424 
    425       We map this function across an infinite repetition of the
    426       `iterator_id_t`, and raise `StopIteration` to terminate the iteration.
    427 
    428       Args:
    429         iterator_id_t: A `tf.int64` tensor whose value uniquely identifies
    430           the iterator in `generator_state` from which to generate an element.
    431 
    432       Returns:
    433         A nested structure of tensors representing an element from the iterator.
    434       """
    435 
    436       def generator_py_func(iterator_id):
    437         """A `py_func` that will be called to invoke the iterator."""
    438         # `next()` raises `StopIteration` when there are no more
    439         # elements remaining to be generated.
    440         values = next(generator_state.get_iterator(iterator_id))
    441 
    442         # Use the same _convert function from the py_func() implementation to
    443         # convert the returned values to arrays early, so that we can inspect
    444         # their values.
    445         try:
    446           flattened_values = nest.flatten_up_to(output_types, values)
    447         except (TypeError, ValueError):
    448           raise TypeError(
    449               "`generator` yielded an element that did not match the expected "
    450               "structure. The expected structure was %s, but the yielded "
    451               "element was %s." % (output_types, values))
    452         ret_arrays = []
    453         for ret, dtype in zip(flattened_values, flattened_types):
    454           try:
    455             ret_arrays.append(script_ops.FuncRegistry._convert(  # pylint: disable=protected-access
    456                 ret, dtype=dtype.as_numpy_dtype))
    457           except (TypeError, ValueError):
    458             raise TypeError(
    459                 "`generator` yielded an element that could not be converted to "
    460                 "the expected type. The expected type was %s, but the yielded "
    461                 "element was %s." % (dtype.name, ret))
    462 
    463         # Additional type and shape checking to ensure that the components
    464         # of the generated element match the `output_types` and `output_shapes`
    465         # arguments.
    466         for (ret_array, expected_dtype, expected_shape) in zip(
    467             ret_arrays, flattened_types, flattened_shapes):
    468           if ret_array.dtype != expected_dtype.as_numpy_dtype:
    469             raise TypeError(
    470                 "`generator` yielded an element of type %s where an element "
    471                 "of type %s was expected." % (ret_array.dtype,
    472                                               expected_dtype.as_numpy_dtype))
    473           if not expected_shape.is_compatible_with(ret_array.shape):
    474             raise ValueError(
    475                 "`generator` yielded an element of shape %s where an element "
    476                 "of shape %s was expected." % (ret_array.shape, expected_shape))
    477 
    478         return ret_arrays
    479 
    480       flat_values = script_ops.py_func(
    481           generator_py_func, [iterator_id_t], flattened_types, stateful=True)
    482 
    483       # The `py_func()` op drops the inferred shapes, so we add them back in
    484       # here.
    485       if output_shapes is not None:
    486         for ret_t, shape in zip(flat_values, flattened_shapes):
    487           ret_t.set_shape(shape)
    488 
    489       return nest.pack_sequence_as(output_types, flat_values)
    490 
    491     def finalize_fn(iterator_id_t):
    492       """Releases host-side state for the iterator with ID `iterator_id_t`."""
    493 
    494       def finalize_py_func(iterator_id):
    495         generator_state.iterator_completed(iterator_id)
    496         # We return a dummy value so that the `finalize_fn` has a valid
    497         # signature.
    498         # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
    499         # casting in `py_func()` will create an array of `np.int32` on Windows,
    500         # leading to a runtime error.
    501         return np.array(0, dtype=np.int64)
    502 
    503       return script_ops.py_func(
    504           finalize_py_func, [iterator_id_t], dtypes.int64, stateful=True)
    505 
    506     # This function associates each traversal of `generator` with a unique
    507     # iterator ID.
    508     def flat_map_fn(dummy_arg):
    509       # The `get_iterator_id_fn` gets a unique ID for the current instance of
    510       # of the generator.
    511       # The `generator_next_fn` gets the next element from the iterator with the
    512       # given ID, and raises StopIteration when that iterator contains no
    513       # more elements.
    514       return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn,
    515                                finalize_fn)
    516 
    517     # A single-element dataset that, each time it is evaluated, contains a
    518     # freshly-generated and unique (for the returned dataset) int64
    519     # ID that will be used to identify the appropriate Python state, which
    520     # is encapsulated in `generator_state`, and captured in
    521     # `get_iterator_id_map_fn`.
    522     dummy = 0
    523     id_dataset = Dataset.from_tensors(dummy)
    524 
    525     # A dataset that contains all of the elements generated by a
    526     # single iterator created from `generator`, identified by the
    527     # iterator ID contained in `id_dataset`. Lifting the iteration
    528     # into a flat_map here enables multiple repetitions and/or nested
    529     # versions of the returned dataset to be created, because it forces
    530     # the generation of a new ID for each version.
    531     return id_dataset.flat_map(flat_map_fn)
    532 
    533   @staticmethod
    534   def range(*args):
    535     """Creates a `Dataset` of a step-separated range of values.
    536 
    537     For example:
    538 
    539     ```python
    540     Dataset.range(5) == [0, 1, 2, 3, 4]
    541     Dataset.range(2, 5) == [2, 3, 4]
    542     Dataset.range(1, 5, 2) == [1, 3]
    543     Dataset.range(1, 5, -2) == []
    544     Dataset.range(5, 1) == []
    545     Dataset.range(5, 1, -2) == [5, 3]
    546     ```
    547 
    548     Args:
    549       *args: follows the same semantics as python's xrange.
    550         len(args) == 1 -> start = 0, stop = args[0], step = 1
    551         len(args) == 2 -> start = args[0], stop = args[1], step = 1
    552         len(args) == 3 -> start = args[0], stop = args[1, stop = args[2]
    553 
    554     Returns:
    555       Dataset: A `RangeDataset`.
    556 
    557     Raises:
    558       ValueError: if len(args) == 0.
    559     """
    560     return RangeDataset(*args)
    561 
    562   @staticmethod
    563   def zip(datasets):
    564     """Creates a `Dataset` by zipping together the given datasets.
    565 
    566     This method has similar semantics to the built-in `zip()` function
    567     in Python, with the main difference being that the `datasets`
    568     argument can be an arbitrary nested structure of `Dataset` objects.
    569     For example:
    570 
    571     ```python
    572     # NOTE: The following examples use `{ ... }` to represent the
    573     # contents of a dataset.
    574     a = { 1, 2, 3 }
    575     b = { 4, 5, 6 }
    576     c = { (7, 8), (9, 10), (11, 12) }
    577     d = { 13, 14 }
    578 
    579     # The nested structure of the `datasets` argument determines the
    580     # structure of elements in the resulting dataset.
    581     Dataset.zip((a, b)) == { (1, 4), (2, 5), (3, 6) }
    582     Dataset.zip((b, a)) == { (4, 1), (5, 2), (6, 3) }
    583 
    584     # The `datasets` argument may contain an arbitrary number of
    585     # datasets.
    586     Dataset.zip((a, b, c)) == { (1, 4, (7, 8)),
    587                                 (2, 5, (9, 10)),
    588                                 (3, 6, (11, 12)) }
    589 
    590     # The number of elements in the resulting dataset is the same as
    591     # the size of the smallest dataset in `datasets`.
    592     Dataset.zip((a, d)) == { (1, 13), (2, 14) }
    593     ```
    594 
    595     Args:
    596       datasets: A nested structure of datasets.
    597 
    598     Returns:
    599       Dataset: A `Dataset`.
    600     """
    601     return ZipDataset(datasets)
    602 
    603   def concatenate(self, dataset):
    604     """Creates a `Dataset` by concatenating given dataset with this dataset.
    605 
    606     ```python
    607     # NOTE: The following examples use `{ ... }` to represent the
    608     # contents of a dataset.
    609     a = { 1, 2, 3 }
    610     b = { 4, 5, 6, 7 }
    611 
    612     # Input dataset and dataset to be concatenated should have same
    613     # nested structures and output types.
    614     # c = { (8, 9), (10, 11), (12, 13) }
    615     # d = { 14.0, 15.0, 16.0 }
    616     # a.concatenate(c) and a.concatenate(d) would result in error.
    617 
    618     a.concatenate(b) == { 1, 2, 3, 4, 5, 6, 7 }
    619     ```
    620 
    621     Args:
    622       dataset: `Dataset` to be concatenated.
    623 
    624     Returns:
    625       Dataset: A `Dataset`.
    626     """
    627     return ConcatenateDataset(self, dataset)
    628 
    629   def prefetch(self, buffer_size):
    630     """Creates a `Dataset` that prefetches elements from this dataset.
    631 
    632     Args:
    633       buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
    634         maximum number of elements that will be buffered when prefetching.
    635 
    636     Returns:
    637       Dataset: A `Dataset`.
    638     """
    639     return PrefetchDataset(self, buffer_size)
    640 
    641   @staticmethod
    642   def list_files(file_pattern, shuffle=None, seed=None):
    643     """A dataset of all files matching one or more glob patterns.
    644 
    645     NOTE: The default behavior of this method is to return filenames in
    646     a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False`
    647     to get results in a deterministic order.
    648 
    649     Example:
    650       If we had the following files on our filesystem:
    651         - /path/to/dir/a.txt
    652         - /path/to/dir/b.py
    653         - /path/to/dir/c.py
    654       If we pass "/path/to/dir/*.py" as the directory, the dataset would
    655       produce:
    656         - /path/to/dir/b.py
    657         - /path/to/dir/c.py
    658 
    659     Args:
    660       file_pattern: A string, a list of strings, or a `tf.Tensor` of string type
    661         (scalar or vector), representing the filename glob (i.e. shell wildcard)
    662         pattern(s) that will be matched.
    663       shuffle: (Optional.) If `True`, the file names will be shuffled randomly.
    664         Defaults to `True`.
    665       seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
    666         seed that will be used to create the distribution. See
    667         `tf.set_random_seed` for behavior.
    668 
    669     Returns:
    670      Dataset: A `Dataset` of strings corresponding to file names.
    671     """
    672     with ops.name_scope("list_files"):
    673       if shuffle is None:
    674         shuffle = True
    675       file_pattern = ops.convert_to_tensor(
    676           file_pattern, dtype=dtypes.string, name="file_pattern")
    677       matching_files = gen_io_ops.matching_files(file_pattern)
    678 
    679       # Raise an exception if `file_pattern` does not match any files.
    680       condition = math_ops.greater(array_ops.shape(matching_files)[0], 0,
    681                                    name="match_not_empty")
    682 
    683       message = math_ops.add(
    684           "No files matched pattern: ",
    685           string_ops.reduce_join(file_pattern, separator=", "), name="message")
    686 
    687       assert_not_empty = control_flow_ops.Assert(
    688           condition, [message], summarize=1, name="assert_not_empty")
    689       with ops.control_dependencies([assert_not_empty]):
    690         matching_files = array_ops.identity(matching_files)
    691 
    692       dataset = Dataset.from_tensor_slices(matching_files)
    693       if shuffle:
    694         # NOTE(mrry): The shuffle buffer size must be greater than zero, but the
    695         # list of files might be empty.
    696         buffer_size = math_ops.maximum(
    697             array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
    698         dataset = dataset.shuffle(buffer_size, seed=seed)
    699       return dataset
    700 
    701   def repeat(self, count=None):
    702     """Repeats this dataset `count` times.
    703 
    704     NOTE: If this dataset is a function of global state (e.g. a random number
    705     generator), then different repetitions may produce different elements.
    706 
    707     Args:
    708       count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
    709         number of times the dataset should be repeated. The default behavior
    710         (if `count` is `None` or `-1`) is for the dataset be repeated
    711         indefinitely.
    712 
    713     Returns:
    714       Dataset: A `Dataset`.
    715     """
    716     return RepeatDataset(self, count)
    717 
    718   def _enumerate(self, start=0):
    719 
    720     max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
    721     return Dataset.zip((Dataset.range(start, max_value), self))
    722 
    723   def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None):
    724     """Randomly shuffles the elements of this dataset.
    725 
    726     This dataset fills a buffer with `buffer_size` elements, then randomly
    727     samples elements from this buffer, replacing the selected elements with new
    728     elements. For perfect shuffling, a buffer size greater than or equal to the
    729     full size of the dataset is required.
    730 
    731     For instance, if your dataset contains 10,000 elements but `buffer_size` is
    732     set to 1,000, then `shuffle` will initially select a random element from
    733     only the first 1,000 elements in the buffer. Once an element is selected,
    734     its space in the buffer is replaced by the next (i.e. 1,001-st) element,
    735     maintaining the 1,000 element buffer.
    736 
    737     Args:
    738       buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
    739         number of elements from this dataset from which the new
    740         dataset will sample.
    741       seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
    742         random seed that will be used to create the distribution. See
    743         `tf.set_random_seed` for behavior.
    744       reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
    745         that the dataset should be pseudorandomly reshuffled each time it is
    746         iterated over. (Defaults to `True`.)
    747 
    748     Returns:
    749       Dataset: A `Dataset`.
    750     """
    751     return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration)
    752 
    753   def cache(self, filename=""):
    754     """Caches the elements in this dataset.
    755 
    756     Args:
    757       filename: A `tf.string` scalar `tf.Tensor`, representing the name of a
    758         directory on the filesystem to use for caching tensors in this Dataset.
    759         If a filename is not provided, the dataset will be cached in memory.
    760 
    761     Returns:
    762       Dataset: A `Dataset`.
    763     """
    764     return CacheDataset(self, filename)
    765 
    766   def take(self, count):
    767     """Creates a `Dataset` with at most `count` elements from this dataset.
    768 
    769     Args:
    770       count: A `tf.int64` scalar `tf.Tensor`, representing the number of
    771         elements of this dataset that should be taken to form the new dataset.
    772         If `count` is -1, or if `count` is greater than the size of this
    773         dataset, the new dataset will contain all elements of this dataset.
    774 
    775     Returns:
    776       Dataset: A `Dataset`.
    777     """
    778     return TakeDataset(self, count)
    779 
    780   def skip(self, count):
    781     """Creates a `Dataset` that skips `count` elements from this dataset.
    782 
    783     Args:
    784       count: A `tf.int64` scalar `tf.Tensor`, representing the number
    785         of elements of this dataset that should be skipped to form the
    786         new dataset.  If `count` is greater than the size of this
    787         dataset, the new dataset will contain no elements.  If `count`
    788         is -1, skips the entire dataset.
    789 
    790     Returns:
    791       Dataset: A `Dataset`.
    792     """
    793     return SkipDataset(self, count)
    794 
    795   def shard(self, num_shards, index):
    796     """Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
    797 
    798     This dataset operator is very useful when running distributed training, as
    799     it allows each worker to read a unique subset.
    800 
    801     When reading a single input file, you can skip elements as follows:
    802 
    803     ```python
    804     d = tf.data.TFRecordDataset(input_file)
    805     d = d.shard(num_workers, worker_index)
    806     d = d.repeat(num_epochs)
    807     d = d.shuffle(shuffle_buffer_size)
    808     d = d.map(parser_fn, num_parallel_calls=num_map_threads)
    809     ```
    810 
    811     Important caveats:
    812 
    813     - Be sure to shard before you use any randomizing operator (such as
    814       shuffle).
    815     - Generally it is best if the shard operator is used early in the dataset
    816       pipeline. For example, when reading from a set of TFRecord files, shard
    817       before converting the dataset to input samples. This avoids reading every
    818       file on every worker. The following is an example of an efficient
    819       sharding strategy within a complete pipeline:
    820 
    821     ```python
    822     d = Dataset.list_files(pattern)
    823     d = d.shard(num_workers, worker_index)
    824     d = d.repeat(num_epochs)
    825     d = d.shuffle(shuffle_buffer_size)
    826     d = d.interleave(tf.data.TFRecordDataset,
    827                      cycle_length=num_readers, block_length=1)
    828     d = d.map(parser_fn, num_parallel_calls=num_map_threads)
    829     ```
    830 
    831     Args:
    832       num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
    833         shards operating in parallel.
    834       index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
    835 
    836     Returns:
    837       Dataset: A `Dataset`.
    838 
    839     Raises:
    840       InvalidArgumentError: if `num_shards` or `index` are illegal values.
    841         Note: error checking is done on a best-effort basis, and errors aren't
    842         guaranteed to be caught upon dataset creation. (e.g. providing in a
    843         placeholder tensor bypasses the early checking, and will instead result
    844         in an error during a session.run call.)
    845     """
    846     return ShardDataset(self, num_shards, index)
    847 
    848   def batch(self, batch_size, drop_remainder=False):
    849     """Combines consecutive elements of this dataset into batches.
    850 
    851     The tensors in the resulting element will have an additional outer
    852     dimension, which will be `batch_size` (or `N % batch_size` for the last
    853     element if `batch_size` does not divide the number of input elements `N`
    854     evenly and `drop_remainder` is `False`). If your program depends on the
    855     batches having the same outer dimension, you should set the `drop_remainder`
    856     argument to `True` to prevent the smaller batch from being produced.
    857 
    858     Args:
    859       batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
    860         consecutive elements of this dataset to combine in a single batch.
    861       drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
    862         whether the last batch should be dropped in the case it has fewer than
    863         `batch_size` elements; the default behavior is not to drop the smaller
    864         batch.
    865 
    866     Returns:
    867       Dataset: A `Dataset`.
    868     """
    869     return BatchDataset(self, batch_size, drop_remainder)
    870 
    871   def padded_batch(self,
    872                    batch_size,
    873                    padded_shapes,
    874                    padding_values=None,
    875                    drop_remainder=False):
    876     """Combines consecutive elements of this dataset into padded batches.
    877 
    878     This transformation combines multiple consecutive elements of the input
    879     dataset into a single element.
    880 
    881     Like `tf.data.Dataset.batch`, the tensors in the resulting element will
    882     have an additional outer dimension, which will be `batch_size` (or
    883     `N % batch_size` for the last element if `batch_size` does not divide the
    884     number of input elements `N` evenly and `drop_remainder` is `False`). If
    885     your program depends on the batches having the same outer dimension, you
    886     should set the `drop_remainder` argument to `True` to prevent the smaller
    887     batch from being produced.
    888 
    889     Unlike `tf.data.Dataset.batch`, the input elements to be batched may have
    890     different shapes, and this transformation will pad each component to the
    891     respective shape in `padding_shapes`. The `padding_shapes` argument
    892     determines the resulting shape for each dimension of each component in an
    893     output element:
    894 
    895     * If the dimension is a constant (e.g. `tf.Dimension(37)`), the component
    896       will be padded out to that length in that dimension.
    897     * If the dimension is unknown (e.g. `tf.Dimension(None)`), the component
    898       will be padded out to the maximum length of all elements in that
    899       dimension.
    900 
    901     See also `tf.data.experimental.dense_to_sparse_batch`, which combines
    902     elements that may have different shapes into a `tf.SparseTensor`.
    903 
    904     Args:
    905       batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
    906         consecutive elements of this dataset to combine in a single batch.
    907       padded_shapes: A nested structure of `tf.TensorShape` or
    908         `tf.int64` vector tensor-like objects representing the shape
    909         to which the respective component of each input element should
    910         be padded prior to batching. Any unknown dimensions
    911         (e.g. `tf.Dimension(None)` in a `tf.TensorShape` or `-1` in a
    912         tensor-like object) will be padded to the maximum size of that
    913         dimension in each batch.
    914       padding_values: (Optional.) A nested structure of scalar-shaped
    915         `tf.Tensor`, representing the padding values to use for the
    916         respective components.  Defaults are `0` for numeric types and
    917         the empty string for string types.
    918       drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
    919         whether the last batch should be dropped in the case it has fewer than
    920         `batch_size` elements; the default behavior is not to drop the smaller
    921         batch.
    922 
    923     Returns:
    924       Dataset: A `Dataset`.
    925     """
    926     return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values,
    927                               drop_remainder)
    928 
    929   def map(self, map_func, num_parallel_calls=None):
    930     """Maps `map_func` across the elements of this dataset.
    931 
    932     This transformation applies `map_func` to each element of this dataset, and
    933     returns a new dataset containing the transformed elements, in the same
    934     order as they appeared in the input.
    935 
    936     For example:
    937 
    938     ```python
    939     # NOTE: The following examples use `{ ... }` to represent the
    940     # contents of a dataset.
    941     a = { 1, 2, 3, 4, 5 }
    942 
    943     a.map(lambda x: x + 1) = { 2, 3, 4, 5, 6 }
    944     ```
    945 
    946     The input signature of `map_func` is determined by the structure of each
    947     element in this dataset. For example:
    948 
    949     ```python
    950     # Each element is a `tf.Tensor` object.
    951     a = { 1, 2, 3, 4, 5 }
    952     # `map_func` takes a single argument of type `tf.Tensor` with the same
    953     # shape and dtype.
    954     result = a.map(lambda x: ...)
    955 
    956     # Each element is a tuple containing two `tf.Tensor` objects.
    957     b = { (1, "foo"), (2, "bar"), (3, "baz") }
    958     # `map_func` takes two arguments of type `tf.Tensor`.
    959     result = b.map(lambda x_int, y_str: ...)
    960 
    961     # Each element is a dictionary mapping strings to `tf.Tensor` objects.
    962     c = { {"a": 1, "b": "foo"}, {"a": 2, "b": "bar"}, {"a": 3, "b": "baz"} }
    963     # `map_func` takes a single argument of type `dict` with the same keys as
    964     # the elements.
    965     result = c.map(lambda d: ...)
    966     ```
    967 
    968     The value or values returned by `map_func` determine the structure of each
    969     element in the returned dataset.
    970 
    971     ```python
    972     # `map_func` returns a scalar `tf.Tensor` of type `tf.float32`.
    973     def f(...):
    974       return tf.constant(37.0)
    975     result = dataset.map(f)
    976     result.output_classes == tf.Tensor
    977     result.output_types == tf.float32
    978     result.output_shapes == []  # scalar
    979 
    980     # `map_func` returns two `tf.Tensor` objects.
    981     def g(...):
    982       return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
    983     result = dataset.map(g)
    984     result.output_classes == (tf.Tensor, tf.Tensor)
    985     result.output_types == (tf.float32, tf.string)
    986     result.output_shapes == ([], [3])
    987 
    988     # Python primitives, lists, and NumPy arrays are implicitly converted to
    989     # `tf.Tensor`.
    990     def h(...):
    991       return 37.0, ["Foo", "Bar", "Baz"], np.array([1.0, 2.0] dtype=np.float64)
    992     result = dataset.map(h)
    993     result.output_classes == (tf.Tensor, tf.Tensor, tf.Tensor)
    994     result.output_types == (tf.float32, tf.string, tf.float64)
    995     result.output_shapes == ([], [3], [2])
    996 
    997     # `map_func` can return nested structures.
    998     def i(...):
    999       return {"a": 37.0, "b": [42, 16]}, "foo"
   1000     result.output_classes == ({"a": tf.Tensor, "b": tf.Tensor}, tf.Tensor)
   1001     result.output_types == ({"a": tf.float32, "b": tf.int32}, tf.string)
   1002     result.output_shapes == ({"a": [], "b": [2]}, [])
   1003     ```
   1004 
   1005     In addition to `tf.Tensor` objects, `map_func` can accept as arguments and
   1006     return `tf.SparseTensor` objects.
   1007 
   1008     Args:
   1009       map_func: A function mapping a nested structure of tensors (having
   1010         shapes and types defined by `self.output_shapes` and
   1011        `self.output_types`) to another nested structure of tensors.
   1012       num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
   1013         representing the number elements to process asynchronously in parallel.
   1014         If not specified, elements will be processed sequentially. If the value
   1015         `tf.data.experimental.AUTOTUNE` is used, then the number of parallel
   1016         calls is set dynamically based on available CPU.
   1017 
   1018     Returns:
   1019       Dataset: A `Dataset`.
   1020     """
   1021     if num_parallel_calls is None:
   1022       return MapDataset(self, map_func, preserve_cardinality=True)
   1023     else:
   1024       return ParallelMapDataset(
   1025           self, map_func, num_parallel_calls, preserve_cardinality=True)
   1026 
   1027   def flat_map(self, map_func):
   1028     """Maps `map_func` across this dataset and flattens the result.
   1029 
   1030     Use `flat_map` if you want to make sure that the order of your dataset
   1031     stays the same. For example, to flatten a dataset of batches into a
   1032     dataset of their elements:
   1033 
   1034     ```python
   1035     # NOTE: The following examples use `{ ... }` to represent the
   1036     # contents of a dataset. '[...]' represents a tensor.
   1037     a = {[1,2,3,4,5], [6,7,8,9], [10]}
   1038 
   1039     a.flat_map(lambda x: Dataset.from_tensor_slices(x)) ==
   1040       {[1,2,3,4,5,6,7,8,9,10]}
   1041     ```
   1042 
   1043     `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since
   1044     `flat_map` produces the same output as
   1045     `tf.data.Dataset.interleave(cycle_length=1)`
   1046 
   1047     Args:
   1048       map_func: A function mapping a nested structure of tensors (having shapes
   1049         and types defined by `self.output_shapes` and `self.output_types`) to a
   1050         `Dataset`.
   1051 
   1052     Returns:
   1053       Dataset: A `Dataset`.
   1054     """
   1055     return FlatMapDataset(self, map_func)
   1056 
   1057   def interleave(self,
   1058                  map_func,
   1059                  cycle_length,
   1060                  block_length=1,
   1061                  num_parallel_calls=None):
   1062     """Maps `map_func` across this dataset, and interleaves the results.
   1063 
   1064     For example, you can use `Dataset.interleave()` to process many input files
   1065     concurrently:
   1066 
   1067     ```python
   1068     # Preprocess 4 files concurrently, and interleave blocks of 16 records from
   1069     # each file.
   1070     filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...]
   1071     dataset = (Dataset.from_tensor_slices(filenames)
   1072                .interleave(lambda x:
   1073                    TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
   1074                    cycle_length=4, block_length=16))
   1075     ```
   1076 
   1077     The `cycle_length` and `block_length` arguments control the order in which
   1078     elements are produced. `cycle_length` controls the number of input elements
   1079     that are processed concurrently. If you set `cycle_length` to 1, this
   1080     transformation will handle one input element at a time, and will produce
   1081     identical results to `tf.data.Dataset.flat_map`. In general,
   1082     this transformation will apply `map_func` to `cycle_length` input elements,
   1083     open iterators on the returned `Dataset` objects, and cycle through them
   1084     producing `block_length` consecutive elements from each iterator, and
   1085     consuming the next input element each time it reaches the end of an
   1086     iterator.
   1087 
   1088     For example:
   1089 
   1090     ```python
   1091     # NOTE: The following examples use `{ ... }` to represent the
   1092     # contents of a dataset.
   1093     a = { 1, 2, 3, 4, 5 }
   1094 
   1095     # NOTE: New lines indicate "block" boundaries.
   1096     a.interleave(lambda x: Dataset.from_tensors(x).repeat(6),
   1097                  cycle_length=2, block_length=4) == {
   1098         1, 1, 1, 1,
   1099         2, 2, 2, 2,
   1100         1, 1,
   1101         2, 2,
   1102         3, 3, 3, 3,
   1103         4, 4, 4, 4,
   1104         3, 3,
   1105         4, 4,
   1106         5, 5, 5, 5,
   1107         5, 5,
   1108     }
   1109     ```
   1110 
   1111     NOTE: The order of elements yielded by this transformation is
   1112     deterministic, as long as `map_func` is a pure function. If
   1113     `map_func` contains any stateful operations, the order in which
   1114     that state is accessed is undefined.
   1115 
   1116     Args:
   1117       map_func: A function mapping a nested structure of tensors (having shapes
   1118         and types defined by `self.output_shapes` and `self.output_types`) to a
   1119         `Dataset`.
   1120       cycle_length: The number of elements from this dataset that will be
   1121         processed concurrently.
   1122       block_length: The number of consecutive elements to produce from each
   1123         input element before cycling to another input element.
   1124       num_parallel_calls: (Optional.) If specified, the implementation creates
   1125         a threadpool, which is used to fetch inputs from cycle elements
   1126         asynchronously and in parallel. The default behavior is to fetch inputs
   1127         from cycle elements synchronously with no parallelism. If the value
   1128         `tf.data.experimental.AUTOTUNE` is used, then the number of parallel
   1129         calls is set dynamically based on available CPU.
   1130 
   1131     Returns:
   1132       Dataset: A `Dataset`.
   1133     """
   1134     if num_parallel_calls is None:
   1135       return InterleaveDataset(self, map_func, cycle_length, block_length)
   1136     else:
   1137       return ParallelInterleaveDataset(self, map_func, cycle_length,
   1138                                        block_length, num_parallel_calls)
   1139 
   1140   def filter(self, predicate):
   1141     """Filters this dataset according to `predicate`.
   1142 
   1143     ```python
   1144     d = tf.data.Dataset.from_tensor_slices([1, 2, 3])
   1145 
   1146     d = d.filter(lambda x: x < 3) # [1, 2]
   1147 
   1148     # `tf.math.equal(x, y)` is required for equality comparison
   1149     def filter_fn(x):
   1150       return tf.math.equal(x, 1)
   1151 
   1152     d = d.filter(filter_fn) # [1]
   1153     ```
   1154 
   1155     Args:
   1156       predicate: A function mapping a nested structure of tensors (having shapes
   1157         and types defined by `self.output_shapes` and `self.output_types`) to a
   1158         scalar `tf.bool` tensor.
   1159 
   1160     Returns:
   1161       Dataset: The `Dataset` containing the elements of this dataset for which
   1162           `predicate` is `True`.
   1163     """
   1164     return FilterDataset(self, predicate)
   1165 
   1166   def apply(self, transformation_func):
   1167     """Applies a transformation function to this dataset.
   1168 
   1169     `apply` enables chaining of custom `Dataset` transformations, which are
   1170     represented as functions that take one `Dataset` argument and return a
   1171     transformed `Dataset`.
   1172 
   1173     For example:
   1174 
   1175     ```
   1176     dataset = (dataset.map(lambda x: x ** 2)
   1177                .apply(group_by_window(key_func, reduce_func, window_size))
   1178                .map(lambda x: x ** 3))
   1179     ```
   1180 
   1181     Args:
   1182       transformation_func: A function that takes one `Dataset` argument and
   1183         returns a `Dataset`.
   1184 
   1185     Returns:
   1186       Dataset: The `Dataset` returned by applying `transformation_func` to this
   1187           dataset.
   1188     """
   1189     dataset = transformation_func(self)
   1190     if not isinstance(dataset, DatasetV2):
   1191       raise TypeError(
   1192           "`transformation_func` must return a Dataset. Got {}.".format(
   1193               dataset))
   1194     dataset._input_datasets = [self]  # pylint: disable=protected-access
   1195     return dataset
   1196 
   1197   def window(self, size, shift=None, stride=1, drop_remainder=False):
   1198     """Combines input elements into a dataset of windows.
   1199 
   1200     Each window is a dataset itself and contains `size` elements (or
   1201     possibly fewer if there are not enough input elements to fill the window
   1202     and `drop_remainder` evaluates to false).
   1203 
   1204     The `stride` argument determines the stride of the input elements,
   1205     and the `shift` argument determines the shift of the window.
   1206 
   1207     For example:
   1208     - `tf.data.Dataset.range(7).window(2)` produces
   1209       `{{0, 1}, {2, 3}, {4, 5}, {6}}`
   1210     - `tf.data.Dataset.range(7).window(3, 2, 1, True)` produces
   1211       `{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}}`
   1212     - `tf.data.Dataset.range(7).window(3, 1, 2, True)` produces
   1213       `{{0, 2, 4}, {1, 3, 5}, {2, 4, 6}}`
   1214 
   1215     Args:
   1216       size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements
   1217         of the input dataset to combine into a window.
   1218       shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
   1219         forward shift of the sliding window in each iteration. Defaults to
   1220         `size`.
   1221       stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
   1222         stride of the input elements in the sliding window.
   1223       drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
   1224         whether a window should be dropped in case its size is smaller than
   1225         `window_size`.
   1226 
   1227     Returns:
   1228       Dataset: A `Dataset` of windows, each of which is a nested `Dataset` with
   1229         the same structure as this dataset, but a finite subsequence of its
   1230         elements.
   1231     """
   1232     if shift is None:
   1233       shift = size
   1234     return WindowDataset(self, size, shift, stride, drop_remainder)
   1235 
   1236   def reduce(self, initial_state, reduce_func):
   1237     """Reduces the input dataset to a single element.
   1238 
   1239     The transformation calls `reduce_func` successively on every element of
   1240     the input dataset until the dataset is exhausted, aggregating information in
   1241     its internal state. The `initial_state` argument is used for the initial
   1242     state and the final state is returned as the result.
   1243 
   1244     For example:
   1245     - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1)`
   1246       produces `5`
   1247     - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y)`
   1248       produces `10`
   1249 
   1250     Args:
   1251       initial_state: A nested structure of tensors, representing the initial
   1252         state of the transformation.
   1253       reduce_func: A function that maps `(old_state, input_element)` to
   1254         `new_state`. It must take two arguments and return a nested structure
   1255         of tensors. The structure of `new_state` must match the structure of
   1256         `initial_state`.
   1257 
   1258     Returns:
   1259       A nested structure of `tf.Tensor` objects, corresponding to the final
   1260       state of the transformation.
   1261 
   1262     """
   1263 
   1264     with ops.name_scope("initial_state"):
   1265       # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
   1266       # values to tensors.
   1267       initial_state = nest.pack_sequence_as(initial_state, [
   1268           sparse_tensor_lib.SparseTensor.from_value(t)
   1269           if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
   1270               t, name="component_%d" % i)
   1271           for i, t in enumerate(nest.flatten(initial_state))
   1272       ])
   1273 
   1274     # Compute initial values for the state classes, shapes and types based on
   1275     # the initial state.
   1276     state_structure = structure_lib.Structure.from_value(initial_state)
   1277 
   1278     # Iteratively rerun the reduce function until reaching a fixed point on
   1279     # `state_structure`.
   1280     need_to_rerun = True
   1281     while need_to_rerun:
   1282 
   1283       wrapped_func = StructuredFunctionWrapper(
   1284           reduce_func,
   1285           "reduce()",
   1286           input_structure=structure_lib.NestedStructure(
   1287               (state_structure, self._element_structure)),
   1288           add_to_graph=False)
   1289 
   1290       # Extract and validate class information from the returned values.
   1291       output_classes = wrapped_func.output_classes
   1292       state_classes = state_structure._to_legacy_output_classes()  # pylint: disable=protected-access
   1293       for new_state_class, state_class in zip(
   1294           nest.flatten(output_classes), nest.flatten(state_classes)):
   1295         if not issubclass(new_state_class, state_class):
   1296           raise TypeError(
   1297               "The element classes for the new state must match the initial "
   1298               "state. Expected %s; got %s." % (state_classes,
   1299                                                wrapped_func.output_classes))
   1300 
   1301       # Extract and validate type information from the returned values.
   1302       output_types = wrapped_func.output_types
   1303       state_types = state_structure._to_legacy_output_types()  # pylint: disable=protected-access
   1304       for new_state_type, state_type in zip(
   1305           nest.flatten(output_types), nest.flatten(state_types)):
   1306         if new_state_type != state_type:
   1307           raise TypeError(
   1308               "The element types for the new state must match the initial "
   1309               "state. Expected %s; got %s." % (state_types,
   1310                                                wrapped_func.output_types))
   1311 
   1312       # Extract shape information from the returned values.
   1313       output_shapes = wrapped_func.output_shapes
   1314       state_shapes = state_structure._to_legacy_output_shapes()  # pylint: disable=protected-access
   1315       flat_state_shapes = nest.flatten(state_shapes)
   1316       flat_new_state_shapes = nest.flatten(output_shapes)
   1317       weakened_state_shapes = [
   1318           original.most_specific_compatible_shape(new)
   1319           for original, new in zip(flat_state_shapes, flat_new_state_shapes)
   1320       ]
   1321 
   1322       need_to_rerun = False
   1323       for original_shape, weakened_shape in zip(flat_state_shapes,
   1324                                                 weakened_state_shapes):
   1325         if original_shape.ndims is not None and (
   1326             weakened_shape.ndims is None or
   1327             original_shape.as_list() != weakened_shape.as_list()):
   1328           need_to_rerun = True
   1329           break
   1330 
   1331       if need_to_rerun:
   1332         # TODO(b/110122868): Support a "most specific compatible structure"
   1333         # method for combining structures, to avoid using legacy structures
   1334         # here.
   1335         state_structure = structure_lib.convert_legacy_structure(
   1336             state_types,
   1337             nest.pack_sequence_as(state_shapes, weakened_state_shapes),
   1338             state_classes)
   1339 
   1340     reduce_func = wrapped_func.function
   1341     reduce_func.add_to_graph(ops.get_default_graph())
   1342 
   1343     # pylint: disable=protected-access
   1344     return state_structure._from_compatible_tensor_list(
   1345         gen_dataset_ops.reduce_dataset(
   1346             self._variant_tensor,
   1347             state_structure._to_tensor_list(initial_state),
   1348             reduce_func.captured_inputs,
   1349             f=reduce_func,
   1350             output_shapes=state_structure._flat_shapes,
   1351             output_types=state_structure._flat_types))
   1352 
   1353   def with_options(self, options):
   1354     """Returns a new `tf.data.Dataset` with the given options set.
   1355 
   1356     The options are "global" in the sense they apply to the entire dataset.
   1357     If options are set multiple times, they are merged as long as different
   1358     options do not use different non-default values.
   1359 
   1360     Args:
   1361       options: A `tf.data.Options` that identifies the options the use.
   1362 
   1363     Returns:
   1364       Dataset: A `Dataset` with the given options.
   1365 
   1366     Raises:
   1367       ValueError: when an option is set more than once to a non-default value
   1368     """
   1369     return _OptionsDataset(self, options)
   1370 
   1371 
   1372 @tf_export(v1=["data.Dataset"])
   1373 class DatasetV1(DatasetV2):
   1374   """Represents a potentially large set of elements.
   1375 
   1376   A `Dataset` can be used to represent an input pipeline as a
   1377   collection of elements (nested structures of tensors) and a "logical
   1378   plan" of transformations that act on those elements.
   1379   """
   1380 
   1381   def __init__(self):
   1382     try:
   1383       variant_tensor = self._as_variant_tensor()
   1384     except AttributeError as e:
   1385       if "_as_variant_tensor" in str(e):
   1386         raise AttributeError("Please use _variant_tensor instead of "
   1387                              "_as_variant_tensor() to obtain the variant "
   1388                              "associated with a dataset")
   1389       raise AttributeError("A likely cause of this error is that the super "
   1390                            "call for this dataset is not the last line of the "
   1391                            "__init__ method. The base class causes the "
   1392                            "_as_variant_tensor call in its constructor and "
   1393                            "if that uses attributes defined in the __init__ "
   1394                            "method, those attrs need to be defined before the "
   1395                            "super call.")
   1396     super(DatasetV1, self).__init__(variant_tensor)
   1397 
   1398   @abc.abstractmethod
   1399   def _as_variant_tensor(self):
   1400     """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset.
   1401 
   1402     Returns:
   1403       A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset.
   1404     """
   1405     raise NotImplementedError("Dataset._as_variant_tensor")
   1406 
   1407   @deprecation.deprecated(
   1408       None, "Use `for ... in dataset:` to iterate over a dataset. If using "
   1409       "`tf.estimator`, return the `Dataset` object directly from your input "
   1410       "function. As a last resort, you can use "
   1411       "`tf.compat.v1.data.make_one_shot_iterator(dataset)`.")
   1412   def make_one_shot_iterator(self):
   1413     """Creates an `Iterator` for enumerating the elements of this dataset.
   1414 
   1415     Note: The returned iterator will be initialized automatically.
   1416     A "one-shot" iterator does not currently support re-initialization.
   1417 
   1418     Returns:
   1419       An `Iterator` over the elements of this dataset.
   1420     """
   1421     return self._make_one_shot_iterator()
   1422 
   1423   def _make_one_shot_iterator(self):  # pylint: disable=missing-docstring
   1424     if context.executing_eagerly():
   1425       return iterator_ops.EagerIterator(self)
   1426 
   1427     _ensure_same_dataset_graph(self)
   1428     # Now that we create datasets at python object creation time, the capture
   1429     # by value _make_dataset() function would try to capture these variant
   1430     # tensor dataset inputs, which are marked as stateful ops and would throw
   1431     # an error if we try and capture them. We therefore traverse the graph
   1432     # to find all these ops and whitelist them so that the capturing
   1433     # logic instead of throwing an error recreates these ops which is what was
   1434     # happening before.
   1435     all_ds_ops = traverse.obtain_all_variant_tensor_ops(self)
   1436     graph_level_seed, op_level_seed = core_random_seed.get_seed(None)
   1437 
   1438     # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
   1439     # a 0-argument function.
   1440     @function.Defun(capture_by_value=True, whitelisted_stateful_ops=all_ds_ops)
   1441     def _make_dataset():
   1442       """Factory function for a dataset."""
   1443       # NOTE(mrry): `Defun` does not capture the graph-level seed from the
   1444       # enclosing graph, so if a graph-level seed is present we set the local
   1445       # graph seed based on a combination of the graph- and op-level seeds.
   1446       if graph_level_seed is not None:
   1447         assert op_level_seed is not None
   1448         core_random_seed.set_random_seed(
   1449             (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1))
   1450 
   1451       dataset = self._apply_options()
   1452       return dataset._variant_tensor  # pylint: disable=protected-access
   1453 
   1454     try:
   1455       _make_dataset.add_to_graph(ops.get_default_graph())
   1456     except ValueError as err:
   1457       if "Cannot capture a stateful node" in str(err):
   1458         raise ValueError(
   1459             "Failed to create a one-shot iterator for a dataset. "
   1460             "`Dataset.make_one_shot_iterator()` does not support datasets that "
   1461             "capture stateful objects, such as a `Variable` or `LookupTable`. "
   1462             "In these cases, use `Dataset.make_initializable_iterator()`. "
   1463             "(Original error: %s)" % err)
   1464       else:
   1465         six.reraise(ValueError, err)
   1466 
   1467     # pylint: disable=protected-access
   1468     return iterator_ops.Iterator(
   1469         gen_dataset_ops.one_shot_iterator(
   1470             dataset_factory=_make_dataset, **flat_structure(self)),
   1471         None, get_legacy_output_types(self), get_legacy_output_shapes(self),
   1472         get_legacy_output_classes(self))
   1473 
   1474   @deprecation.deprecated(
   1475       None, "Use `for ... in dataset:` to iterate over a dataset. If using "
   1476       "`tf.estimator`, return the `Dataset` object directly from your input "
   1477       "function. As a last resort, you can use "
   1478       "`tf.compat.v1.data.make_initializable_iterator(dataset)`.")
   1479   def make_initializable_iterator(self, shared_name=None):
   1480     """Creates an `Iterator` for enumerating the elements of this dataset.
   1481 
   1482     Note: The returned iterator will be in an uninitialized state,
   1483     and you must run the `iterator.initializer` operation before using it:
   1484 
   1485     ```python
   1486     dataset = ...
   1487     iterator = dataset.make_initializable_iterator()
   1488     # ...
   1489     sess.run(iterator.initializer)
   1490     ```
   1491 
   1492     Args:
   1493       shared_name: (Optional.) If non-empty, the returned iterator will be
   1494         shared under the given name across multiple sessions that share the
   1495         same devices (e.g. when using a remote server).
   1496 
   1497     Returns:
   1498       An `Iterator` over the elements of this dataset.
   1499 
   1500     Raises:
   1501       RuntimeError: If eager execution is enabled.
   1502     """
   1503 
   1504     return self._make_initializable_iterator(shared_name)
   1505 
   1506   def _make_initializable_iterator(self, shared_name=None):  # pylint: disable=missing-docstring
   1507     if context.executing_eagerly():
   1508       raise RuntimeError(
   1509           "dataset.make_initializable_iterator is not supported when eager "
   1510           "execution is enabled.")
   1511     _ensure_same_dataset_graph(self)
   1512     dataset = self._apply_options()
   1513     if shared_name is None:
   1514       shared_name = ""
   1515     if compat.forward_compatible(2018, 8, 3):
   1516       iterator_resource = gen_dataset_ops.iterator_v2(
   1517           container="", shared_name=shared_name, **flat_structure(self))
   1518     else:
   1519       iterator_resource = gen_dataset_ops.iterator(
   1520           container="", shared_name=shared_name, **flat_structure(self))
   1521     with ops.colocate_with(iterator_resource):
   1522       initializer = gen_dataset_ops.make_iterator(
   1523           dataset._variant_tensor,  # pylint: disable=protected-access
   1524           iterator_resource)
   1525     # pylint: disable=protected-access
   1526     return iterator_ops.Iterator(
   1527         iterator_resource, initializer, get_legacy_output_types(dataset),
   1528         get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset))
   1529 
   1530   @property
   1531   def output_classes(self):
   1532     """Returns the class of each component of an element of this dataset.
   1533 
   1534     The expected values are `tf.Tensor` and `tf.SparseTensor`.
   1535 
   1536     Returns:
   1537       A nested structure of Python `type` objects corresponding to each
   1538       component of an element of this dataset.
   1539     """
   1540     return self._element_structure._to_legacy_output_classes()  # pylint: disable=protected-access
   1541 
   1542   @property
   1543   def output_shapes(self):
   1544     """Returns the shape of each component of an element of this dataset.
   1545 
   1546     Returns:
   1547       A nested structure of `tf.TensorShape` objects corresponding to each
   1548       component of an element of this dataset.
   1549     """
   1550     return self._element_structure._to_legacy_output_shapes()  # pylint: disable=protected-access
   1551 
   1552   @property
   1553   def output_types(self):
   1554     """Returns the type of each component of an element of this dataset.
   1555 
   1556     Returns:
   1557       A nested structure of `tf.DType` objects corresponding to each component
   1558       of an element of this dataset.
   1559     """
   1560     return self._element_structure._to_legacy_output_types()  # pylint: disable=protected-access
   1561 
   1562   @property
   1563   def _element_structure(self):
   1564     # TODO(b/110122868): Remove this override once all `Dataset` instances
   1565     # implement `element_structure`.
   1566     return structure_lib.convert_legacy_structure(
   1567         self.output_types, self.output_shapes, self.output_classes)
   1568 
   1569   @staticmethod
   1570   @functools.wraps(DatasetV2.from_tensors)
   1571   def from_tensors(tensors):
   1572     return DatasetV1Adapter(DatasetV2.from_tensors(tensors))
   1573 
   1574   @staticmethod
   1575   @functools.wraps(DatasetV2.from_tensor_slices)
   1576   def from_tensor_slices(tensors):
   1577     return DatasetV1Adapter(DatasetV2.from_tensor_slices(tensors))
   1578 
   1579   @staticmethod
   1580   @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.")
   1581   def from_sparse_tensor_slices(sparse_tensor):
   1582     """Splits each rank-N `tf.SparseTensor` in this dataset row-wise.
   1583 
   1584     Args:
   1585       sparse_tensor: A `tf.SparseTensor`.
   1586 
   1587     Returns:
   1588       Dataset: A `Dataset` of rank-(N-1) sparse tensors.
   1589     """
   1590     return DatasetV1Adapter(SparseTensorSliceDataset(sparse_tensor))
   1591 
   1592   @staticmethod
   1593   @functools.wraps(DatasetV2.from_generator)
   1594   def from_generator(generator, output_types, output_shapes=None, args=None):
   1595     return DatasetV1Adapter(DatasetV2.from_generator(
   1596         generator, output_types, output_shapes, args))
   1597 
   1598   @staticmethod
   1599   @functools.wraps(DatasetV2.range)
   1600   def range(*args):
   1601     return DatasetV1Adapter(DatasetV2.range(*args))
   1602 
   1603   @staticmethod
   1604   @functools.wraps(DatasetV2.zip)
   1605   def zip(datasets):
   1606     return DatasetV1Adapter(DatasetV2.zip(datasets))
   1607 
   1608   @functools.wraps(DatasetV2.concatenate)
   1609   def concatenate(self, dataset):
   1610     return DatasetV1Adapter(super(DatasetV1, self).concatenate(dataset))
   1611 
   1612   @functools.wraps(DatasetV2.prefetch)
   1613   def prefetch(self, buffer_size):
   1614     return DatasetV1Adapter(super(DatasetV1, self).prefetch(buffer_size))
   1615 
   1616   @staticmethod
   1617   @functools.wraps(DatasetV2.list_files)
   1618   def list_files(file_pattern, shuffle=None, seed=None):
   1619     return DatasetV1Adapter(DatasetV2.list_files(file_pattern, shuffle, seed))
   1620 
   1621   @functools.wraps(DatasetV2.repeat)
   1622   def repeat(self, count=None):
   1623     return DatasetV1Adapter(super(DatasetV1, self).repeat(count))
   1624 
   1625   @functools.wraps(DatasetV2.shuffle)
   1626   def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None):
   1627     return DatasetV1Adapter(super(DatasetV1, self).shuffle(
   1628         buffer_size, seed, reshuffle_each_iteration))
   1629 
   1630   @functools.wraps(DatasetV2.cache)
   1631   def cache(self, filename=""):
   1632     return DatasetV1Adapter(super(DatasetV1, self).cache(filename))
   1633 
   1634   @functools.wraps(DatasetV2.take)
   1635   def take(self, count):
   1636     return DatasetV1Adapter(super(DatasetV1, self).take(count))
   1637 
   1638   @functools.wraps(DatasetV2.skip)
   1639   def skip(self, count):
   1640     return DatasetV1Adapter(super(DatasetV1, self).skip(count))
   1641 
   1642   @functools.wraps(DatasetV2.shard)
   1643   def shard(self, num_shards, index):
   1644     return DatasetV1Adapter(super(DatasetV1, self).shard(num_shards, index))
   1645 
   1646   @functools.wraps(DatasetV2.batch)
   1647   def batch(self, batch_size, drop_remainder=False):
   1648     return DatasetV1Adapter(super(DatasetV1, self).batch(
   1649         batch_size, drop_remainder))
   1650 
   1651   @functools.wraps(DatasetV2.padded_batch)
   1652   def padded_batch(self,
   1653                    batch_size,
   1654                    padded_shapes,
   1655                    padding_values=None,
   1656                    drop_remainder=False):
   1657     return DatasetV1Adapter(super(DatasetV1, self).padded_batch(
   1658         batch_size, padded_shapes, padding_values, drop_remainder))
   1659 
   1660   @functools.wraps(DatasetV2.map)
   1661   def map(self, map_func, num_parallel_calls=None):
   1662     if num_parallel_calls is None:
   1663       return DatasetV1Adapter(
   1664           MapDataset(self, map_func, preserve_cardinality=False))
   1665     else:
   1666       return DatasetV1Adapter(
   1667           ParallelMapDataset(
   1668               self, map_func, num_parallel_calls, preserve_cardinality=False))
   1669 
   1670   @deprecation.deprecated(None, "Use `tf.data.Dataset.map()")
   1671   def map_with_legacy_function(self, map_func, num_parallel_calls=None):
   1672     """Maps `map_func` across the elements of this dataset.
   1673 
   1674     NOTE: This is an escape hatch for existing uses of `map` that do not work
   1675     with V2 functions. New uses are strongly discouraged and existing uses
   1676     should migrate to `map` as this method will be removed in V2.
   1677 
   1678     Args:
   1679       map_func: A function mapping a nested structure of tensors (having shapes
   1680         and types defined by `self.output_shapes` and `self.output_types`) to
   1681         another nested structure of tensors.
   1682       num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
   1683         representing the number elements to process asynchronously in parallel.
   1684         If not specified, elements will be processed sequentially. If the value
   1685         `tf.data.experimental.AUTOTUNE` is used, then the number of parallel
   1686         calls is set dynamically based on available CPU.
   1687 
   1688     Returns:
   1689       Dataset: A `Dataset`.
   1690     """
   1691     if num_parallel_calls is None:
   1692       return DatasetV1Adapter(
   1693           MapDataset(
   1694               self,
   1695               map_func,
   1696               preserve_cardinality=False,
   1697               use_legacy_function=True))
   1698     else:
   1699       return DatasetV1Adapter(
   1700           ParallelMapDataset(
   1701               self,
   1702               map_func,
   1703               num_parallel_calls,
   1704               preserve_cardinality=False,
   1705               use_legacy_function=True))
   1706 
   1707   @functools.wraps(DatasetV2.flat_map)
   1708   def flat_map(self, map_func):
   1709     return DatasetV1Adapter(super(DatasetV1, self).flat_map(map_func))
   1710 
   1711   @functools.wraps(DatasetV2.interleave)
   1712   def interleave(self,
   1713                  map_func,
   1714                  cycle_length,
   1715                  block_length=1,
   1716                  num_parallel_calls=None):
   1717     return DatasetV1Adapter(super(DatasetV1, self).interleave(
   1718         map_func, cycle_length, block_length, num_parallel_calls))
   1719 
   1720   @functools.wraps(DatasetV2.filter)
   1721   def filter(self, predicate):
   1722     return DatasetV1Adapter(super(DatasetV1, self).filter(predicate))
   1723 
   1724   @deprecation.deprecated(None, "Use `tf.data.Dataset.filter()")
   1725   def filter_with_legacy_function(self, predicate):
   1726     """Filters this dataset according to `predicate`.
   1727 
   1728     NOTE: This is an escape hatch for existing uses of `filter` that do not work
   1729     with V2 functions. New uses are strongly discouraged and existing uses
   1730     should migrate to `filter` as this method will be removed in V2.
   1731 
   1732     Args:
   1733       predicate: A function mapping a nested structure of tensors (having shapes
   1734         and types defined by `self.output_shapes` and `self.output_types`) to a
   1735         scalar `tf.bool` tensor.
   1736 
   1737     Returns:
   1738       Dataset: The `Dataset` containing the elements of this dataset for which
   1739           `predicate` is `True`.
   1740     """
   1741     return FilterDataset(self, predicate, use_legacy_function=True)
   1742 
   1743   @functools.wraps(DatasetV2.apply)
   1744   def apply(self, transformation_func):
   1745     return DatasetV1Adapter(super(DatasetV1, self).apply(transformation_func))
   1746 
   1747   @functools.wraps(DatasetV2.window)
   1748   def window(self, size, shift=None, stride=1, drop_remainder=False):
   1749     return DatasetV1Adapter(super(DatasetV1, self).window(
   1750         size, shift, stride, drop_remainder))
   1751 
   1752   @functools.wraps(DatasetV2.with_options)
   1753   def with_options(self, options):
   1754     return DatasetV1Adapter(super(DatasetV1, self).with_options(options))
   1755 
   1756 
   1757 # TODO(b/119044825): Until all `tf.data` unit tests are converted to V2, keep
   1758 # this alias in place.
   1759 Dataset = DatasetV1
   1760 
   1761 
   1762 class DatasetV1Adapter(DatasetV1):
   1763   """Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API."""
   1764 
   1765   def __init__(self, dataset):
   1766     self._dataset = dataset
   1767     super(DatasetV1Adapter, self).__init__()
   1768 
   1769   def _as_variant_tensor(self):
   1770     return self._dataset._variant_tensor  # pylint: disable=protected-access
   1771 
   1772   def _has_captured_ref(self):
   1773     return self._dataset._has_captured_ref()  # pylint: disable=protected-access
   1774 
   1775   def _inputs(self):
   1776     return self._dataset._inputs()  # pylint: disable=protected-access
   1777 
   1778   def options(self):
   1779     return self._dataset.options()
   1780 
   1781   @property
   1782   def _element_structure(self):
   1783     return self._dataset._element_structure  # pylint: disable=protected-access
   1784 
   1785   def __iter__(self):
   1786     return iter(self._dataset)
   1787 
   1788 
   1789 def _ensure_same_dataset_graph(dataset):
   1790   """Walks the dataset graph to ensure all datasets come from the same graph."""
   1791   current_graph = ops.get_default_graph()
   1792   bfs_q = Queue.Queue()
   1793   bfs_q.put(dataset)  # pylint: disable=protected-access
   1794   visited = []
   1795   while not bfs_q.empty():
   1796     ds = bfs_q.get()
   1797     visited.append(ds)
   1798     ds_graph = ds._graph  # pylint: disable=protected-access
   1799     if current_graph != ds_graph:
   1800       logging.warning("The graph (" + str(current_graph) + ") of the iterator "
   1801                       "is different from the graph (" + str(ds_graph) + ") "
   1802                       "the dataset: " + str(ds._variant_tensor) + " was "  # pylint: disable=protected-access
   1803                       "created in. If you are using the Estimator API, "
   1804                       "make sure that no part of the dataset returned by the "
   1805                       "`input_fn` function is defined outside the `input_fn` "
   1806                       "function. Please ensure that all datasets in the "
   1807                       "pipeline are created in the same graph as the iterator. "
   1808                       "NOTE: This warning will become an error in future "
   1809                       "versions of TensorFlow.")
   1810     for input_ds in ds._inputs():  # pylint: disable=protected-access
   1811       if input_ds not in visited:
   1812         bfs_q.put(input_ds)
   1813 
   1814 
   1815 @tf_export(v1=["data.make_one_shot_iterator"])
   1816 def make_one_shot_iterator(dataset):
   1817   """Creates a `tf.data.Iterator` for enumerating the elements of a dataset.
   1818 
   1819   Note: The returned iterator will be initialized automatically.
   1820   A "one-shot" iterator does not support re-initialization.
   1821 
   1822   Args:
   1823     dataset: A `tf.data.Dataset`.
   1824 
   1825   Returns:
   1826     A `tf.data.Iterator` over the elements of this dataset.
   1827   """
   1828   try:
   1829     # Call the defined `_make_one_shot_iterator()` if there is one, because some
   1830     # datasets (e.g. for prefetching) override its behavior.
   1831     return dataset._make_one_shot_iterator()  # pylint: disable=protected-access
   1832   except AttributeError:
   1833     return DatasetV1Adapter(dataset)._make_one_shot_iterator()  # pylint: disable=protected-access
   1834 
   1835 
   1836 @tf_export(v1=["data.make_initializable_iterator"])
   1837 def make_initializable_iterator(dataset, shared_name=None):
   1838   """Creates a `tf.data.Iterator` for enumerating the elements of a dataset.
   1839 
   1840   Note: The returned iterator will be in an uninitialized state,
   1841   and you must run the `iterator.initializer` operation before using it:
   1842 
   1843   ```python
   1844   dataset = ...
   1845   iterator = tf.data.make_initializable_iterator(dataset)
   1846   # ...
   1847   sess.run(iterator.initializer)
   1848   ```
   1849 
   1850   Args:
   1851     dataset: A `tf.data.Dataset`.
   1852     shared_name: (Optional.) If non-empty, the returned iterator will be
   1853       shared under the given name across multiple sessions that share the
   1854       same devices (e.g. when using a remote server).
   1855 
   1856   Returns:
   1857     A `tf.data.Iterator` over the elements of `dataset`.
   1858 
   1859   Raises:
   1860     RuntimeError: If eager execution is enabled.
   1861   """
   1862   try:
   1863     # Call the defined `_make_initializable_iterator()` if there is one, because
   1864     # some datasets (e.g. for prefetching) override its behavior.
   1865     return dataset._make_initializable_iterator(shared_name)  # pylint: disable=protected-access
   1866   except AttributeError:
   1867     return DatasetV1Adapter(dataset)._make_initializable_iterator(shared_name)  # pylint: disable=protected-access
   1868 
   1869 
   1870 # TODO(b/110122868): Replace this method with a public API for reflecting on
   1871 # dataset structure.
   1872 def get_structure(dataset_or_iterator):
   1873   """Returns the `tf.data.experimental.Structure` of a `Dataset` or `Iterator`.
   1874 
   1875   Args:
   1876     dataset_or_iterator: A `tf.data.Dataset`, `tf.data.Iterator`, or
   1877     `EagerIterator`.
   1878 
   1879   Returns:
   1880     A `tf.data.experimental.Structure` representing the structure of the
   1881     elements of `dataset_or_iterator`.
   1882 
   1883   Raises:
   1884     TypeError: If `dataset_or_iterator` is not a dataset or iterator object.
   1885   """
   1886   try:
   1887     ret = dataset_or_iterator._element_structure  # pylint: disable=protected-access
   1888     if isinstance(ret, structure_lib.Structure):
   1889       return ret
   1890   except AttributeError:
   1891     pass
   1892   raise TypeError("`dataset_or_iterator` must be a Dataset or Iterator object, "
   1893                   "but got %s." % type(dataset_or_iterator))
   1894 
   1895 
   1896 # TODO(b/110122868): Remove all uses of this method.
   1897 def get_legacy_output_shapes(dataset_or_iterator):
   1898   """Returns the output shapes of a `Dataset` or `Iterator`.
   1899 
   1900   This utility method replaces the deprecated-in-V2
   1901   `tf.compat.v1.Dataset.output_shapes` property.
   1902 
   1903   Args:
   1904     dataset_or_iterator: A `tf.data.Dataset`, `tf.data.Iterator`, or
   1905     `EagerIterator`.
   1906 
   1907   Returns:
   1908     A nested structure of `tf.TensorShape` objects corresponding to each
   1909     component of an element of the given dataset or iterator.
   1910   """
   1911   return get_structure(dataset_or_iterator)._to_legacy_output_shapes()  # pylint: disable=protected-access
   1912 
   1913 
   1914 # TODO(b/110122868): Remove all uses of this method.
   1915 def get_legacy_output_types(dataset_or_iterator):
   1916   """Returns the output shapes of a `Dataset` or `Iterator`.
   1917 
   1918   This utility method replaces the deprecated-in-V2
   1919   `tf.compat.v1.Dataset.output_types` property.
   1920 
   1921   Args:
   1922     dataset_or_iterator: A `tf.data.Dataset`, `tf.data.Iterator`, or
   1923     `EagerIterator`.
   1924 
   1925   Returns:
   1926     A nested structure of `tf.DType` objects corresponding to each component
   1927     of an element of this dataset.
   1928   """
   1929   return get_structure(dataset_or_iterator)._to_legacy_output_types()  # pylint: disable=protected-access
   1930 
   1931 
   1932 # TODO(b/110122868): Remove all uses of this method.
   1933 def get_legacy_output_classes(dataset_or_iterator):
   1934   """Returns the output classes of a `Dataset` or `Iterator`.
   1935 
   1936   This utility method replaces the deprecated-in-V2
   1937   `tf.compat.v1.Dataset.output_classes` property.
   1938 
   1939   Args:
   1940     dataset_or_iterator: A `tf.data.Dataset`, `tf.data.Iterator`, or
   1941     `EagerIterator`.
   1942 
   1943   Returns:
   1944     A nested structure of Python `type` or `tf.data.experimental.Structure`
   1945     objects corresponding to each component of an element of this dataset.
   1946   """
   1947   return get_structure(dataset_or_iterator)._to_legacy_output_classes()  # pylint: disable=protected-access
   1948 
   1949 
   1950 @tf_export("data.Options")
   1951 class Options(options_lib.OptionsBase):
   1952   """Represents options for tf.data.Dataset.
   1953 
   1954   An `Options` object can be, for instance, used to control which static
   1955   optimizations to apply or whether to use performance modeling to dynamically
   1956   tune the parallelism of operations such as `tf.data.Dataset.map` or
   1957   `tf.data.Dataset.interleave`.
   1958   """
   1959 
   1960   experimental_deterministic = options_lib.create_option(
   1961       name="experimental_deterministic",
   1962       ty=bool,
   1963       docstring=
   1964       "Whether the outputs need to be produced in deterministic order. If None,"
   1965       " defaults to True.")
   1966 
   1967   experimental_numa_aware = options_lib.create_option(
   1968       name="experimental_numa_aware",
   1969       ty=bool,
   1970       docstring=
   1971       "Whether to use NUMA-aware operations. If None, defaults to False.")
   1972 
   1973   experimental_optimization = options_lib.create_option(
   1974       name="experimental_optimization",
   1975       ty=optimization_options.OptimizationOptions,
   1976       docstring=
   1977       "The optimization options associated with the dataset. See "
   1978       "`tf.data.experimental.OptimizationOptions` for more details.",
   1979       default_factory=optimization_options.OptimizationOptions)
   1980 
   1981   experimental_stats = options_lib.create_option(
   1982       name="experimental_stats",
   1983       ty=stats_options.StatsOptions,
   1984       docstring=
   1985       "The statistics options associated with the dataset. See "
   1986       "`tf.data.experimental.StatsOptions` for more details.",
   1987       default_factory=stats_options.StatsOptions)
   1988 
   1989   experimental_threading = options_lib.create_option(
   1990       name="experimental_threading",
   1991       ty=threading_options.ThreadingOptions,
   1992       docstring=
   1993       "The threading options associated with the dataset. See "
   1994       "`tf.data.experimental.ThreadingOptions` for more details.",
   1995       default_factory=threading_options.ThreadingOptions)
   1996 
   1997   def _static_optimizations(self):
   1998     """Produces the list of enabled static optimizations."""
   1999 
   2000     result = []
   2001     result.extend(self.experimental_optimization._static_optimizations())  # pylint: disable=protected-access
   2002 
   2003     if self.experimental_numa_aware:
   2004       result.append("make_numa_aware")
   2005     if self.experimental_deterministic is False:
   2006       result.append("make_sloppy")
   2007     exp_stats_options = self.experimental_stats
   2008     if exp_stats_options and exp_stats_options.latency_all_edges:
   2009       result.append("latency_all_edges")
   2010     return result
   2011 
   2012   def merge(self, options):
   2013     """Merges itself with the given `tf.data.Options`.
   2014 
   2015     The given `tf.data.Options` can be merged as long as there does not exist an
   2016     attribute that is set to different values in `self` and `options`.
   2017 
   2018     Args:
   2019       options: a `tf.data.Options` to merge with
   2020 
   2021     Raises:
   2022       ValueError: if the given `tf.data.Options` cannot be merged
   2023 
   2024     Returns:
   2025       New `tf.data.Options()` object which is the result of merging self with
   2026       the input `tf.data.Options`.
   2027     """
   2028     return options_lib.merge_options(self, options)
   2029 
   2030 
   2031 class DatasetSource(DatasetV2):
   2032   """Abstract class representing a dataset with no inputs."""
   2033 
   2034   def _inputs(self):
   2035     return []
   2036 
   2037 
   2038 class UnaryDataset(DatasetV2):
   2039   """Abstract class representing a dataset with one input."""
   2040 
   2041   def __init__(self, input_dataset, variant_tensor):
   2042     self._input_dataset = input_dataset
   2043     super(UnaryDataset, self).__init__(variant_tensor)
   2044 
   2045   def _inputs(self):
   2046     return [self._input_dataset]
   2047 
   2048 
   2049 class UnaryUnchangedStructureDataset(UnaryDataset):
   2050   """Represents a unary dataset with the same input and output structure."""
   2051 
   2052   def __init__(self, input_dataset, variant_tensor):
   2053     self._input_dataset = input_dataset
   2054     super(UnaryUnchangedStructureDataset, self).__init__(
   2055         input_dataset, variant_tensor)
   2056 
   2057   @property
   2058   def _element_structure(self):
   2059     return self._input_dataset._element_structure  # pylint: disable=protected-access
   2060 
   2061 
   2062 class TensorDataset(DatasetSource):
   2063   """A `Dataset` with a single element, viz. a nested structure of tensors."""
   2064 
   2065   def __init__(self, tensors):
   2066     """See `Dataset.from_tensors()` for details."""
   2067     with ops.name_scope("tensors"):
   2068       tensors = nest.pack_sequence_as(tensors, [
   2069           sparse_tensor_lib.SparseTensor.from_value(t)
   2070           if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
   2071               t, name="component_%d" % i)
   2072           for i, t in enumerate(nest.flatten(tensors))
   2073       ])
   2074     self._structure = structure_lib.Structure.from_value(tensors)
   2075     self._tensors = self._structure._to_tensor_list(tensors)  # pylint: disable=protected-access
   2076 
   2077     variant_tensor = gen_dataset_ops.tensor_dataset(
   2078         self._tensors, output_shapes=self._structure._flat_shapes)  # pylint: disable=protected-access
   2079     super(TensorDataset, self).__init__(variant_tensor)
   2080 
   2081   @property
   2082   def _element_structure(self):
   2083     return self._structure
   2084 
   2085 
   2086 class TensorSliceDataset(DatasetSource):
   2087   """A `Dataset` of slices from a nested structure of tensors."""
   2088 
   2089   def __init__(self, tensors):
   2090     """See `Dataset.from_tensor_slices()` for details."""
   2091     with ops.name_scope("tensors"):
   2092       tensors = nest.pack_sequence_as(tensors, [
   2093           sparse_tensor_lib.SparseTensor.from_value(t)
   2094           if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
   2095               t, name="component_%d" % i)
   2096           for i, t in enumerate(nest.flatten(tensors))
   2097       ])
   2098 
   2099     batched_structure = structure_lib.Structure.from_value(tensors)
   2100     # pylint: disable=protected-access
   2101     self._tensors = batched_structure._to_batched_tensor_list(tensors)
   2102     self._structure = batched_structure._unbatch()
   2103     # pylint: enable=protected-access
   2104 
   2105     batch_dim = tensor_shape.Dimension(tensor_shape.dimension_value(
   2106         self._tensors[0].get_shape()[0]))
   2107     for t in self._tensors[1:]:
   2108       batch_dim.assert_is_compatible_with(tensor_shape.Dimension(
   2109           tensor_shape.dimension_value(t.get_shape()[0])))
   2110 
   2111     variant_tensor = gen_dataset_ops.tensor_slice_dataset(
   2112         self._tensors, output_shapes=self._structure._flat_shapes)  # pylint: disable=protected-access
   2113     super(TensorSliceDataset, self).__init__(variant_tensor)
   2114 
   2115   @property
   2116   def _element_structure(self):
   2117     return self._structure
   2118 
   2119 
   2120 class SparseTensorSliceDataset(DatasetSource):
   2121   """A `Dataset` that splits a rank-N `tf.SparseTensor` into its rows."""
   2122 
   2123   def __init__(self, sparse_tensor):
   2124     """See `Dataset.from_sparse_tensor_slices()` for details."""
   2125     if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor):
   2126       raise TypeError(
   2127           "`sparse_tensor` must be a `tf.SparseTensor` object. Was {}.".format(
   2128               sparse_tensor))
   2129     self._sparse_tensor = sparse_tensor
   2130 
   2131     indices_shape = self._sparse_tensor.indices.get_shape()
   2132     shape_shape = self._sparse_tensor.dense_shape.get_shape()
   2133     rank = (indices_shape.dims[1] - 1).merge_with(shape_shape.dims[0] - 1)
   2134     self._structure = structure_lib.NestedStructure(
   2135         (structure_lib.TensorStructure(dtypes.int64, [None, rank]),
   2136          structure_lib.TensorStructure(self._sparse_tensor.dtype, [None]),
   2137          structure_lib.TensorStructure(dtypes.int64, [rank])))
   2138 
   2139     variant_tensor = gen_dataset_ops.sparse_tensor_slice_dataset(
   2140         self._sparse_tensor.indices, self._sparse_tensor.values,
   2141         self._sparse_tensor.dense_shape)
   2142     super(SparseTensorSliceDataset, self).__init__(variant_tensor)
   2143 
   2144   @property
   2145   def _element_structure(self):
   2146     return self._structure
   2147 
   2148 
   2149 class _VariantDataset(DatasetV2):
   2150   """A Dataset wrapper around a `tf.variant`-typed function argument."""
   2151 
   2152   def __init__(self, dataset_variant, structure):
   2153     self._structure = structure
   2154     super(_VariantDataset, self).__init__(dataset_variant)
   2155 
   2156   def _inputs(self):
   2157     return []
   2158 
   2159   @property
   2160   def _element_structure(self):
   2161     return self._structure
   2162 
   2163 
   2164 @tf_export("data.experimental.DatasetStructure")
   2165 class DatasetStructure(structure_lib.Structure):
   2166   """Represents a `Dataset` of structured values."""
   2167 
   2168   def __init__(self, element_structure):
   2169     self._element_structure = element_structure
   2170 
   2171   @property
   2172   def _flat_shapes(self):
   2173     return [tensor_shape.scalar()]
   2174 
   2175   @property
   2176   def _flat_types(self):
   2177     return [dtypes.variant]
   2178 
   2179   def is_compatible_with(self, other):
   2180     # pylint: disable=protected-access
   2181     return (isinstance(other, DatasetStructure) and
   2182             self._element_structure.is_compatible_with(
   2183                 other._element_structure))
   2184 
   2185   def _to_tensor_list(self, value):
   2186     return [value._variant_tensor]  # pylint: disable=protected-access
   2187 
   2188   def _to_batched_tensor_list(self, value):
   2189     raise NotImplementedError("Unbatching for `tf.data.Dataset` objects.")
   2190 
   2191   def _from_tensor_list(self, flat_value):
   2192     if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
   2193         not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
   2194       raise ValueError(
   2195           "DatasetStructure corresponds to a single tf.variant scalar.")
   2196     return self._from_compatible_tensor_list(flat_value)
   2197 
   2198   def _from_compatible_tensor_list(self, flat_value):
   2199     # pylint: disable=protected-access
   2200     return _VariantDataset(flat_value[0], self._element_structure)
   2201 
   2202   @staticmethod
   2203   def from_value(value):
   2204     return DatasetStructure(value._element_structure)  # pylint: disable=protected-access
   2205 
   2206   def _to_legacy_output_types(self):
   2207     return self
   2208 
   2209   def _to_legacy_output_shapes(self):
   2210     return self
   2211 
   2212   def _to_legacy_output_classes(self):
   2213     return self
   2214 
   2215   def _batch(self, batch_size):
   2216     raise NotImplementedError("Batching for `tf.data.Dataset` objects.")
   2217 
   2218   def _unbatch(self):
   2219     raise NotImplementedError("Unbatching for `tf.data.Dataset` objects.")
   2220 
   2221 
   2222 # pylint: disable=protected-access
   2223 structure_lib.Structure._register_custom_converter(DatasetV2,
   2224                                                    DatasetStructure.from_value)
   2225 # pylint: enable=protected-access
   2226 
   2227 
   2228 class StructuredFunctionWrapper(object):
   2229   """A function wrapper that supports structured arguments and return values."""
   2230 
   2231   # pylint: disable=protected-access
   2232   def __init__(self,
   2233                func,
   2234                transformation_name,
   2235                dataset=None,
   2236                input_classes=None,
   2237                input_shapes=None,
   2238                input_types=None,
   2239                input_structure=None,
   2240                add_to_graph=True,
   2241                use_legacy_function=False,
   2242                defun_kwargs=None):
   2243     """Creates a new `StructuredFunctionWrapper` for the given function.
   2244 
   2245     Args:
   2246       func: A function from a nested structure to another nested structure.
   2247       transformation_name: Human-readable name of the transformation in which
   2248         this function is being instantiated, for error messages.
   2249       dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this
   2250         dataset will be assumed as the structure for `func` arguments; otherwise
   2251         `input_classes`, `input_shapes`, and `input_types` must be defined.
   2252       input_classes: (Optional.) A nested structure of `type`. If given, this
   2253         argument defines the Python types for `func` arguments.
   2254       input_shapes: (Optional.) A nested structure of `tf.TensorShape`. If
   2255         given, this argument defines the shapes and structure for `func`
   2256         arguments.
   2257       input_types: (Optional.) A nested structure of `tf.DType`. If given, this
   2258         argument defines the element types and structure for `func` arguments.
   2259       input_structure: (Optional.) A `Structure` object. If given, this argument
   2260         defines the element types and structure for `func` arguments.
   2261       add_to_graph: (Optional.) If `True`, the function will be added to the
   2262         default graph.
   2263       use_legacy_function: (Optional.) A boolean that determines whether the
   2264         function be created using `tensorflow.python.eager.function.defun`
   2265         (default behavior) or `tensorflow.python.framework.function.Defun`
   2266         (legacy beheavior).
   2267       defun_kwargs: (Optional.) A dictionary mapping string argument names to
   2268         values. If supplied, will be passed to `function` as keyword arguments.
   2269 
   2270     Raises:
   2271       ValueError: If an invalid combination of `dataset`, `input_classes`,
   2272         `input_shapes`, and `input_types` is passed.
   2273     """
   2274     if input_structure is None:
   2275       if dataset is None:
   2276         if input_classes is None or input_shapes is None or input_types is None:
   2277           raise ValueError("Either `dataset`, `input_structure` or all of "
   2278                            "`input_classes`, `input_shapes`, and `input_types` "
   2279                            "must be specified.")
   2280         self._input_structure = structure_lib.convert_legacy_structure(
   2281             input_types, input_shapes, input_classes)
   2282       else:
   2283         if not (input_classes is None and input_shapes is None and
   2284                 input_types is None):
   2285           raise ValueError("Either `dataset`, `input_structure` or all of "
   2286                            "`input_classes`, `input_shapes`, and `input_types` "
   2287                            "must be specified.")
   2288         self._input_structure = dataset._element_structure
   2289     else:
   2290       if not (dataset is None and input_classes is None and input_shapes is None
   2291               and input_types is None):
   2292         raise ValueError("Either `dataset`, `input_structure`, or all of "
   2293                          "`input_classes`, `input_shapes`, and `input_types` "
   2294                          "must be specified.")
   2295       self._input_structure = input_structure
   2296 
   2297     if defun_kwargs is None:
   2298       defun_kwargs = {}
   2299 
   2300     readable_transformation_name = transformation_name.replace(
   2301         ".", "_")[:-2] if len(transformation_name) > 2 else ""
   2302 
   2303     func_name = "_".join(
   2304         [readable_transformation_name,
   2305          function_utils.get_func_name(func)])
   2306 
   2307     def _warn_if_collections(transformation_name):
   2308       """Prints a warning if the given graph uses common graph collections.
   2309 
   2310       NOTE(mrry): Currently a warning is only generated for resources. Any
   2311       variables created will be automatically hoisted out to the outermost scope
   2312       using `init_scope()`. Some collections (such as for control-flow contexts)
   2313       are benign and should not generate a warning.
   2314 
   2315       Args:
   2316         transformation_name: A human-readable name for the transformation.
   2317       """
   2318       warnings.warn("Creating resources inside a function passed to %s "
   2319                     "is not supported. Create each resource outside the "
   2320                     "function, and capture it inside the function to use it." %
   2321                     transformation_name, stacklevel=5)
   2322 
   2323     def _wrapper_helper(*args):
   2324       """Wrapper for passing nested structures to and from tf.data functions."""
   2325       nested_args = self._input_structure._from_compatible_tensor_list(args)
   2326       if not _should_unpack_args(nested_args):
   2327         nested_args = (nested_args,)
   2328 
   2329       ret = func(*nested_args)
   2330       # If `func` returns a list of tensors, `nest.flatten()` and
   2331       # `ops.convert_to_tensor()` would conspire to attempt to stack
   2332       # those tensors into a single tensor, because the customized
   2333       # version of `nest.flatten()` does not recurse into lists. Since
   2334       # it is more likely that the list arose from returning the
   2335       # result of an operation (such as `tf.py_func()`) that returns a
   2336       # list of not-necessarily-stackable tensors, we treat the
   2337       # returned value is a `tuple` instead. A user wishing to pack
   2338       # the return value into a single tensor can use an explicit
   2339       # `tf.stack()` before returning.
   2340       if isinstance(ret, list):
   2341         ret = tuple(ret)
   2342 
   2343       try:
   2344         self._output_structure = structure_lib.Structure.from_value(ret)
   2345       except (ValueError, TypeError):
   2346         raise TypeError("Unsupported return value from function passed to "
   2347                         "%s: %s." % (transformation_name, ret))
   2348       return ret
   2349 
   2350     if use_legacy_function:
   2351       func_name = func_name + "_" + str(ops.uid())
   2352 
   2353       @function.Defun(
   2354           *self._input_structure._flat_types,
   2355           func_name=func_name,
   2356           **defun_kwargs)
   2357       def wrapper_fn(*args):
   2358         ret = _wrapper_helper(*args)
   2359         # _warn_if_collections(transformation_name, ops.get_default_graph(), 0)
   2360         return self._output_structure._to_tensor_list(ret)
   2361 
   2362       self._function = wrapper_fn
   2363       resource_tracker = tracking.ResourceTracker()
   2364       with tracking.resource_tracker_scope(resource_tracker):
   2365         if add_to_graph:
   2366           self._function.add_to_graph(ops.get_default_graph())
   2367         else:
   2368           # Use the private method that will execute `wrapper_fn` but delay
   2369           # adding it to the graph in case (e.g.) we need to rerun the function.
   2370           self._function._create_definition_if_needed()
   2371       if resource_tracker.resources:
   2372         _warn_if_collections(transformation_name)
   2373 
   2374     else:
   2375       defun_kwargs.update({"func_name": func_name})
   2376 
   2377       # TODO(b/124254153): Enable autograph once the overhead is low enough.
   2378       # TODO(mdan): Make sure autograph recurses into _wrapper_helper when on.
   2379       @eager_function.defun_with_attributes(
   2380           input_signature=[
   2381               tensor_spec.TensorSpec(input_shape, input_type)  # pylint: disable=g-complex-comprehension
   2382               for input_shape, input_type in zip(
   2383                   self._input_structure._flat_shapes,
   2384                   self._input_structure._flat_types)
   2385           ],
   2386           autograph=False,
   2387           attributes=defun_kwargs)
   2388       def wrapper_fn(*args):  # pylint: disable=missing-docstring
   2389         ret = _wrapper_helper(*args)
   2390         ret = self._output_structure._to_tensor_list(ret)
   2391         return [ops.convert_to_tensor(t) for t in ret]
   2392 
   2393       resource_tracker = tracking.ResourceTracker()
   2394       with tracking.resource_tracker_scope(resource_tracker):
   2395         self._function = wrapper_fn._get_concrete_function_internal()
   2396         if add_to_graph:
   2397           self._function.add_to_graph(ops.get_default_graph())
   2398       if resource_tracker.resources:
   2399         _warn_if_collections(transformation_name)
   2400 
   2401       outer_graph_seed = ops.get_default_graph().seed
   2402       if outer_graph_seed and self._function.graph.seed == outer_graph_seed:
   2403         if self._function.graph._seed_used:
   2404           warnings.warn(
   2405               "Seed %s from outer graph might be getting used by function %s, "
   2406               "if the random op has not been provided any seed. Explicitly set "
   2407               "the seed in the function if this is not the intended behavior."
   2408               %(outer_graph_seed, func_name), stacklevel=4)
   2409   # pylint: enable=protected-access
   2410 
   2411   @property
   2412   def output_structure(self):
   2413     return self._output_structure
   2414 
   2415   @property
   2416   def output_classes(self):
   2417     return self._output_structure._to_legacy_output_classes()  # pylint: disable=protected-access
   2418 
   2419   @property
   2420   def output_shapes(self):
   2421     return self._output_structure._to_legacy_output_shapes()  # pylint: disable=protected-access
   2422 
   2423   @property
   2424   def output_types(self):
   2425     return self._output_structure._to_legacy_output_types()  # pylint: disable=protected-access
   2426 
   2427   @property
   2428   def function(self):
   2429     return self._function
   2430 
   2431 
   2432 def flat_structure(dataset):
   2433   """Helper for setting `output_shapes` and `output_types` attrs of Dataset ops.
   2434 
   2435   Most Dataset op constructors expect `output_shapes` and `output_types`
   2436   arguments that represent the flattened structure of an element. This helper
   2437   function generates these attrs as a keyword argument dictionary, allowing
   2438   `Dataset._variant_tensor` implementations to pass
   2439   `**flat_structure(self)` to the op constructor.
   2440 
   2441   Args:
   2442     dataset: A `tf.data.Dataset`.
   2443 
   2444   Returns:
   2445     A dictionary of keyword arguments that can be passed to many Dataset op
   2446     constructors.
   2447   """
   2448   # pylint: disable=protected-access
   2449   structure = dataset._element_structure
   2450   return {
   2451       "output_shapes": structure._flat_shapes,
   2452       "output_types": structure._flat_types,
   2453   }
   2454 
   2455 
   2456 class _GeneratorDataset(DatasetSource):
   2457   """A `Dataset` that generates elements by invoking a function."""
   2458 
   2459   def __init__(self, init_args, init_func, next_func, finalize_func):
   2460     """Constructs a `_GeneratorDataset`.
   2461 
   2462     Args:
   2463       init_args: A nested structure representing the arguments to `init_func`.
   2464       init_func: A TensorFlow function that will be called on `init_args` each
   2465         time a C++ iterator over this dataset is constructed. Returns a nested
   2466         structure representing the "state" of the dataset.
   2467       next_func: A TensorFlow function that will be called on the result of
   2468         `init_func` to produce each element, and that raises `OutOfRangeError`
   2469         to terminate iteration.
   2470       finalize_func: A TensorFlow function that will be called on the result of
   2471         `init_func` immediately before a C++ iterator over this dataset is
   2472         destroyed. The return value is ignored.
   2473     """
   2474     self._init_args = init_args
   2475 
   2476     self._init_structure = structure_lib.Structure.from_value(init_args)
   2477 
   2478     self._init_func = StructuredFunctionWrapper(
   2479         init_func,
   2480         self._transformation_name(),
   2481         input_structure=self._init_structure)
   2482 
   2483     self._next_func = StructuredFunctionWrapper(
   2484         next_func,
   2485         self._transformation_name(),
   2486         input_structure=self._init_func.output_structure)
   2487 
   2488     self._finalize_func = StructuredFunctionWrapper(
   2489         finalize_func,
   2490         self._transformation_name(),
   2491         input_structure=self._init_func.output_structure)
   2492     variant_tensor = gen_dataset_ops.generator_dataset(
   2493         self._init_structure._to_tensor_list(self._init_args)  # pylint: disable=protected-access
   2494         + self._init_func.function.captured_inputs,
   2495         self._next_func.function.captured_inputs,
   2496         self._finalize_func.function.captured_inputs,
   2497         init_func=self._init_func.function,
   2498         next_func=self._next_func.function,
   2499         finalize_func=self._finalize_func.function,
   2500         **flat_structure(self))
   2501     super(_GeneratorDataset, self).__init__(variant_tensor)
   2502 
   2503   @property
   2504   def _element_structure(self):
   2505     return self._next_func.output_structure
   2506 
   2507   def _transformation_name(self):
   2508     return "Dataset.from_generator()"
   2509 
   2510 
   2511 class ZipDataset(DatasetV2):
   2512   """A `Dataset` that zips its inputs together."""
   2513 
   2514   def __init__(self, datasets):
   2515     """See `Dataset.zip()` for details."""
   2516     for ds in nest.flatten(datasets):
   2517       if not isinstance(ds, DatasetV2):
   2518         if isinstance(ds, list):
   2519           message = ("The argument to `Dataset.zip()` must be a nested "
   2520                      "structure of `Dataset` objects. Nested structures do not "
   2521                      "support Python lists; please use a tuple instead.")
   2522         else:
   2523           message = ("The argument to `Dataset.zip()` must be a nested "
   2524                      "structure of `Dataset` objects.")
   2525         raise TypeError(message)
   2526     self._datasets = datasets
   2527     self._structure = structure_lib.NestedStructure(
   2528         nest.pack_sequence_as(
   2529             self._datasets,
   2530             [ds._element_structure for ds in nest.flatten(self._datasets)]))  # pylint: disable=protected-access
   2531 
   2532     # pylint: disable=protected-access
   2533     variant_tensor = gen_dataset_ops.zip_dataset(
   2534         [ds._variant_tensor for ds in nest.flatten(self._datasets)],
   2535         **flat_structure(self))
   2536     # pylint: enable=protected-access
   2537     super(ZipDataset, self).__init__(variant_tensor)
   2538 
   2539   def _inputs(self):
   2540     return nest.flatten(self._datasets)
   2541 
   2542   @property
   2543   def _element_structure(self):
   2544     return self._structure
   2545 
   2546 
   2547 class ConcatenateDataset(DatasetV2):
   2548   """A `Dataset` that concatenates its input with given dataset."""
   2549 
   2550   def __init__(self, input_dataset, dataset_to_concatenate):
   2551     """See `Dataset.concatenate()` for details."""
   2552     self._input_dataset = input_dataset
   2553     self._dataset_to_concatenate = dataset_to_concatenate
   2554 
   2555     output_types = get_legacy_output_types(input_dataset)
   2556     if output_types != get_legacy_output_types(dataset_to_concatenate):
   2557       raise TypeError(
   2558           "Two datasets to concatenate have different types %s and %s" %
   2559           (output_types, get_legacy_output_types(dataset_to_concatenate)))
   2560 
   2561     output_classes = get_legacy_output_classes(input_dataset)
   2562     if output_classes != get_legacy_output_classes(dataset_to_concatenate):
   2563       raise TypeError(
   2564           "Two datasets to concatenate have different classes %s and %s" %
   2565           (output_classes, get_legacy_output_classes(dataset_to_concatenate)))
   2566 
   2567     input_shapes = get_legacy_output_shapes(self._input_dataset)
   2568     output_shapes = nest.pack_sequence_as(input_shapes, [
   2569         ts1.most_specific_compatible_shape(ts2)
   2570         for (ts1, ts2) in zip(
   2571             nest.flatten(input_shapes),
   2572             nest.flatten(get_legacy_output_shapes(
   2573                 self._dataset_to_concatenate)))
   2574     ])
   2575 
   2576     self._structure = structure_lib.convert_legacy_structure(
   2577         output_types, output_shapes, output_classes)
   2578 
   2579     self._input_datasets = [input_dataset, dataset_to_concatenate]
   2580     # pylint: disable=protected-access
   2581     variant_tensor = gen_dataset_ops.concatenate_dataset(
   2582         input_dataset._variant_tensor, dataset_to_concatenate._variant_tensor,
   2583         **flat_structure(self))
   2584     # pylint: enable=protected-access
   2585     super(ConcatenateDataset, self).__init__(variant_tensor)
   2586 
   2587   def _inputs(self):
   2588     return self._input_datasets
   2589 
   2590   @property
   2591   def _element_structure(self):
   2592     return self._structure
   2593 
   2594 
   2595 class RepeatDataset(UnaryUnchangedStructureDataset):
   2596   """A `Dataset` that repeats its input several times."""
   2597 
   2598   def __init__(self, input_dataset, count):
   2599     """See `Dataset.repeat()` for details."""
   2600     self._input_dataset = input_dataset
   2601     if count is None:
   2602       self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
   2603     else:
   2604       self._count = ops.convert_to_tensor(
   2605           count, dtype=dtypes.int64, name="count")
   2606     variant_tensor = gen_dataset_ops.repeat_dataset(
   2607         input_dataset._variant_tensor,  # pylint: disable=protected-access
   2608         count=self._count,
   2609         **flat_structure(self))
   2610     super(RepeatDataset, self).__init__(input_dataset, variant_tensor)
   2611 
   2612 
   2613 class RangeDataset(DatasetSource):
   2614   """A `Dataset` of a step separated range of values."""
   2615 
   2616   def __init__(self, *args):
   2617     """See `Dataset.range()` for details."""
   2618     self._parse_args(*args)
   2619     variant_tensor = gen_dataset_ops.range_dataset(
   2620         start=self._start,
   2621         stop=self._stop,
   2622         step=self._step,
   2623         **flat_structure(self))
   2624     super(RangeDataset, self).__init__(variant_tensor)
   2625 
   2626   def _parse_args(self, *args):
   2627     """Parse arguments according to the same rules as the `range()` builtin."""
   2628     if len(args) == 1:
   2629       self._start = self._build_tensor(0, "start")
   2630       self._stop = self._build_tensor(args[0], "stop")
   2631       self._step = self._build_tensor(1, "step")
   2632     elif len(args) == 2:
   2633       self._start = self._build_tensor(args[0], "start")
   2634       self._stop = self._build_tensor(args[1], "stop")
   2635       self._step = self._build_tensor(1, "step")
   2636     elif len(args) == 3:
   2637       self._start = self._build_tensor(args[0], "start")
   2638       self._stop = self._build_tensor(args[1], "stop")
   2639       self._step = self._build_tensor(args[2], "step")
   2640     else:
   2641       raise ValueError("Invalid arguments to RangeDataset: %s" % str(args))
   2642 
   2643   def _build_tensor(self, int64_value, name):
   2644     return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
   2645 
   2646   @property
   2647   def _element_structure(self):
   2648     return structure_lib.TensorStructure(dtypes.int64, [])
   2649 
   2650 
   2651 class CacheDataset(UnaryUnchangedStructureDataset):
   2652   """A `Dataset` that caches elements of its input."""
   2653 
   2654   def __init__(self, input_dataset, filename):
   2655     """See `Dataset.cache()` for details."""
   2656     self._input_dataset = input_dataset
   2657     self._filename = ops.convert_to_tensor(
   2658         filename, dtype=dtypes.string, name="filename")
   2659     variant_tensor = gen_dataset_ops.cache_dataset(
   2660         input_dataset._variant_tensor,  # pylint: disable=protected-access
   2661         filename=self._filename,
   2662         **flat_structure(self))
   2663     super(CacheDataset, self).__init__(input_dataset, variant_tensor)
   2664 
   2665 
   2666 class ShuffleDataset(UnaryUnchangedStructureDataset):
   2667   """A `Dataset` that randomly shuffles the elements of its input."""
   2668 
   2669   def __init__(self,
   2670                input_dataset,
   2671                buffer_size,
   2672                seed=None,
   2673                reshuffle_each_iteration=None):
   2674     """Randomly shuffles the elements of this dataset.
   2675 
   2676     Args:
   2677       input_dataset: The input dataset.
   2678       buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
   2679         number of elements from this dataset from which the new
   2680         dataset will sample.
   2681       seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
   2682         random seed that will be used to create the distribution. See
   2683         `tf.set_random_seed` for behavior.
   2684       reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
   2685         that the dataset should be pseudorandomly reshuffled each time it is
   2686         iterated over. (Defaults to `True`.)
   2687 
   2688     Returns:
   2689       A `Dataset`.
   2690 
   2691     Raises:
   2692       ValueError: if invalid arguments are provided.
   2693     """
   2694     self._input_dataset = input_dataset
   2695     self._buffer_size = ops.convert_to_tensor(
   2696         buffer_size, dtype=dtypes.int64, name="buffer_size")
   2697     self._seed, self._seed2 = random_seed.get_seed(seed)
   2698 
   2699     if reshuffle_each_iteration is None:
   2700       self._reshuffle_each_iteration = True
   2701     else:
   2702       self._reshuffle_each_iteration = reshuffle_each_iteration
   2703     variant_tensor = gen_dataset_ops.shuffle_dataset(
   2704         input_dataset._variant_tensor,  # pylint: disable=protected-access
   2705         buffer_size=self._buffer_size,
   2706         seed=self._seed,
   2707         seed2=self._seed2,
   2708         reshuffle_each_iteration=self._reshuffle_each_iteration,
   2709         **flat_structure(self))
   2710     super(ShuffleDataset, self).__init__(input_dataset, variant_tensor)
   2711 
   2712 
   2713 class TakeDataset(UnaryUnchangedStructureDataset):
   2714   """A `Dataset` containing the first `count` elements from its input."""
   2715 
   2716   def __init__(self, input_dataset, count):
   2717     """See `Dataset.take()` for details."""
   2718     self._input_dataset = input_dataset
   2719     self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
   2720     variant_tensor = gen_dataset_ops.take_dataset(
   2721         input_dataset._variant_tensor,  # pylint: disable=protected-access
   2722         count=self._count,
   2723         **flat_structure(self))
   2724     super(TakeDataset, self).__init__(input_dataset, variant_tensor)
   2725 
   2726 
   2727 class SkipDataset(UnaryUnchangedStructureDataset):
   2728   """A `Dataset` skipping the first `count` elements from its input."""
   2729 
   2730   def __init__(self, input_dataset, count):
   2731     """See `Dataset.skip()` for details."""
   2732     self._input_dataset = input_dataset
   2733     self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
   2734     variant_tensor = gen_dataset_ops.skip_dataset(
   2735         input_dataset._variant_tensor,  # pylint: disable=protected-access
   2736         count=self._count,
   2737         **flat_structure(self))
   2738     super(SkipDataset, self).__init__(input_dataset, variant_tensor)
   2739 
   2740 
   2741 class ShardDataset(UnaryUnchangedStructureDataset):
   2742   """A `Dataset` for sharding its input."""
   2743 
   2744   def __init__(self, input_dataset, num_shards, index):
   2745     """See `Dataset.shard()` for details."""
   2746     self._input_dataset = input_dataset
   2747     self._num_shards = ops.convert_to_tensor(
   2748         num_shards, dtype=dtypes.int64, name="num_shards")
   2749     self._index = ops.convert_to_tensor(index, dtype=dtypes.int64, name="index")
   2750     variant_tensor = gen_dataset_ops.shard_dataset(
   2751         input_dataset._variant_tensor,  # pylint: disable=protected-access
   2752         num_shards=self._num_shards,
   2753         index=self._index,
   2754         **flat_structure(self))
   2755     super(ShardDataset, self).__init__(input_dataset, variant_tensor)
   2756 
   2757 
   2758 class BatchDataset(UnaryDataset):
   2759   """A `Dataset` that batches contiguous elements from its input."""
   2760 
   2761   def __init__(self, input_dataset, batch_size, drop_remainder):
   2762     """See `Dataset.batch()` for details."""
   2763     self._input_dataset = input_dataset
   2764     self._batch_size = ops.convert_to_tensor(
   2765         batch_size, dtype=dtypes.int64, name="batch_size")
   2766     self._drop_remainder = ops.convert_to_tensor(
   2767         drop_remainder, dtype=dtypes.bool, name="drop_remainder")
   2768 
   2769     constant_drop_remainder = tensor_util.constant_value(self._drop_remainder)
   2770     # pylint: disable=protected-access
   2771     if constant_drop_remainder:
   2772       # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
   2773       # or `False` (explicitly retaining the remainder).
   2774       self._structure = input_dataset._element_structure._batch(
   2775           tensor_util.constant_value(self._batch_size))
   2776     else:
   2777       self._structure = input_dataset._element_structure._batch(None)
   2778     variant_tensor = gen_dataset_ops.batch_dataset_v2(
   2779         input_dataset._variant_tensor,  # pylint: disable=protected-access
   2780         batch_size=self._batch_size,
   2781         drop_remainder=self._drop_remainder,
   2782         **flat_structure(self))
   2783     super(BatchDataset, self).__init__(input_dataset, variant_tensor)
   2784 
   2785   @property
   2786   def _element_structure(self):
   2787     return self._structure
   2788 
   2789 
   2790 def _is_padded_shape_compatible_with(padded_shape, input_component_shape):
   2791   """Returns `True` if `input_component_shape` can be padded to `padded_shape`.
   2792 
   2793   Args:
   2794     padded_shape: A `tf.TensorShape`.
   2795     input_component_shape: A `tf.TensorShape`.
   2796 
   2797   Returns:
   2798     `True` if `input_component_shape` can be padded to `padded_shape`, otherwise
   2799     `False`.
   2800   """
   2801 
   2802   if padded_shape.dims is None or input_component_shape.dims is None:
   2803     return True
   2804   if len(padded_shape.dims) != len(input_component_shape.dims):
   2805     return False
   2806   for padded_dim, input_dim in zip(
   2807       padded_shape.dims, input_component_shape.dims):
   2808     if (padded_dim.value is not None and input_dim.value is not None
   2809         and padded_dim.value < input_dim.value):
   2810       return False
   2811   return True
   2812 
   2813 
   2814 def _padded_shape_to_tensor(padded_shape, input_component_shape):
   2815   """Converts `padded_shape` to a `tf.Tensor` representing that shape.
   2816 
   2817   Args:
   2818     padded_shape: A shape-like object, which may be a `tf.TensorShape`, a Python
   2819       sequence, or a 1-D `tf.Tensor` of `tf.int64` elements.
   2820     input_component_shape: A `tf.TensorShape`, with which `padded_shape` must
   2821       be compatible.
   2822 
   2823   Returns:
   2824     A 1-D `tf.Tensor` of `tf.int64` elements, representing `padded_shape`.
   2825 
   2826   Raises:
   2827     ValueError: If `padded_shape` is not a shape or not compatible with
   2828       `input_component_shape`.
   2829     TypeError: If `padded_shape` is not convertible to a `tf.int64` tensor.
   2830   """
   2831   try:
   2832     # Try to convert the `padded_shape` to a `tf.TensorShape`
   2833     padded_shape_as_shape = tensor_shape.as_shape(padded_shape)
   2834     # We will return the "canonical" tensor representation, which uses
   2835     # `-1` in place of `None`.
   2836     ret = ops.convert_to_tensor(
   2837         [dim if dim is not None else -1
   2838          for dim in padded_shape_as_shape.as_list()], dtype=dtypes.int64)
   2839   except (TypeError, ValueError):
   2840     # The argument was not trivially convertible to a
   2841     # `tf.TensorShape`, so fall back on the conversion to tensor
   2842     # machinery.
   2843     ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64)
   2844     if ret.shape.dims is not None and len(ret.shape.dims) != 1:
   2845       raise ValueError(
   2846           "Padded shape %s must be a 1-D tensor of tf.int64 values, but its "
   2847           "shape was %s." % (padded_shape, ret.shape))
   2848     if ret.dtype != dtypes.int64:
   2849       raise TypeError(
   2850           "Padded shape %s must be a 1-D tensor of tf.int64 values, but its "
   2851           "element type was %s." % (padded_shape, ret.dtype.name))
   2852     padded_shape_as_shape = tensor_util.constant_value_as_shape(ret)
   2853 
   2854   if not _is_padded_shape_compatible_with(padded_shape_as_shape,
   2855                                           input_component_shape):
   2856     raise ValueError("The padded shape %s is not compatible with the "
   2857                      "corresponding input component shape %s."
   2858                      % (padded_shape_as_shape, input_component_shape))
   2859 
   2860   return ret
   2861 
   2862 
   2863 def _padding_value_to_tensor(value, output_type):
   2864   """Converts the padding value to a tensor.
   2865 
   2866   Args:
   2867     value: The padding value.
   2868     output_type: Its expected dtype.
   2869 
   2870   Returns:
   2871     A scalar `Tensor`.
   2872 
   2873   Raises:
   2874     ValueError: if the padding value is not a scalar.
   2875     TypeError: if the padding value's type does not match `output_type`.
   2876   """
   2877   value = ops.convert_to_tensor(value, name="padding_value")
   2878   if not value.shape.is_compatible_with(tensor_shape.scalar()):
   2879     raise ValueError("Padding value should be a scalar, but is not: %s" % value)
   2880   if value.dtype != output_type:
   2881     raise TypeError("Padding value tensor (%s) does not match output type: %s" %
   2882                     (value, output_type))
   2883   return value
   2884 
   2885 
   2886 def _default_padding(input_dataset):
   2887   """Returns default padding tensors in a structure matching `input_dataset`."""
   2888   def make_zero(t):
   2889     if t.base_dtype == dtypes.string:
   2890       return ""
   2891     elif t.base_dtype == dtypes.variant:
   2892       error_msg = ("Unable to create padding for field of type 'variant' "
   2893                    "because t.base_type == dtypes.variant == "
   2894                    "{}.".format(
   2895                        t.base_dtype))
   2896       raise TypeError(error_msg)
   2897     else:
   2898       return np.zeros_like(t.as_numpy_dtype())
   2899 
   2900   return nest.map_structure(
   2901       make_zero, get_legacy_output_types(input_dataset))
   2902 
   2903 
   2904 class PaddedBatchDataset(UnaryDataset):
   2905   """A `Dataset` that batches and pads contiguous elements from its input."""
   2906 
   2907   def __init__(self, input_dataset, batch_size, padded_shapes, padding_values,
   2908                drop_remainder):
   2909     """See `Dataset.batch()` for details."""
   2910     self._input_dataset = input_dataset
   2911     if sparse.any_sparse(get_legacy_output_classes(input_dataset)):
   2912       # TODO(b/63669786): support batching of sparse tensors
   2913       raise TypeError(
   2914           "Batching of padded sparse tensors is not currently supported")
   2915     self._input_dataset = input_dataset
   2916     self._batch_size = ops.convert_to_tensor(
   2917         batch_size, dtype=dtypes.int64, name="batch_size")
   2918     padding_values = (
   2919         padding_values
   2920         if padding_values is not None else _default_padding(input_dataset))
   2921 
   2922     input_shapes = get_legacy_output_shapes(input_dataset)
   2923     flat_padded_shapes = nest.flatten_up_to(input_shapes, padded_shapes)
   2924 
   2925     flat_padded_shapes_as_tensors = []
   2926 
   2927     for input_component_shape, padded_shape in zip(
   2928         nest.flatten(input_shapes), flat_padded_shapes):
   2929       flat_padded_shapes_as_tensors.append(
   2930           _padded_shape_to_tensor(padded_shape, input_component_shape))
   2931 
   2932     self._padded_shapes = nest.pack_sequence_as(input_shapes,
   2933                                                 flat_padded_shapes_as_tensors)
   2934 
   2935     self._padding_values = nest.map_structure_up_to(
   2936         input_shapes, _padding_value_to_tensor, padding_values,
   2937         get_legacy_output_types(input_dataset))
   2938     self._drop_remainder = ops.convert_to_tensor(
   2939         drop_remainder, dtype=dtypes.bool, name="drop_remainder")
   2940 
   2941     def _padded_shape_to_batch_shape(s):
   2942       return tensor_shape.vector(
   2943           tensor_util.constant_value(self._batch_size) if smart_cond.
   2944           smart_constant_value(self._drop_remainder) else None).concatenate(
   2945               tensor_util.constant_value_as_shape(s))
   2946 
   2947     output_shapes = nest.map_structure(
   2948         _padded_shape_to_batch_shape, self._padded_shapes)
   2949     self._structure = structure_lib.convert_legacy_structure(
   2950         get_legacy_output_types(self._input_dataset), output_shapes,
   2951         get_legacy_output_classes(self._input_dataset))
   2952 
   2953     # pylint: disable=protected-access
   2954     # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018.
   2955     if smart_cond.smart_constant_value(self._drop_remainder) is False:
   2956       variant_tensor = gen_dataset_ops.padded_batch_dataset(
   2957           input_dataset._variant_tensor,  # pylint: disable=protected-access
   2958           batch_size=self._batch_size,
   2959           padded_shapes=[
   2960               ops.convert_to_tensor(s, dtype=dtypes.int64)
   2961               for s in nest.flatten(self._padded_shapes)
   2962           ],
   2963           padding_values=nest.flatten(self._padding_values),
   2964           output_shapes=self._structure._flat_shapes)
   2965     else:
   2966       variant_tensor = gen_dataset_ops.padded_batch_dataset_v2(
   2967           input_dataset._variant_tensor,  # pylint: disable=protected-access
   2968           batch_size=self._batch_size,
   2969           padded_shapes=[
   2970               ops.convert_to_tensor(s, dtype=dtypes.int64)
   2971               for s in nest.flatten(self._padded_shapes)
   2972           ],
   2973           padding_values=nest.flatten(self._padding_values),
   2974           drop_remainder=self._drop_remainder,
   2975           output_shapes=self._structure._flat_shapes)
   2976     super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor)
   2977 
   2978   @property
   2979   def _element_structure(self):
   2980     return self._structure
   2981 
   2982 
   2983 def _should_unpack_args(args):
   2984   """Returns `True` if `args` should be `*args` when passed to a callable."""
   2985   return type(args) is tuple  # pylint: disable=unidiomatic-typecheck
   2986 
   2987 
   2988 class MapDataset(UnaryDataset):
   2989   """A `Dataset` that maps a function over elements in its input."""
   2990 
   2991   def __init__(self,
   2992                input_dataset,
   2993                map_func,
   2994                use_inter_op_parallelism=True,
   2995                preserve_cardinality=False,
   2996                use_legacy_function=False):
   2997     """See `Dataset.map()` for details."""
   2998     self._input_dataset = input_dataset
   2999     self._use_inter_op_parallelism = use_inter_op_parallelism
   3000     self._preserve_cardinality = preserve_cardinality
   3001     self._map_func = StructuredFunctionWrapper(
   3002         map_func,
   3003         self._transformation_name(),
   3004         dataset=input_dataset,
   3005         use_legacy_function=use_legacy_function)
   3006     variant_tensor = gen_dataset_ops.map_dataset(
   3007         input_dataset._variant_tensor,  # pylint: disable=protected-access
   3008         self._map_func.function.captured_inputs,
   3009         f=self._map_func.function,
   3010         use_inter_op_parallelism=self._use_inter_op_parallelism,
   3011         preserve_cardinality=self._preserve_cardinality,
   3012         **flat_structure(self))
   3013     super(MapDataset, self).__init__(input_dataset, variant_tensor)
   3014 
   3015   def _functions(self):
   3016     return [self._map_func]
   3017 
   3018   @property
   3019   def _element_structure(self):
   3020     return self._map_func.output_structure
   3021 
   3022   def _transformation_name(self):
   3023     return "Dataset.map()"
   3024 
   3025 
   3026 class ParallelMapDataset(UnaryDataset):
   3027   """A `Dataset` that maps a function over elements in its input in parallel."""
   3028 
   3029   def __init__(self,
   3030                input_dataset,
   3031                map_func,
   3032                num_parallel_calls,
   3033                use_inter_op_parallelism=True,
   3034                preserve_cardinality=False,
   3035                use_legacy_function=False):
   3036     """See `Dataset.map()` for details."""
   3037     self._input_dataset = input_dataset
   3038     self._use_inter_op_parallelism = use_inter_op_parallelism
   3039     self._map_func = StructuredFunctionWrapper(
   3040         map_func,
   3041         self._transformation_name(),
   3042         dataset=input_dataset,
   3043         use_legacy_function=use_legacy_function)
   3044     self._num_parallel_calls = ops.convert_to_tensor(
   3045         num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls")
   3046     self._preserve_cardinality = preserve_cardinality
   3047     variant_tensor = gen_dataset_ops.parallel_map_dataset(
   3048         input_dataset._variant_tensor,  # pylint: disable=protected-access
   3049         self._map_func.function.captured_inputs,
   3050         f=self._map_func.function,
   3051         num_parallel_calls=self._num_parallel_calls,
   3052         use_inter_op_parallelism=self._use_inter_op_parallelism,
   3053         preserve_cardinality=self._preserve_cardinality,
   3054         **flat_structure(self))
   3055     super(ParallelMapDataset, self).__init__(input_dataset, variant_tensor)
   3056 
   3057   def _functions(self):
   3058     return [self._map_func]
   3059 
   3060   @property
   3061   def _element_structure(self):
   3062     return self._map_func.output_structure
   3063 
   3064   def _transformation_name(self):
   3065     return "Dataset.map()"
   3066 
   3067 
   3068 class FlatMapDataset(UnaryDataset):
   3069   """A `Dataset` that maps a function over its input and flattens the result."""
   3070 
   3071   def __init__(self, input_dataset, map_func):
   3072     """See `Dataset.flat_map()` for details."""
   3073     self._input_dataset = input_dataset
   3074     self._map_func = StructuredFunctionWrapper(
   3075         map_func, self._transformation_name(), dataset=input_dataset)
   3076     if not isinstance(self._map_func.output_structure, DatasetStructure):
   3077       raise TypeError(
   3078           "`map_func` must return a `Dataset` object. Got {}".format(
   3079               type(self._map_func.output_structure)))
   3080     self._structure = self._map_func.output_structure._element_structure  # pylint: disable=protected-access
   3081     variant_tensor = gen_dataset_ops.flat_map_dataset(
   3082         input_dataset._variant_tensor,  # pylint: disable=protected-access
   3083         self._map_func.function.captured_inputs,
   3084         f=self._map_func.function,
   3085         **flat_structure(self))
   3086     super(FlatMapDataset, self).__init__(input_dataset, variant_tensor)
   3087 
   3088   def _functions(self):
   3089     return [self._map_func]
   3090 
   3091   @property
   3092   def _element_structure(self):
   3093     return self._structure
   3094 
   3095   def _transformation_name(self):
   3096     return "Dataset.flat_map()"
   3097 
   3098 
   3099 class InterleaveDataset(UnaryDataset):
   3100   """A `Dataset` that maps a function over its input and interleaves the result.
   3101   """
   3102 
   3103   def __init__(self, input_dataset, map_func, cycle_length, block_length):
   3104     """See `Dataset.interleave()` for details."""
   3105     self._input_dataset = input_dataset
   3106     self._map_func = StructuredFunctionWrapper(
   3107         map_func, self._transformation_name(), dataset=input_dataset)
   3108     if not isinstance(self._map_func.output_structure, DatasetStructure):
   3109       raise TypeError(
   3110           "`map_func` must return a `Dataset` object. Got {}".format(
   3111               type(self._map_func.output_structure)))
   3112     self._structure = self._map_func.output_structure._element_structure  # pylint: disable=protected-access
   3113     self._cycle_length = ops.convert_to_tensor(
   3114         cycle_length, dtype=dtypes.int64, name="cycle_length")
   3115     self._block_length = ops.convert_to_tensor(
   3116         block_length, dtype=dtypes.int64, name="block_length")
   3117 
   3118     variant_tensor = gen_dataset_ops.interleave_dataset(
   3119         input_dataset._variant_tensor,  # pylint: disable=protected-access
   3120         self._map_func.function.captured_inputs,  # pylint: disable=protected-access
   3121         self._cycle_length,
   3122         self._block_length,
   3123         f=self._map_func.function,
   3124         **flat_structure(self))
   3125     super(InterleaveDataset, self).__init__(input_dataset, variant_tensor)
   3126 
   3127   def _functions(self):
   3128     return [self._map_func]
   3129 
   3130   @property
   3131   def _element_structure(self):
   3132     return self._structure
   3133 
   3134   def _transformation_name(self):
   3135     return "Dataset.interleave()"
   3136 
   3137 
   3138 class ParallelInterleaveDataset(UnaryDataset):
   3139   """A `Dataset` that maps a function over its input and interleaves the result."""
   3140 
   3141   def __init__(self, input_dataset, map_func, cycle_length, block_length,
   3142                num_parallel_calls):
   3143     """See `Dataset.interleave()` for details."""
   3144     self._input_dataset = input_dataset
   3145     self._map_func = StructuredFunctionWrapper(
   3146         map_func, self._transformation_name(), dataset=input_dataset)
   3147     if not isinstance(self._map_func.output_structure, DatasetStructure):
   3148       raise TypeError(
   3149           "`map_func` must return a `Dataset` object. Got {}".format(
   3150               type(self._map_func.output_structure)))
   3151     self._structure = self._map_func.output_structure._element_structure  # pylint: disable=protected-access
   3152     self._cycle_length = ops.convert_to_tensor(
   3153         cycle_length, dtype=dtypes.int64, name="cycle_length")
   3154     self._block_length = ops.convert_to_tensor(
   3155         block_length, dtype=dtypes.int64, name="block_length")
   3156     self._num_parallel_calls = ops.convert_to_tensor(
   3157         num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
   3158     variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v2(
   3159         input_dataset._variant_tensor,  # pylint: disable=protected-access
   3160         self._map_func.function.captured_inputs,  # pylint: disable=protected-access
   3161         self._cycle_length,
   3162         self._block_length,
   3163         self._num_parallel_calls,
   3164         f=self._map_func.function,
   3165         **flat_structure(self))
   3166     super(ParallelInterleaveDataset, self).__init__(input_dataset,
   3167                                                     variant_tensor)
   3168 
   3169   def _functions(self):
   3170     return [self._map_func]
   3171 
   3172   @property
   3173   def _element_structure(self):
   3174     return self._structure
   3175 
   3176   def _transformation_name(self):
   3177     return "Dataset.interleave()"
   3178 
   3179 
   3180 class FilterDataset(UnaryUnchangedStructureDataset):
   3181   """A `Dataset` that filters its input according to a predicate function."""
   3182 
   3183   def __init__(self, input_dataset, predicate, use_legacy_function=False):
   3184     """See `Dataset.filter()` for details."""
   3185     self._input_dataset = input_dataset
   3186     wrapped_func = StructuredFunctionWrapper(
   3187         predicate,
   3188         self._transformation_name(),
   3189         dataset=input_dataset,
   3190         use_legacy_function=use_legacy_function)
   3191     if not wrapped_func.output_structure.is_compatible_with(
   3192         structure_lib.TensorStructure(dtypes.bool, [])):
   3193       error_msg = ("`predicate` return type must be convertible to a scalar "
   3194                    "boolean tensor. Was {}.").format(
   3195                        wrapped_func.output_structure)
   3196       raise ValueError(error_msg)
   3197     self._predicate = wrapped_func
   3198     variant_tensor = gen_dataset_ops.filter_dataset(
   3199         input_dataset._variant_tensor,  # pylint: disable=protected-access
   3200         other_arguments=self._predicate.function.captured_inputs,
   3201         predicate=self._predicate.function,
   3202         **flat_structure(self))
   3203     super(FilterDataset, self).__init__(input_dataset, variant_tensor)
   3204 
   3205   def _functions(self):
   3206     return [self._predicate]
   3207 
   3208   def _transformation_name(self):
   3209     return "Dataset.filter()"
   3210 
   3211 
   3212 class PrefetchDataset(UnaryUnchangedStructureDataset):
   3213   """A `Dataset` that asynchronously prefetches its input."""
   3214 
   3215   def __init__(self, input_dataset, buffer_size):
   3216     """See `Dataset.prefetch()` for details."""
   3217     self._input_dataset = input_dataset
   3218     if buffer_size is None:
   3219       buffer_size = -1  # This is the sentinel for auto-tuning.
   3220     self._buffer_size = ops.convert_to_tensor(
   3221         buffer_size, dtype=dtypes.int64, name="buffer_size")
   3222     variant_tensor = gen_dataset_ops.prefetch_dataset(
   3223         input_dataset._variant_tensor,  # pylint: disable=protected-access
   3224         buffer_size=self._buffer_size,
   3225         **flat_structure(self))
   3226     super(PrefetchDataset, self).__init__(input_dataset, variant_tensor)
   3227 
   3228 
   3229 class WindowDataset(UnaryDataset):
   3230   """A dataset that creates window datasets from the input elements."""
   3231 
   3232   def __init__(self, input_dataset, size, shift, stride, drop_remainder):
   3233     """See `window_dataset()` for more details."""
   3234     self._input_dataset = input_dataset
   3235     self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size")
   3236     self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift")
   3237     self._stride = ops.convert_to_tensor(
   3238         stride, dtype=dtypes.int64, name="stride")
   3239     self._drop_remainder = ops.convert_to_tensor(
   3240         drop_remainder, dtype=dtypes.bool, name="drop_remainder")
   3241     nest_of_structures = nest.pack_sequence_as(
   3242         get_legacy_output_classes(input_dataset),
   3243         [
   3244             DatasetStructure(structure_lib.convert_legacy_structure(
   3245                 output_type, output_shape, output_class))
   3246             for output_class, output_shape, output_type in zip(
   3247                 nest.flatten(get_legacy_output_classes(input_dataset)),
   3248                 nest.flatten(get_legacy_output_shapes(input_dataset)),
   3249                 nest.flatten(get_legacy_output_types(input_dataset)))
   3250         ])
   3251     self._structure = structure_lib.NestedStructure(nest_of_structures)
   3252     variant_tensor = gen_dataset_ops.window_dataset(
   3253         input_dataset._variant_tensor,  # pylint: disable=protected-access
   3254         self._size,
   3255         self._shift,
   3256         self._stride,
   3257         self._drop_remainder,
   3258         **flat_structure(self))
   3259     super(WindowDataset, self).__init__(input_dataset, variant_tensor)
   3260 
   3261   @property
   3262   def _element_structure(self):
   3263     return self._structure
   3264 
   3265 
   3266 class _OptionsDataset(UnaryUnchangedStructureDataset):
   3267   """An identity `Dataset` that stores options."""
   3268 
   3269   def __init__(self, input_dataset, options):
   3270     self._input_dataset = input_dataset
   3271     self._options = input_dataset.options()
   3272     if self._options:
   3273       self._options = self._options.merge(options)
   3274     else:
   3275       self._options = options
   3276     variant_tensor = input_dataset._variant_tensor  # pylint: disable=protected-access
   3277     super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
   3278 
   3279   def options(self):
   3280     return self._options
   3281 
   3282 
   3283 class _ModelDataset(UnaryUnchangedStructureDataset):
   3284   """A `Dataset` that acts as an identity, and models performance."""
   3285 
   3286   def __init__(self, input_dataset, cpu_budget):
   3287     self._input_dataset = input_dataset
   3288     variant_tensor = gen_dataset_ops.model_dataset(
   3289         input_dataset._variant_tensor,  # pylint: disable=protected-access
   3290         cpu_budget=cpu_budget,
   3291         **flat_structure(self))
   3292     super(_ModelDataset, self).__init__(input_dataset, variant_tensor)
   3293 
   3294 
   3295 class _OptimizeDataset(UnaryUnchangedStructureDataset):
   3296   """A `Dataset` that acts as an identity, and applies optimizations."""
   3297 
   3298   def __init__(self, input_dataset, optimizations):
   3299     self._input_dataset = input_dataset
   3300     if optimizations is None:
   3301       optimizations = []
   3302     self._optimizations = ops.convert_to_tensor(
   3303         optimizations, dtype=dtypes.string, name="optimizations")
   3304     variant_tensor = gen_dataset_ops.optimize_dataset(
   3305         input_dataset._variant_tensor,  # pylint: disable=protected-access
   3306         self._optimizations,
   3307         **flat_structure(self))
   3308     super(_OptimizeDataset, self).__init__(input_dataset, variant_tensor)
   3309 
   3310 
   3311 class _SetStatsAggregatorDataset(UnaryUnchangedStructureDataset):
   3312   """A `Dataset` that acts as an identity, and sets a stats aggregator."""
   3313 
   3314   def __init__(self, input_dataset, aggregator, prefix, counter_prefix):
   3315     self._input_dataset = input_dataset
   3316     self._stats_aggregator = aggregator
   3317     self._prefix = prefix
   3318     self._counter_prefix = counter_prefix
   3319     variant_tensor = ged_ops.experimental_set_stats_aggregator_dataset(
   3320         input_dataset._variant_tensor,  # pylint: disable=protected-access
   3321         self._stats_aggregator._resource,  # pylint: disable=protected-access
   3322         self._prefix,
   3323         self._counter_prefix,
   3324         **flat_structure(self))
   3325     super(_SetStatsAggregatorDataset, self).__init__(input_dataset,
   3326                                                      variant_tensor)
   3327 
   3328 
   3329 class _MaxIntraOpParallelismDataset(UnaryUnchangedStructureDataset):
   3330   """A `Dataset` that acts as an identity, overriding intra-op parallelism."""
   3331 
   3332   def __init__(self, input_dataset, max_intra_op_parallelism):
   3333     self._input_dataset = input_dataset
   3334     self._max_intra_op_parallelism = ops.convert_to_tensor(
   3335         max_intra_op_parallelism,
   3336         dtype=dtypes.int64,
   3337         name="max_intra_op_parallelism")
   3338     variant_tensor = ged_ops.experimental_max_intra_op_parallelism_dataset(
   3339         input_dataset._variant_tensor,  # pylint: disable=protected-access
   3340         self._max_intra_op_parallelism,
   3341         **flat_structure(self))
   3342     super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset,
   3343                                                         variant_tensor)
   3344 
   3345 
   3346 class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset):
   3347   """A `Dataset` that acts as an identity, setting a private threadpool."""
   3348 
   3349   def __init__(self, input_dataset, num_threads):
   3350     self._input_dataset = input_dataset
   3351     self._num_threads = ops.convert_to_tensor(
   3352         num_threads, dtype=dtypes.int64, name="num_threads")
   3353     variant_tensor = ged_ops.experimental_private_thread_pool_dataset(
   3354         input_dataset._variant_tensor,  # pylint: disable=protected-access
   3355         self._num_threads,
   3356         **flat_structure(self))
   3357     super(_PrivateThreadPoolDataset, self).__init__(input_dataset,
   3358                                                     variant_tensor)
   3359