Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Python wrappers for Datasets and Iterators."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.contrib.data.python.ops import batching
     21 from tensorflow.contrib.data.python.ops import enumerate_ops
     22 from tensorflow.contrib.data.python.ops import error_ops
     23 from tensorflow.contrib.data.python.ops import grouping
     24 from tensorflow.python.data.ops import dataset_ops
     25 from tensorflow.python.data.util import nest
     26 from tensorflow.python.ops import gen_dataset_ops
     27 from tensorflow.python.ops import gen_io_ops
     28 from tensorflow.python.util import deprecation
     29 
     30 
     31 class Dataset(dataset_ops.Dataset):
     32   """Represents a potentially large set of elements.
     33 
     34   A `Dataset` can be used to represent an input pipeline as a
     35   collection of elements (nested structures of tensors) and a "logical
     36   plan" of transformations that act on those elements.
     37   """
     38 
     39   def __init__(self, dataset):
     40     super(Dataset, self).__init__()
     41     self._dataset = dataset
     42 
     43   @deprecation.deprecated(None, "Use `ds._as_variant_tensor()`.")
     44   def make_dataset_resource(self):
     45     return self._as_variant_tensor()
     46 
     47   def _as_variant_tensor(self):
     48     return self._dataset._as_variant_tensor()  # pylint: disable=protected-access
     49 
     50   @property
     51   def output_classes(self):
     52     return self._dataset.output_classes
     53 
     54   @property
     55   def output_shapes(self):
     56     return self._dataset.output_shapes
     57 
     58   @property
     59   def output_types(self):
     60     return self._dataset.output_types
     61 
     62   @staticmethod
     63   @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensors()`.")
     64   def from_tensors(tensors):
     65     """Creates a `Dataset` with a single element, comprising the given tensors.
     66 
     67     Args:
     68       tensors: A nested structure of tensors.
     69 
     70     Returns:
     71       A `Dataset`.
     72     """
     73     return Dataset(dataset_ops.TensorDataset(tensors))
     74 
     75   @staticmethod
     76   @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.")
     77   def from_tensor_slices(tensors):
     78     """Creates a `Dataset` whose elements are slices of the given tensors.
     79 
     80     Args:
     81       tensors: A nested structure of tensors, each having the same size in the
     82         0th dimension.
     83 
     84     Returns:
     85       A `Dataset`.
     86     """
     87     return Dataset(dataset_ops.TensorSliceDataset(tensors))
     88 
     89   @staticmethod
     90   @deprecation.deprecated(None,
     91                           "Use `tf.data.Dataset.from_sparse_tensor_slices()`.")
     92   def from_sparse_tensor_slices(sparse_tensor):
     93     """Splits each rank-N `tf.SparseTensor` in this dataset row-wise.
     94 
     95     Args:
     96       sparse_tensor: A `tf.SparseTensor`.
     97 
     98     Returns:
     99       A `Dataset` of rank-(N-1) sparse tensors.
    100     """
    101     return Dataset(dataset_ops.SparseTensorSliceDataset(sparse_tensor))
    102 
    103   @staticmethod
    104   @deprecation.deprecated(None, "Use `tf.data.Dataset.from_generator()`.")
    105   def from_generator(generator, output_types, output_shapes=None):
    106     """Creates a `Dataset` whose elements are generated by `generator`.
    107 
    108     The `generator` argument must be a callable object that returns
    109     an object that support the `iter()` protocol (e.g. a generator function).
    110     The elements generated by `generator` must be compatible with the given
    111     `output_types` and (optional) `output_shapes` arguments.
    112 
    113     For example:
    114 
    115     ```python
    116     import itertools
    117 
    118     def gen():
    119       for i in itertools.count(1):
    120         yield (i, [1] * i)
    121 
    122     ds = Dataset.from_generator(
    123         gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None])))
    124     value = ds.make_one_shot_iterator().get_next()
    125 
    126     sess.run(value)  # (1, array([1]))
    127     sess.run(value)  # (2, array([1, 1]))
    128     ```
    129 
    130     Args:
    131       generator: A callable object that takes no arguments and returns an
    132         object that supports the `iter()` protocol.
    133       output_types: A nested structure of `tf.DType` objects corresponding to
    134         each component of an element yielded by `generator`.
    135       output_shapes: (Optional.) A nested structure of `tf.TensorShape`
    136         objects corresponding to each component of an element yielded by
    137         `generator`.
    138 
    139     Returns:
    140       A `Dataset`.
    141     """
    142     return Dataset(dataset_ops.Dataset.from_generator(
    143         generator, output_types, output_shapes))
    144 
    145   @staticmethod
    146   @deprecation.deprecated(None, "Use `tf.data.Dataset.range()`.")
    147   def range(*args):
    148     """Creates a `Dataset` of a step-separated range of values.
    149 
    150     For example:
    151 
    152     ```python
    153     Dataset.range(5) == [0, 1, 2, 3, 4]
    154     Dataset.range(2, 5) == [2, 3, 4]
    155     Dataset.range(1, 5, 2) == [1, 3]
    156     Dataset.range(1, 5, -2) == []
    157     Dataset.range(5, 1) == []
    158     Dataset.range(5, 1, -2) == [5, 3]
    159     ```
    160 
    161     Args:
    162       *args: follow same semantics as python's xrange.
    163         len(args) == 1 -> start = 0, stop = args[0], step = 1
    164         len(args) == 2 -> start = args[0], stop = args[1], step = 1
    165         len(args) == 3 -> start = args[0], stop = args[1, stop = args[2]
    166 
    167     Returns:
    168       A `RangeDataset`.
    169 
    170     Raises:
    171       ValueError: if len(args) == 0.
    172     """
    173     return Dataset(dataset_ops.RangeDataset(*args))
    174 
    175   @staticmethod
    176   @deprecation.deprecated(None, "Use `tf.data.Dataset.zip()`.")
    177   def zip(datasets):
    178     """Creates a `Dataset` by zipping together the given datasets.
    179 
    180     This method has similar semantics to the built-in `zip()` function
    181     in Python, with the main difference being that the `datasets`
    182     argument can be an arbitrary nested structure of `Dataset` objects.
    183     For example:
    184 
    185     ```python
    186     # NOTE: The following examples use `{ ... }` to represent the
    187     # contents of a dataset.
    188     a = { 1, 2, 3 }
    189     b = { 4, 5, 6 }
    190     c = { (7, 8), (9, 10), (11, 12) }
    191     d = { 13, 14 }
    192 
    193     # The nested structure of the `datasets` argument determines the
    194     # structure of elements in the resulting dataset.
    195     Dataset.zip((a, b)) == { (1, 4), (2, 5), (3, 6) }
    196     Dataset.zip((b, a)) == { (4, 1), (5, 2), (6, 3) }
    197 
    198     # The `datasets` argument may contain an arbitrary number of
    199     # datasets.
    200     Dataset.zip((a, b, c)) == { (1, 4, (7, 8)),
    201                                 (2, 5, (9, 10)),
    202                                 (3, 6, (11, 12)) }
    203 
    204     # The number of elements in the resulting dataset is the same as
    205     # the size of the smallest dataset in `datasets`.
    206     Dataset.zip((a, d)) == { (1, 13), (2, 14) }
    207     ```
    208 
    209     Args:
    210       datasets: A nested structure of datasets.
    211 
    212     Returns:
    213       A `Dataset`.
    214     """
    215     return Dataset(dataset_ops.ZipDataset(datasets))
    216 
    217   def concatenate(self, dataset):
    218     """Creates a `Dataset` by concatenating given dataset with this dataset.
    219 
    220     ```python
    221     # NOTE: The following examples use `{ ... }` to represent the
    222     # contents of a dataset.
    223     a = { 1, 2, 3 }
    224     b = { 4, 5, 6, 7 }
    225 
    226     # Input dataset and dataset to be concatenated should have same
    227     # nested structures and output types.
    228     # c = { (8, 9), (10, 11), (12, 13) }
    229     # d = { 14.0, 15.0, 16.0 }
    230     # a.concatenate(c) and a.concatenate(d) would result in error.
    231 
    232     a.concatenate(b) == { 1, 2, 3, 4, 5, 6, 7 }
    233     ```
    234 
    235     Args:
    236       dataset: `Dataset` to be concatenated.
    237 
    238     Returns:
    239       A `Dataset`.
    240     """
    241     return Dataset(dataset_ops.ConcatenateDataset(self._dataset, dataset))
    242 
    243   def prefetch(self, buffer_size):
    244     """Creates a `Dataset` that prefetches elements from this dataset.
    245 
    246     Args:
    247       buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
    248         maximum number elements that will be buffered when prefetching.
    249 
    250     Returns:
    251       A `Dataset`.
    252     """
    253     return Dataset(dataset_ops.PrefetchDataset(self._dataset, buffer_size))
    254 
    255   @staticmethod
    256   @deprecation.deprecated(None, "Use `tf.data.Dataset.list_files()`.")
    257   def list_files(file_pattern):
    258     """A dataset of all files matching a pattern.
    259 
    260     Example:
    261       If we had the following files on our filesystem:
    262         - /path/to/dir/a.txt
    263         - /path/to/dir/b.py
    264         - /path/to/dir/c.py
    265       If we pass "/path/to/dir/*.py" as the directory, the dataset would
    266       produce:
    267         - /path/to/dir/b.py
    268         - /path/to/dir/c.py
    269 
    270     Args:
    271       file_pattern: A string or scalar string `tf.Tensor`, representing
    272         the filename pattern that will be matched.
    273 
    274     Returns:
    275      A `Dataset` of strings corresponding to file names.
    276     """
    277     return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern))
    278 
    279   def repeat(self, count=None):
    280     """Repeats this dataset `count` times.
    281 
    282     Args:
    283       count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
    284         number of times the elements of this dataset should be repeated. The
    285         default behavior (if `count` is `None` or `-1`) is for the elements to
    286         be repeated indefinitely.
    287 
    288     Returns:
    289       A `Dataset`.
    290     """
    291     return Dataset(dataset_ops.RepeatDataset(self._dataset, count))
    292 
    293   @deprecation.deprecated(
    294       None, "Use `ds.apply(tf.contrib.data.enumerate_dataset())`.")
    295   def enumerate(self, start=0):
    296     """Deprecated: Use `Dataset.apply(tf.contrib.data.enumerate_dataset(..)`."""
    297 
    298     return self.apply(enumerate_ops.enumerate_dataset(start))
    299 
    300   def shuffle(self, buffer_size, seed=None):
    301     """Randomly shuffles the elements of this dataset.
    302 
    303     Args:
    304       buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
    305         number of elements from this dataset from which the new
    306         dataset will sample.
    307       seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
    308         random seed that will be used to create the distribution. See
    309         @{tf.set_random_seed} for behavior.
    310 
    311     Returns:
    312       A `Dataset`.
    313     """
    314     return Dataset(dataset_ops.ShuffleDataset(self._dataset, buffer_size, seed))
    315 
    316   def cache(self, filename=""):
    317     """Caches the elements in this dataset.
    318 
    319     Args:
    320       filename: A `tf.string` scalar `tf.Tensor`, representing the name of a
    321         directory on the filesystem to use for caching tensors in this Dataset.
    322         If a filename is not provided, the dataset will be cached in memory.
    323 
    324     Returns:
    325       A `Dataset`.
    326     """
    327     return Dataset(dataset_ops.CacheDataset(self._dataset, filename))
    328 
    329   def take(self, count):
    330     """Creates a `Dataset` with at most `count` elements from this dataset.
    331 
    332     Args:
    333       count: A `tf.int64` scalar `tf.Tensor`, representing the number of
    334         elements of this dataset that should be taken to form the new dataset.
    335         If `count` is -1, or if `count` is greater than the size of this
    336         dataset, the new dataset will contain all elements of this dataset.
    337 
    338     Returns:
    339       A `Dataset`.
    340     """
    341     return Dataset(dataset_ops.TakeDataset(self._dataset, count))
    342 
    343   def skip(self, count):
    344     """Creates a `Dataset` that skips `count` elements from this dataset.
    345 
    346     Args:
    347       count: A `tf.int64` scalar `tf.Tensor`, representing the number
    348         of elements of this dataset that should be skipped to form the
    349         new dataset.  If `count` is greater than the size of this
    350         dataset, the new dataset will contain no elements.  If `count`
    351         is -1, skips the entire dataset.
    352 
    353     Returns:
    354       A `Dataset`.
    355     """
    356     return Dataset(dataset_ops.SkipDataset(self._dataset, count))
    357 
    358   def shard(self, num_shards, index):
    359     """Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
    360 
    361     This dataset operator is very useful when running distributed training, as
    362     it allows each worker to read a unique subset.
    363 
    364     When reading a single input file, you can skip elements as follows:
    365 
    366     ```python
    367     d = tf.data.TFRecordDataset(FLAGS.input_file)
    368     d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
    369     d = d.repeat(FLAGS.num_epochs)
    370     d = d.shuffle(FLAGS.shuffle_buffer_size)
    371     d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
    372     ```
    373 
    374     Important caveats:
    375 
    376     - Be sure to shard before you use any randomizing operator (such as
    377       shuffle).
    378     - Generally it is best if the shard operator is used early in the dataset
    379       pipeline. For example, when reading from a set of TFRecord files, shard
    380       before converting the dataset to input samples. This avoids reading every
    381       file on every worker. The following is an example of an efficient
    382       sharding strategy within a complete pipeline:
    383 
    384     ```python
    385     d = tf.data.Dataset.list_files(FLAGS.pattern)
    386     d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
    387     d = d.repeat(FLAGS.num_epochs)
    388     d = d.shuffle(FLAGS.shuffle_buffer_size)
    389     d = d.interleave(tf.data.TFRecordDataset,
    390                      cycle_length=FLAGS.num_readers, block_length=1)
    391     d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
    392     ```
    393 
    394     Args:
    395       num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
    396         shards operating in parallel.
    397       index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
    398 
    399     Returns:
    400       A `Dataset`.
    401 
    402     Raises:
    403       ValueError: if `num_shards` or `index` are illegal values. Note: error
    404         checking is done on a best-effort basis, and aren't guaranteed to be
    405         caught upon dataset creation. (e.g. providing in a placeholder tensor
    406         bypasses the early checking, and will instead result in an error during
    407         a session.run call.)
    408     """
    409     return Dataset(self._dataset.shard(num_shards, index))
    410 
    411   @deprecation.deprecated(
    412       None, "Use `ds.apply(tf.contrib.data.ignore_errors())`.")
    413   def ignore_errors(self):
    414     """Deprecated: Use `Dataset.apply(tf.contrib.data.ignore_errors())`."""
    415 
    416     return self.apply(error_ops.ignore_errors())
    417 
    418   def batch(self, batch_size):
    419     """Combines consecutive elements of this dataset into batches.
    420 
    421     Args:
    422       batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
    423         consecutive elements of this dataset to combine in a single batch.
    424 
    425     Returns:
    426       A `Dataset`.
    427     """
    428     return Dataset(dataset_ops.BatchDataset(self._dataset, batch_size))
    429 
    430   def padded_batch(self, batch_size, padded_shapes, padding_values=None):
    431     """Combines consecutive elements of this dataset into padded batches.
    432 
    433     Like `Dataset.dense_to_sparse_batch()`, this method combines
    434     multiple consecutive elements of this dataset, which might have
    435     different shapes, into a single element. The tensors in the
    436     resulting element have an additional outer dimension, and are
    437     padded to the respective shape in `padded_shapes`.
    438 
    439     Args:
    440       batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
    441         consecutive elements of this dataset to combine in a single batch.
    442       padded_shapes: A nested structure of `tf.TensorShape` or
    443         `tf.int64` vector tensor-like objects representing the shape
    444         to which the respective component of each input element should
    445         be padded prior to batching. Any unknown dimensions
    446         (e.g. `tf.Dimension(None)` in a `tf.TensorShape` or `-1` in a
    447         tensor-like object) will be padded to the maximum size of that
    448         dimension in each batch.
    449       padding_values: (Optional.) A nested structure of scalar-shaped
    450         `tf.Tensor`, representing the padding values to use for the
    451         respective components.  Defaults are `0` for numeric types and
    452         the empty string for string types.
    453 
    454     Returns:
    455       A `Dataset`.
    456     """
    457     return Dataset(
    458         dataset_ops.PaddedBatchDataset(self._dataset, batch_size, padded_shapes,
    459                                        padding_values))
    460 
    461   @deprecation.deprecated(
    462       None, "Use `ds.apply(tf.contrib.data.dense_to_sparse_batch())`.")
    463   def dense_to_sparse_batch(self, batch_size, row_shape):
    464     """Use: `Dataset.apply(tf.contrib.data.dense_to_sparse_batch(...))`."""
    465 
    466     return self.apply(batching.dense_to_sparse_batch(batch_size, row_shape))
    467 
    468   @deprecation.deprecated(
    469       None, "Use `ds.apply(tf.contrib.data.group_by_window())`.")
    470   def group_by_window(self, key_func, reduce_func, window_size):
    471     """Deprecated: Use `Dataset.apply(tf.contrib.data.group_by_window(...))`."""
    472 
    473     return self.apply(
    474         grouping.group_by_window(key_func, reduce_func, window_size))
    475 
    476   @deprecation.deprecated_args(
    477       None,
    478       "`output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.",
    479       "num_threads", "output_buffer_size")
    480   def map(self,
    481           map_func,
    482           num_threads=None,
    483           output_buffer_size=None,
    484           num_parallel_calls=None):
    485     """Maps `map_func` across this dataset.
    486 
    487     Args:
    488       map_func: A function mapping a nested structure of tensors (having
    489         shapes and types defined by `self.output_shapes` and
    490        `self.output_types`) to another nested structure of tensors.
    491       num_threads: (Optional.) Deprecated, use `num_parallel_calls` instead.
    492       output_buffer_size: (Optional.) A `tf.int64` scalar `tf.Tensor`,
    493         representing the maximum number of processed elements that will be
    494         buffered.
    495       num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
    496         representing the number elements to process in parallel. If not
    497         specified, elements will be processed sequentially.
    498 
    499     Returns:
    500       A `Dataset`.
    501     """
    502     if num_threads is None and num_parallel_calls is None:
    503       ret = Dataset(dataset_ops.MapDataset(self._dataset, map_func))
    504     else:
    505       if num_threads is None:
    506         ret = Dataset(
    507             dataset_ops.ParallelMapDataset(self._dataset, map_func,
    508                                            num_parallel_calls))
    509       else:
    510         ret = Dataset(
    511             dataset_ops.ParallelMapDataset(self._dataset, map_func,
    512                                            num_threads))
    513     if output_buffer_size is not None:
    514       ret = ret.prefetch(output_buffer_size)
    515     return ret
    516 
    517   def flat_map(self, map_func):
    518     """Maps `map_func` across this dataset and flattens the result.
    519 
    520     Args:
    521       map_func: A function mapping a nested structure of tensors (having shapes
    522         and types defined by `self.output_shapes` and `self.output_types`) to a
    523         `Dataset`.
    524 
    525     Returns:
    526       A `Dataset`.
    527     """
    528     return Dataset(dataset_ops.FlatMapDataset(self._dataset, map_func))
    529 
    530   def interleave(self, map_func, cycle_length, block_length=1):
    531     """Maps `map_func` across this dataset, and interleaves the results.
    532 
    533     For example, you can use `Dataset.interleave()` to process many input files
    534     concurrently:
    535 
    536     ```python
    537     # Preprocess 4 files concurrently, and interleave blocks of 16 records from
    538     # each file.
    539     filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...]
    540     dataset = (Dataset.from_tensor_slices(filenames)
    541                .interleave(lambda x:
    542                    TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
    543                    cycle_length=4, block_length=16))
    544     ```
    545 
    546     The `cycle_length` and `block_length` arguments control the order in which
    547     elements are produced. `cycle_length` controls the number of input elements
    548     that are processed concurrently. If you set `cycle_length` to 1, this
    549     transformation will handle one input element at a time, and will produce
    550     identical results = to @{tf.data.Dataset.flat_map}. In general,
    551     this transformation will apply `map_func` to `cycle_length` input elements,
    552     open iterators on the returned `Dataset` objects, and cycle through them
    553     producing `block_length` consecutive elements from each iterator, and
    554     consuming the next input element each time it reaches the end of an
    555     iterator.
    556 
    557     For example:
    558 
    559     ```python
    560     # NOTE: The following examples use `{ ... }` to represent the
    561     # contents of a dataset.
    562     a = { 1, 2, 3, 4, 5 }
    563 
    564     # NOTE: New lines indicate "block" boundaries.
    565     a.interleave(lambda x: Dataset.from_tensors(x).repeat(6),
    566                  cycle_length=2, block_length=4) == {
    567         1, 1, 1, 1,
    568         2, 2, 2, 2,
    569         1, 1,
    570         2, 2,
    571         3, 3, 3, 3,
    572         4, 4, 4, 4,
    573         3, 3,
    574         4, 4,
    575         5, 5, 5, 5,
    576         5, 5,
    577     }
    578     ```
    579 
    580     NOTE: The order of elements yielded by this transformation is
    581     deterministic, as long as `map_func` is a pure function. If
    582     `map_func` contains any stateful operations, the order in which
    583     that state is accessed is undefined.
    584 
    585     Args:
    586       map_func: A function mapping a nested structure of tensors (having shapes
    587         and types defined by `self.output_shapes` and `self.output_types`) to a
    588         `Dataset`.
    589       cycle_length: The number of elements from this dataset that will be
    590         processed concurrently.
    591       block_length: The number of consecutive elements to produce from each
    592         input element before cycling to another input element.
    593 
    594     Returns:
    595       A `Dataset`.
    596     """
    597     return Dataset(
    598         dataset_ops.InterleaveDataset(self._dataset, map_func, cycle_length,
    599                                       block_length))
    600 
    601   @deprecation.deprecated(None, "Use `ds.apply(tf.contrib.data.unbatch())`.")
    602   def unbatch(self):
    603     """Deprecated: Use `Dataset.apply(tf.contrib.data.unbatch()`."""
    604 
    605     return self.apply(batching.unbatch())
    606 
    607   def filter(self, predicate):
    608     """Filters this dataset according to `predicate`.
    609 
    610     Args:
    611       predicate: A function mapping a nested structure of tensors (having shapes
    612         and types defined by `self.output_shapes` and `self.output_types`) to a
    613         scalar `tf.bool` tensor.
    614 
    615     Returns:
    616       A `Dataset`.
    617     """
    618     return Dataset(dataset_ops.FilterDataset(self._dataset, predicate))
    619 
    620   def apply(self, transformation_func):
    621     """Apply a transformation function to this dataset.
    622 
    623     `apply` enables chaining of custom `Dataset` transformations, which are
    624     represented as functions that take one `Dataset` argument and return a
    625     transformed `Dataset`.
    626 
    627     For example:
    628 
    629     ```
    630     dataset = (dataset.map(lambda x: x ** 2)
    631                .(group_by_window(key_func, reduce_func, window_size))
    632                .map(lambda x: x ** 3))
    633     ```
    634 
    635     Args:
    636       transformation_func: A function that takes one `Dataset` argument and
    637         returns a `Dataset`.
    638 
    639     Returns:
    640       The `Dataset` returned by applying `transformation_func` to this dataset.
    641     """
    642     dataset = transformation_func(self)
    643     if not isinstance(dataset, dataset_ops.Dataset):
    644       raise TypeError("`transformation_func` must return a Dataset.")
    645     return Dataset(dataset)
    646 
    647 
    648 def get_single_element(dataset):
    649   """Returns the single element in `dataset` as a nested structure of tensors.
    650 
    651   This function enables you to use a @{tf.data.Dataset} in a stateless
    652   "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}.
    653   This can be useful when your preprocessing transformations are expressed
    654   as a `Dataset`, and you want to use the transformation at serving time.
    655   For example:
    656 
    657   ```python
    658   input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE])
    659 
    660   def preprocessing_fn(input_str):
    661     # ...
    662     return image, label
    663 
    664   dataset = (tf.data.Dataset.from_tensor_slices(input_batch)
    665              .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
    666              .batch(BATCH_SIZE))
    667 
    668   image_batch, label_batch = tf.contrib.data.get_single_element(dataset)
    669   ```
    670 
    671   Args:
    672     dataset: A @{tf.data.Dataset} object containing a single element.
    673 
    674   Returns:
    675     A nested structure of @{tf.Tensor} objects, corresponding to the single
    676     element of `dataset`.
    677 
    678   Raises:
    679     TypeError: if `dataset` is not a `tf.data.Dataset` object.
    680     InvalidArgumentError (at runtime): if `dataset` does not contain exactly
    681       one element.
    682   """
    683   if not isinstance(dataset, dataset_ops.Dataset):
    684     raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
    685   return nest.pack_sequence_as(
    686       dataset.output_types,
    687       gen_dataset_ops.dataset_to_single_element(
    688           dataset._as_variant_tensor(),  # pylint: disable=protected-access
    689           output_types=nest.flatten(dataset.output_types),
    690           output_shapes=nest.flatten(dataset.output_shapes)))
    691