1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Python wrappers for Datasets and Iterators.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.python.data.ops import dataset_ops 21 from tensorflow.python.data.util import nest 22 from tensorflow.python.ops import gen_dataset_ops 23 24 25 def get_single_element(dataset): 26 """Returns the single element in `dataset` as a nested structure of tensors. 27 28 This function enables you to use a @{tf.data.Dataset} in a stateless 29 "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}. 30 This can be useful when your preprocessing transformations are expressed 31 as a `Dataset`, and you want to use the transformation at serving time. 32 For example: 33 34 ```python 35 input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE]) 36 37 def preprocessing_fn(input_str): 38 # ... 39 return image, label 40 41 dataset = (tf.data.Dataset.from_tensor_slices(input_batch) 42 .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) 43 .batch(BATCH_SIZE)) 44 45 image_batch, label_batch = tf.contrib.data.get_single_element(dataset) 46 ``` 47 48 Args: 49 dataset: A @{tf.data.Dataset} object containing a single element. 50 51 Returns: 52 A nested structure of @{tf.Tensor} objects, corresponding to the single 53 element of `dataset`. 54 55 Raises: 56 TypeError: if `dataset` is not a `tf.data.Dataset` object. 57 InvalidArgumentError (at runtime): if `dataset` does not contain exactly 58 one element. 59 """ 60 if not isinstance(dataset, dataset_ops.Dataset): 61 raise TypeError("`dataset` must be a `tf.data.Dataset` object.") 62 return nest.pack_sequence_as( 63 dataset.output_types, 64 gen_dataset_ops.dataset_to_single_element( 65 dataset._as_variant_tensor(), # pylint: disable=protected-access 66 output_types=nest.flatten(dataset.output_types), 67 output_shapes=nest.flatten(dataset.output_shapes))) 68