Home | History | Annotate | Download | only in python
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Iteration over tf.data.Datasets when eager execution is enabled."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import threading
     22 
     23 from tensorflow.contrib.data.python.ops import prefetching_ops
     24 from tensorflow.python.data.ops import iterator_ops
     25 from tensorflow.python.data.util import nest
     26 from tensorflow.python.data.util import sparse
     27 from tensorflow.python.eager import context
     28 from tensorflow.python.framework import constant_op
     29 from tensorflow.python.framework import dtypes
     30 from tensorflow.python.framework import errors
     31 from tensorflow.python.framework import function
     32 from tensorflow.python.framework import ops
     33 from tensorflow.python.ops import gen_dataset_ops
     34 from tensorflow.python.ops import resource_variable_ops
     35 
     36 _uid_counter = 0
     37 _uid_lock = threading.Lock()
     38 
     39 
     40 def _generate_shared_name(prefix):
     41   with _uid_lock:
     42     global _uid_counter
     43     uid = _uid_counter
     44     _uid_counter += 1
     45   return "{}{}".format(prefix, uid)
     46 
     47 
     48 class Iterator(object):
     49   """An iterator producing tf.Tensor objects from a tf.data.Dataset."""
     50 
     51   def __init__(self, dataset):
     52     """Creates a new iterator over the given dataset.
     53 
     54     For example:
     55     ```python
     56     dataset = tf.data.Dataset.range(4)
     57     for x in Iterator(dataset):
     58       print(x)
     59     ```
     60 
     61     Tensors produced will be placed on the device on which this iterator object
     62     was created.
     63 
     64     Args:
     65       dataset: A `tf.data.Dataset` object.
     66 
     67     Raises:
     68       RuntimeError: When invoked without eager execution enabled.
     69     """
     70 
     71     if not context.in_eager_mode():
     72       raise RuntimeError(
     73           "{} objects can only be used when eager execution is enabled, use "
     74           "tf.data.Dataset.make_iterator or "
     75           "tf.data.Dataset.make_one_shot_iterator for graph construction".
     76           format(type(self)))
     77     with ops.device("/device:CPU:0"):
     78       ds_variant = dataset._as_variant_tensor()  # pylint: disable=protected-access
     79       self._output_classes = dataset.output_classes
     80       self._output_types = dataset.output_types
     81       self._output_shapes = dataset.output_shapes
     82       self._flat_output_types = nest.flatten(
     83           sparse.as_dense_types(self._output_types, self._output_classes))
     84       self._flat_output_shapes = nest.flatten(
     85           sparse.as_dense_shapes(self._output_shapes, self._output_classes))
     86       self._resource = gen_dataset_ops.iterator(
     87           shared_name="",
     88           container=_generate_shared_name("eageriterator"),
     89           output_types=self._flat_output_types,
     90           output_shapes=self._flat_output_shapes)
     91       gen_dataset_ops.make_iterator(ds_variant, self._resource)
     92       # Delete the resource when this object is deleted
     93       self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
     94           handle=self._resource, handle_device="/device:CPU:0")
     95     self._device = context.context().device_name
     96     self._buffer_resource_handle = None
     97     if not context.context().device_spec.device_type:
     98       is_remote_device = False
     99     else:
    100       is_remote_device = context.context().device_spec.device_type != "CPU"
    101     if is_remote_device:
    102       with ops.device("/device:CPU:0"):
    103         iter_string_handle = gen_dataset_ops.iterator_to_string_handle(
    104             self._resource)
    105 
    106         @function.Defun(dtypes.string)
    107         def remote_fn(h):
    108           remote_iterator = iterator_ops.Iterator.from_string_handle(
    109               h, self._output_types, self._output_shapes)
    110           return remote_iterator.get_next()
    111 
    112         remote_fn.add_to_graph(None)
    113         target = constant_op.constant("/device:CPU:0")
    114       with ops.device(self._device):
    115         self._buffer_resource_handle = prefetching_ops.function_buffering_resource(  # pylint: disable=line-too-long
    116             string_arg=iter_string_handle,
    117             f=remote_fn,
    118             target_device=target,
    119             buffer_size=10,
    120             thread_pool_size=1,
    121             container="",
    122             shared_name=_generate_shared_name("function_buffer_resource"))
    123         self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter(  # pylint: disable=line-too-long
    124             handle=self._buffer_resource_handle,
    125             handle_device=self._device)
    126 
    127   def __iter__(self):
    128     return self
    129 
    130   def __next__(self):  # For Python 3 compatibility
    131     return self.next()
    132 
    133   def _next_internal(self):
    134     """Returns a nested structure of `tf.Tensor`s containing the next element.
    135     """
    136     with ops.device(self._device):
    137       if self._buffer_resource_handle is not None:
    138         ret = prefetching_ops.function_buffering_resource_get_next(
    139             function_buffer_resource=self._buffer_resource_handle,
    140             output_types=self._flat_output_types)
    141       else:
    142         # TODO(ashankar): Consider removing this ops.device() contextmanager
    143         # and instead mimic ops placement in graphs: Operations on resource
    144         # handles execute on the same device as where the resource is placed.
    145         # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
    146         # because in eager mode this code will run synchronously on the calling
    147         # thread. Therefore we do not need to make a defensive context switch
    148         # to a background thread, and can achieve a small constant performance
    149         # boost by invoking the iterator synchronously.
    150         ret = gen_dataset_ops.iterator_get_next_sync(
    151             self._resource,
    152             output_types=self._flat_output_types,
    153             output_shapes=self._flat_output_shapes)
    154 
    155     return sparse.deserialize_sparse_tensors(
    156         nest.pack_sequence_as(self._output_types, ret), self._output_types,
    157         self._output_shapes, self._output_classes)
    158 
    159   def next(self):
    160     """Returns a nested structure of `tf.Tensor`s containing the next element.
    161     """
    162     try:
    163       return self._next_internal()
    164     except errors.OutOfRangeError:
    165       raise StopIteration
    166 
    167   @property
    168   def output_classes(self):
    169     """Returns the class of each component of an element of this iterator.
    170 
    171     The expected values are `tf.Tensor` and `tf.SparseTensor`.
    172 
    173     Returns:
    174       A nested structure of Python `type` objects corresponding to each
    175       component of an element of this dataset.
    176     """
    177     return self._output_classes
    178 
    179   @property
    180   def output_shapes(self):
    181     """Returns the shape of each component of an element of this iterator.
    182 
    183     Returns:
    184       A nested structure of `tf.TensorShape` objects corresponding to each
    185       component of an element of this dataset.
    186     """
    187     return self._output_shapes
    188 
    189   @property
    190   def output_types(self):
    191     """Returns the type of each component of an element of this iterator.
    192 
    193     Returns:
    194       A nested structure of `tf.DType` objects corresponding to each component
    195       of an element of this dataset.
    196     """
    197     return self._output_types
    198 
    199   def get_next(self, name=None):
    200     """Returns a nested structure of `tf.Tensor`s containing the next element.
    201 
    202     Args:
    203       name: (Optional.) A name for the created operation. Currently unused.
    204 
    205     Returns:
    206       A nested structure of `tf.Tensor` objects.
    207 
    208     Raises:
    209       `tf.errors.OutOfRangeError`: If the end of the dataset has been reached.
    210     """
    211     del name
    212     return self._next_internal()
    213