Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 #==============================================================================
     15 """Data Flow Operations."""
     16 # pylint: disable=g-bad-name
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import hashlib
     23 import threading
     24 
     25 import six
     26 
     27 from tensorflow.python.eager import context
     28 from tensorflow.python.framework import dtypes as _dtypes
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.framework import random_seed
     31 from tensorflow.python.framework import tensor_shape
     32 from tensorflow.python.framework import tensor_util
     33 from tensorflow.python.lib.io import python_io
     34 from tensorflow.python.ops import array_ops
     35 from tensorflow.python.ops import control_flow_ops
     36 from tensorflow.python.ops import gen_data_flow_ops
     37 from tensorflow.python.ops import math_ops
     38 # go/tf-wildcard-import
     39 # pylint: disable=wildcard-import
     40 from tensorflow.python.ops.gen_data_flow_ops import *
     41 from tensorflow.python.util.tf_export import tf_export
     42 
     43 # pylint: enable=wildcard-import
     44 
     45 
     46 def _as_type_list(dtypes):
     47   """Convert dtypes to a list of types."""
     48   assert dtypes is not None
     49   if not (isinstance(dtypes, list) or isinstance(dtypes, tuple)):
     50     # We have a single type.
     51     return [dtypes]
     52   else:
     53     # We have a list or tuple of types.
     54     return list(dtypes)
     55 
     56 
     57 def _as_shape_list(shapes,
     58                    dtypes,
     59                    unknown_dim_allowed=False,
     60                    unknown_rank_allowed=False):
     61   """Convert shapes to a list of tuples of int (or None)."""
     62   del dtypes
     63   if unknown_dim_allowed:
     64     if (not isinstance(shapes, collections.Sequence) or not shapes or
     65         any(shape is None or isinstance(shape, int) for shape in shapes)):
     66       raise ValueError(
     67           "When providing partial shapes, a list of shapes must be provided.")
     68   if shapes is None:
     69     return None
     70   if isinstance(shapes, tensor_shape.TensorShape):
     71     shapes = [shapes]
     72   if not isinstance(shapes, (tuple, list)):
     73     raise TypeError(
     74         "shapes must be a TensorShape or a list or tuple of TensorShapes.")
     75   if all(shape is None or isinstance(shape, int) for shape in shapes):
     76     # We have a single shape.
     77     shapes = [shapes]
     78   shapes = [tensor_shape.as_shape(shape) for shape in shapes]
     79   if not unknown_dim_allowed:
     80     if any([not shape.is_fully_defined() for shape in shapes]):
     81       raise ValueError("All shapes must be fully defined: %s" % shapes)
     82   if not unknown_rank_allowed:
     83     if any([shape.dims is None for shape in shapes]):
     84       raise ValueError("All shapes must have a defined rank: %s" % shapes)
     85 
     86   return shapes
     87 
     88 
     89 def _as_name_list(names, dtypes):
     90   if names is None:
     91     return None
     92   if not isinstance(names, (list, tuple)):
     93     names = [names]
     94   if len(names) != len(dtypes):
     95     raise ValueError("List of names must have the same length as the list "
     96                      "of dtypes")
     97   return list(names)
     98 
     99 
    100 def _shape_common(s1, s2):
    101   """The greatest lower bound (ordered by specificity) TensorShape."""
    102   s1 = tensor_shape.TensorShape(s1)
    103   s2 = tensor_shape.TensorShape(s2)
    104   if s1.ndims is None or s2.ndims is None or s1.ndims != s2.ndims:
    105     return tensor_shape.unknown_shape()
    106   d = [
    107       d1 if d1 is not None and d1 == d2 else None
    108       for (d1, d2) in zip(s1.as_list(), s2.as_list())
    109   ]
    110   return tensor_shape.TensorShape(d)
    111 
    112 
    113 # pylint: disable=protected-access
    114 @tf_export("QueueBase")
    115 class QueueBase(object):
    116   """Base class for queue implementations.
    117 
    118   A queue is a TensorFlow data structure that stores tensors across
    119   multiple steps, and exposes operations that enqueue and dequeue
    120   tensors.
    121 
    122   Each queue element is a tuple of one or more tensors, where each
    123   tuple component has a static dtype, and may have a static shape. The
    124   queue implementations support versions of enqueue and dequeue that
    125   handle single elements, versions that support enqueuing and
    126   dequeuing a batch of elements at once.
    127 
    128   See @{tf.FIFOQueue} and
    129   @{tf.RandomShuffleQueue} for concrete
    130   implementations of this class, and instructions on how to create
    131   them.
    132 
    133   @compatibility(eager)
    134   Queues are not compatible with eager execution. Instead, please
    135   use `tf.data` to get data into your model.
    136   @end_compatibility
    137   """
    138 
    139   def __init__(self, dtypes, shapes, names, queue_ref):
    140     """Constructs a queue object from a queue reference.
    141 
    142     The two optional lists, `shapes` and `names`, must be of the same length
    143     as `dtypes` if provided.  The values at a given index `i` indicate the
    144     shape and name to use for the corresponding queue component in `dtypes`.
    145 
    146     Args:
    147       dtypes:  A list of types.  The length of dtypes must equal the number
    148         of tensors in each element.
    149       shapes: Constraints on the shapes of tensors in an element:
    150         A list of shape tuples or None. This list is the same length
    151         as dtypes.  If the shape of any tensors in the element are constrained,
    152         all must be; shapes can be None if the shapes should not be constrained.
    153       names: Optional list of names.  If provided, the `enqueue()` and
    154         `dequeue()` methods will use dictionaries with these names as keys.
    155         Must be None or a list or tuple of the same length as `dtypes`.
    156       queue_ref: The queue reference, i.e. the output of the queue op.
    157 
    158     Raises:
    159       ValueError: If one of the arguments is invalid.
    160       RuntimeError: If eager execution is enabled.
    161     """
    162     if context.in_eager_mode():
    163       raise RuntimeError(
    164           "Queues are not supported when eager execution is enabled. "
    165           "Instead, please use tf.data to get data into your model.")
    166     self._dtypes = dtypes
    167     if shapes is not None:
    168       if len(shapes) != len(dtypes):
    169         raise ValueError("Queue shapes must have the same length as dtypes")
    170       self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
    171     else:
    172       self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes]
    173     if names is not None:
    174       if len(names) != len(dtypes):
    175         raise ValueError("Queue names must have the same length as dtypes")
    176       self._names = names
    177     else:
    178       self._names = None
    179     self._queue_ref = queue_ref
    180     if context.in_graph_mode():
    181       self._name = self._queue_ref.op.name.split("/")[-1]
    182     else:
    183       self._name = context.context().scope_name
    184 
    185   @staticmethod
    186   def from_list(index, queues):
    187     """Create a queue using the queue reference from `queues[index]`.
    188 
    189     Args:
    190       index: An integer scalar tensor that determines the input that gets
    191         selected.
    192       queues: A list of `QueueBase` objects.
    193 
    194     Returns:
    195       A `QueueBase` object.
    196 
    197     Raises:
    198       TypeError: When `queues` is not a list of `QueueBase` objects,
    199         or when the data types of `queues` are not all the same.
    200     """
    201     if ((not queues) or (not isinstance(queues, list)) or
    202         (not all(isinstance(x, QueueBase) for x in queues))):
    203       raise TypeError("A list of queues expected")
    204 
    205     dtypes = queues[0].dtypes
    206     if not all([dtypes == q.dtypes for q in queues[1:]]):
    207       raise TypeError("Queues do not have matching component dtypes.")
    208 
    209     names = queues[0].names
    210     if not all([names == q.names for q in queues[1:]]):
    211       raise TypeError("Queues do not have matching component names.")
    212 
    213     queue_shapes = [q.shapes for q in queues]
    214     reduced_shapes = [
    215         six.moves.reduce(_shape_common, s) for s in zip(*queue_shapes)
    216     ]
    217 
    218     queue_refs = array_ops.stack([x.queue_ref for x in queues])
    219     selected_queue = array_ops.gather(queue_refs, index)
    220     return QueueBase(
    221         dtypes=dtypes,
    222         shapes=reduced_shapes,
    223         names=names,
    224         queue_ref=selected_queue)
    225 
    226   @property
    227   def queue_ref(self):
    228     """The underlying queue reference."""
    229     return self._queue_ref
    230 
    231   @property
    232   def name(self):
    233     """The name of the underlying queue."""
    234     if context.in_graph_mode():
    235       return self._queue_ref.op.name
    236     return self._name
    237 
    238   @property
    239   def dtypes(self):
    240     """The list of dtypes for each component of a queue element."""
    241     return self._dtypes
    242 
    243   @property
    244   def shapes(self):
    245     """The list of shapes for each component of a queue element."""
    246     return self._shapes
    247 
    248   @property
    249   def names(self):
    250     """The list of names for each component of a queue element."""
    251     return self._names
    252 
    253   def _check_enqueue_dtypes(self, vals):
    254     """Validate and convert `vals` to a list of `Tensor`s.
    255 
    256     The `vals` argument can be a Tensor, a list or tuple of tensors, or a
    257     dictionary with tensor values.
    258 
    259     If it is a dictionary, the queue must have been constructed with a
    260     `names` attribute and the dictionary keys must match the queue names.
    261     If the queue was constructed with a `names` attribute, `vals` must
    262     be a dictionary.
    263 
    264     Args:
    265       vals: A tensor, a list or tuple of tensors, or a dictionary..
    266 
    267     Returns:
    268       A list of `Tensor` objects.
    269 
    270     Raises:
    271       ValueError: If `vals` is invalid.
    272     """
    273     if isinstance(vals, dict):
    274       if not self._names:
    275         raise ValueError("Queue must have names to enqueue a dictionary")
    276       if sorted(self._names, key=str) != sorted(vals.keys(), key=str):
    277         raise ValueError("Keys in dictionary to enqueue do not match "
    278                          "names of Queue.  Dictionary: (%s), Queue: (%s)" %
    279                          (sorted(vals.keys()), sorted(self._names)))
    280       # The order of values in `self._names` indicates the order in which the
    281       # tensors in the dictionary `vals` must be listed.
    282       vals = [vals[k] for k in self._names]
    283     else:
    284       if self._names:
    285         raise ValueError("You must enqueue a dictionary in a Queue with names")
    286       if not isinstance(vals, (list, tuple)):
    287         vals = [vals]
    288 
    289     tensors = []
    290     for i, (val, dtype) in enumerate(zip(vals, self._dtypes)):
    291       tensors.append(
    292           ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i))
    293 
    294     return tensors
    295 
    296   def _scope_vals(self, vals):
    297     """Return a list of values to pass to `name_scope()`.
    298 
    299     Args:
    300       vals: A tensor, a list or tuple of tensors, or a dictionary.
    301 
    302     Returns:
    303       The values in vals as a list.
    304     """
    305     if isinstance(vals, (list, tuple)):
    306       return vals
    307     elif isinstance(vals, dict):
    308       return vals.values()
    309     else:
    310       return [vals]
    311 
    312   def enqueue(self, vals, name=None):
    313     """Enqueues one element to this queue.
    314 
    315     If the queue is full when this operation executes, it will block
    316     until the element has been enqueued.
    317 
    318     At runtime, this operation may raise an error if the queue is
    319     @{tf.QueueBase.close} before or during its execution. If the
    320     queue is closed before this operation runs,
    321     `tf.errors.CancelledError` will be raised. If this operation is
    322     blocked, and either (i) the queue is closed by a close operation
    323     with `cancel_pending_enqueues=True`, or (ii) the session is
    324     @{tf.Session.close},
    325     `tf.errors.CancelledError` will be raised.
    326 
    327     Args:
    328       vals: A tensor, a list or tuple of tensors, or a dictionary containing
    329         the values to enqueue.
    330       name: A name for the operation (optional).
    331 
    332     Returns:
    333       The operation that enqueues a new tuple of tensors to the queue.
    334     """
    335     with ops.name_scope(name, "%s_enqueue" % self._name,
    336                         self._scope_vals(vals)) as scope:
    337       vals = self._check_enqueue_dtypes(vals)
    338 
    339       # NOTE(mrry): Not using a shape function because we need access to
    340       # the `QueueBase` object.
    341       for val, shape in zip(vals, self._shapes):
    342         val.get_shape().assert_is_compatible_with(shape)
    343 
    344       if self._queue_ref.dtype == _dtypes.resource:
    345         return gen_data_flow_ops._queue_enqueue_v2(
    346             self._queue_ref, vals, name=scope)
    347       else:
    348         return gen_data_flow_ops._queue_enqueue(
    349             self._queue_ref, vals, name=scope)
    350 
    351   def enqueue_many(self, vals, name=None):
    352     """Enqueues zero or more elements to this queue.
    353 
    354     This operation slices each component tensor along the 0th dimension to
    355     make multiple queue elements. All of the tensors in `vals` must have the
    356     same size in the 0th dimension.
    357 
    358     If the queue is full when this operation executes, it will block
    359     until all of the elements have been enqueued.
    360 
    361     At runtime, this operation may raise an error if the queue is
    362     @{tf.QueueBase.close} before or during its execution. If the
    363     queue is closed before this operation runs,
    364     `tf.errors.CancelledError` will be raised. If this operation is
    365     blocked, and either (i) the queue is closed by a close operation
    366     with `cancel_pending_enqueues=True`, or (ii) the session is
    367     @{tf.Session.close},
    368     `tf.errors.CancelledError` will be raised.
    369 
    370     Args:
    371       vals: A tensor, a list or tuple of tensors, or a dictionary
    372         from which the queue elements are taken.
    373       name: A name for the operation (optional).
    374 
    375     Returns:
    376       The operation that enqueues a batch of tuples of tensors to the queue.
    377     """
    378     with ops.name_scope(name, "%s_EnqueueMany" % self._name,
    379                         self._scope_vals(vals)) as scope:
    380       vals = self._check_enqueue_dtypes(vals)
    381 
    382       # NOTE(mrry): Not using a shape function because we need access to
    383       # the `QueueBase` object.
    384       batch_dim = vals[0].get_shape().with_rank_at_least(1)[0]
    385       for val, shape in zip(vals, self._shapes):
    386         batch_dim = batch_dim.merge_with(
    387             val.get_shape().with_rank_at_least(1)[0])
    388         val.get_shape()[1:].assert_is_compatible_with(shape)
    389 
    390       return gen_data_flow_ops._queue_enqueue_many_v2(
    391           self._queue_ref, vals, name=scope)
    392 
    393   def _dequeue_return_value(self, tensors):
    394     """Return the value to return from a dequeue op.
    395 
    396     If the queue has names, return a dictionary with the
    397     names as keys.  Otherwise return either a single tensor
    398     or a list of tensors depending on the length of `tensors`.
    399 
    400     Args:
    401       tensors: List of tensors from the dequeue op.
    402 
    403     Returns:
    404       A single tensor, a list of tensors, or a dictionary
    405       of tensors.
    406     """
    407     if self._names:
    408       # The returned values in `tensors` are in the same order as
    409       # the names in `self._names`.
    410       return {n: tensors[i] for i, n in enumerate(self._names)}
    411     elif len(tensors) == 1:
    412       return tensors[0]
    413     else:
    414       return tensors
    415 
    416   def dequeue(self, name=None):
    417     """Dequeues one element from this queue.
    418 
    419     If the queue is empty when this operation executes, it will block
    420     until there is an element to dequeue.
    421 
    422     At runtime, this operation may raise an error if the queue is
    423     @{tf.QueueBase.close} before or during its execution. If the
    424     queue is closed, the queue is empty, and there are no pending
    425     enqueue operations that can fulfill this request,
    426     `tf.errors.OutOfRangeError` will be raised. If the session is
    427     @{tf.Session.close},
    428     `tf.errors.CancelledError` will be raised.
    429 
    430     Args:
    431       name: A name for the operation (optional).
    432 
    433     Returns:
    434       The tuple of tensors that was dequeued.
    435     """
    436     if name is None:
    437       name = "%s_Dequeue" % self._name
    438     if self._queue_ref.dtype == _dtypes.resource:
    439       ret = gen_data_flow_ops._queue_dequeue_v2(
    440           self._queue_ref, self._dtypes, name=name)
    441     else:
    442       ret = gen_data_flow_ops._queue_dequeue(
    443           self._queue_ref, self._dtypes, name=name)
    444 
    445     # NOTE(mrry): Not using a shape function because we need access to
    446     # the `QueueBase` object.
    447     if context.in_graph_mode():
    448       op = ret[0].op
    449       for output, shape in zip(op.values(), self._shapes):
    450         output.set_shape(shape)
    451 
    452     return self._dequeue_return_value(ret)
    453 
    454   def dequeue_many(self, n, name=None):
    455     """Dequeues and concatenates `n` elements from this queue.
    456 
    457     This operation concatenates queue-element component tensors along
    458     the 0th dimension to make a single component tensor.  All of the
    459     components in the dequeued tuple will have size `n` in the 0th dimension.
    460 
    461     If the queue is closed and there are less than `n` elements left, then an
    462     `OutOfRange` exception is raised.
    463 
    464     At runtime, this operation may raise an error if the queue is
    465     @{tf.QueueBase.close} before or during its execution. If the
    466     queue is closed, the queue contains fewer than `n` elements, and
    467     there are no pending enqueue operations that can fulfill this
    468     request, `tf.errors.OutOfRangeError` will be raised. If the
    469     session is @{tf.Session.close},
    470     `tf.errors.CancelledError` will be raised.
    471 
    472     Args:
    473       n: A scalar `Tensor` containing the number of elements to dequeue.
    474       name: A name for the operation (optional).
    475 
    476     Returns:
    477       The list of concatenated tensors that was dequeued.
    478     """
    479     if name is None:
    480       name = "%s_DequeueMany" % self._name
    481 
    482     ret = gen_data_flow_ops._queue_dequeue_many_v2(
    483         self._queue_ref, n=n, component_types=self._dtypes, name=name)
    484 
    485     # NOTE(mrry): Not using a shape function because we need access to
    486     # the Queue object.
    487     if context.in_graph_mode():
    488       op = ret[0].op
    489       batch_dim = tensor_shape.Dimension(
    490           tensor_util.constant_value(op.inputs[1]))
    491       for output, shape in zip(op.values(), self._shapes):
    492         output.set_shape(
    493             tensor_shape.TensorShape([batch_dim]).concatenate(shape))
    494 
    495     return self._dequeue_return_value(ret)
    496 
    497   def dequeue_up_to(self, n, name=None):
    498     """Dequeues and concatenates `n` elements from this queue.
    499 
    500     **Note** This operation is not supported by all queues.  If a queue does not
    501     support DequeueUpTo, then a `tf.errors.UnimplementedError` is raised.
    502 
    503     This operation concatenates queue-element component tensors along
    504     the 0th dimension to make a single component tensor. If the queue
    505     has not been closed, all of the components in the dequeued tuple
    506     will have size `n` in the 0th dimension.
    507 
    508     If the queue is closed and there are more than `0` but fewer than
    509     `n` elements remaining, then instead of raising a
    510     `tf.errors.OutOfRangeError` like @{tf.QueueBase.dequeue_many},
    511     less than `n` elements are returned immediately.  If the queue is
    512     closed and there are `0` elements left in the queue, then a
    513     `tf.errors.OutOfRangeError` is raised just like in `dequeue_many`.
    514     Otherwise the behavior is identical to `dequeue_many`.
    515 
    516     Args:
    517       n: A scalar `Tensor` containing the number of elements to dequeue.
    518       name: A name for the operation (optional).
    519 
    520     Returns:
    521       The tuple of concatenated tensors that was dequeued.
    522     """
    523     if name is None:
    524       name = "%s_DequeueUpTo" % self._name
    525 
    526     ret = gen_data_flow_ops._queue_dequeue_up_to_v2(
    527         self._queue_ref, n=n, component_types=self._dtypes, name=name)
    528 
    529     # NOTE(mrry): Not using a shape function because we need access to
    530     # the Queue object.
    531     if context.in_graph_mode():
    532       op = ret[0].op
    533       for output, shape in zip(op.values(), self._shapes):
    534         output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape))
    535 
    536     return self._dequeue_return_value(ret)
    537 
    538   def close(self, cancel_pending_enqueues=False, name=None):
    539     """Closes this queue.
    540 
    541     This operation signals that no more elements will be enqueued in
    542     the given queue. Subsequent `enqueue` and `enqueue_many`
    543     operations will fail. Subsequent `dequeue` and `dequeue_many`
    544     operations will continue to succeed if sufficient elements remain
    545     in the queue. Subsequently dequeue and dequeue_many operations
    546     that would otherwise block waiting for more elements (if close
    547     hadn't been called) will now fail immediately.
    548 
    549     If `cancel_pending_enqueues` is `True`, all pending requests will also
    550     be canceled.
    551 
    552     Args:
    553       cancel_pending_enqueues: (Optional.) A boolean, defaulting to
    554         `False` (described above).
    555       name: A name for the operation (optional).
    556 
    557     Returns:
    558       The operation that closes the queue.
    559     """
    560     if name is None:
    561       name = "%s_Close" % self._name
    562     if self._queue_ref.dtype == _dtypes.resource:
    563       return gen_data_flow_ops._queue_close_v2(
    564           self._queue_ref,
    565           cancel_pending_enqueues=cancel_pending_enqueues,
    566           name=name)
    567     else:
    568       return gen_data_flow_ops._queue_close(
    569           self._queue_ref,
    570           cancel_pending_enqueues=cancel_pending_enqueues,
    571           name=name)
    572 
    573   def is_closed(self, name=None):
    574     """ Returns true if queue is closed.
    575 
    576     This operation returns true if the queue is closed and false if the queue
    577     is open.
    578 
    579     Args:
    580       name: A name for the operation (optional).
    581 
    582     Returns:
    583       True if the queue is closed and false if the queue is open.
    584     """
    585     if name is None:
    586       name = "%s_Is_Closed" % self._name
    587     if self._queue_ref.dtype == _dtypes.resource:
    588       return gen_data_flow_ops.queue_is_closed_v2(self._queue_ref, name=name)
    589     else:
    590       return gen_data_flow_ops.queue_is_closed_(self._queue_ref, name=name)
    591 
    592   def size(self, name=None):
    593     """Compute the number of elements in this queue.
    594 
    595     Args:
    596       name: A name for the operation (optional).
    597 
    598     Returns:
    599       A scalar tensor containing the number of elements in this queue.
    600     """
    601     if name is None:
    602       name = "%s_Size" % self._name
    603     if self._queue_ref.dtype == _dtypes.resource:
    604       return gen_data_flow_ops._queue_size_v2(self._queue_ref, name=name)
    605     else:
    606       return gen_data_flow_ops._queue_size(self._queue_ref, name=name)
    607 
    608 
    609 @tf_export("RandomShuffleQueue")
    610 class RandomShuffleQueue(QueueBase):
    611   """A queue implementation that dequeues elements in a random order.
    612 
    613   See @{tf.QueueBase} for a description of the methods on
    614   this class.
    615 
    616   @compatibility(eager)
    617   Queues are not compatible with eager execution. Instead, please
    618   use `tf.data` to get data into your model.
    619   @end_compatibility
    620   """
    621 
    622   def __init__(self,
    623                capacity,
    624                min_after_dequeue,
    625                dtypes,
    626                shapes=None,
    627                names=None,
    628                seed=None,
    629                shared_name=None,
    630                name="random_shuffle_queue"):
    631     """Create a queue that dequeues elements in a random order.
    632 
    633     A `RandomShuffleQueue` has bounded capacity; supports multiple
    634     concurrent producers and consumers; and provides exactly-once
    635     delivery.
    636 
    637     A `RandomShuffleQueue` holds a list of up to `capacity`
    638     elements. Each element is a fixed-length tuple of tensors whose
    639     dtypes are described by `dtypes`, and whose shapes are optionally
    640     described by the `shapes` argument.
    641 
    642     If the `shapes` argument is specified, each component of a queue
    643     element must have the respective fixed shape. If it is
    644     unspecified, different queue elements may have different shapes,
    645     but the use of `dequeue_many` is disallowed.
    646 
    647     The `min_after_dequeue` argument allows the caller to specify a
    648     minimum number of elements that will remain in the queue after a
    649     `dequeue` or `dequeue_many` operation completes, to ensure a
    650     minimum level of mixing of elements. This invariant is maintained
    651     by blocking those operations until sufficient elements have been
    652     enqueued. The `min_after_dequeue` argument is ignored after the
    653     queue has been closed.
    654 
    655     Args:
    656       capacity: An integer. The upper bound on the number of elements
    657         that may be stored in this queue.
    658       min_after_dequeue: An integer (described above).
    659       dtypes:  A list of `DType` objects. The length of `dtypes` must equal
    660         the number of tensors in each queue element.
    661       shapes: (Optional.) A list of fully-defined `TensorShape` objects
    662         with the same length as `dtypes`, or `None`.
    663       names: (Optional.) A list of string naming the components in the queue
    664         with the same length as `dtypes`, or `None`.  If specified the dequeue
    665         methods return a dictionary with the names as keys.
    666       seed: A Python integer. Used to create a random seed. See
    667         @{tf.set_random_seed}
    668         for behavior.
    669       shared_name: (Optional.) If non-empty, this queue will be shared under
    670         the given name across multiple sessions.
    671       name: Optional name for the queue operation.
    672     """
    673     dtypes = _as_type_list(dtypes)
    674     shapes = _as_shape_list(shapes, dtypes)
    675     names = _as_name_list(names, dtypes)
    676     seed1, seed2 = random_seed.get_seed(seed)
    677     if seed1 is None and seed2 is None:
    678       seed1, seed2 = 0, 0
    679     elif seed is None and shared_name is not None:
    680       # This means that graph seed is provided but op seed is not provided.
    681       # If shared_name is also provided, make seed2 depend only on the graph
    682       # seed and shared_name. (seed2 from get_seed() is generally dependent on
    683       # the id of the last op created.)
    684       string = (str(seed1) + shared_name).encode("utf-8")
    685       seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
    686     queue_ref = gen_data_flow_ops._random_shuffle_queue_v2(
    687         component_types=dtypes,
    688         shapes=shapes,
    689         capacity=capacity,
    690         min_after_dequeue=min_after_dequeue,
    691         seed=seed1,
    692         seed2=seed2,
    693         shared_name=shared_name,
    694         name=name)
    695 
    696     super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref)
    697 
    698 
    699 @tf_export("FIFOQueue")
    700 class FIFOQueue(QueueBase):
    701   """A queue implementation that dequeues elements in first-in first-out order.
    702 
    703   See @{tf.QueueBase} for a description of the methods on
    704   this class.
    705 
    706   @compatibility(eager)
    707   Queues are not compatible with eager execution. Instead, please
    708   use `tf.data` to get data into your model.
    709   @end_compatibility
    710   """
    711 
    712   def __init__(self,
    713                capacity,
    714                dtypes,
    715                shapes=None,
    716                names=None,
    717                shared_name=None,
    718                name="fifo_queue"):
    719     """Creates a queue that dequeues elements in a first-in first-out order.
    720 
    721     A `FIFOQueue` has bounded capacity; supports multiple concurrent
    722     producers and consumers; and provides exactly-once delivery.
    723 
    724     A `FIFOQueue` holds a list of up to `capacity` elements. Each
    725     element is a fixed-length tuple of tensors whose dtypes are
    726     described by `dtypes`, and whose shapes are optionally described
    727     by the `shapes` argument.
    728 
    729     If the `shapes` argument is specified, each component of a queue
    730     element must have the respective fixed shape. If it is
    731     unspecified, different queue elements may have different shapes,
    732     but the use of `dequeue_many` is disallowed.
    733 
    734     Args:
    735       capacity: An integer. The upper bound on the number of elements
    736         that may be stored in this queue.
    737       dtypes:  A list of `DType` objects. The length of `dtypes` must equal
    738         the number of tensors in each queue element.
    739       shapes: (Optional.) A list of fully-defined `TensorShape` objects
    740         with the same length as `dtypes`, or `None`.
    741       names: (Optional.) A list of string naming the components in the queue
    742         with the same length as `dtypes`, or `None`.  If specified the dequeue
    743         methods return a dictionary with the names as keys.
    744       shared_name: (Optional.) If non-empty, this queue will be shared under
    745         the given name across multiple sessions.
    746       name: Optional name for the queue operation.
    747     """
    748     dtypes = _as_type_list(dtypes)
    749     shapes = _as_shape_list(shapes, dtypes)
    750     names = _as_name_list(names, dtypes)
    751     queue_ref = gen_data_flow_ops._fifo_queue_v2(
    752         component_types=dtypes,
    753         shapes=shapes,
    754         capacity=capacity,
    755         shared_name=shared_name,
    756         name=name)
    757 
    758     super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
    759 
    760 
    761 @tf_export("PaddingFIFOQueue")
    762 class PaddingFIFOQueue(QueueBase):
    763   """A FIFOQueue that supports batching variable-sized tensors by padding.
    764 
    765   A `PaddingFIFOQueue` may contain components with dynamic shape, while also
    766   supporting `dequeue_many`.  See the constructor for more details.
    767 
    768   See @{tf.QueueBase} for a description of the methods on
    769   this class.
    770 
    771   @compatibility(eager)
    772   Queues are not compatible with eager execution. Instead, please
    773   use `tf.data` to get data into your model.
    774   @end_compatibility
    775   """
    776 
    777   def __init__(self,
    778                capacity,
    779                dtypes,
    780                shapes,
    781                names=None,
    782                shared_name=None,
    783                name="padding_fifo_queue"):
    784     """Creates a queue that dequeues elements in a first-in first-out order.
    785 
    786     A `PaddingFIFOQueue` has bounded capacity; supports multiple concurrent
    787     producers and consumers; and provides exactly-once delivery.
    788 
    789     A `PaddingFIFOQueue` holds a list of up to `capacity` elements. Each
    790     element is a fixed-length tuple of tensors whose dtypes are
    791     described by `dtypes`, and whose shapes are described by the `shapes`
    792     argument.
    793 
    794     The `shapes` argument must be specified; each component of a queue
    795     element must have the respective shape.  Shapes of fixed
    796     rank but variable size are allowed by setting any shape dimension to None.
    797     In this case, the inputs' shape may vary along the given dimension, and
    798     `dequeue_many` will pad the given dimension with zeros up to the maximum
    799     shape of all elements in the given batch.
    800 
    801     Args:
    802       capacity: An integer. The upper bound on the number of elements
    803         that may be stored in this queue.
    804       dtypes:  A list of `DType` objects. The length of `dtypes` must equal
    805         the number of tensors in each queue element.
    806       shapes: A list of `TensorShape` objects, with the same length as
    807         `dtypes`.  Any dimension in the `TensorShape` containing value
    808         `None` is dynamic and allows values to be enqueued with
    809          variable size in that dimension.
    810       names: (Optional.) A list of string naming the components in the queue
    811         with the same length as `dtypes`, or `None`.  If specified the dequeue
    812         methods return a dictionary with the names as keys.
    813       shared_name: (Optional.) If non-empty, this queue will be shared under
    814         the given name across multiple sessions.
    815       name: Optional name for the queue operation.
    816 
    817     Raises:
    818       ValueError: If shapes is not a list of shapes, or the lengths of dtypes
    819         and shapes do not match, or if names is specified and the lengths of
    820         dtypes and names do not match.
    821     """
    822     dtypes = _as_type_list(dtypes)
    823     shapes = _as_shape_list(shapes, dtypes, unknown_dim_allowed=True)
    824     names = _as_name_list(names, dtypes)
    825     if len(dtypes) != len(shapes):
    826       raise ValueError("Shapes must be provided for all components, "
    827                        "but received %d dtypes and %d shapes." % (len(dtypes),
    828                                                                   len(shapes)))
    829 
    830     queue_ref = gen_data_flow_ops._padding_fifo_queue_v2(
    831         component_types=dtypes,
    832         shapes=shapes,
    833         capacity=capacity,
    834         shared_name=shared_name,
    835         name=name)
    836 
    837     super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
    838 
    839 
    840 @tf_export("PriorityQueue")
    841 class PriorityQueue(QueueBase):
    842   """A queue implementation that dequeues elements in prioritized order.
    843 
    844   See @{tf.QueueBase} for a description of the methods on
    845   this class.
    846 
    847   @compatibility(eager)
    848   Queues are not compatible with eager execution. Instead, please
    849   use `tf.data` to get data into your model.
    850   @end_compatibility
    851   """
    852 
    853   def __init__(self,
    854                capacity,
    855                types,
    856                shapes=None,
    857                names=None,
    858                shared_name=None,
    859                name="priority_queue"):
    860     """Creates a queue that dequeues elements in a first-in first-out order.
    861 
    862     A `PriorityQueue` has bounded capacity; supports multiple concurrent
    863     producers and consumers; and provides exactly-once delivery.
    864 
    865     A `PriorityQueue` holds a list of up to `capacity` elements. Each
    866     element is a fixed-length tuple of tensors whose dtypes are
    867     described by `types`, and whose shapes are optionally described
    868     by the `shapes` argument.
    869 
    870     If the `shapes` argument is specified, each component of a queue
    871     element must have the respective fixed shape. If it is
    872     unspecified, different queue elements may have different shapes,
    873     but the use of `dequeue_many` is disallowed.
    874 
    875     Enqueues and Dequeues to the `PriorityQueue` must include an additional
    876     tuple entry at the beginning: the `priority`.  The priority must be
    877     an int64 scalar (for `enqueue`) or an int64 vector (for `enqueue_many`).
    878 
    879     Args:
    880       capacity: An integer. The upper bound on the number of elements
    881         that may be stored in this queue.
    882       types:  A list of `DType` objects. The length of `types` must equal
    883         the number of tensors in each queue element, except the first priority
    884         element.  The first tensor in each element is the priority,
    885         which must be type int64.
    886       shapes: (Optional.) A list of fully-defined `TensorShape` objects,
    887         with the same length as `types`, or `None`.
    888       names: (Optional.) A list of strings naming the components in the queue
    889         with the same length as `dtypes`, or `None`.  If specified, the dequeue
    890         methods return a dictionary with the names as keys.
    891       shared_name: (Optional.) If non-empty, this queue will be shared under
    892         the given name across multiple sessions.
    893       name: Optional name for the queue operation.
    894     """
    895     types = _as_type_list(types)
    896     shapes = _as_shape_list(shapes, types)
    897 
    898     queue_ref = gen_data_flow_ops._priority_queue_v2(
    899         component_types=types,
    900         shapes=shapes,
    901         capacity=capacity,
    902         shared_name=shared_name,
    903         name=name)
    904 
    905     priority_dtypes = [_dtypes.int64] + types
    906     priority_shapes = [()] + shapes if shapes else shapes
    907 
    908     super(PriorityQueue, self).__init__(priority_dtypes, priority_shapes, names,
    909                                         queue_ref)
    910 
    911 
    912 # TODO(josh11b): class BatchQueue(QueueBase):
    913 
    914 
    915 class Barrier(object):
    916   """Represents a key-value map that persists across graph executions."""
    917 
    918   def __init__(self, types, shapes=None, shared_name=None, name="barrier"):
    919     """Creates a barrier that persists across different graph executions.
    920 
    921     A barrier represents a key-value map, where each key is a string, and
    922     each value is a tuple of tensors.
    923 
    924     At runtime, the barrier contains 'complete' and 'incomplete'
    925     elements. A complete element has defined tensors for all
    926     components of its value tuple, and may be accessed using
    927     take_many. An incomplete element has some undefined components in
    928     its value tuple, and may be updated using insert_many.
    929 
    930     The barrier call `take_many` outputs values in a particular order.
    931     First, it only outputs completed values.  Second, the order in which
    932     completed values are returned matches the order in which their very
    933     first component was inserted into the barrier.  So, for example, for this
    934     sequence of insertions and removals:
    935 
    936       barrier = Barrier((tf.string, tf.int32), shapes=((), ()))
    937       barrier.insert_many(0, keys=["k1", "k2"], values=["a", "b"]).run()
    938       barrier.insert_many(1, keys=["k1"], values=[1]).run()
    939       barrier.insert_many(0, keys=["k3"], values=["c"]).run()
    940       barrier.insert_many(1, keys=["k3"], values=[3]).run()
    941       barrier.insert_many(1, keys=["k2"], values=[2]).run()
    942 
    943       (indices, keys, values) = barrier.take_many(2)
    944       (indices_val, keys_val, values0_val, values1_val) =
    945          session.run([indices, keys, values[0], values[1]])
    946 
    947     The output will be (up to permutation of "k1" and "k2"):
    948 
    949       indices_val == (-2**63, -2**63)
    950       keys_val == ("k1", "k2")
    951       values0_val == ("a", "b")
    952       values1_val == (1, 2)
    953 
    954     Note the key "k2" was inserted into the barrier before "k3".  Even though
    955     "k3" was completed first, both are complete by the time
    956     take_many is called.  As a result, "k2" is prioritized and "k1" and "k2"
    957     are returned first.  "k3" remains in the barrier until the next execution
    958     of `take_many`.  Since "k1" and "k2" had their first insertions into
    959     the barrier together, their indices are the same (-2**63).  The index
    960     of "k3" will be -2**63 + 1, because it was the next new inserted key.
    961 
    962     Args:
    963       types: A single dtype or a tuple of dtypes, corresponding to the
    964         dtypes of the tensor elements that comprise a value in this barrier.
    965       shapes: Optional. Constraints on the shapes of tensors in the values:
    966         a single tensor shape tuple; a tuple of tensor shape tuples
    967         for each barrier-element tuple component; or None if the shape should
    968         not be constrained.
    969       shared_name: Optional. If non-empty, this barrier will be shared under
    970         the given name across multiple sessions.
    971       name: Optional name for the barrier op.
    972 
    973     Raises:
    974       ValueError: If one of the `shapes` indicate no elements.
    975     """
    976     self._types = _as_type_list(types)
    977 
    978     if shapes is not None:
    979       shapes = _as_shape_list(shapes, self._types)
    980       self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
    981       for i, shape in enumerate(self._shapes):
    982         if shape.num_elements() == 0:
    983           raise ValueError("Empty tensors are not supported, but received "
    984                            "shape '%s' at index %d" % (shape, i))
    985     else:
    986       self._shapes = [tensor_shape.unknown_shape() for _ in self._types]
    987 
    988     self._barrier_ref = gen_data_flow_ops._barrier(
    989         component_types=self._types,
    990         shapes=self._shapes,
    991         shared_name=shared_name,
    992         name=name)
    993     if context.in_graph_mode():
    994       self._name = self._barrier_ref.op.name.split("/")[-1]
    995     else:
    996       self._name = context.context().scope_name
    997 
    998   @property
    999   def barrier_ref(self):
   1000     """Get the underlying barrier reference."""
   1001     return self._barrier_ref
   1002 
   1003   @property
   1004   def name(self):
   1005     """The name of the underlying barrier."""
   1006     if context.in_graph_mode():
   1007       return self._barrier_ref.op.name
   1008     return self._name
   1009 
   1010   def insert_many(self, component_index, keys, values, name=None):
   1011     """For each key, assigns the respective value to the specified component.
   1012 
   1013     This operation updates each element at component_index.
   1014 
   1015     Args:
   1016       component_index: The component of the value that is being assigned.
   1017       keys: A vector of keys, with length n.
   1018       values: An any-dimensional tensor of values, which are associated with the
   1019         respective keys. The first dimension must have length n.
   1020       name: Optional name for the op.
   1021 
   1022     Returns:
   1023       The operation that performs the insertion.
   1024     Raises:
   1025       InvalidArgumentsError: If inserting keys and values without elements.
   1026     """
   1027     if name is None:
   1028       name = "%s_BarrierInsertMany" % self._name
   1029     return gen_data_flow_ops._barrier_insert_many(
   1030         self._barrier_ref, keys, values, component_index, name=name)
   1031 
   1032   def take_many(self,
   1033                 num_elements,
   1034                 allow_small_batch=False,
   1035                 timeout=None,
   1036                 name=None):
   1037     """Takes the given number of completed elements from this barrier.
   1038 
   1039     This operation concatenates completed-element component tensors along
   1040     the 0th dimension to make a single component tensor.
   1041 
   1042     If barrier has no completed elements, this operation will block
   1043     until there are 'num_elements' elements to take.
   1044 
   1045     TODO(b/25743580): the semantics of `allow_small_batch` are experimental
   1046     and may be extended to other cases in the future.
   1047 
   1048     TODO(ebrevdo): If a take_many(allow_small_batch=True) is blocking
   1049     already when the barrier is closed, it will block for ever. Fix this
   1050     by using asynchronous operations.
   1051 
   1052     Args:
   1053       num_elements: The number of elements to take.
   1054       allow_small_batch: If the barrier is closed, don't block if there are less
   1055         completed elements than requested, but instead return all available
   1056         completed elements.
   1057       timeout: This specifies the number of milliseconds to block
   1058         before returning with DEADLINE_EXCEEDED. (This option is not
   1059         supported yet.)
   1060       name: A name for the operation (optional).
   1061 
   1062     Returns:
   1063       A tuple of (index, key, value_list).
   1064       "index" is a int64 tensor of length num_elements containing the
   1065         index of the insert_many call for which the very first component of
   1066         the given element was inserted into the Barrier, starting with
   1067         the value -2**63.  Note, this value is different from the
   1068         index of the insert_many call for which the element was completed.
   1069       "key" is a string tensor of length num_elements containing the keys.
   1070       "value_list" is a tuple of tensors, each one with size num_elements
   1071         in the 0th dimension for each component in the barrier's values.
   1072 
   1073     """
   1074     if name is None:
   1075       name = "%s_BarrierTakeMany" % self._name
   1076     ret = gen_data_flow_ops._barrier_take_many(
   1077         self._barrier_ref,
   1078         num_elements,
   1079         self._types,
   1080         allow_small_batch,
   1081         timeout,
   1082         name=name)
   1083 
   1084     # NOTE(mrry): Not using a shape function because we need access to
   1085     # the Barrier object.
   1086     if context.in_graph_mode():
   1087       op = ret[0].op
   1088       if allow_small_batch:
   1089         batch_dim = None
   1090       else:
   1091         batch_dim = tensor_shape.Dimension(
   1092             tensor_util.constant_value(op.inputs[1]))
   1093       op.outputs[0].set_shape(tensor_shape.vector(batch_dim))  # indices
   1094       op.outputs[1].set_shape(tensor_shape.vector(batch_dim))  # keys
   1095       for output, shape in zip(op.outputs[2:], self._shapes):  # value_list
   1096         output.set_shape(
   1097             tensor_shape.TensorShape([batch_dim]).concatenate(shape))
   1098 
   1099     return ret
   1100 
   1101   def close(self, cancel_pending_enqueues=False, name=None):
   1102     """Closes this barrier.
   1103 
   1104     This operation signals that no more new key values will be inserted in the
   1105     given barrier. Subsequent InsertMany operations with new keys will fail.
   1106     InsertMany operations that just complement already existing keys with other
   1107     components, will continue to succeed. Subsequent TakeMany operations will
   1108     continue to succeed if sufficient elements remain in the barrier. Subsequent
   1109     TakeMany operations that would block will fail immediately.
   1110 
   1111     If `cancel_pending_enqueues` is `True`, all pending requests to the
   1112     underlying queue will also be canceled, and completing of already
   1113     started values is also not acceptable anymore.
   1114 
   1115     Args:
   1116       cancel_pending_enqueues: (Optional.) A boolean, defaulting to
   1117         `False` (described above).
   1118       name: Optional name for the op.
   1119 
   1120     Returns:
   1121       The operation that closes the barrier.
   1122     """
   1123     if name is None:
   1124       name = "%s_BarrierClose" % self._name
   1125     return gen_data_flow_ops._barrier_close(
   1126         self._barrier_ref,
   1127         cancel_pending_enqueues=cancel_pending_enqueues,
   1128         name=name)
   1129 
   1130   def ready_size(self, name=None):
   1131     """Compute the number of complete elements in the given barrier.
   1132 
   1133     Args:
   1134       name: A name for the operation (optional).
   1135 
   1136     Returns:
   1137       A single-element tensor containing the number of complete elements in the
   1138       given barrier.
   1139     """
   1140     if name is None:
   1141       name = "%s_BarrierReadySize" % self._name
   1142     return gen_data_flow_ops._barrier_ready_size(self._barrier_ref, name=name)
   1143 
   1144   def incomplete_size(self, name=None):
   1145     """Compute the number of incomplete elements in the given barrier.
   1146 
   1147     Args:
   1148       name: A name for the operation (optional).
   1149 
   1150     Returns:
   1151       A single-element tensor containing the number of incomplete elements in
   1152       the given barrier.
   1153     """
   1154     if name is None:
   1155       name = "%s_BarrierIncompleteSize" % self._name
   1156     return gen_data_flow_ops._barrier_incomplete_size(
   1157         self._barrier_ref, name=name)
   1158 
   1159 
   1160 @tf_export("ConditionalAccumulatorBase")
   1161 class ConditionalAccumulatorBase(object):
   1162   """A conditional accumulator for aggregating gradients.
   1163 
   1164   Up-to-date gradients (i.e., time step at which gradient was computed is
   1165   equal to the accumulator's time step) are added to the accumulator.
   1166 
   1167   Extraction of the average gradient is blocked until the required number of
   1168   gradients has been accumulated.
   1169   """
   1170 
   1171   def __init__(self, dtype, shape, accumulator_ref):
   1172     """Creates a new ConditionalAccumulator.
   1173 
   1174     Args:
   1175       dtype: Datatype of the accumulated gradients.
   1176       shape: Shape of the accumulated gradients.
   1177       accumulator_ref: A handle to the conditional accumulator, created by sub-
   1178         classes
   1179     """
   1180     self._dtype = dtype
   1181     if shape is not None:
   1182       self._shape = tensor_shape.TensorShape(shape)
   1183     else:
   1184       self._shape = tensor_shape.unknown_shape()
   1185     self._accumulator_ref = accumulator_ref
   1186     if context.in_graph_mode():
   1187       self._name = self._accumulator_ref.op.name.split("/")[-1]
   1188     else:
   1189       self._name = context.context().scope_name
   1190 
   1191   @property
   1192   def accumulator_ref(self):
   1193     """The underlying accumulator reference."""
   1194     return self._accumulator_ref
   1195 
   1196   @property
   1197   def name(self):
   1198     """The name of the underlying accumulator."""
   1199     return self._name
   1200 
   1201   @property
   1202   def dtype(self):
   1203     """The datatype of the gradients accumulated by this accumulator."""
   1204     return self._dtype
   1205 
   1206   def num_accumulated(self, name=None):
   1207     """Number of gradients that have currently been aggregated in accumulator.
   1208 
   1209     Args:
   1210       name: Optional name for the operation.
   1211 
   1212     Returns:
   1213       Number of accumulated gradients currently in accumulator.
   1214     """
   1215     if name is None:
   1216       name = "%s_NumAccumulated" % self._name
   1217     return gen_data_flow_ops.accumulator_num_accumulated(
   1218         self._accumulator_ref, name=name)
   1219 
   1220   def set_global_step(self, new_global_step, name=None):
   1221     """Sets the global time step of the accumulator.
   1222 
   1223     The operation logs a warning if we attempt to set to a time step that is
   1224     lower than the accumulator's own time step.
   1225 
   1226     Args:
   1227       new_global_step: Value of new time step. Can be a variable or a constant
   1228       name: Optional name for the operation.
   1229 
   1230     Returns:
   1231       Operation that sets the accumulator's time step.
   1232     """
   1233     return gen_data_flow_ops.accumulator_set_global_step(
   1234         self._accumulator_ref,
   1235         math_ops.to_int64(ops.convert_to_tensor(new_global_step)),
   1236         name=name)
   1237 
   1238 
   1239 @tf_export("ConditionalAccumulator")
   1240 class ConditionalAccumulator(ConditionalAccumulatorBase):
   1241   """A conditional accumulator for aggregating gradients.
   1242 
   1243   Up-to-date gradients (i.e., time step at which gradient was computed is
   1244   equal to the accumulator's time step) are added to the accumulator.
   1245 
   1246   Extraction of the average gradient is blocked until the required number of
   1247   gradients has been accumulated.
   1248   """
   1249 
   1250   def __init__(self,
   1251                dtype,
   1252                shape=None,
   1253                shared_name=None,
   1254                name="conditional_accumulator"):
   1255     """Creates a new ConditionalAccumulator.
   1256 
   1257     Args:
   1258       dtype: Datatype of the accumulated gradients.
   1259       shape: Shape of the accumulated gradients.
   1260       shared_name: Optional. If non-empty, this accumulator will be shared under
   1261         the given name across multiple sessions.
   1262       name: Optional name for the accumulator.
   1263     """
   1264     accumulator_ref = gen_data_flow_ops.conditional_accumulator(
   1265         dtype=dtype, shape=shape, shared_name=shared_name, name=name)
   1266     super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref)
   1267 
   1268   def apply_grad(self, grad, local_step=0, name=None):
   1269     """Attempts to apply a gradient to the accumulator.
   1270 
   1271     The attempt is silently dropped if the gradient is stale, i.e., local_step
   1272     is less than the accumulator's global time step.
   1273 
   1274     Args:
   1275       grad: The gradient tensor to be applied.
   1276       local_step: Time step at which the gradient was computed.
   1277       name: Optional name for the operation.
   1278 
   1279     Returns:
   1280       The operation that (conditionally) applies a gradient to the accumulator.
   1281 
   1282     Raises:
   1283       ValueError: If grad is of the wrong shape
   1284     """
   1285     grad = ops.convert_to_tensor(grad, self._dtype)
   1286     grad.get_shape().assert_is_compatible_with(self._shape)
   1287     local_step = math_ops.to_int64(ops.convert_to_tensor(local_step))
   1288     return gen_data_flow_ops.accumulator_apply_gradient(
   1289         self._accumulator_ref, local_step=local_step, gradient=grad, name=name)
   1290 
   1291   def take_grad(self, num_required, name=None):
   1292     """Attempts to extract the average gradient from the accumulator.
   1293 
   1294     The operation blocks until sufficient number of gradients have been
   1295     successfully applied to the accumulator.
   1296 
   1297     Once successful, the following actions are also triggered:
   1298 
   1299     - Counter of accumulated gradients is reset to 0.
   1300     - Aggregated gradient is reset to 0 tensor.
   1301     - Accumulator's internal time step is incremented by 1.
   1302 
   1303     Args:
   1304       num_required: Number of gradients that needs to have been aggregated
   1305       name: Optional name for the operation
   1306 
   1307     Returns:
   1308       A tensor holding the value of the average gradient.
   1309 
   1310     Raises:
   1311       InvalidArgumentError: If num_required < 1
   1312     """
   1313     out = gen_data_flow_ops.accumulator_take_gradient(
   1314         self._accumulator_ref, num_required, dtype=self._dtype, name=name)
   1315     out.set_shape(self._shape)
   1316     return out
   1317 
   1318 
   1319 @tf_export("SparseConditionalAccumulator")
   1320 class SparseConditionalAccumulator(ConditionalAccumulatorBase):
   1321   """A conditional accumulator for aggregating sparse gradients.
   1322 
   1323   Sparse gradients are represented by IndexedSlices.
   1324 
   1325   Up-to-date gradients (i.e., time step at which gradient was computed is
   1326   equal to the accumulator's time step) are added to the accumulator.
   1327 
   1328   Extraction of the average gradient is blocked until the required number of
   1329   gradients has been accumulated.
   1330 
   1331   Args:
   1332     dtype: Datatype of the accumulated gradients.
   1333     shape: Shape of the accumulated gradients.
   1334     shared_name: Optional. If non-empty, this accumulator will be shared under
   1335       the given name across multiple sessions.
   1336     name: Optional name for the accumulator.
   1337   """
   1338 
   1339   def __init__(self,
   1340                dtype,
   1341                shape=None,
   1342                shared_name=None,
   1343                name="sparse_conditional_accumulator"):
   1344     accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator(
   1345         dtype=dtype, shape=shape, shared_name=shared_name, name=name)
   1346     super(SparseConditionalAccumulator, self).__init__(dtype, shape,
   1347                                                        accumulator_ref)
   1348 
   1349   def apply_indexed_slices_grad(self, grad, local_step=0, name=None):
   1350     """Attempts to apply a gradient to the accumulator.
   1351 
   1352     The attempt is silently dropped if the gradient is stale, i.e., local_step
   1353     is less than the accumulator's global time step.
   1354 
   1355     Args:
   1356       grad: The gradient IndexedSlices to be applied.
   1357       local_step: Time step at which the gradient was computed.
   1358       name: Optional name for the operation.
   1359 
   1360     Returns:
   1361       The operation that (conditionally) applies a gradient to the accumulator.
   1362 
   1363     Raises:
   1364       InvalidArgumentError: If grad is of the wrong shape
   1365     """
   1366     return self.apply_grad(
   1367         grad_indices=grad.indices,
   1368         grad_values=grad.values,
   1369         grad_shape=grad.dense_shape,
   1370         local_step=local_step,
   1371         name=name)
   1372 
   1373   def apply_grad(self,
   1374                  grad_indices,
   1375                  grad_values,
   1376                  grad_shape=None,
   1377                  local_step=0,
   1378                  name=None):
   1379     """Attempts to apply a sparse gradient to the accumulator.
   1380 
   1381     The attempt is silently dropped if the gradient is stale, i.e., local_step
   1382     is less than the accumulator's global time step.
   1383 
   1384     A sparse gradient is represented by its indices, values and possibly empty
   1385     or None shape. Indices must be a vector representing the locations of
   1386     non-zero entries in the tensor. Values are the non-zero slices of the
   1387     gradient, and must have the same first dimension as indices, i.e., the nnz
   1388     represented by indices and values must be consistent. Shape, if not empty or
   1389     None, must be consistent with the accumulator's shape (if also provided).
   1390 
   1391     Example:
   1392       A tensor [[0, 0], [0. 1], [2, 3]] can be represented
   1393         indices: [1,2]
   1394         values: [[0,1],[2,3]]
   1395         shape: [3, 2]
   1396 
   1397     Args:
   1398       grad_indices: Indices of the sparse gradient to be applied.
   1399       grad_values: Values of the sparse gradient to be applied.
   1400       grad_shape: Shape of the sparse gradient to be applied.
   1401       local_step: Time step at which the gradient was computed.
   1402       name: Optional name for the operation.
   1403 
   1404     Returns:
   1405       The operation that (conditionally) applies a gradient to the accumulator.
   1406 
   1407     Raises:
   1408       InvalidArgumentError: If grad is of the wrong shape
   1409     """
   1410     local_step = math_ops.to_int64(ops.convert_to_tensor(local_step))
   1411     return gen_data_flow_ops.sparse_accumulator_apply_gradient(
   1412         self._accumulator_ref,
   1413         local_step=local_step,
   1414         gradient_indices=math_ops.to_int64(grad_indices),
   1415         gradient_values=grad_values,
   1416         gradient_shape=math_ops.to_int64([]
   1417                                          if grad_shape is None else grad_shape),
   1418         has_known_shape=(grad_shape is not None),
   1419         name=name)
   1420 
   1421   def take_grad(self, num_required, name=None):
   1422     """Attempts to extract the average gradient from the accumulator.
   1423 
   1424     The operation blocks until sufficient number of gradients have been
   1425     successfully applied to the accumulator.
   1426 
   1427     Once successful, the following actions are also triggered:
   1428     - Counter of accumulated gradients is reset to 0.
   1429     - Aggregated gradient is reset to 0 tensor.
   1430     - Accumulator's internal time step is incremented by 1.
   1431 
   1432     Args:
   1433       num_required: Number of gradients that needs to have been aggregated
   1434       name: Optional name for the operation
   1435 
   1436     Returns:
   1437       A tuple of indices, values, and shape representing the average gradient.
   1438 
   1439     Raises:
   1440       InvalidArgumentError: If num_required < 1
   1441     """
   1442     return gen_data_flow_ops.sparse_accumulator_take_gradient(
   1443         self._accumulator_ref, num_required, dtype=self._dtype, name=name)
   1444 
   1445   def take_indexed_slices_grad(self, num_required, name=None):
   1446     """Attempts to extract the average gradient from the accumulator.
   1447 
   1448     The operation blocks until sufficient number of gradients have been
   1449     successfully applied to the accumulator.
   1450 
   1451     Once successful, the following actions are also triggered:
   1452     - Counter of accumulated gradients is reset to 0.
   1453     - Aggregated gradient is reset to 0 tensor.
   1454     - Accumulator's internal time step is incremented by 1.
   1455 
   1456     Args:
   1457       num_required: Number of gradients that needs to have been aggregated
   1458       name: Optional name for the operation
   1459 
   1460     Returns:
   1461       An IndexedSlices holding the value of the average gradient.
   1462 
   1463     Raises:
   1464       InvalidArgumentError: If num_required < 1
   1465     """
   1466     return_val = gen_data_flow_ops.sparse_accumulator_take_gradient(
   1467         self._accumulator_ref, num_required, dtype=self._dtype, name=name)
   1468     return ops.IndexedSlices(
   1469         indices=return_val.indices,
   1470         values=return_val.values,
   1471         dense_shape=return_val.shape)
   1472 
   1473 
   1474 class BaseStagingArea(object):
   1475   """Base class for Staging Areas."""
   1476   _identifier = 0
   1477   _lock = threading.Lock()
   1478 
   1479   def __init__(self,
   1480                dtypes,
   1481                shapes=None,
   1482                names=None,
   1483                shared_name=None,
   1484                capacity=0,
   1485                memory_limit=0):
   1486     if shared_name is None:
   1487       self._name = (
   1488           ops.get_default_graph().unique_name(self.__class__.__name__))
   1489     elif isinstance(shared_name, six.string_types):
   1490       self._name = shared_name
   1491     else:
   1492       raise ValueError("shared_name must be a string")
   1493 
   1494     self._dtypes = dtypes
   1495 
   1496     if shapes is not None:
   1497       if len(shapes) != len(dtypes):
   1498         raise ValueError("StagingArea shapes must be the same length as dtypes")
   1499       self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
   1500     else:
   1501       self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes]
   1502 
   1503     if names is not None:
   1504       if len(names) != len(dtypes):
   1505         raise ValueError("StagingArea names must be the same length as dtypes")
   1506       self._names = names
   1507     else:
   1508       self._names = None
   1509 
   1510     self._capacity = capacity
   1511     self._memory_limit = memory_limit
   1512 
   1513     # all get and put ops must colocate with this op
   1514     with ops.name_scope("%s_root" % self._name):
   1515       self._coloc_op = control_flow_ops.no_op()
   1516 
   1517   @property
   1518   def name(self):
   1519     """The name of the staging area."""
   1520     return self._name
   1521 
   1522   @property
   1523   def dtypes(self):
   1524     """The list of dtypes for each component of a staging area element."""
   1525     return self._dtypes
   1526 
   1527   @property
   1528   def shapes(self):
   1529     """The list of shapes for each component of a staging area element."""
   1530     return self._shapes
   1531 
   1532   @property
   1533   def names(self):
   1534     """The list of names for each component of a staging area element."""
   1535     return self._names
   1536 
   1537   @property
   1538   def capacity(self):
   1539     """The maximum number of elements of this staging area."""
   1540     return self._capacity
   1541 
   1542   @property
   1543   def memory_limit(self):
   1544     """The maximum number of bytes of this staging area."""
   1545     return self._memory_limit
   1546 
   1547   def _check_put_dtypes(self, vals, indices=None):
   1548     """Validate and convert `vals` to a list of `Tensor`s.
   1549 
   1550     The `vals` argument can be a Tensor, a list or tuple of tensors, or a
   1551     dictionary with tensor values.
   1552 
   1553     If `vals` is a list, then the appropriate indices associated with the
   1554     values must be provided.
   1555 
   1556     If it is a dictionary, the staging area must have been constructed with a
   1557     `names` attribute and the dictionary keys must match the staging area names.
   1558     `indices` will be inferred from the dictionary keys.
   1559     If the staging area was constructed with a `names` attribute, `vals` must
   1560     be a dictionary.
   1561 
   1562     Checks that the dtype and shape of each value matches that
   1563     of the staging area.
   1564 
   1565     Args:
   1566       vals: A tensor, a list or tuple of tensors, or a dictionary..
   1567 
   1568     Returns:
   1569       A (tensors, indices) tuple where `tensors` is a list of `Tensor` objects
   1570       and `indices` is a list of indices associed with the tensors.
   1571 
   1572     Raises:
   1573       ValueError: If `vals` or `indices` is invalid.
   1574     """
   1575     if isinstance(vals, dict):
   1576       if not self._names:
   1577         raise ValueError(
   1578             "Staging areas must have names to enqueue a dictionary")
   1579       if not set(vals.keys()).issubset(self._names):
   1580         raise ValueError("Keys in dictionary to put do not match names "
   1581                          "of staging area. Dictionary: (%s), Queue: (%s)" %
   1582                          (sorted(vals.keys()), sorted(self._names)))
   1583       # The order of values in `self._names` indicates the order in which the
   1584       # tensors in the dictionary `vals` must be listed.
   1585       vals, indices, n = zip(*[(vals[k], i, k)
   1586                                for i, k in enumerate(self._names)
   1587                                if k in vals])
   1588     else:
   1589       if self._names:
   1590         raise ValueError("You must enqueue a dictionary in a staging area "
   1591                          "with names")
   1592 
   1593       if indices is None:
   1594         raise ValueError("Indices must be supplied when inserting a list "
   1595                          "of tensors")
   1596 
   1597       if len(indices) != len(vals):
   1598         raise ValueError("Number of indices '%s' doesn't match "
   1599                          "number of values '%s'")
   1600 
   1601       if not isinstance(vals, (list, tuple)):
   1602         vals = [vals]
   1603         indices = [0]
   1604 
   1605     # Sanity check number of values
   1606     if not len(vals) <= len(self._dtypes):
   1607       raise ValueError("Unexpected number of inputs '%s' vs '%s'" %
   1608                        (len(vals), len(self._dtypes)))
   1609 
   1610     tensors = []
   1611 
   1612     for val, i in zip(vals, indices):
   1613       dtype, shape = self._dtypes[i], self._shapes[i]
   1614       # Check dtype
   1615       if not val.dtype == dtype:
   1616         raise ValueError("Datatypes do not match. '%s' != '%s'" %
   1617                          (str(val.dtype), str(dtype)))
   1618 
   1619       # Check shape
   1620       val.get_shape().assert_is_compatible_with(shape)
   1621 
   1622       tensors.append(
   1623           ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i))
   1624 
   1625     return tensors, indices
   1626 
   1627   def _create_device_transfers(self, tensors):
   1628     """Encode inter-device transfers if the current device
   1629     is not the same as the Staging Area's device
   1630     """
   1631 
   1632     if not isinstance(tensors, (tuple, list)):
   1633       tensors = [tensors]
   1634 
   1635     curr_device_scope = control_flow_ops.no_op().device
   1636 
   1637     if curr_device_scope != self._coloc_op.device:
   1638       tensors = [array_ops.identity(t) for t in tensors]
   1639 
   1640     return tensors
   1641 
   1642   def _get_return_value(self, tensors, indices):
   1643     """Return the value to return from a get op.
   1644 
   1645     If the staging area has names, return a dictionary with the
   1646     names as keys.  Otherwise return either a single tensor
   1647     or a list of tensors depending on the length of `tensors`.
   1648 
   1649     Args:
   1650       tensors: List of tensors from the get op.
   1651       indices: Indices of associated names and shapes
   1652 
   1653     Returns:
   1654       A single tensor, a list of tensors, or a dictionary
   1655       of tensors.
   1656     """
   1657 
   1658     tensors = self._create_device_transfers(tensors)
   1659 
   1660     # Sets shape
   1661     for output, i in zip(tensors, indices):
   1662       output.set_shape(self._shapes[i])
   1663 
   1664     if self._names:
   1665       # The returned values in `tensors` are in the same order as
   1666       # the names in `self._names`.
   1667       return {self._names[i]: t for t, i in zip(tensors, indices)}
   1668     return tensors
   1669 
   1670   def _scope_vals(self, vals):
   1671     """Return a list of values to pass to `name_scope()`.
   1672 
   1673     Args:
   1674       vals: A tensor, a list or tuple of tensors, or a dictionary.
   1675 
   1676     Returns:
   1677       The values in vals as a list.
   1678     """
   1679     if isinstance(vals, (list, tuple)):
   1680       return vals
   1681     elif isinstance(vals, dict):
   1682       return vals.values()
   1683     else:
   1684       return [vals]
   1685 
   1686 
   1687 class StagingArea(BaseStagingArea):
   1688   """Class for staging inputs. No ordering guarantees.
   1689 
   1690   A `StagingArea` is a TensorFlow data structure that stores tensors across
   1691   multiple steps, and exposes operations that can put and get tensors.
   1692 
   1693   Each `StagingArea` element is a tuple of one or more tensors, where each
   1694   tuple component has a static dtype, and may have a static shape.
   1695 
   1696   The capacity of a `StagingArea` may be bounded or unbounded.
   1697   It supports multiple concurrent producers and consumers; and
   1698   provides exactly-once delivery.
   1699 
   1700   Each element of a `StagingArea` is a fixed-length tuple of tensors whose
   1701   dtypes are described by `dtypes`, and whose shapes are optionally described
   1702   by the `shapes` argument.
   1703 
   1704   If the `shapes` argument is specified, each component of a staging area
   1705   element must have the respective fixed shape. If it is
   1706   unspecified, different elements may have different shapes,
   1707 
   1708   It can be configured with a capacity in which case
   1709   put(values) will block until space becomes available.
   1710 
   1711   Similarly, it can be configured with a memory limit which
   1712   will block put(values) until space is available.
   1713   This is mostly useful for limiting the number of tensors on
   1714   devices such as GPUs.
   1715 
   1716   All get() and peek() commands block if the requested data
   1717   is not present in the Staging Area.
   1718 
   1719   """
   1720 
   1721   def __init__(self,
   1722                dtypes,
   1723                shapes=None,
   1724                names=None,
   1725                shared_name=None,
   1726                capacity=0,
   1727                memory_limit=0):
   1728     """Constructs a staging area object.
   1729 
   1730     The two optional lists, `shapes` and `names`, must be of the same length
   1731     as `dtypes` if provided.  The values at a given index `i` indicate the
   1732     shape and name to use for the corresponding queue component in `dtypes`.
   1733 
   1734     The device scope at the time of object creation determines where the
   1735     storage for the `StagingArea` will reside.  Calls to `put` will incur a copy
   1736     to this memory space, if necessary.  Tensors returned by `get` will be
   1737     placed according to the device scope when `get` is called.
   1738 
   1739     Args:
   1740       dtypes:  A list of types.  The length of dtypes must equal the number
   1741         of tensors in each element.
   1742       capacity: (Optional.) Maximum number of elements.
   1743         An integer. If zero, the Staging Area is unbounded
   1744       memory_limit: (Optional.) Maximum number of bytes of all tensors
   1745         in the Staging Area.
   1746         An integer. If zero, the Staging Area is unbounded
   1747       shapes: (Optional.) Constraints on the shapes of tensors in an element.
   1748         A list of shape tuples or None. This list is the same length
   1749         as dtypes.  If the shape of any tensors in the element are constrained,
   1750         all must be; shapes can be None if the shapes should not be constrained.
   1751       names: (Optional.) If provided, the `get()` and
   1752         `put()` methods will use dictionaries with these names as keys.
   1753         Must be None or a list or tuple of the same length as `dtypes`.
   1754       shared_name: (Optional.) A name to be used for the shared object. By
   1755         passing the same name to two different python objects they will share
   1756         the underlying staging area. Must be a string.
   1757 
   1758     Raises:
   1759       ValueError: If one of the arguments is invalid.
   1760     """
   1761 
   1762     super(StagingArea, self).__init__(dtypes, shapes, names, shared_name,
   1763                                       capacity, memory_limit)
   1764 
   1765   def put(self, values, name=None):
   1766     """Create an op that places a value into the staging area.
   1767 
   1768     This operation will block if the `StagingArea` has reached
   1769     its capacity.
   1770 
   1771     Args:
   1772       values: Tensor (or a tuple of Tensors) to place into the staging area.
   1773       name: A name for the operation (optional).
   1774 
   1775     Returns:
   1776         The created op.
   1777 
   1778     Raises:
   1779       ValueError: If the number or type of inputs don't match the staging area.
   1780     """
   1781     with ops.name_scope(name, "%s_put" % self._name,
   1782                         self._scope_vals(values)) as scope:
   1783 
   1784       # Hard-code indices for this staging area
   1785       indices = (
   1786           list(six.moves.range(len(values)))
   1787           if isinstance(values, (list, tuple)) else None)
   1788       vals, _ = self._check_put_dtypes(values, indices)
   1789 
   1790       with ops.colocate_with(self._coloc_op):
   1791         op = gen_data_flow_ops.stage(
   1792             values=vals,
   1793             shared_name=self._name,
   1794             name=scope,
   1795             capacity=self._capacity,
   1796             memory_limit=self._memory_limit)
   1797 
   1798       return op
   1799 
   1800   def __internal_get(self, get_fn, name):
   1801     with ops.colocate_with(self._coloc_op):
   1802       ret = get_fn()
   1803 
   1804     indices = list(six.moves.range(len(self._dtypes)))  # Hard coded
   1805     return self._get_return_value(ret, indices)
   1806 
   1807   def get(self, name=None):
   1808     """Gets one element from this staging area.
   1809 
   1810     If the staging area is empty when this operation executes, it will block
   1811     until there is an element to dequeue.
   1812 
   1813     Note that unlike others ops that can block, like the queue Dequeue
   1814     operations, this can stop other work from happening.  To avoid this, the
   1815     intended use is for this to be called only when there will be an element
   1816     already available.  One method for doing this in a training loop would be to
   1817     run a `put()` call during a warmup session.run call, and then call both
   1818     `get()` and `put()` in each subsequent step.
   1819 
   1820     The placement of the returned tensor will be determined by the current
   1821     device scope when this function is called.
   1822 
   1823     Args:
   1824       name: A name for the operation (optional).
   1825 
   1826     Returns:
   1827       The tuple of tensors that was gotten.
   1828     """
   1829     if name is None:
   1830       name = "%s_get" % self._name
   1831 
   1832     # pylint: disable=bad-continuation
   1833     fn = lambda: gen_data_flow_ops.unstage(dtypes=self._dtypes,
   1834                     shared_name=self._name, name=name,
   1835                     capacity=self._capacity,
   1836                     memory_limit=self._memory_limit)
   1837     # pylint: enable=bad-continuation
   1838 
   1839     return self.__internal_get(fn, name)
   1840 
   1841   def peek(self, index, name=None):
   1842     """Peeks at an element in the staging area.
   1843 
   1844     If the staging area is too small to contain the element at
   1845     the specified index, it will block until enough elements
   1846     are inserted to complete the operation.
   1847 
   1848     The placement of the returned tensor will be determined by
   1849     the current device scope when this function is called.
   1850 
   1851     Args:
   1852       index: The index of the tensor within the staging area
   1853               to look up.
   1854       name: A name for the operation (optional).
   1855 
   1856     Returns:
   1857       The tuple of tensors that was gotten.
   1858     """
   1859     if name is None:
   1860       name = "%s_peek" % self._name
   1861 
   1862     # pylint: disable=bad-continuation
   1863     fn = lambda: gen_data_flow_ops.stage_peek(index,
   1864                     dtypes=self._dtypes, shared_name=self._name,
   1865                     name=name, capacity=self._capacity,
   1866                     memory_limit=self._memory_limit)
   1867     # pylint: enable=bad-continuation
   1868 
   1869     return self.__internal_get(fn, name)
   1870 
   1871   def size(self, name=None):
   1872     """Returns the number of elements in the staging area.
   1873 
   1874     Args:
   1875         name: A name for the operation (optional)
   1876 
   1877     Returns:
   1878         The created op
   1879     """
   1880     if name is None:
   1881       name = "%s_size" % self._name
   1882 
   1883     return gen_data_flow_ops.stage_size(
   1884         name=name,
   1885         shared_name=self._name,
   1886         dtypes=self._dtypes,
   1887         capacity=self._capacity,
   1888         memory_limit=self._memory_limit)
   1889 
   1890   def clear(self, name=None):
   1891     """Clears the staging area.
   1892 
   1893     Args:
   1894         name: A name for the operation (optional)
   1895 
   1896     Returns:
   1897         The created op
   1898     """
   1899     if name is None:
   1900       name = "%s_clear" % self._name
   1901 
   1902     return gen_data_flow_ops.stage_clear(
   1903         name=name,
   1904         shared_name=self._name,
   1905         dtypes=self._dtypes,
   1906         capacity=self._capacity,
   1907         memory_limit=self._memory_limit)
   1908 
   1909 
   1910 class MapStagingArea(BaseStagingArea):
   1911   """A `MapStagingArea` is a TensorFlow data structure that stores tensors across multiple steps, and exposes operations that can put and get tensors.
   1912 
   1913   Each `MapStagingArea` element is a (key, value) pair.
   1914   Only int64 keys are supported, other types should be
   1915   hashed to produce a key.
   1916   Values are a tuple of one or more tensors.
   1917   Each tuple component has a static dtype,
   1918   and may have a static shape.
   1919 
   1920   The capacity of a `MapStagingArea` may be bounded or unbounded.
   1921   It supports multiple concurrent producers and consumers; and
   1922   provides exactly-once delivery.
   1923 
   1924   Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors
   1925   whose
   1926   dtypes are described by `dtypes`, and whose shapes are optionally described
   1927   by the `shapes` argument.
   1928 
   1929   If the `shapes` argument is specified, each component of a staging area
   1930   element must have the respective fixed shape. If it is
   1931   unspecified, different elements may have different shapes,
   1932 
   1933   It behaves like an associative container with support for:
   1934 
   1935    - put(key, values)
   1936    - peek(key)         like dict.get(key)
   1937    - get(key)          like dict.pop(key)
   1938    - get(key=None)     like dict.popitem()
   1939    - size()
   1940    - clear()
   1941 
   1942   If ordered a tree structure ordered by key will be used and
   1943   get(key=None) will remove (key, value) pairs in increasing key order.
   1944   Otherwise a hashtable
   1945 
   1946   It can be configured with a capacity in which case
   1947   put(key, values) will block until space becomes available.
   1948 
   1949   Similarly, it can be configured with a memory limit which
   1950   will block put(key, values) until space is available.
   1951   This is mostly useful for limiting the number of tensors on
   1952   devices such as GPUs.
   1953 
   1954   All get() and peek() commands block if the requested
   1955   (key, value) pair is not present in the staging area.
   1956 
   1957   Partial puts are supported and will be placed in an incomplete
   1958   map until such time as all values associated with the key have
   1959   been inserted. Once completed, this (key, value) pair will be
   1960   inserted into the map. Data in the incomplete map
   1961   counts towards the memory limit, but not towards capacity limit.
   1962 
   1963   Partial gets from the map are also supported.
   1964   This removes the partially requested tensors from the entry,
   1965   but the entry is only removed from the map once all tensors
   1966   associated with it are removed.
   1967   """
   1968 
   1969   def __init__(self,
   1970                dtypes,
   1971                shapes=None,
   1972                names=None,
   1973                shared_name=None,
   1974                ordered=False,
   1975                capacity=0,
   1976                memory_limit=0):
   1977     """Args:
   1978 
   1979       dtypes:  A list of types.  The length of dtypes must equal the number
   1980         of tensors in each element.
   1981       capacity: (Optional.) Maximum number of elements.
   1982         An integer. If zero, the Staging Area is unbounded
   1983       memory_limit: (Optional.) Maximum number of bytes of all tensors
   1984         in the Staging Area (excluding keys).
   1985         An integer. If zero, the Staging Area is unbounded
   1986       ordered: (Optional.) If True the underlying data structure
   1987         is a tree ordered on key. Otherwise assume a hashtable.
   1988       shapes: (Optional.) Constraints on the shapes of tensors in an element.
   1989         A list of shape tuples or None. This list is the same length
   1990         as dtypes.  If the shape of any tensors in the element are constrained,
   1991         all must be; shapes can be None if the shapes should not be constrained.
   1992       names: (Optional.) If provided, the `get()` and
   1993         `put()` methods will use dictionaries with these names as keys.
   1994         Must be None or a list or tuple of the same length as `dtypes`.
   1995       shared_name: (Optional.) A name to be used for the shared object. By
   1996         passing the same name to two different python objects they will share
   1997         the underlying staging area. Must be a string.
   1998 
   1999     Raises:
   2000       ValueError: If one of the arguments is invalid.
   2001 
   2002     """
   2003 
   2004     super(MapStagingArea, self).__init__(dtypes, shapes, names, shared_name,
   2005                                          capacity, memory_limit)
   2006 
   2007     # Defer to different methods depending if the map is ordered
   2008     self._ordered = ordered
   2009 
   2010     if ordered:
   2011       self._put_fn = gen_data_flow_ops.ordered_map_stage
   2012       self._pop_fn = gen_data_flow_ops.ordered_map_unstage
   2013       self._popitem_fn = gen_data_flow_ops.ordered_map_unstage_no_key
   2014       self._peek_fn = gen_data_flow_ops.ordered_map_peek
   2015       self._size_fn = gen_data_flow_ops.ordered_map_size
   2016       self._incomplete_size_fn = gen_data_flow_ops.ordered_map_incomplete_size
   2017       self._clear_fn = gen_data_flow_ops.ordered_map_clear
   2018     else:
   2019       self._put_fn = gen_data_flow_ops.map_stage
   2020       self._pop_fn = gen_data_flow_ops.map_unstage
   2021       self._popitem_fn = gen_data_flow_ops.map_unstage_no_key
   2022       self._peek_fn = gen_data_flow_ops.map_peek
   2023       self._size_fn = gen_data_flow_ops.map_size
   2024       self._incomplete_size_fn = gen_data_flow_ops.map_incomplete_size
   2025       self._clear_fn = gen_data_flow_ops.map_clear
   2026 
   2027   def put(self, key, vals, indices=None, name=None):
   2028     """Create an op that stores the (key, vals) pair in the staging area.
   2029 
   2030     Incomplete puts are possible, preferably using a dictionary for vals
   2031     as the appropriate dtypes and shapes can be inferred from the value names
   2032     dictionary key values. If vals is a list or tuple, indices must
   2033     also be specified so that the op knows at which element position
   2034     to perform the insert.
   2035 
   2036     This operation will block if the capacity or memory limit of this
   2037     container is reached.
   2038 
   2039     Args:
   2040         key: Key associated with the data
   2041         vals: Tensor (or a dict/tuple of Tensors) to place
   2042                 into the staging area.
   2043         indices: (Optional) if vals is a tuple/list, this is required.
   2044         name: A name for the operation (optional)
   2045 
   2046     Returns:
   2047         The created op
   2048 
   2049     Raises:
   2050         ValueError: If the number or type of inputs don't match the staging
   2051         area.
   2052     """
   2053 
   2054     with ops.name_scope(name, "%s_put" % self._name,
   2055                         self._scope_vals(vals)) as scope:
   2056 
   2057       vals, indices = self._check_put_dtypes(vals, indices)
   2058 
   2059       with ops.colocate_with(self._coloc_op):
   2060         op = self._put_fn(
   2061             key,
   2062             indices,
   2063             vals,
   2064             dtypes=self._dtypes,
   2065             shared_name=self._name,
   2066             name=scope,
   2067             capacity=self._capacity,
   2068             memory_limit=self._memory_limit)
   2069     return op
   2070 
   2071   def _get_indices_and_dtypes(self, indices=None):
   2072     if indices is None:
   2073       indices = list(six.moves.range(len(self._dtypes)))
   2074 
   2075     if not isinstance(indices, (tuple, list)):
   2076       raise TypeError("Invalid indices type '%s'" % type(indices))
   2077 
   2078     if len(indices) == 0:
   2079       raise ValueError("Empty indices")
   2080 
   2081     if all(isinstance(i, str) for i in indices):
   2082       if self._names is None:
   2083         raise ValueError("String indices provided '%s', but this Staging Area "
   2084                          "was not created with names." % indices)
   2085 
   2086       try:
   2087         indices = [self._names.index(n) for n in indices]
   2088       except ValueError:
   2089         raise ValueError("Named index '%s' not in "
   2090                          "Staging Area names '%s'" % (n, self._names))
   2091     elif all(isinstance(i, int) for i in indices):
   2092       pass
   2093     else:
   2094       raise TypeError("Mixed types in indices '%s'. "
   2095                       "May only be str or int" % indices)
   2096 
   2097     dtypes = [self._dtypes[i] for i in indices]
   2098 
   2099     return indices, dtypes
   2100 
   2101   def peek(self, key, indices=None, name=None):
   2102     """Peeks at staging area data associated with the key.
   2103 
   2104     If the key is not in the staging area, it will block
   2105     until the associated (key, value) is inserted.
   2106 
   2107     Args:
   2108         key: Key associated with the required data
   2109         indices: Partial list of tensors to retrieve (optional).
   2110                 A list of integer or string indices.
   2111                 String indices are only valid if the Staging Area
   2112                 has names associated with it.
   2113         name: A name for the operation (optional)
   2114 
   2115     Returns:
   2116         The created op
   2117     """
   2118 
   2119     if name is None:
   2120       name = "%s_pop" % self._name
   2121 
   2122     indices, dtypes = self._get_indices_and_dtypes(indices)
   2123 
   2124     with ops.colocate_with(self._coloc_op):
   2125       result = self._peek_fn(
   2126           key,
   2127           shared_name=self._name,
   2128           indices=indices,
   2129           dtypes=dtypes,
   2130           name=name,
   2131           capacity=self._capacity,
   2132           memory_limit=self._memory_limit)
   2133 
   2134     return self._get_return_value(result, indices)
   2135 
   2136   def get(self, key=None, indices=None, name=None):
   2137     """If the key is provided, the associated (key, value) is returned from the staging area.
   2138 
   2139     If the key is not in the staging area, this method will block until
   2140     the associated (key, value) is inserted.
   2141     If no key is provided and the staging area is ordered,
   2142     the (key, value) with the smallest key will be returned.
   2143     Otherwise, a random (key, value) will be returned.
   2144 
   2145     If the staging area is empty when this operation executes,
   2146     it will block until there is an element to dequeue.
   2147 
   2148     Args:
   2149         key: Key associated with the required data (Optional)
   2150         indices: Partial list of tensors to retrieve (optional).
   2151                 A list of integer or string indices.
   2152                 String indices are only valid if the Staging Area
   2153                 has names associated with it.
   2154         name: A name for the operation (optional)
   2155 
   2156     Returns:
   2157         The created op
   2158     """
   2159     if key is None:
   2160       return self._popitem(indices=indices, name=name)
   2161     else:
   2162       return self._pop(key, indices=indices, name=name)
   2163 
   2164   def _pop(self, key, indices=None, name=None):
   2165     """Remove and return the associated (key, value) is returned from the staging area.
   2166 
   2167     If the key is not in the staging area, this method will block until
   2168     the associated (key, value) is inserted.
   2169     Args:
   2170         key: Key associated with the required data
   2171         indices: Partial list of tensors to retrieve (optional).
   2172                 A list of integer or string indices.
   2173                 String indices are only valid if the Staging Area
   2174                 has names associated with it.
   2175         name: A name for the operation (optional)
   2176 
   2177     Returns:
   2178         The created op
   2179     """
   2180     if name is None:
   2181       name = "%s_get" % self._name
   2182 
   2183     indices, dtypes = self._get_indices_and_dtypes(indices)
   2184 
   2185     with ops.colocate_with(self._coloc_op):
   2186       result = self._pop_fn(
   2187           key,
   2188           shared_name=self._name,
   2189           indices=indices,
   2190           dtypes=dtypes,
   2191           name=name,
   2192           capacity=self._capacity,
   2193           memory_limit=self._memory_limit)
   2194 
   2195     return key, self._get_return_value(result, indices)
   2196 
   2197   def _popitem(self, indices=None, name=None):
   2198     """If the staging area is ordered, the (key, value) with the smallest key will be returned.
   2199 
   2200     Otherwise, a random (key, value) will be returned.
   2201     If the staging area is empty when this operation executes,
   2202     it will block until there is an element to dequeue.
   2203 
   2204     Args:
   2205         key: Key associated with the required data
   2206         indices: Partial list of tensors to retrieve (optional).
   2207                 A list of integer or string indices.
   2208                 String indices are only valid if the Staging Area
   2209                 has names associated with it.
   2210         name: A name for the operation (optional)
   2211 
   2212     Returns:
   2213         The created op
   2214     """
   2215     if name is None:
   2216       name = "%s_get_nokey" % self._name
   2217 
   2218     indices, dtypes = self._get_indices_and_dtypes(indices)
   2219 
   2220     with ops.colocate_with(self._coloc_op):
   2221       key, result = self._popitem_fn(
   2222           shared_name=self._name,
   2223           indices=indices,
   2224           dtypes=dtypes,
   2225           name=name,
   2226           capacity=self._capacity,
   2227           memory_limit=self._memory_limit)
   2228 
   2229     # Separate keys and results out from
   2230     # underlying namedtuple
   2231     key = self._create_device_transfers(key)[0]
   2232     result = self._get_return_value(result, indices)
   2233 
   2234     return key, result
   2235 
   2236   def size(self, name=None):
   2237     """Returns the number of elements in the staging area.
   2238 
   2239     Args:
   2240         name: A name for the operation (optional)
   2241 
   2242     Returns:
   2243         The created op
   2244     """
   2245     if name is None:
   2246       name = "%s_size" % self._name
   2247 
   2248     return self._size_fn(
   2249         shared_name=self._name,
   2250         name=name,
   2251         dtypes=self._dtypes,
   2252         capacity=self._capacity,
   2253         memory_limit=self._memory_limit)
   2254 
   2255   def incomplete_size(self, name=None):
   2256     """Returns the number of incomplete elements in the staging area.
   2257 
   2258     Args:
   2259         name: A name for the operation (optional)
   2260 
   2261     Returns:
   2262         The created op
   2263     """
   2264     if name is None:
   2265       name = "%s_incomplete_size" % self._name
   2266 
   2267     return self._incomplete_size_fn(
   2268         shared_name=self._name,
   2269         name=name,
   2270         dtypes=self._dtypes,
   2271         capacity=self._capacity,
   2272         memory_limit=self._memory_limit)
   2273 
   2274   def clear(self, name=None):
   2275     """Clears the staging area.
   2276 
   2277     Args:
   2278         name: A name for the operation (optional)
   2279 
   2280     Returns:
   2281         The created op
   2282     """
   2283     if name is None:
   2284       name = "%s_clear" % self._name
   2285 
   2286     return self._clear_fn(
   2287         shared_name=self._name,
   2288         name=name,
   2289         dtypes=self._dtypes,
   2290         capacity=self._capacity,
   2291         memory_limit=self._memory_limit)
   2292 
   2293 
   2294 class RecordInput(object):
   2295   """RecordInput asynchronously reads and randomly yields TFRecords.
   2296 
   2297   A RecordInput Op will continuously read a batch of records asynchronously
   2298   into a buffer of some fixed capacity. It can also asynchronously yield
   2299   random records from this buffer.
   2300 
   2301   It will not start yielding until at least `buffer_size / 2` elements have been
   2302   placed into the buffer so that sufficient randomization can take place.
   2303 
   2304   The order the files are read will be shifted each epoch by `shift_amount` so
   2305   that the data is presented in a different order every epoch.
   2306   """
   2307 
   2308   def __init__(self,
   2309                file_pattern,
   2310                batch_size=1,
   2311                buffer_size=1,
   2312                parallelism=1,
   2313                shift_ratio=0,
   2314                seed=0,
   2315                name=None,
   2316                batches=None,
   2317                compression_type=None):
   2318     """Constructs a RecordInput Op.
   2319 
   2320     Args:
   2321       file_pattern: File path to the dataset, possibly containing wildcards.
   2322         All matching files will be iterated over each epoch.
   2323       batch_size: How many records to return at a time.
   2324       buffer_size: The maximum number of records the buffer will contain.
   2325       parallelism: How many reader threads to use for reading from files.
   2326       shift_ratio: What percentage of the total number files to move the start
   2327         file forward by each epoch.
   2328       seed: Specify the random number seed used by generator that randomizes
   2329         records.
   2330       name: Optional name for the operation.
   2331       batches: None by default, creating a single batch op. Otherwise specifies
   2332         how many batches to create, which are returned as a list when
   2333         `get_yield_op()` is called. An example use case is to split processing
   2334         between devices on one computer.
   2335       compression_type: The type of compression for the file. Currently ZLIB and
   2336         GZIP are supported. Defaults to none.
   2337 
   2338     Raises:
   2339       ValueError: If one of the arguments is invalid.
   2340     """
   2341     self._batch_size = batch_size
   2342     if batches is not None:
   2343       self._batch_size *= batches
   2344     self._batches = batches
   2345     self._file_pattern = file_pattern
   2346     self._buffer_size = buffer_size
   2347     self._parallelism = parallelism
   2348     self._shift_ratio = shift_ratio
   2349     self._seed = seed
   2350     self._name = name
   2351     self._compression_type = python_io.TFRecordCompressionType.NONE
   2352     if compression_type is not None:
   2353       self._compression_type = compression_type
   2354 
   2355   def get_yield_op(self):
   2356     """Adds a node that yields a group of records every time it is executed.
   2357     If RecordInput `batches` parameter is not None, it yields a list of
   2358     record batches with the specified `batch_size`.
   2359     """
   2360     compression_type = python_io.TFRecordOptions.get_compression_type_string(
   2361         python_io.TFRecordOptions(self._compression_type))
   2362     records = gen_data_flow_ops.record_input(
   2363         file_pattern=self._file_pattern,
   2364         file_buffer_size=self._buffer_size,
   2365         file_parallelism=self._parallelism,
   2366         file_shuffle_shift_ratio=self._shift_ratio,
   2367         batch_size=self._batch_size,
   2368         file_random_seed=self._seed,
   2369         compression_type=compression_type,
   2370         name=self._name)
   2371     if self._batches is None:
   2372       return records
   2373     else:
   2374       with ops.name_scope(self._name):
   2375         batch_list = [[] for i in six.moves.range(self._batches)]
   2376         records = array_ops.split(records, self._batch_size, 0)
   2377         records = [array_ops.reshape(record, []) for record in records]
   2378         for index, protobuf in zip(six.moves.range(len(records)), records):
   2379           batch_index = index % self._batches
   2380           batch_list[batch_index].append(protobuf)
   2381         return batch_list
   2382