Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_TENSOR_ARRAY_H_
     17 #define TENSORFLOW_KERNELS_TENSOR_ARRAY_H_
     18 
     19 #include <limits.h>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/partial_tensor_shape.h"
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/framework/resource_mgr.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/kernels/aggregate_ops.h"
     30 #include "tensorflow/core/kernels/fill_functor.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/platform/logging.h"
     33 #include "tensorflow/core/platform/types.h"
     34 
     35 namespace tensorflow {
     36 
     37 typedef Eigen::ThreadPoolDevice CPUDevice;
     38 typedef Eigen::GpuDevice GPUDevice;
     39 
     40 namespace tensor_array {
     41 
     42 // Full implementations are in tensor_array.cc
     43 template <typename Device, typename T>
     44 Status AddToTensor(OpKernelContext* ctx, Tensor* sum, const Tensor* current,
     45                    const Tensor* add) {
     46   return errors::InvalidArgument(
     47       "tensor_array::AddToTensor type not supported: ",
     48       DataTypeString(DataTypeToEnum<T>::value));
     49 };
     50 
     51 #define TENSOR_ARRAY_WRITE_OR_ADD(Device, T)                         \
     52   template <>                                                        \
     53   Status AddToTensor<Device, T>(OpKernelContext * ctx, Tensor * sum, \
     54                                 const Tensor* current, const Tensor* add);
     55 
     56 #define TENSOR_ARRAY_WRITE_OR_ADD_CPU(T) TENSOR_ARRAY_WRITE_OR_ADD(CPUDevice, T)
     57 TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_CPU)
     58 #undef TENSOR_ARRAY_WRITE_OR_ADD_CPU
     59 
     60 #if GOOGLE_CUDA
     61 
     62 #define TENSOR_ARRAY_WRITE_OR_ADD_GPU(T) TENSOR_ARRAY_WRITE_OR_ADD(GPUDevice, T)
     63 TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
     64 TF_CALL_complex64(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
     65 TF_CALL_complex128(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
     66 #undef TENSOR_ARRAY_WRITE_OR_ADD_GPU
     67 
     68 #endif  // GOOGLE_CUDA
     69 
     70 #undef TENSOR_ARRAY_WRITE_OR_ADD
     71 
     72 template <typename Device, typename T>
     73 Status TensorSetZero(OpKernelContext* ctx, Tensor* value) {
     74   return errors::InvalidArgument(
     75       "tensor_array::TensorSetZero type not supported: ",
     76       DataTypeString(DataTypeToEnum<T>::value));
     77 };
     78 
     79 #define TENSOR_ARRAY_SET_ZERO(Device, T) \
     80   template <>                            \
     81   Status TensorSetZero<Device, T>(OpKernelContext * ctx, Tensor * value);
     82 
     83 #define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
     84 TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU)
     85 #undef TENSOR_ARRAY_SET_ZERO_CPU
     86 
     87 #if GOOGLE_CUDA
     88 
     89 #define TENSOR_ARRAY_SET_ZERO_GPU(T) TENSOR_ARRAY_SET_ZERO(GPUDevice, T)
     90 TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU);
     91 TF_CALL_complex64(TENSOR_ARRAY_SET_ZERO_GPU);
     92 TF_CALL_complex128(TENSOR_ARRAY_SET_ZERO_GPU);
     93 #undef TENSOR_ARRAY_SET_ZERO_GPU
     94 
     95 #endif  // GOOGLE_CUDA
     96 
     97 #undef TENSOR_ARRAY_SET_ZERO
     98 
     99 }  // namespace tensor_array
    100 
    101 // The TensorArray object keeps an array of PersistentTensors.  It
    102 // allows reading from the array and writing to the array.
    103 //
    104 // Important properties:
    105 //   * Usually, writing to a particular index in the TensorArray is allowed at
    106 //     most once per index.  In a special case, writes with the flag
    107 //     multiple_writes_aggregate allow multiple writes to the same
    108 //     index.  In this case, the writes are summed.
    109 //   * Multiple reads are supported.
    110 //   * Deep copies of PersistentTensors are rarely made.  The only
    111 //     time they are made is when WriteOrAggregate is called at least twice
    112 //     on the same index with the flag multiple_writes_aggregate = True.
    113 //   * Reading and Writing to the array is protected by a mutex.
    114 //     All operations on a TensorArray are thread-safe.
    115 //   * A TensorArray may be preemptively closed, which releases all
    116 //     memory associated with it.
    117 //
    118 // These properties together allow the TensorArray to work as a
    119 // functional object and makes gradient computation easy.  For
    120 // example:
    121 //   * Write-Once semantics mean the gradient of a TensorArray Read never has to
    122 //     worry which of multiple writes to that index the gradient value
    123 //     is meant for.
    124 //   * Read-Many semantics (when using clear_after_read=false) allow the
    125 //     TensorArray to be read, packed, or concatenated multiple times;
    126 //     and the gradient operations use the multiple_writes_aggregate
    127 //     flag to aggregate the backprop writes.  Multiple backprop writes to
    128 //     the same index are partial gradients corresponding to the
    129 //     multiple reads of that index in the forward phase.
    130 //
    131 class TensorArray : public ResourceBase {
    132  public:
    133   static std::atomic<int64> tensor_array_counter;
    134 
    135   // Construct a TensorArray for holding Tensors of type 'dtype' with
    136   // 'N' elements.  While the underlying storage is a std::vector and
    137   // can hold more than MAX_INT entries, in practice we do not expect
    138   // users to construct this many Tensors for storage in a TensorArray.
    139   TensorArray(const string& key, const DataType& dtype, const Tensor& handle,
    140               int32 N, const PartialTensorShape& element_shape,
    141               bool identical_element_shapes, bool dynamic_size,
    142               bool multiple_writes_aggregate, bool is_grad, int32 marked_size,
    143               bool clear_after_read)
    144       : key_(key),
    145         dtype_(dtype),
    146         handle_(handle),
    147         closed_(false),
    148         dynamic_size_(dynamic_size),
    149         multiple_writes_aggregate_(multiple_writes_aggregate),
    150         gradients_disallowed_(false),
    151         clear_after_read_(clear_after_read),
    152         is_grad_(is_grad),
    153         marked_size_(marked_size),
    154         element_shape_(element_shape),
    155         identical_element_shapes_(identical_element_shapes),
    156         tensors_(N) {}
    157 
    158   // Write PersistentTensor 'value' to index 'index'.
    159   //
    160   // Preconditions:
    161   //  * The TensorArray is not closed
    162   //  * If the array has dynamic size:
    163   //      The index is >= 0
    164   //    Otherwise:
    165   //      The index is in [0, N) where N == Size()
    166   //  * The dtype of the Tensor in 'value' matches the TensorArray's dtype.
    167   //  * If multiple_writes_aggregate is false:
    168   //    The Tensor at 'index' has not yet been written to.
    169   //  * If multiple_writes_aggregate is true:
    170   //    The Tensor at 'index' has the same shape as value.
    171   //
    172   // Side effects:
    173   //  * On the first write to 'index':
    174   //    - The underlying Tensor in 'value' has a new reference to it.
    175   //    - The index 'index' is marked as written.
    176   //  * If multiple_writes_aggregate is false, subsequent writes to 'index'
    177   //    raise an InvalidArgument error.
    178   //  * If multiple_writes_aggregate is true, subsequent writes to 'index':
    179   //    - The underlying Tensors in 'value' and from the first write
    180   //      are released and a local PersistentTensor is created.
    181   //    - Index 'index' is also marked as local_copy.
    182   //    - The gradients_disallowed flag is set true (GradientsAllowed()
    183   //      will now return false).
    184   //
    185   // Note, value is passed as a pointer because we its underlying
    186   // Tensor's shape is accessed.  Otherwise it is not modified.
    187   template <typename Device, typename T>
    188   Status WriteOrAggregate(OpKernelContext* ctx, const int32 index,
    189                           PersistentTensor* value) {
    190     mutex_lock l(mu_);
    191     return LockedWriteOrAggregate<Device, T>(ctx, index, value);
    192   }
    193 
    194   template <typename Device, typename T>
    195   Status WriteOrAggregateMany(OpKernelContext* ctx,
    196                               const std::vector<int32>& indices,
    197                               std::vector<PersistentTensor>* values) {
    198     mutex_lock l(mu_);
    199     int32 i = 0;
    200     for (const int32 ix : indices) {
    201       Status s = LockedWriteOrAggregate<Device, T>(ctx, ix, &(*values)[i]);
    202       ++i;
    203       TF_RETURN_IF_ERROR(s);
    204     }
    205     return Status::OK();
    206   }
    207 
    208   // Read from index 'index' into PersistentTensor 'value'.
    209   //
    210   // Preconditions:
    211   //  * The TensorArray is not closed
    212   //  * The index is in [0, N)
    213   //  * The Tensor at 'index' has been written to.
    214   //  * The Tensor at 'index' has not been read from with flag
    215   //    clear_after_read = true.
    216   //
    217   // Side effects:
    218   //  * If clear_after_read is true, the reference to the underlying
    219   //    Tensor is deleted.
    220   //  * The reference to the underlying Tensor at 'index' is copied to
    221   //    the returned '*value'.
    222   //  * The index is marked as read (it cannot be rewritten to).
    223   template <typename Device, typename T>
    224   Status Read(OpKernelContext* ctx, const int32 index,
    225               PersistentTensor* value) {
    226     mutex_lock l(mu_);
    227     return LockedRead<Device, T>(ctx, index, value);
    228   }
    229 
    230   template <typename Device, typename T>
    231   Status ReadMany(OpKernelContext* ctx, const std::vector<int32>& indices,
    232                   std::vector<PersistentTensor>* values) {
    233     mutex_lock l(mu_);
    234     values->clear();
    235     values->resize(indices.size());
    236     int32 i = 0;
    237     for (const int32 ix : indices) {
    238       Status s = LockedRead<Device, T>(ctx, ix, &(*values)[i]);
    239       ++i;
    240       if (!s.ok()) return s;
    241     }
    242     return Status::OK();
    243   }
    244 
    245   DataType ElemType() const { return dtype_; }
    246 
    247   PartialTensorShape ElemShape() {
    248     mutex_lock l(mu_);
    249     return element_shape_;
    250   }
    251 
    252   Status SetElemShape(const PartialTensorShape& candidate) {
    253     mutex_lock l(mu_);
    254     PartialTensorShape new_element_shape_;
    255     Status s = element_shape_.MergeWith(candidate, &new_element_shape_);
    256     if (!s.ok()) {
    257       return s;
    258     }
    259     element_shape_ = new_element_shape_;
    260     return Status::OK();
    261   }
    262 
    263   string DebugString() override {
    264     mutex_lock l(mu_);
    265     CHECK(!closed_);
    266     return strings::StrCat("TensorArray[", tensors_.size(), "]");
    267   }
    268 
    269   bool IsClosed() {
    270     mutex_lock l(mu_);
    271     return closed_;
    272   }
    273 
    274   // Return the size of the TensorArray.
    275   Status Size(int32* size) {
    276     mutex_lock l(mu_);
    277     TF_RETURN_IF_ERROR(LockedReturnIfClosed());
    278     *size = tensors_.size();
    279     return Status::OK();
    280   }
    281 
    282   // Record the size of the TensorArray after an unpack or split.
    283   Status SetMarkedSize(int32 size) {
    284     mutex_lock l(mu_);
    285     TF_RETURN_IF_ERROR(LockedReturnIfClosed());
    286     if (!is_grad_) {
    287       marked_size_ = size;
    288     }
    289     return Status::OK();
    290   }
    291 
    292   // Return the marked size of the TensorArray.
    293   Status MarkedSize(int32* size) {
    294     mutex_lock l(mu_);
    295     TF_RETURN_IF_ERROR(LockedReturnIfClosed());
    296     *size = marked_size_;
    297     return Status::OK();
    298   }
    299 
    300   // Return the size that should be used by pack or concat op.
    301   Status PackOrConcatSize(int32* size) {
    302     mutex_lock l(mu_);
    303     TF_RETURN_IF_ERROR(LockedReturnIfClosed());
    304     *size = is_grad_ ? marked_size_ : tensors_.size();
    305     return Status::OK();
    306   }
    307 
    308   // Once a TensorArray is being used for gradient calculations, it
    309   // should be marked as no longer resizeable.
    310   void DisableDynamicSize() {
    311     mutex_lock l(mu_);
    312     dynamic_size_ = false;
    313   }
    314 
    315   bool HasDynamicSize() {
    316     mutex_lock l(mu_);
    317     return dynamic_size_;
    318   }
    319 
    320   bool GradientsAllowed() {
    321     mutex_lock l(mu_);
    322     return !gradients_disallowed_;
    323   }
    324 
    325   bool HasIdenticalElementShapes() const { return identical_element_shapes_; }
    326 
    327   // Copy the TensorShapes from another TensorArray into this one.
    328   // The sizes of the two TensorArrays must match and this one
    329   // may not have any entries filled in.  This performs a "soft copy",
    330   // essentially filling the current TensorArray with virtual
    331   // zero-tensors, which will be replaced by future aggregate writes,
    332   // or instantiated by future reads.  Requires a non-const pointer
    333   // to the rhs to access its mutex.
    334   Status CopyShapesFrom(TensorArray* rhs);
    335 
    336   // Clear the TensorArray, including any Tensor references, and mark as closed.
    337   void ClearAndMarkClosed() {
    338     mutex_lock l(mu_);
    339     tensors_.clear();
    340     closed_ = true;
    341   }
    342 
    343   mutex* mu() { return &mu_; }
    344   Tensor* handle() { return &handle_; }
    345 
    346   ResourceHandle resource_handle(OpKernelContext* ctx) {
    347     return MakePerStepResourceHandle<TensorArray>(ctx, key_);
    348   }
    349 
    350  private:
    351   Status LockedWrite(OpKernelContext* ctx, const int32 index,
    352                      PersistentTensor* value) EXCLUSIVE_LOCKS_REQUIRED(mu_);
    353 
    354   template <typename Device, typename T>
    355   Status LockedWriteOrAggregate(OpKernelContext* ctx, const int32 index,
    356                                 PersistentTensor* value)
    357       EXCLUSIVE_LOCKS_REQUIRED(mu_);
    358 
    359   template <typename Device, typename T>
    360   Status LockedRead(OpKernelContext* ctx, const int32 index,
    361                     PersistentTensor* value) EXCLUSIVE_LOCKS_REQUIRED(mu_);
    362 
    363   Status LockedReturnIfClosed() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    364     if (closed_) {
    365       return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
    366                                      " has already been closed.");
    367     }
    368     return Status::OK();
    369   }
    370 
    371   const string key_;
    372 
    373   const DataType dtype_;
    374   Tensor handle_;
    375 
    376   mutex mu_;
    377 
    378   // Marks that the tensor_array_ has been cleared.
    379   bool closed_ GUARDED_BY(mu_);
    380 
    381   // Writes are allowed to grow the array.
    382   bool dynamic_size_;
    383 
    384   // Multiple writes to the same index will result in summation of the
    385   // values (used by backprop)
    386   const bool multiple_writes_aggregate_;
    387 
    388   // If multiple Writes were attempted (e.g. via attribute
    389   // multiple_writes_aggregate), then gradients are disallowed.
    390   bool gradients_disallowed_ GUARDED_BY(mu_);
    391 
    392   // After a read at an index, clear away its PersistentTensor to
    393   // release memory.
    394   const bool clear_after_read_;
    395 
    396   // True iff this is a gradient tensor array.
    397   const bool is_grad_;
    398 
    399   // The size of the TensorArray after a (legacy) unpack or split is performed.
    400   // -1 if there has been no unpack or split performed on the TensorArray.
    401   int32 marked_size_;
    402 
    403   // The shape of each element in the TensorArray, may be partially known or not
    404   // known at all.
    405   PartialTensorShape element_shape_ GUARDED_BY(mu_);
    406 
    407   // Whether all elements in the TensorArray have identical shapes.
    408   // This allows certain behaviors, like dynamically checking for
    409   // consistent shapes on write, and being able to fill in properly
    410   // shaped zero tensors on stack -- even if the initial element_shape
    411   // was not fully defined.
    412   const bool identical_element_shapes_;
    413 
    414   // TensorAndState is used to keep track of the PersistentTensors
    415   // stored in the TensorArray, along with their shapes, and a boolean
    416   // that determines whether they have already been read or not.
    417   struct TensorAndState {
    418     TensorAndState()
    419         : written(false), read(false), cleared(false), local_copy(false) {}
    420     PersistentTensor tensor;
    421     TensorShape shape;
    422     bool written;  // True if a Tensor has been written to the index.
    423     bool read;  // True if a Tensor has been written to and read from the index.
    424     bool cleared;  // True if a tensor has been read with
    425                    // clear_after_read = true;
    426 
    427     // Used by writes when multiple_writes_aggregate is true.  In this
    428     // case, the first time a value is written, it is a shallow copy.
    429     // The second time a value is written, it is aggregated.  However,
    430     // in this case a new Tensor must be constructed to hold the
    431     // aggregated value.  This flag marks that such a Tensor is being
    432     // used.  All future writes will aggregate to the existing local Tensor.
    433     bool local_copy;
    434   };
    435   // The list of underlying PersistentTensors and states.
    436   std::vector<TensorAndState> tensors_ GUARDED_BY(mu_);
    437 };
    438 
    439 template <typename Device, typename T>
    440 Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx,
    441                                            const int32 index,
    442                                            PersistentTensor* value) {
    443   TF_RETURN_IF_ERROR(LockedReturnIfClosed());
    444   size_t index_size = static_cast<size_t>(index);
    445   if (index < 0 || (!dynamic_size_ && index_size >= tensors_.size())) {
    446     return errors::InvalidArgument(
    447         "TensorArray ", handle_.vec<string>()(1), ": Tried to write to index ",
    448         index, " but array is not resizeable and size is: ", tensors_.size());
    449   }
    450   if (dynamic_size_) {
    451     // We must grow the internal TensorArray
    452     if (index_size >= tensors_.capacity()) {
    453       tensors_.reserve(2 * (index_size + 1));
    454     }
    455     if (index_size >= tensors_.size()) {
    456       tensors_.resize(index_size + 1);
    457     }
    458   }
    459   TensorAndState& t = tensors_[index];
    460 
    461   Tensor* value_t = value->AccessTensor(ctx);
    462   if (value_t->dtype() != dtype_) {
    463     return errors::InvalidArgument(
    464         "TensorArray ", handle_.vec<string>()(1),
    465         ": Could not write to TensorArray index ", index,
    466         " because the value dtype is ", DataTypeString(value_t->dtype()),
    467         " but TensorArray dtype is ", DataTypeString(dtype_), ".");
    468   }
    469   if (!element_shape_.IsCompatibleWith(value_t->shape())) {
    470     return errors::InvalidArgument(
    471         "TensorArray ", handle_.vec<string>()(1),
    472         ": Could not write to TensorArray index ", index,
    473         " because the value shape is ", value_t->shape().DebugString(),
    474         " which is incompatible with the TensorArray's inferred element "
    475         "shape: ",
    476         element_shape_.DebugString(), " (consider setting infer_shape=False).");
    477   } else if (identical_element_shapes_ && !element_shape_.IsFullyDefined()) {
    478     element_shape_ = PartialTensorShape(value_t->shape().dim_sizes());
    479   }
    480 
    481   if (t.read) {
    482     return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
    483                                    ": Could not write to TensorArray index ",
    484                                    index, " because it has already been read.");
    485   }
    486 
    487   if (!multiple_writes_aggregate_ && t.written) {
    488     return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
    489                                    ": Could not write to TensorArray index ",
    490                                    index,
    491                                    " because it has already been written to.");
    492   }
    493 
    494   if (t.written) {
    495     DCHECK(multiple_writes_aggregate_);
    496 
    497     // Check that value_t shape matches t.shape
    498     if (value_t->shape() != t.shape) {
    499       return errors::InvalidArgument(
    500           "TensorArray ", handle_.vec<string>()(1),
    501           ": Could not aggregate to TensorArray index ", index,
    502           " because the existing shape is ", t.shape.DebugString(),
    503           " but the new input shape is ", value_t->shape().DebugString(), ".");
    504     }
    505 
    506     if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) {
    507       // If existing_t == nullptr but written == true, then what was stored
    508       // was just a shape, which just means zeros.  So all we must do in this
    509       // case is copy the reference over and return early.
    510       t.tensor = *value;
    511       return Status::OK();
    512     }
    513 
    514     Tensor* existing_t = t.tensor.AccessTensor(ctx);
    515 
    516     if (t.local_copy) {
    517       Status s = tensor_array::AddToTensor<Device, T>(ctx, existing_t,
    518                                                       existing_t, value_t);
    519       TF_RETURN_IF_ERROR(s);
    520     } else {
    521       PersistentTensor local_tensor;
    522       Tensor* local_tensor_t;
    523       TF_RETURN_IF_ERROR(ctx->allocate_persistent(
    524           dtype_, existing_t->shape(), &local_tensor, &local_tensor_t));
    525       Status s = tensor_array::AddToTensor<Device, T>(ctx, local_tensor_t,
    526                                                       existing_t, value_t);
    527       TF_RETURN_IF_ERROR(s);
    528       t.tensor = local_tensor;
    529       t.local_copy = true;
    530     }
    531 
    532     // We've aggregated the values, so disallow backprop on this
    533     // TensorArray.
    534     gradients_disallowed_ = true;
    535   } else {
    536     t.tensor = *value;
    537     t.shape = value_t->shape();
    538     t.written = true;
    539   }
    540   return Status::OK();
    541 }
    542 
    543 template <typename Device, typename T>
    544 Status TensorArray::LockedRead(OpKernelContext* ctx, const int32 index,
    545                                PersistentTensor* value) {
    546   TF_RETURN_IF_ERROR(LockedReturnIfClosed());
    547   if ((index < 0) ||
    548       (!is_grad_ && (static_cast<size_t>(index) >= tensors_.size()))) {
    549     return errors::InvalidArgument("Tried to read from index ", index,
    550                                    " but array size is: ", tensors_.size());
    551   }
    552   size_t index_t = static_cast<size_t>(index);
    553   if ((is_grad_ && (index_t >= tensors_.size() || !tensors_[index].written)) ||
    554       (!is_grad_ && (index_t < tensors_.size() && !tensors_[index].written))) {
    555     // Special case returning zeros if this is a gradient read that happens
    556     // after a stop_gradients call with dynamic forward TensorArrays.
    557     // There is sometimes a race condition where the gradient is not
    558     // written due to stop_gradients, but is later read.
    559     TensorShape element_shape;
    560     if (is_grad_ && index_t < tensors_.size() &&
    561         tensors_[index].shape.dims() > 0) {
    562       // A gradient TensorArray has more specific gradient information
    563       // available for each entry.  A forward TensorArray must rely on
    564       // the global element_shape_ to fill in zeros on read.
    565       element_shape = tensors_[index].shape;
    566     } else if (!element_shape_.IsFullyDefined()) {
    567       return errors::InvalidArgument(
    568           "TensorArray ", handle_.vec<string>()(1),
    569           ": Could not read from TensorArray index ", index,
    570           ".  Furthermore, the element shape is not fully defined: ",
    571           element_shape_.DebugString(),
    572           ".  It is possible you are working with a resizeable TensorArray and "
    573           "stop_gradients is not allowing the gradients to be written.  If you "
    574           "set the full "
    575           "element_shape property on the forward TensorArray, the proper "
    576           "all-zeros tensor "
    577           "will be returned instead of incurring this error.");
    578     } else {
    579       element_shape_.AsTensorShape(&element_shape);  // Always succeeds.
    580     }
    581     if (index_t >= tensors_.size()) {
    582       // Fill in tensors_ up to index to have known shape.
    583       size_t old_tensors_size = tensors_.size();
    584       tensors_.resize(index + 1);
    585       for (size_t i = old_tensors_size; i < index + 1; ++i) {
    586         tensors_[i].shape = element_shape;
    587         tensors_[i].written = true;
    588       }
    589     } else {
    590       tensors_[index].shape = element_shape;
    591       tensors_[index].written = true;
    592     }
    593   }
    594 
    595   TensorAndState& t = tensors_[index];
    596 
    597   if (t.cleared) {
    598     return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
    599                                    ": Could not read index ", index,
    600                                    " twice because it was cleared after a "
    601                                    "previous read (perhaps try setting "
    602                                    "clear_after_read = false?).");
    603   }
    604 
    605   if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) {
    606     // We stored just a shape, but no value.  This means create and
    607     // return zeros of the appropriate shape.
    608     Tensor* tensor_t;
    609     TF_RETURN_IF_ERROR(
    610         ctx->allocate_persistent(dtype_, t.shape, &t.tensor, &tensor_t));
    611     if (t.shape.num_elements() > 0) {
    612       Status s = tensor_array::TensorSetZero<Device, T>(ctx, tensor_t);
    613       if (!s.ok()) return s;
    614     }
    615   }
    616 
    617   // Data is available inside the tensor, copy the reference over.
    618   *value = t.tensor;
    619 
    620   if (clear_after_read_) {
    621     t.tensor = PersistentTensor();
    622     t.cleared = true;
    623   }
    624   t.read = true;
    625   return Status::OK();
    626 }
    627 
    628 }  // namespace tensorflow
    629 
    630 #endif  // TENSORFLOW_KERNELS_TENSOR_ARRAY_H_
    631