1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Python wrappers for Datasets and Iterators.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.contrib.data.python.ops import batching 21 from tensorflow.contrib.data.python.ops import enumerate_ops 22 from tensorflow.contrib.data.python.ops import error_ops 23 from tensorflow.contrib.data.python.ops import grouping 24 from tensorflow.python.data.ops import dataset_ops 25 from tensorflow.python.data.util import nest 26 from tensorflow.python.ops import gen_dataset_ops 27 from tensorflow.python.ops import gen_io_ops 28 from tensorflow.python.util import deprecation 29 30 31 class Dataset(dataset_ops.Dataset): 32 """Represents a potentially large set of elements. 33 34 A `Dataset` can be used to represent an input pipeline as a 35 collection of elements (nested structures of tensors) and a "logical 36 plan" of transformations that act on those elements. 37 """ 38 39 def __init__(self, dataset): 40 super(Dataset, self).__init__() 41 self._dataset = dataset 42 43 @deprecation.deprecated(None, "Use `ds._as_variant_tensor()`.") 44 def make_dataset_resource(self): 45 return self._as_variant_tensor() 46 47 def _as_variant_tensor(self): 48 return self._dataset._as_variant_tensor() # pylint: disable=protected-access 49 50 @property 51 def output_classes(self): 52 return self._dataset.output_classes 53 54 @property 55 def output_shapes(self): 56 return self._dataset.output_shapes 57 58 @property 59 def output_types(self): 60 return self._dataset.output_types 61 62 @staticmethod 63 @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensors()`.") 64 def from_tensors(tensors): 65 """Creates a `Dataset` with a single element, comprising the given tensors. 66 67 Args: 68 tensors: A nested structure of tensors. 69 70 Returns: 71 A `Dataset`. 72 """ 73 return Dataset(dataset_ops.TensorDataset(tensors)) 74 75 @staticmethod 76 @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.") 77 def from_tensor_slices(tensors): 78 """Creates a `Dataset` whose elements are slices of the given tensors. 79 80 Args: 81 tensors: A nested structure of tensors, each having the same size in the 82 0th dimension. 83 84 Returns: 85 A `Dataset`. 86 """ 87 return Dataset(dataset_ops.TensorSliceDataset(tensors)) 88 89 @staticmethod 90 @deprecation.deprecated(None, 91 "Use `tf.data.Dataset.from_sparse_tensor_slices()`.") 92 def from_sparse_tensor_slices(sparse_tensor): 93 """Splits each rank-N `tf.SparseTensor` in this dataset row-wise. 94 95 Args: 96 sparse_tensor: A `tf.SparseTensor`. 97 98 Returns: 99 A `Dataset` of rank-(N-1) sparse tensors. 100 """ 101 return Dataset(dataset_ops.SparseTensorSliceDataset(sparse_tensor)) 102 103 @staticmethod 104 @deprecation.deprecated(None, "Use `tf.data.Dataset.from_generator()`.") 105 def from_generator(generator, output_types, output_shapes=None): 106 """Creates a `Dataset` whose elements are generated by `generator`. 107 108 The `generator` argument must be a callable object that returns 109 an object that support the `iter()` protocol (e.g. a generator function). 110 The elements generated by `generator` must be compatible with the given 111 `output_types` and (optional) `output_shapes` arguments. 112 113 For example: 114 115 ```python 116 import itertools 117 118 def gen(): 119 for i in itertools.count(1): 120 yield (i, [1] * i) 121 122 ds = Dataset.from_generator( 123 gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None]))) 124 value = ds.make_one_shot_iterator().get_next() 125 126 sess.run(value) # (1, array([1])) 127 sess.run(value) # (2, array([1, 1])) 128 ``` 129 130 Args: 131 generator: A callable object that takes no arguments and returns an 132 object that supports the `iter()` protocol. 133 output_types: A nested structure of `tf.DType` objects corresponding to 134 each component of an element yielded by `generator`. 135 output_shapes: (Optional.) A nested structure of `tf.TensorShape` 136 objects corresponding to each component of an element yielded by 137 `generator`. 138 139 Returns: 140 A `Dataset`. 141 """ 142 return Dataset(dataset_ops.Dataset.from_generator( 143 generator, output_types, output_shapes)) 144 145 @staticmethod 146 @deprecation.deprecated(None, "Use `tf.data.Dataset.range()`.") 147 def range(*args): 148 """Creates a `Dataset` of a step-separated range of values. 149 150 For example: 151 152 ```python 153 Dataset.range(5) == [0, 1, 2, 3, 4] 154 Dataset.range(2, 5) == [2, 3, 4] 155 Dataset.range(1, 5, 2) == [1, 3] 156 Dataset.range(1, 5, -2) == [] 157 Dataset.range(5, 1) == [] 158 Dataset.range(5, 1, -2) == [5, 3] 159 ``` 160 161 Args: 162 *args: follow same semantics as python's xrange. 163 len(args) == 1 -> start = 0, stop = args[0], step = 1 164 len(args) == 2 -> start = args[0], stop = args[1], step = 1 165 len(args) == 3 -> start = args[0], stop = args[1, stop = args[2] 166 167 Returns: 168 A `RangeDataset`. 169 170 Raises: 171 ValueError: if len(args) == 0. 172 """ 173 return Dataset(dataset_ops.RangeDataset(*args)) 174 175 @staticmethod 176 @deprecation.deprecated(None, "Use `tf.data.Dataset.zip()`.") 177 def zip(datasets): 178 """Creates a `Dataset` by zipping together the given datasets. 179 180 This method has similar semantics to the built-in `zip()` function 181 in Python, with the main difference being that the `datasets` 182 argument can be an arbitrary nested structure of `Dataset` objects. 183 For example: 184 185 ```python 186 # NOTE: The following examples use `{ ... }` to represent the 187 # contents of a dataset. 188 a = { 1, 2, 3 } 189 b = { 4, 5, 6 } 190 c = { (7, 8), (9, 10), (11, 12) } 191 d = { 13, 14 } 192 193 # The nested structure of the `datasets` argument determines the 194 # structure of elements in the resulting dataset. 195 Dataset.zip((a, b)) == { (1, 4), (2, 5), (3, 6) } 196 Dataset.zip((b, a)) == { (4, 1), (5, 2), (6, 3) } 197 198 # The `datasets` argument may contain an arbitrary number of 199 # datasets. 200 Dataset.zip((a, b, c)) == { (1, 4, (7, 8)), 201 (2, 5, (9, 10)), 202 (3, 6, (11, 12)) } 203 204 # The number of elements in the resulting dataset is the same as 205 # the size of the smallest dataset in `datasets`. 206 Dataset.zip((a, d)) == { (1, 13), (2, 14) } 207 ``` 208 209 Args: 210 datasets: A nested structure of datasets. 211 212 Returns: 213 A `Dataset`. 214 """ 215 return Dataset(dataset_ops.ZipDataset(datasets)) 216 217 def concatenate(self, dataset): 218 """Creates a `Dataset` by concatenating given dataset with this dataset. 219 220 ```python 221 # NOTE: The following examples use `{ ... }` to represent the 222 # contents of a dataset. 223 a = { 1, 2, 3 } 224 b = { 4, 5, 6, 7 } 225 226 # Input dataset and dataset to be concatenated should have same 227 # nested structures and output types. 228 # c = { (8, 9), (10, 11), (12, 13) } 229 # d = { 14.0, 15.0, 16.0 } 230 # a.concatenate(c) and a.concatenate(d) would result in error. 231 232 a.concatenate(b) == { 1, 2, 3, 4, 5, 6, 7 } 233 ``` 234 235 Args: 236 dataset: `Dataset` to be concatenated. 237 238 Returns: 239 A `Dataset`. 240 """ 241 return Dataset(dataset_ops.ConcatenateDataset(self._dataset, dataset)) 242 243 def prefetch(self, buffer_size): 244 """Creates a `Dataset` that prefetches elements from this dataset. 245 246 Args: 247 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the 248 maximum number elements that will be buffered when prefetching. 249 250 Returns: 251 A `Dataset`. 252 """ 253 return Dataset(dataset_ops.PrefetchDataset(self._dataset, buffer_size)) 254 255 @staticmethod 256 @deprecation.deprecated(None, "Use `tf.data.Dataset.list_files()`.") 257 def list_files(file_pattern): 258 """A dataset of all files matching a pattern. 259 260 Example: 261 If we had the following files on our filesystem: 262 - /path/to/dir/a.txt 263 - /path/to/dir/b.py 264 - /path/to/dir/c.py 265 If we pass "/path/to/dir/*.py" as the directory, the dataset would 266 produce: 267 - /path/to/dir/b.py 268 - /path/to/dir/c.py 269 270 Args: 271 file_pattern: A string or scalar string `tf.Tensor`, representing 272 the filename pattern that will be matched. 273 274 Returns: 275 A `Dataset` of strings corresponding to file names. 276 """ 277 return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern)) 278 279 def repeat(self, count=None): 280 """Repeats this dataset `count` times. 281 282 Args: 283 count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 284 number of times the elements of this dataset should be repeated. The 285 default behavior (if `count` is `None` or `-1`) is for the elements to 286 be repeated indefinitely. 287 288 Returns: 289 A `Dataset`. 290 """ 291 return Dataset(dataset_ops.RepeatDataset(self._dataset, count)) 292 293 @deprecation.deprecated( 294 None, "Use `ds.apply(tf.contrib.data.enumerate_dataset())`.") 295 def enumerate(self, start=0): 296 """Deprecated: Use `Dataset.apply(tf.contrib.data.enumerate_dataset(..)`.""" 297 298 return self.apply(enumerate_ops.enumerate_dataset(start)) 299 300 def shuffle(self, buffer_size, seed=None): 301 """Randomly shuffles the elements of this dataset. 302 303 Args: 304 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the 305 number of elements from this dataset from which the new 306 dataset will sample. 307 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 308 random seed that will be used to create the distribution. See 309 @{tf.set_random_seed} for behavior. 310 311 Returns: 312 A `Dataset`. 313 """ 314 return Dataset(dataset_ops.ShuffleDataset(self._dataset, buffer_size, seed)) 315 316 def cache(self, filename=""): 317 """Caches the elements in this dataset. 318 319 Args: 320 filename: A `tf.string` scalar `tf.Tensor`, representing the name of a 321 directory on the filesystem to use for caching tensors in this Dataset. 322 If a filename is not provided, the dataset will be cached in memory. 323 324 Returns: 325 A `Dataset`. 326 """ 327 return Dataset(dataset_ops.CacheDataset(self._dataset, filename)) 328 329 def take(self, count): 330 """Creates a `Dataset` with at most `count` elements from this dataset. 331 332 Args: 333 count: A `tf.int64` scalar `tf.Tensor`, representing the number of 334 elements of this dataset that should be taken to form the new dataset. 335 If `count` is -1, or if `count` is greater than the size of this 336 dataset, the new dataset will contain all elements of this dataset. 337 338 Returns: 339 A `Dataset`. 340 """ 341 return Dataset(dataset_ops.TakeDataset(self._dataset, count)) 342 343 def skip(self, count): 344 """Creates a `Dataset` that skips `count` elements from this dataset. 345 346 Args: 347 count: A `tf.int64` scalar `tf.Tensor`, representing the number 348 of elements of this dataset that should be skipped to form the 349 new dataset. If `count` is greater than the size of this 350 dataset, the new dataset will contain no elements. If `count` 351 is -1, skips the entire dataset. 352 353 Returns: 354 A `Dataset`. 355 """ 356 return Dataset(dataset_ops.SkipDataset(self._dataset, count)) 357 358 def shard(self, num_shards, index): 359 """Creates a `Dataset` that includes only 1/`num_shards` of this dataset. 360 361 This dataset operator is very useful when running distributed training, as 362 it allows each worker to read a unique subset. 363 364 When reading a single input file, you can skip elements as follows: 365 366 ```python 367 d = tf.data.TFRecordDataset(FLAGS.input_file) 368 d = d.shard(FLAGS.num_workers, FLAGS.worker_index) 369 d = d.repeat(FLAGS.num_epochs) 370 d = d.shuffle(FLAGS.shuffle_buffer_size) 371 d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads) 372 ``` 373 374 Important caveats: 375 376 - Be sure to shard before you use any randomizing operator (such as 377 shuffle). 378 - Generally it is best if the shard operator is used early in the dataset 379 pipeline. For example, when reading from a set of TFRecord files, shard 380 before converting the dataset to input samples. This avoids reading every 381 file on every worker. The following is an example of an efficient 382 sharding strategy within a complete pipeline: 383 384 ```python 385 d = tf.data.Dataset.list_files(FLAGS.pattern) 386 d = d.shard(FLAGS.num_workers, FLAGS.worker_index) 387 d = d.repeat(FLAGS.num_epochs) 388 d = d.shuffle(FLAGS.shuffle_buffer_size) 389 d = d.interleave(tf.data.TFRecordDataset, 390 cycle_length=FLAGS.num_readers, block_length=1) 391 d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads) 392 ``` 393 394 Args: 395 num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of 396 shards operating in parallel. 397 index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. 398 399 Returns: 400 A `Dataset`. 401 402 Raises: 403 ValueError: if `num_shards` or `index` are illegal values. Note: error 404 checking is done on a best-effort basis, and aren't guaranteed to be 405 caught upon dataset creation. (e.g. providing in a placeholder tensor 406 bypasses the early checking, and will instead result in an error during 407 a session.run call.) 408 """ 409 return Dataset(self._dataset.shard(num_shards, index)) 410 411 @deprecation.deprecated( 412 None, "Use `ds.apply(tf.contrib.data.ignore_errors())`.") 413 def ignore_errors(self): 414 """Deprecated: Use `Dataset.apply(tf.contrib.data.ignore_errors())`.""" 415 416 return self.apply(error_ops.ignore_errors()) 417 418 def batch(self, batch_size): 419 """Combines consecutive elements of this dataset into batches. 420 421 Args: 422 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 423 consecutive elements of this dataset to combine in a single batch. 424 425 Returns: 426 A `Dataset`. 427 """ 428 return Dataset(dataset_ops.BatchDataset(self._dataset, batch_size)) 429 430 def padded_batch(self, batch_size, padded_shapes, padding_values=None): 431 """Combines consecutive elements of this dataset into padded batches. 432 433 Like `Dataset.dense_to_sparse_batch()`, this method combines 434 multiple consecutive elements of this dataset, which might have 435 different shapes, into a single element. The tensors in the 436 resulting element have an additional outer dimension, and are 437 padded to the respective shape in `padded_shapes`. 438 439 Args: 440 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 441 consecutive elements of this dataset to combine in a single batch. 442 padded_shapes: A nested structure of `tf.TensorShape` or 443 `tf.int64` vector tensor-like objects representing the shape 444 to which the respective component of each input element should 445 be padded prior to batching. Any unknown dimensions 446 (e.g. `tf.Dimension(None)` in a `tf.TensorShape` or `-1` in a 447 tensor-like object) will be padded to the maximum size of that 448 dimension in each batch. 449 padding_values: (Optional.) A nested structure of scalar-shaped 450 `tf.Tensor`, representing the padding values to use for the 451 respective components. Defaults are `0` for numeric types and 452 the empty string for string types. 453 454 Returns: 455 A `Dataset`. 456 """ 457 return Dataset( 458 dataset_ops.PaddedBatchDataset(self._dataset, batch_size, padded_shapes, 459 padding_values)) 460 461 @deprecation.deprecated( 462 None, "Use `ds.apply(tf.contrib.data.dense_to_sparse_batch())`.") 463 def dense_to_sparse_batch(self, batch_size, row_shape): 464 """Use: `Dataset.apply(tf.contrib.data.dense_to_sparse_batch(...))`.""" 465 466 return self.apply(batching.dense_to_sparse_batch(batch_size, row_shape)) 467 468 @deprecation.deprecated( 469 None, "Use `ds.apply(tf.contrib.data.group_by_window())`.") 470 def group_by_window(self, key_func, reduce_func, window_size): 471 """Deprecated: Use `Dataset.apply(tf.contrib.data.group_by_window(...))`.""" 472 473 return self.apply( 474 grouping.group_by_window(key_func, reduce_func, window_size)) 475 476 @deprecation.deprecated_args( 477 None, 478 "`output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.", 479 "num_threads", "output_buffer_size") 480 def map(self, 481 map_func, 482 num_threads=None, 483 output_buffer_size=None, 484 num_parallel_calls=None): 485 """Maps `map_func` across this dataset. 486 487 Args: 488 map_func: A function mapping a nested structure of tensors (having 489 shapes and types defined by `self.output_shapes` and 490 `self.output_types`) to another nested structure of tensors. 491 num_threads: (Optional.) Deprecated, use `num_parallel_calls` instead. 492 output_buffer_size: (Optional.) A `tf.int64` scalar `tf.Tensor`, 493 representing the maximum number of processed elements that will be 494 buffered. 495 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 496 representing the number elements to process in parallel. If not 497 specified, elements will be processed sequentially. 498 499 Returns: 500 A `Dataset`. 501 """ 502 if num_threads is None and num_parallel_calls is None: 503 ret = Dataset(dataset_ops.MapDataset(self._dataset, map_func)) 504 else: 505 if num_threads is None: 506 ret = Dataset( 507 dataset_ops.ParallelMapDataset(self._dataset, map_func, 508 num_parallel_calls)) 509 else: 510 ret = Dataset( 511 dataset_ops.ParallelMapDataset(self._dataset, map_func, 512 num_threads)) 513 if output_buffer_size is not None: 514 ret = ret.prefetch(output_buffer_size) 515 return ret 516 517 def flat_map(self, map_func): 518 """Maps `map_func` across this dataset and flattens the result. 519 520 Args: 521 map_func: A function mapping a nested structure of tensors (having shapes 522 and types defined by `self.output_shapes` and `self.output_types`) to a 523 `Dataset`. 524 525 Returns: 526 A `Dataset`. 527 """ 528 return Dataset(dataset_ops.FlatMapDataset(self._dataset, map_func)) 529 530 def interleave(self, map_func, cycle_length, block_length=1): 531 """Maps `map_func` across this dataset, and interleaves the results. 532 533 For example, you can use `Dataset.interleave()` to process many input files 534 concurrently: 535 536 ```python 537 # Preprocess 4 files concurrently, and interleave blocks of 16 records from 538 # each file. 539 filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...] 540 dataset = (Dataset.from_tensor_slices(filenames) 541 .interleave(lambda x: 542 TextLineDataset(x).map(parse_fn, num_parallel_calls=1), 543 cycle_length=4, block_length=16)) 544 ``` 545 546 The `cycle_length` and `block_length` arguments control the order in which 547 elements are produced. `cycle_length` controls the number of input elements 548 that are processed concurrently. If you set `cycle_length` to 1, this 549 transformation will handle one input element at a time, and will produce 550 identical results = to @{tf.data.Dataset.flat_map}. In general, 551 this transformation will apply `map_func` to `cycle_length` input elements, 552 open iterators on the returned `Dataset` objects, and cycle through them 553 producing `block_length` consecutive elements from each iterator, and 554 consuming the next input element each time it reaches the end of an 555 iterator. 556 557 For example: 558 559 ```python 560 # NOTE: The following examples use `{ ... }` to represent the 561 # contents of a dataset. 562 a = { 1, 2, 3, 4, 5 } 563 564 # NOTE: New lines indicate "block" boundaries. 565 a.interleave(lambda x: Dataset.from_tensors(x).repeat(6), 566 cycle_length=2, block_length=4) == { 567 1, 1, 1, 1, 568 2, 2, 2, 2, 569 1, 1, 570 2, 2, 571 3, 3, 3, 3, 572 4, 4, 4, 4, 573 3, 3, 574 4, 4, 575 5, 5, 5, 5, 576 5, 5, 577 } 578 ``` 579 580 NOTE: The order of elements yielded by this transformation is 581 deterministic, as long as `map_func` is a pure function. If 582 `map_func` contains any stateful operations, the order in which 583 that state is accessed is undefined. 584 585 Args: 586 map_func: A function mapping a nested structure of tensors (having shapes 587 and types defined by `self.output_shapes` and `self.output_types`) to a 588 `Dataset`. 589 cycle_length: The number of elements from this dataset that will be 590 processed concurrently. 591 block_length: The number of consecutive elements to produce from each 592 input element before cycling to another input element. 593 594 Returns: 595 A `Dataset`. 596 """ 597 return Dataset( 598 dataset_ops.InterleaveDataset(self._dataset, map_func, cycle_length, 599 block_length)) 600 601 @deprecation.deprecated(None, "Use `ds.apply(tf.contrib.data.unbatch())`.") 602 def unbatch(self): 603 """Deprecated: Use `Dataset.apply(tf.contrib.data.unbatch()`.""" 604 605 return self.apply(batching.unbatch()) 606 607 def filter(self, predicate): 608 """Filters this dataset according to `predicate`. 609 610 Args: 611 predicate: A function mapping a nested structure of tensors (having shapes 612 and types defined by `self.output_shapes` and `self.output_types`) to a 613 scalar `tf.bool` tensor. 614 615 Returns: 616 A `Dataset`. 617 """ 618 return Dataset(dataset_ops.FilterDataset(self._dataset, predicate)) 619 620 def apply(self, transformation_func): 621 """Apply a transformation function to this dataset. 622 623 `apply` enables chaining of custom `Dataset` transformations, which are 624 represented as functions that take one `Dataset` argument and return a 625 transformed `Dataset`. 626 627 For example: 628 629 ``` 630 dataset = (dataset.map(lambda x: x ** 2) 631 .(group_by_window(key_func, reduce_func, window_size)) 632 .map(lambda x: x ** 3)) 633 ``` 634 635 Args: 636 transformation_func: A function that takes one `Dataset` argument and 637 returns a `Dataset`. 638 639 Returns: 640 The `Dataset` returned by applying `transformation_func` to this dataset. 641 """ 642 dataset = transformation_func(self) 643 if not isinstance(dataset, dataset_ops.Dataset): 644 raise TypeError("`transformation_func` must return a Dataset.") 645 return Dataset(dataset) 646 647 648 def get_single_element(dataset): 649 """Returns the single element in `dataset` as a nested structure of tensors. 650 651 This function enables you to use a @{tf.data.Dataset} in a stateless 652 "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}. 653 This can be useful when your preprocessing transformations are expressed 654 as a `Dataset`, and you want to use the transformation at serving time. 655 For example: 656 657 ```python 658 input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE]) 659 660 def preprocessing_fn(input_str): 661 # ... 662 return image, label 663 664 dataset = (tf.data.Dataset.from_tensor_slices(input_batch) 665 .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) 666 .batch(BATCH_SIZE)) 667 668 image_batch, label_batch = tf.contrib.data.get_single_element(dataset) 669 ``` 670 671 Args: 672 dataset: A @{tf.data.Dataset} object containing a single element. 673 674 Returns: 675 A nested structure of @{tf.Tensor} objects, corresponding to the single 676 element of `dataset`. 677 678 Raises: 679 TypeError: if `dataset` is not a `tf.data.Dataset` object. 680 InvalidArgumentError (at runtime): if `dataset` does not contain exactly 681 one element. 682 """ 683 if not isinstance(dataset, dataset_ops.Dataset): 684 raise TypeError("`dataset` must be a `tf.data.Dataset` object.") 685 return nest.pack_sequence_as( 686 dataset.output_types, 687 gen_dataset_ops.dataset_to_single_element( 688 dataset._as_variant_tensor(), # pylint: disable=protected-access 689 output_types=nest.flatten(dataset.output_types), 690 output_shapes=nest.flatten(dataset.output_shapes))) 691