Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """TensorArray: a dynamically sized array of Tensors.
     16 
     17 @@TensorArray
     18 """
     19 # Mixture of pep8 and non-pep8 names, so disable pylint bad-name
     20 # pylint: disable=g-bad-name
     21 from __future__ import absolute_import
     22 from __future__ import division
     23 from __future__ import print_function
     24 
     25 import contextlib
     26 
     27 from tensorflow.python.eager import context
     28 from tensorflow.python.framework import constant_op
     29 from tensorflow.python.framework import dtypes
     30 from tensorflow.python.framework import errors_impl
     31 from tensorflow.python.framework import ops
     32 from tensorflow.python.framework import tensor_shape
     33 from tensorflow.python.framework import tensor_util
     34 from tensorflow.python.ops import array_ops
     35 from tensorflow.python.ops import gen_data_flow_ops
     36 from tensorflow.python.ops import math_ops
     37 from tensorflow.python.util import tf_should_use
     38 from tensorflow.python.util.tf_export import tf_export
     39 
     40 
     41 # _GraphTensorArray accesses many of the hidden generated ops, but is in
     42 # fact built to wrap these methods.
     43 # pylint: disable=protected-access
     44 class _GraphTensorArray(object):
     45   """Graph-mode implementation of TensorArray.
     46   """
     47 
     48   def __init__(self,
     49                dtype,
     50                size=None,
     51                dynamic_size=None,
     52                clear_after_read=None,
     53                tensor_array_name=None,
     54                handle=None,
     55                flow=None,
     56                infer_shape=True,
     57                element_shape=None,
     58                colocate_with_first_write_call=True,
     59                name=None):
     60     """Constructs a graph mode TensorArray.
     61 
     62     Args:
     63       dtype: (required) data type of the TensorArray.
     64       size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
     65         Required if handle is not provided.
     66       dynamic_size: (optional) Python bool: If true, writes to the TensorArray
     67         can grow the TensorArray past its initial size.  Default: False.
     68       clear_after_read: Boolean (optional, default: True).  If True, clear
     69         TensorArray values after reading them.  This disables read-many
     70         semantics, but allows early release of memory.
     71       tensor_array_name: (optional) Python string: the name of the TensorArray.
     72         This is used when creating the TensorArray handle.  If this value is
     73         set, handle should be None.
     74       handle: (optional) A `Tensor` handle to an existing TensorArray.  If this
     75         is set, tensor_array_name should be None. Only supported in graph mode.
     76       flow: (optional) A float `Tensor` scalar coming from an existing
     77         `TensorArray.flow`. Only supported in graph mode.
     78       infer_shape: (optional, default: True) If True, shape inference
     79         is enabled.  In this case, all elements must have the same shape.
     80       element_shape: (optional, default: None) A `TensorShape` object specifying
     81         the shape constraints of each of the elements of the TensorArray.
     82         Need not be fully defined.
     83       colocate_with_first_write_call: If `True`, the TensorArray will be
     84         colocated on the same device as the Tensor used on its first write
     85         (write operations include `write`, `unstack`, and `split`).  If `False`,
     86         the TensorArray will be placed on the device determined by the
     87         device context available during its initialization.
     88       name: A name for the operation (optional).
     89 
     90     Raises:
     91       ValueError: if both handle and tensor_array_name are provided.
     92       TypeError: if handle is provided but is not a Tensor.
     93     """
     94     if handle is not None and tensor_array_name:
     95       raise ValueError(
     96           "Cannot construct with both handle and tensor_array_name")
     97     if handle is not None and not isinstance(handle, ops.Tensor):
     98       raise TypeError("Handle must be a Tensor")
     99     if handle is None and size is None:
    100       raise ValueError("Size must be provided if handle is not provided")
    101     if handle is not None and size is not None:
    102       raise ValueError("Cannot provide both a handle and size "
    103                        "at the same time")
    104     if handle is not None and element_shape is not None:
    105       raise ValueError("Cannot provide both a handle and element_shape "
    106                        "at the same time")
    107     if handle is not None and dynamic_size is not None:
    108       raise ValueError("Cannot provide both a handle and dynamic_size "
    109                        "at the same time")
    110     if handle is not None and clear_after_read is not None:
    111       raise ValueError("Cannot provide both a handle and clear_after_read "
    112                        "at the same time")
    113 
    114     if clear_after_read is None:
    115       clear_after_read = True
    116     dynamic_size = dynamic_size or False
    117 
    118     self._dtype = dtype
    119 
    120     # Used to keep track of what tensors the TensorArray should be
    121     # colocated with.  We choose to colocate the TensorArray with the
    122     # first tensor written to it.
    123     self._colocate_with_first_write_call = colocate_with_first_write_call
    124     if colocate_with_first_write_call:
    125       self._colocate_with = []
    126     else:
    127       self._colocate_with = None
    128 
    129     # Record the current static shape for the array elements. The element
    130     # shape is defined either by `element_shape` or the shape of the tensor
    131     # of the first write. If `infer_shape` is true, all writes checks for
    132     # shape equality.
    133     if element_shape is None:
    134       self._infer_shape = infer_shape
    135       self._element_shape = []
    136     else:
    137       self._infer_shape = True
    138       self._element_shape = [tensor_shape.TensorShape(element_shape)]
    139     with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope:
    140       if handle is not None:
    141         self._handle = handle
    142         if flow is None:
    143           raise ValueError("flow must not be None if handle is not None.")
    144         self._flow = flow
    145       else:
    146         # Construct the TensorArray with an empty device.  The first
    147         # write into the TensorArray from a Tensor with a set device
    148         # will retroactively set the device value of this op.
    149         def create():
    150           """Create the TensorArray op."""
    151           return gen_data_flow_ops._tensor_array_v3(
    152               dtype=dtype,
    153               size=size,
    154               element_shape=element_shape,
    155               identical_element_shapes=infer_shape,
    156               dynamic_size=dynamic_size,
    157               clear_after_read=clear_after_read,
    158               tensor_array_name=tensor_array_name,
    159               name=scope)
    160         if colocate_with_first_write_call:
    161           with ops.device(None), ops.colocate_with(None, ignore_existing=True):
    162             self._handle, self._flow = create()
    163         else:
    164           self._handle, self._flow = create()
    165 
    166   @property
    167   def flow(self):
    168     return self._flow
    169 
    170   @property
    171   def dtype(self):
    172     return self._dtype
    173 
    174   @property
    175   def handle(self):
    176     return self._handle
    177 
    178   def _merge_element_shape(self, shape):
    179     """Changes the element shape of the array given a shape to merge with.
    180 
    181     Args:
    182       shape: A `TensorShape` object to merge with.
    183 
    184     Raises:
    185       ValueError: if the provided shape is incompatible with the current
    186           element shape of the `TensorArray`.
    187     """
    188 
    189     if self._element_shape:
    190       if not shape.is_compatible_with(self._element_shape[0]):
    191         raise ValueError(
    192             "Inconsistent shapes: saw %s but expected %s "
    193             "(and infer_shape=True)" % (shape, self._element_shape[0]))
    194       self._element_shape[0] = self._element_shape[0].merge_with(shape)
    195     else:
    196       self._element_shape.append(shape)
    197 
    198   @contextlib.contextmanager
    199   def _maybe_colocate_with(self, value):
    200     """Colocate operations with an internal colocation group or `value`.
    201 
    202     Args:
    203       value: `Tensor`, the tensor to try to colocate with.
    204 
    205     Yields:
    206       Does not yield anything, but the new context is a colocation context.
    207 
    208     If no internal colocation group is set, colocate with `value` and set
    209     the internal colocation group to be value.
    210     """
    211     if not self._colocate_with_first_write_call:
    212       yield
    213     else:
    214       if not self._colocate_with:
    215         self._colocate_with.append(value)
    216       with ops.colocate_with(self._colocate_with[0]):
    217         yield
    218 
    219   def identity(self):
    220     """See TensorArray."""
    221     flow = array_ops.identity(self._flow)
    222     ta = TensorArray(
    223         dtype=self._dtype, handle=self._handle, flow=flow,
    224         infer_shape=self._infer_shape,
    225         colocate_with_first_write_call=self._colocate_with_first_write_call)
    226     ta._element_shape = self._element_shape
    227     ta._colocate_with = self._colocate_with
    228     return ta
    229 
    230   def grad(self, source, flow=None, name=None):
    231     """See TensorArray."""
    232     # tensor_array_grad requires a flow input when forward
    233     # TensorArrays are dynamically sized.  This forces the creation
    234     # of the grad TensorArray only once the final forward array's size
    235     # is fixed.
    236     if flow is None:
    237       flow = self.flow
    238     with ops.name_scope(name, "TensorArrayGrad", [self._handle]):
    239       with ops.colocate_with(self._handle):
    240         g_handle, unused_flow = gen_data_flow_ops._tensor_array_grad_v3(
    241             handle=self._handle, source=source, flow_in=flow, name=name)
    242         with ops.control_dependencies([g_handle]):
    243           flow = array_ops.identity(flow, name="gradient_flow")
    244         g = TensorArray(
    245             dtype=self._dtype,
    246             handle=g_handle,
    247             flow=flow,
    248             infer_shape=self._infer_shape,
    249             colocate_with_first_write_call=False)
    250         g._element_shape = self._element_shape
    251         return g
    252 
    253   def read(self, index, name=None):
    254     """See TensorArray."""
    255     value = gen_data_flow_ops._tensor_array_read_v3(
    256         handle=self._handle,
    257         index=index,
    258         flow_in=self._flow,
    259         dtype=self._dtype,
    260         name=name)
    261     if self._element_shape:
    262       value.set_shape(self._element_shape[0].dims)
    263     return value
    264 
    265   @tf_should_use.should_use_result
    266   def write(self, index, value, name=None):
    267     """See TensorArray."""
    268     with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]):
    269       value = ops.convert_to_tensor(value, name="value")
    270       if self._infer_shape:
    271         self._merge_element_shape(value.shape)
    272       with self._maybe_colocate_with(value):
    273         flow_out = gen_data_flow_ops._tensor_array_write_v3(
    274             handle=self._handle,
    275             index=index,
    276             value=value,
    277             flow_in=self._flow,
    278             name=name)
    279       ta = TensorArray(
    280           dtype=self._dtype, handle=self._handle, flow=flow_out,
    281           colocate_with_first_write_call=self._colocate_with_first_write_call)
    282       ta._infer_shape = self._infer_shape
    283       ta._element_shape = self._element_shape
    284       ta._colocate_with = self._colocate_with
    285       return ta
    286 
    287   def stack(self, name=None):
    288     """See TensorArray."""
    289     with ops.colocate_with(self._handle):
    290       with ops.name_scope(name, "TensorArrayStack", [self._handle]):
    291         return self.gather(math_ops.range(0, self.size()), name=name)
    292 
    293   def gather(self, indices, name=None):
    294     """See TensorArray."""
    295     if self._element_shape:
    296       element_shape = self._element_shape[0]
    297     else:
    298       element_shape = tensor_shape.TensorShape(None)
    299     value = gen_data_flow_ops._tensor_array_gather_v3(
    300         handle=self._handle,
    301         indices=indices,
    302         flow_in=self._flow,
    303         dtype=self._dtype,
    304         name=name,
    305         element_shape=element_shape)
    306     if self._element_shape and self._element_shape[0].dims is not None:
    307       value.set_shape([None] + self._element_shape[0].dims)
    308     return value
    309 
    310   def concat(self, name=None):
    311     """See TensorArray."""
    312     if self._element_shape and self._element_shape[0].dims is not None:
    313       element_shape_except0 = (
    314           tensor_shape.TensorShape(self._element_shape[0].dims[1:]))
    315     else:
    316       element_shape_except0 = tensor_shape.TensorShape(None)
    317     value, _ = gen_data_flow_ops._tensor_array_concat_v3(
    318         handle=self._handle,
    319         flow_in=self._flow,
    320         dtype=self._dtype,
    321         name=name,
    322         element_shape_except0=element_shape_except0)
    323     if self._element_shape and self._element_shape[0].dims is not None:
    324       value.set_shape([None] + self._element_shape[0].dims[1:])
    325     return value
    326 
    327   @tf_should_use.should_use_result
    328   def unstack(self, value, name=None):
    329     """See TensorArray."""
    330     with ops.name_scope(name, "TensorArrayUnstack", [self._handle, value]):
    331       num_elements = array_ops.shape(value)[0]
    332       return self.scatter(
    333           indices=math_ops.range(0, num_elements), value=value, name=name)
    334 
    335   @tf_should_use.should_use_result
    336   def scatter(self, indices, value, name=None):
    337     """See TensorArray."""
    338     with ops.name_scope(name, "TensorArrayScatter",
    339                         [self._handle, value, indices]):
    340       value = ops.convert_to_tensor(value, name="value")
    341       if self._infer_shape and context.in_graph_mode():
    342         self._merge_element_shape(value.shape[1:])
    343       with self._maybe_colocate_with(value):
    344         flow_out = gen_data_flow_ops._tensor_array_scatter_v3(
    345             handle=self._handle,
    346             indices=indices,
    347             value=value,
    348             flow_in=self._flow,
    349             name=name)
    350       ta = TensorArray(
    351           dtype=self._dtype, handle=self._handle, flow=flow_out,
    352           colocate_with_first_write_call=self._colocate_with_first_write_call)
    353       ta._infer_shape = self._infer_shape
    354       ta._element_shape = self._element_shape
    355       ta._colocate_with = self._colocate_with
    356       return ta
    357 
    358   @tf_should_use.should_use_result
    359   def split(self, value, lengths, name=None):
    360     """See TensorArray."""
    361     with ops.name_scope(name, "TensorArraySplit",
    362                         [self._handle, value, lengths]):
    363       value = ops.convert_to_tensor(value, name="value")
    364       with self._maybe_colocate_with(value):
    365         lengths_64 = math_ops.to_int64(lengths)
    366         if self._infer_shape and context.in_graph_mode():
    367           clengths = tensor_util.constant_value(lengths_64)
    368           if value.shape.dims is not None:
    369             if clengths is not None and clengths.max() == clengths.min():
    370               self._merge_element_shape(
    371                   tensor_shape.TensorShape([clengths[0]]).concatenate(
    372                       value.shape[1:]))
    373         flow_out = gen_data_flow_ops._tensor_array_split_v3(
    374             handle=self._handle,
    375             value=value,
    376             lengths=lengths_64,
    377             flow_in=self._flow,
    378             name=name)
    379       ta = TensorArray(
    380           dtype=self._dtype, handle=self._handle, flow=flow_out,
    381           colocate_with_first_write_call=self._colocate_with_first_write_call)
    382       ta._infer_shape = self._infer_shape
    383       ta._element_shape = self._element_shape
    384       ta._colocate_with = self._colocate_with
    385       return ta
    386 
    387   def size(self, name=None):
    388     """See TensorArray."""
    389     return gen_data_flow_ops._tensor_array_size_v3(
    390         handle=self._handle, flow_in=self.flow, name=name)
    391 
    392   @tf_should_use.should_use_result
    393   def close(self, name=None):
    394     """See TensorArray."""
    395     return gen_data_flow_ops._tensor_array_close_v3(
    396         handle=self._handle, name=name)
    397 
    398 # pylint: enable=protected-access
    399 
    400 
    401 # pylint: disable=protected-access
    402 def _eager_write_no_copy(ta, index, value):
    403   """Writes value into an _EagerTensorArray without creating a new TensorArray.
    404 
    405   Args:
    406     ta: _EagerTensorArray into which to write value.
    407     index: 0-D.  int32 scalar with the index to write to.
    408     value: N-D.  Tensor of type `dtype`.  The Tensor to write to this index.
    409 
    410   Raises:
    411     errors_impl.AlreadyExistsError: attempting to overwrite an entry.
    412     errors_impl.InvalidArgumentError: value dtype does not match `ta`'s dtype.
    413     errors_impl.OutOfRangeError: `index` is out of bounds.
    414     ValueError: shape of `value` is not consistent with inferred shape.
    415   """
    416 
    417   if isinstance(index, ops.EagerTensor):
    418     index = index.numpy()
    419 
    420   if index < 0:
    421     raise errors_impl.OutOfRangeError(
    422         None, None,
    423         "Writing to negative indices (index %d) is not allowed." % index)
    424 
    425   tensor_array = ta._tensor_array
    426   size = len(tensor_array)
    427   if index >= size:
    428     if not ta._dynamic_size:
    429       raise errors_impl.OutOfRangeError(
    430           None, None,
    431           "Tried to write to index %d but array is not resizeable and size "
    432           "is: %d" % (index, size))
    433     tensor_array.extend([None for _ in range(index - size + 1)])
    434 
    435   if not isinstance(value, ops.EagerTensor):
    436     value = constant_op.constant(value)
    437 
    438   if ta._infer_shape:
    439     if ta._element_shape is None:
    440       ta._element_shape = value.shape
    441     elif ta._element_shape != value.shape:
    442       raise ValueError("Incompatible shape for value (%s), expected (%s)" %
    443                        (value.shape.as_list(), ta._element_shape.as_list()))
    444 
    445   if ta._dtype != value.dtype:
    446     raise errors_impl.InvalidArgumentError(
    447         None, None,
    448         "TensorArray dtype is %s but Op is trying to write dtype %s" %
    449         (ta._dtype.name, value.dtype.name))
    450 
    451   if ta._tensor_array[index] is not None:
    452     raise errors_impl.AlreadyExistsError(
    453         None, None,
    454         "Could not write to TensorArray index %d because it has already been "
    455         "written to." % index)
    456 
    457   tensor_array[index] = value
    458 
    459 # pylint: enable=protected-access
    460 
    461 
    462 class _EagerTensorArray(object):
    463   """Eager-mode implementation of TensorArray.
    464   """
    465 
    466   def __init__(self,
    467                dtype,
    468                size=None,
    469                dynamic_size=None,
    470                clear_after_read=None,
    471                tensor_array_name=None,
    472                handle=None,
    473                flow=None,
    474                infer_shape=True,
    475                element_shape=None,
    476                colocate_with_first_write_call=True,
    477                name=None):
    478     """Constructs an Eager mode TensorArray.
    479 
    480     Args:
    481       dtype: (required) data type of the TensorArray.
    482       size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
    483         Required if handle is not provided.
    484       dynamic_size: (optional) Python bool: If true, writes to the TensorArray
    485         can grow the TensorArray past its initial size.  Default: False.
    486       clear_after_read: Boolean (optional, default: True).  If True, clear
    487         TensorArray values after reading them.  This disables read-many
    488         semantics, but allows early release of memory.
    489       tensor_array_name: unused.
    490       handle: unsupported.
    491       flow: unsupported.
    492       infer_shape: used for error checking, same semantics as TensorArray.
    493       element_shape: used for error checking, same semantics as TensorArray.
    494       colocate_with_first_write_call: unsupported.
    495       name: unsupported.
    496 
    497     Raises:
    498       ValueError: handle or flow are supplied, or if size is not supplied.
    499     """
    500 
    501     del (flow, tensor_array_name, name)  # not meaningful in Eager
    502 
    503     if handle is not None:
    504       raise ValueError("TensorArray handles are not supported in Eager mode.")
    505     if size is None:
    506       raise ValueError("Size must be declared for TensorArrays in Eager mode.")
    507 
    508     # These attributes are not meaningful in Eager, but some library functions
    509     # (e.g., those in control_flow_ops.py) access them to create new tensor
    510     # arrays; as such, we define them for the sake of compatibility.
    511     self._handle = None
    512     # we assign a dummy value to _flow in case other code assumes it to be
    513     # a Tensor
    514     self._flow = constant_op.constant(0, dtype=dtypes.int32)
    515     self._infer_shape = infer_shape
    516     self._element_shape = element_shape
    517     self._colocate_with_first_write_call = colocate_with_first_write_call
    518 
    519     self._dtype = dtype
    520     self._dynamic_size = dynamic_size or False
    521     self._clear_after_read = (
    522         True if clear_after_read is None else clear_after_read)
    523     self._previously_read_indices = []
    524 
    525     if isinstance(size, ops.EagerTensor):
    526       size = size.numpy()
    527     self._tensor_array = [None for _ in range(size)]
    528 
    529   @property
    530   def flow(self):
    531     """Flows are not meaningful in Eager; this exists for compatibility."""
    532     return self._flow
    533 
    534   @property
    535   def dtype(self):
    536     return self._dtype
    537 
    538   @property
    539   def handle(self):
    540     """Handles are not meaningful in Eager; this exists for compatibility."""
    541     return self._handle
    542 
    543   def _identity_without_array(self):
    544     """Returns a new TensorArray with the same properties as this Eager one.
    545 
    546     NB: Does not set the underlying _tensor_array attribute.
    547     """
    548     ta = TensorArray(
    549         dtype=self._dtype,
    550         size=len(self._tensor_array),
    551         dynamic_size=self._dynamic_size,
    552         clear_after_read=self._clear_after_read,
    553         handle=self._handle,
    554         flow=self._flow,
    555         infer_shape=self._infer_shape,
    556         element_shape=self._element_shape,
    557         colocate_with_first_write_call=self._colocate_with_first_write_call)
    558     ta._implementation._previously_read_indices = self._previously_read_indices  # pylint: disable=protected-access
    559     return ta
    560 
    561   def identity(self):
    562     """See TensorArray."""
    563     ta = self._identity_without_array()
    564     ta._implementation._tensor_array = [t for t in self._tensor_array]  # pylint: disable=protected-access
    565     return ta
    566 
    567   def grad(self, source, flow=None, name=None):
    568     raise NotImplementedError(
    569         "TensorArray.grad is not supported in Eager mode; Eager's gradient "
    570         "implementation does not use/need this function to compute gradients "
    571         "of operations that use TensorArrays.")
    572 
    573   def read(self, index, name=None):
    574     """See TensorArray."""
    575     del name  # not meaningful in Eager mode
    576 
    577     if isinstance(index, ops.EagerTensor):
    578       index = index.numpy()
    579 
    580     if index < 0:
    581       raise errors_impl.OutOfRangeError(
    582           None, None,
    583           "Reading from negative indices (index %d) is not allowed." % index)
    584 
    585     if index >= len(self._tensor_array):
    586       raise errors_impl.OutOfRangeError(
    587           None, None, "Tried to read from index %d but array size is: %d" %
    588           (index, len(self._tensor_array)))
    589 
    590     tensor = self._tensor_array[index]
    591     if tensor is None:
    592       if index in self._previously_read_indices:
    593         raise errors_impl.InvalidArgumentError(
    594             None, None,
    595             "Could not read index %d twice because it was cleared after "
    596             "a previous read (perhaps try setting clear_after_read = false?)" %
    597             index)
    598       else:
    599         tensor = self._maybe_zero(index)
    600 
    601     if self._clear_after_read:
    602       self._tensor_array[index] = None
    603       self._previously_read_indices.append(index)
    604     return tensor
    605 
    606   def write(self, index, value, name=None):
    607     """See TensorArray."""
    608     del name  # not meaningful in Eager mode
    609     ta = self.identity()
    610     _eager_write_no_copy(ta._implementation, index, value)  # pylint: disable=protected-access
    611     return ta
    612 
    613   def _maybe_zero(self, ix):
    614     val = self._tensor_array[ix]
    615     if val is None:
    616       val = self._tensor_array[ix] = array_ops.zeros(
    617           shape=self._element_shape, dtype=self._dtype)
    618     return val
    619 
    620   def stack(self, name=None):
    621     """See TensorArray."""
    622     if self._tensor_array:
    623       for ix in range(len(self._tensor_array)):
    624         self._maybe_zero(ix)
    625     return array_ops.stack(self._tensor_array, name=name)
    626 
    627   def gather(self, indices, name=None):
    628     """See TensorArray."""
    629     del name  # not meaningful in Eager mode
    630     return array_ops.stack([self._maybe_zero(i) for i in indices.numpy()])
    631 
    632   def concat(self, name=None):
    633     """See TensorArray."""
    634     try:
    635       return array_ops.concat(
    636           [self._maybe_zero(ix) for ix in range(len(self._tensor_array))],
    637           0, name=name)
    638     except errors_impl.OpError:
    639       # Reproduce a subset of the error-handling for graph-mode TensorArrays.
    640       shapes = [t.shape for t in self._tensor_array]
    641       ndims = [s.ndims for s in shapes]
    642       if 0 in ndims:
    643         idx = ndims.index(0)
    644         raise errors_impl.InvalidArgumentError(
    645             None, None, "Concat saw a scalar shape at index %d but requires "
    646             "at least vectors." % idx)
    647       else:
    648         raise
    649 
    650   def unstack(self, value, name=None):
    651     """See TensorArray."""
    652     tensors = array_ops.unstack(value, name=name)
    653     if len(tensors) > len(self._tensor_array) and not self._dynamic_size:
    654       raise ValueError(
    655           "Cannot unstack %d tensors into a TensorArray of static size %d" %
    656           (len(tensors), len(self._tensor_array)))
    657     ta = self._identity_without_array()
    658     ta._implementation._tensor_array = tensors  # pylint: disable=protected-access
    659     return ta
    660 
    661   def scatter(self, indices, value, name=None):
    662     """See TensorArray."""
    663     del name  # unused in Eager
    664     ta = self.identity()
    665     for index, val in zip(indices.numpy(), array_ops.unstack(value)):
    666       _eager_write_no_copy(ta._implementation, index, val)  # pylint: disable=protected-access
    667     return ta
    668 
    669   def split(self, value, lengths, name=None):
    670     """See TensorArray."""
    671     # error checking to match graph-mode errors
    672     value = constant_op.constant(value)
    673     lengths = constant_op.constant(lengths)
    674     sum_lengths = math_ops.reduce_sum(lengths)
    675     if lengths.shape.ndims != 1:
    676       raise errors_impl.InvalidArgumentError(
    677           None, None, "Expected lengths to be a vector, received shape: %s" %
    678           lengths.shape.as_list())
    679     elif value.shape.ndims == 0:
    680       raise errors_impl.InvalidArgumentError(
    681           None, None, "Expected value to be at least a vector, "
    682           "but received shape: %s" % value.shape.as_list())
    683     elif sum_lengths.numpy() != value.shape.as_list()[0]:
    684       raise errors_impl.InvalidArgumentError(
    685           None, None, "Expected sum of lengths to be equal to "
    686           "values.shape[0], but sum of lengths is %d and "
    687           "value's shape is: %s " % (sum_lengths.numpy(),
    688                                      value.shape.as_list()))
    689     elif not self._dynamic_size and lengths.shape[0] != len(self._tensor_array):
    690       raise errors_impl.InvalidArgumentError(
    691           None, None, "TensorArray's size is not equal to the size of "
    692           "lengths (%d vs. %d), and the TensorArray is not marked as "
    693           "dynamically resizeable" % (len(self._tensor_array),
    694                                       lengths.shape[0]))
    695     else:
    696       ta = self._identity_without_array()
    697       tensor_array = array_ops.split(value, lengths, name=name)
    698       ta._implementation._tensor_array = tensor_array  # pylint: disable=protected-access
    699       return ta
    700 
    701   def size(self, name=None):
    702     """See TensorArray."""
    703     del name  # not meaningful in Eager mode
    704     return constant_op.constant(len(self._tensor_array))
    705 
    706   def close(self, name=None):
    707     del name  # not meaningful in Eager mode
    708     del self._tensor_array[:]
    709     return
    710 
    711 
    712 # TensorArray is designed to hide an underlying implementation object
    713 # and as such accesses many of that object's hidden fields.
    714 # pylint: disable=protected-access
    715 @tf_export("TensorArray")
    716 class TensorArray(object):
    717   """Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays.
    718 
    719   This class is meant to be used with dynamic iteration primitives such as
    720   `while_loop` and `map_fn`.  It supports gradient back-propagation via special
    721   "flow" control flow dependencies.
    722   """
    723 
    724   def __init__(self,
    725                dtype,
    726                size=None,
    727                dynamic_size=None,
    728                clear_after_read=None,
    729                tensor_array_name=None,
    730                handle=None,
    731                flow=None,
    732                infer_shape=True,
    733                element_shape=None,
    734                colocate_with_first_write_call=True,
    735                name=None):
    736     """Construct a new TensorArray or wrap an existing TensorArray handle.
    737 
    738     A note about the parameter `name`:
    739 
    740     The name of the `TensorArray` (even if passed in) is uniquified: each time
    741     a new `TensorArray` is created at runtime it is assigned its own name for
    742     the duration of the run.  This avoids name collisions if a `TensorArray`
    743     is created within a `while_loop`.
    744 
    745     Args:
    746       dtype: (required) data type of the TensorArray.
    747       size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
    748         Required if handle is not provided.
    749       dynamic_size: (optional) Python bool: If true, writes to the TensorArray
    750         can grow the TensorArray past its initial size.  Default: False.
    751       clear_after_read: Boolean (optional, default: True).  If True, clear
    752         TensorArray values after reading them.  This disables read-many
    753         semantics, but allows early release of memory.
    754       tensor_array_name: (optional) Python string: the name of the TensorArray.
    755         This is used when creating the TensorArray handle.  If this value is
    756         set, handle should be None.
    757       handle: (optional) A `Tensor` handle to an existing TensorArray.  If this
    758         is set, tensor_array_name should be None. Only supported in graph mode.
    759       flow: (optional) A float `Tensor` scalar coming from an existing
    760         `TensorArray.flow`. Only supported in graph mode.
    761       infer_shape: (optional, default: True) If True, shape inference
    762         is enabled.  In this case, all elements must have the same shape.
    763       element_shape: (optional, default: None) A `TensorShape` object specifying
    764         the shape constraints of each of the elements of the TensorArray.
    765         Need not be fully defined.
    766       colocate_with_first_write_call: If `True`, the TensorArray will be
    767         colocated on the same device as the Tensor used on its first write
    768         (write operations include `write`, `unstack`, and `split`).  If `False`,
    769         the TensorArray will be placed on the device determined by the
    770         device context available during its initialization.
    771       name: A name for the operation (optional).
    772 
    773     Raises:
    774       ValueError: if both handle and tensor_array_name are provided.
    775       TypeError: if handle is provided but is not a Tensor.
    776     """
    777     if context.in_graph_mode():
    778       implementation = _GraphTensorArray
    779     else:
    780       implementation = _EagerTensorArray
    781 
    782     self._implementation = implementation(
    783         dtype,
    784         size=size,
    785         dynamic_size=dynamic_size,
    786         clear_after_read=clear_after_read,
    787         tensor_array_name=tensor_array_name,
    788         handle=handle,
    789         flow=flow,
    790         infer_shape=infer_shape,
    791         element_shape=element_shape,
    792         colocate_with_first_write_call=colocate_with_first_write_call,
    793         name=name)
    794 
    795   @property
    796   def flow(self):
    797     """The flow `Tensor` forcing ops leading to this TensorArray state."""
    798     return self._implementation._flow
    799 
    800   @property
    801   def dtype(self):
    802     """The data type of this TensorArray."""
    803     return self._implementation._dtype
    804 
    805   @property
    806   def handle(self):
    807     """The reference to the TensorArray."""
    808     return self._implementation._handle
    809 
    810   @property
    811   def _infer_shape(self):
    812     return self._implementation._infer_shape
    813 
    814   @_infer_shape.setter
    815   def _infer_shape(self, infer_shape):
    816     self._implementation._infer_shape = infer_shape
    817 
    818   @property
    819   def _element_shape(self):
    820     return self._implementation._element_shape
    821 
    822   @_element_shape.setter
    823   def _element_shape(self, element_shape):
    824     self._implementation._element_shape = element_shape
    825 
    826   @property
    827   def _colocate_with_first_write_call(self):
    828     return self._implementation._colocate_with_first_write_call
    829 
    830   @property
    831   def _colocate_with(self):
    832     return self._implementation._colocate_with
    833 
    834   @_colocate_with.setter
    835   def _colocate_with(self, colocate_with):
    836     self._implementation._colocate_with = colocate_with
    837 
    838   def identity(self):
    839     """Returns a TensorArray with the same content and properties.
    840 
    841     Returns:
    842       A new TensorArray object with flow that ensures the control dependencies
    843       from the contexts will become control dependencies for writes, reads, etc.
    844       Use this object all for subsequent operations.
    845     """
    846     return self._implementation.identity()
    847 
    848   def grad(self, source, flow=None, name=None):
    849     return self._implementation.grad(source, flow=flow, name=name)
    850 
    851   def read(self, index, name=None):
    852     """Read the value at location `index` in the TensorArray.
    853 
    854     Args:
    855       index: 0-D.  int32 tensor with the index to read from.
    856       name: A name for the operation (optional).
    857 
    858     Returns:
    859       The tensor at index `index`.
    860     """
    861     return self._implementation.read(index, name=name)
    862 
    863   @tf_should_use.should_use_result
    864   def write(self, index, value, name=None):
    865     """Write `value` into index `index` of the TensorArray.
    866 
    867     Args:
    868       index: 0-D.  int32 scalar with the index to write to.
    869       value: N-D.  Tensor of type `dtype`.  The Tensor to write to this index.
    870       name: A name for the operation (optional).
    871 
    872     Returns:
    873       A new TensorArray object with flow that ensures the write occurs.
    874       Use this object all for subsequent operations.
    875 
    876     Raises:
    877       ValueError: if there are more writers than specified.
    878     """
    879     return self._implementation.write(index, value, name=name)
    880 
    881   def stack(self, name=None):
    882     """Return the values in the TensorArray as a stacked `Tensor`.
    883 
    884     All of the values must have been written and their shapes must all match.
    885     If input shapes have rank-`R`, then output shape will have rank-`(R+1)`.
    886 
    887     Args:
    888       name: A name for the operation (optional).
    889 
    890     Returns:
    891       All the tensors in the TensorArray stacked into one tensor.
    892     """
    893     return self._implementation.stack(name=name)
    894 
    895   def gather(self, indices, name=None):
    896     """Return selected values in the TensorArray as a packed `Tensor`.
    897 
    898     All of selected values must have been written and their shapes
    899     must all match.
    900 
    901     Args:
    902       indices: A `1-D` `Tensor` taking values in `[0, max_value)`.  If
    903         the `TensorArray` is not dynamic, `max_value=size()`.
    904       name: A name for the operation (optional).
    905 
    906     Returns:
    907       The tensors in the `TensorArray` selected by `indices`, packed into one
    908       tensor.
    909     """
    910     return self._implementation.gather(indices, name=name)
    911 
    912   def concat(self, name=None):
    913     """Return the values in the TensorArray as a concatenated `Tensor`.
    914 
    915     All of the values must have been written, their ranks must match, and
    916     and their shapes must all match for all dimensions except the first.
    917 
    918     Args:
    919       name: A name for the operation (optional).
    920 
    921     Returns:
    922       All the tensors in the TensorArray concatenated into one tensor.
    923     """
    924     return self._implementation.concat(name=name)
    925 
    926   @tf_should_use.should_use_result
    927   def unstack(self, value, name=None):
    928     """Unstack the values of a `Tensor` in the TensorArray.
    929 
    930     If input value shapes have rank-`R`, then the output TensorArray will
    931     contain elements whose shapes are rank-`(R-1)`.
    932 
    933     Args:
    934       value: (N+1)-D.  Tensor of type `dtype`.  The Tensor to unstack.
    935       name: A name for the operation (optional).
    936 
    937     Returns:
    938       A new TensorArray object with flow that ensures the unstack occurs.
    939       Use this object all for subsequent operations.
    940 
    941     Raises:
    942       ValueError: if the shape inference fails.
    943     """
    944     return self._implementation.unstack(value, name=name)
    945 
    946   @tf_should_use.should_use_result
    947   def scatter(self, indices, value, name=None):
    948     """Scatter the values of a `Tensor` in specific indices of a `TensorArray`.
    949 
    950     Args:
    951       indices: A `1-D` `Tensor` taking values in `[0, max_value)`.  If
    952         the `TensorArray` is not dynamic, `max_value=size()`.
    953       value: (N+1)-D.  Tensor of type `dtype`.  The Tensor to unpack.
    954       name: A name for the operation (optional).
    955 
    956     Returns:
    957       A new TensorArray object with flow that ensures the scatter occurs.
    958       Use this object all for subsequent operations.
    959 
    960     Raises:
    961       ValueError: if the shape inference fails.
    962     """
    963     return self._implementation.scatter(indices, value, name=name)
    964 
    965   @tf_should_use.should_use_result
    966   def split(self, value, lengths, name=None):
    967     """Split the values of a `Tensor` into the TensorArray.
    968 
    969     Args:
    970       value: (N+1)-D.  Tensor of type `dtype`.  The Tensor to split.
    971       lengths: 1-D.  int32 vector with the lengths to use when splitting
    972         `value` along its first dimension.
    973       name: A name for the operation (optional).
    974 
    975     Returns:
    976       A new TensorArray object with flow that ensures the split occurs.
    977       Use this object all for subsequent operations.
    978 
    979     Raises:
    980       ValueError: if the shape inference fails.
    981     """
    982     return self._implementation.split(value, lengths, name=name)
    983 
    984   def size(self, name=None):
    985     """Return the size of the TensorArray."""
    986     return self._implementation.size(name=name)
    987 
    988   @tf_should_use.should_use_result
    989   def close(self, name=None):
    990     """Close the current TensorArray."""
    991     return self._implementation.close(name=name)
    992 
    993 # pylint: enable=protected-access
    994