Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/data_flow_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include <limits>
     21 #include <vector>
     22 // TODO(b/31496047): Fix non-standard include order.
     23 #include <numeric>  // clang-format off
     24 
     25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/register_types.h"
     28 #include "tensorflow/core/framework/resource_mgr.h"
     29 #include "tensorflow/core/framework/tensor.h"
     30 #include "tensorflow/core/framework/tensor_shape.h"
     31 #include "tensorflow/core/framework/types.h"
     32 #include "tensorflow/core/kernels/bounds_check.h"
     33 #include "tensorflow/core/kernels/concat_lib.h"
     34 #include "tensorflow/core/kernels/split_lib.h"
     35 #include "tensorflow/core/kernels/tensor_array.h"
     36 #include "tensorflow/core/lib/core/errors.h"
     37 #include "tensorflow/core/lib/core/refcount.h"
     38 #include "tensorflow/core/lib/strings/strcat.h"
     39 #include "tensorflow/core/platform/dynamic_annotations.h"
     40 #include "tensorflow/core/platform/logging.h"
     41 #include "tensorflow/core/platform/thread_annotations.h"
     42 #include "tensorflow/core/platform/types.h"
     43 
     44 typedef Eigen::ThreadPoolDevice CPUDevice;
     45 #if GOOGLE_CUDA
     46 typedef Eigen::GpuDevice GPUDevice;
     47 #endif  // GOOGLE_CUDA
     48 
     49 // clang-format on
     50 
     51 namespace tensorflow {
     52 
     53 Status GetHandle(OpKernelContext* ctx, string* container, string* ta_handle) {
     54   {
     55     Tensor tensor;
     56     // Assuming that handle is the input at index 0.
     57     if (IsRefType(ctx->input_dtype(0))) {
     58       tensor = ctx->mutable_input(0, false);
     59     } else {
     60       tensor = ctx->input(0);
     61     }
     62     if (tensor.NumElements() != 2) {
     63       return errors::InvalidArgument(
     64           "Tensor array handle must be 2-element vector, but had shape: ",
     65           tensor.shape().DebugString());
     66     }
     67     auto h = tensor.flat<string>();
     68     *container = h(0);
     69     *ta_handle = h(1);
     70   }
     71   return Status::OK();
     72 }
     73 
     74 Status GetTensorArray(OpKernelContext* ctx, TensorArray** tensor_array) {
     75   string container;
     76   string ta_handle;
     77   if (ctx->input_dtype(0) != DT_RESOURCE) {
     78     TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &ta_handle));
     79     ResourceMgr* rm = ctx->resource_manager();
     80     if (rm == nullptr) return errors::Internal("No resource manager.");
     81     TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(),
     82                                   container + ta_handle, tensor_array));
     83     return Status::OK();
     84   } else {
     85     return LookupResource(ctx, HandleFromInput(ctx, 0), tensor_array);
     86   }
     87 }
     88 
     89 Status SetupFlowControlInputs(OpKernelContext* ctx, bool set_output) {
     90   const Tensor* flow_control;
     91   TF_RETURN_IF_ERROR(ctx->input("flow_in", &flow_control));
     92   if (set_output) {
     93     TF_RETURN_IF_ERROR(ctx->set_output("flow_out", *flow_control));
     94   }
     95   return Status::OK();
     96 }
     97 
     98 // CREATION *******************************************************************
     99 
    100 // Virtual class for shared behavior between TensorArrayOp and
    101 // TensorArrayGradOp.
    102 class TensorArrayCreationOp : public OpKernel {
    103  public:
    104   explicit TensorArrayCreationOp(OpKernelConstruction* context)
    105       : OpKernel(context), device_type_(context->device_type()) {}
    106 
    107   void Compute(OpKernelContext* ctx) override {
    108     Tensor tensor_array_output_handle;
    109 
    110     AllocatorAttributes alloc_attr;
    111     alloc_attr.set_on_host(true);
    112     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
    113                             tensorflow::DT_STRING, tensorflow::TensorShape({2}),
    114                             &tensor_array_output_handle, alloc_attr));
    115     // Store the handle in a per-step container of the RM.
    116     ResourceMgr* rm = ctx->resource_manager();
    117     OP_REQUIRES(ctx, rm != nullptr, errors::Internal("No resource manager."));
    118 
    119     TensorArray* output_tensor_array;
    120     OP_REQUIRES_OK(ctx, CreateTensorArray(ctx, rm, &tensor_array_output_handle,
    121                                           &output_tensor_array));
    122     if (IsRefType(ctx->expected_output_dtype(0))) {
    123       ctx->set_output_ref(0, output_tensor_array->mu(),
    124                           output_tensor_array->handle());
    125     } else if (ctx->expected_output_dtype(0) == DT_STRING) {
    126       ctx->set_output(0, *output_tensor_array->handle());
    127     } else {
    128       Tensor* handle;
    129       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
    130       handle->flat<ResourceHandle>()(0) =
    131           output_tensor_array->resource_handle(ctx);
    132     }
    133     if (ctx->num_outputs() == 2) {
    134       // Create the flow output.
    135       Tensor* flow;
    136       OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &flow));
    137       if (device_type_ == DEVICE_CPU) {
    138         // Value doesn't matter, but this makes msan not complaint about
    139         // copying an uninitialized value. To do this on GPU would require
    140         // a kernel launch or a host->device memcpy, so we avoid that.
    141         flow->flat<float>()(0) = 0;
    142       }
    143     }
    144   }
    145 
    146  protected:
    147   virtual Status CreateTensorArray(OpKernelContext* ctx, ResourceMgr* rm,
    148                                    Tensor* tensor_array_output_handle,
    149                                    TensorArray** output_tensor_array) = 0;
    150 
    151  private:
    152   const DeviceType device_type_;
    153 };
    154 
    155 // A per-run local tensor array. The tensor array uses a "per-step" resource
    156 // manager which ensures that correct garbage collection on error or
    157 // successful completion.
    158 class TensorArrayOp : public TensorArrayCreationOp {
    159  public:
    160   explicit TensorArrayOp(OpKernelConstruction* context)
    161       : TensorArrayCreationOp(context) {
    162     OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
    163     OP_REQUIRES_OK(context, context->GetAttr("element_shape", &element_shape_));
    164     OP_REQUIRES_OK(context, context->GetAttr("dynamic_size", &dynamic_size_));
    165     // The HasAttr check is for backwards compatibility with older op
    166     // versions which do not have this attribute.
    167     if (context->HasAttr("identical_element_shapes")) {
    168       OP_REQUIRES_OK(context, context->GetAttr("identical_element_shapes",
    169                                                &identical_element_shapes_));
    170     } else {
    171       identical_element_shapes_ = false;
    172     }
    173     OP_REQUIRES_OK(context,
    174                    context->GetAttr("clear_after_read", &clear_after_read_));
    175     OP_REQUIRES_OK(context,
    176                    context->GetAttr("tensor_array_name", &tensor_array_name_));
    177     if (tensor_array_name_.empty()) tensor_array_name_ = name();
    178   }
    179 
    180   Status CreateTensorArray(OpKernelContext* ctx, ResourceMgr* rm,
    181                            Tensor* tensor_array_output_handle,
    182                            TensorArray** output_tensor_array) override {
    183     const Tensor* tensor_size;
    184     TF_RETURN_IF_ERROR(ctx->input("size", &tensor_size));
    185 
    186     if (!TensorShapeUtils::IsScalar(tensor_size->shape())) {
    187       return errors::InvalidArgument(
    188           "TensorArray size must be scalar, but had shape: ",
    189           tensor_size->shape().DebugString());
    190     }
    191     const int32 size = tensor_size->scalar<int32>()();
    192     if (size < 0) {
    193       return errors::InvalidArgument("Size should be >= 0.");
    194     }
    195 
    196     auto handle = tensor_array_output_handle->flat<string>();
    197     string unique_tensor_array_name =
    198         strings::StrCat(tensor_array_name_, "_",
    199                         TensorArray::tensor_array_counter.fetch_add(1));
    200     handle(0) = "_tensor_arrays";
    201     handle(1) = unique_tensor_array_name;
    202 
    203     auto key = strings::StrCat(handle(0), unique_tensor_array_name);
    204 
    205     TensorArray* tensor_array = new TensorArray(
    206         key, dtype_, *tensor_array_output_handle, size, element_shape_,
    207         identical_element_shapes_, dynamic_size_,
    208         false /* multiple_writes_aggregate */, false /* is_grad */,
    209         -1 /* marked_size */, clear_after_read_);
    210 
    211     TF_RETURN_IF_ERROR(
    212         rm->Create(ctx->step_container()->name(), key, tensor_array));
    213 
    214     *output_tensor_array = tensor_array;
    215 
    216     return Status::OK();
    217   }
    218 
    219  private:
    220   DataType dtype_;
    221   PartialTensorShape element_shape_;
    222   bool identical_element_shapes_;
    223   bool dynamic_size_;
    224   bool clear_after_read_;
    225   string tensor_array_name_;  // The name used to create the TensorArray.
    226 
    227   TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp);
    228 };
    229 
    230 REGISTER_KERNEL_BUILDER(Name("TensorArray").Device(DEVICE_CPU), TensorArrayOp);
    231 REGISTER_KERNEL_BUILDER(Name("TensorArrayV2").Device(DEVICE_CPU),
    232                         TensorArrayOp);
    233 REGISTER_KERNEL_BUILDER(Name("TensorArrayV3").Device(DEVICE_CPU),
    234                         TensorArrayOp);
    235 
    236 #if GOOGLE_CUDA
    237 
    238 #define REGISTER_GPU(type)                                   \
    239   REGISTER_KERNEL_BUILDER(Name("TensorArray")                \
    240                               .Device(DEVICE_GPU)            \
    241                               .TypeConstraint<type>("dtype") \
    242                               .HostMemory("size")            \
    243                               .HostMemory("handle"),         \
    244                           TensorArrayOp);                    \
    245   REGISTER_KERNEL_BUILDER(Name("TensorArrayV2")              \
    246                               .Device(DEVICE_GPU)            \
    247                               .TypeConstraint<type>("dtype") \
    248                               .HostMemory("size")            \
    249                               .HostMemory("handle"),         \
    250                           TensorArrayOp);                    \
    251   REGISTER_KERNEL_BUILDER(Name("TensorArrayV3")              \
    252                               .Device(DEVICE_GPU)            \
    253                               .TypeConstraint<type>("dtype") \
    254                               .HostMemory("size")            \
    255                               .HostMemory("handle"),         \
    256                           TensorArrayOp);
    257 
    258 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
    259 TF_CALL_complex64(REGISTER_GPU);
    260 TF_CALL_complex128(REGISTER_GPU);
    261 REGISTER_GPU(bfloat16);
    262 #undef REGISTER_GPU
    263 
    264 #endif  // GOOGLE_CUDA
    265 
    266 // GRADIENT *******************************************************************
    267 
    268 class TensorArrayGradOp : public TensorArrayCreationOp {
    269  public:
    270   explicit TensorArrayGradOp(OpKernelConstruction* context)
    271       : TensorArrayCreationOp(context) {
    272     OP_REQUIRES_OK(context, context->GetAttr("source", &source_));
    273   }
    274 
    275   Status CreateTensorArray(OpKernelContext* ctx, ResourceMgr* rm,
    276                            Tensor* tensor_array_output_handle,
    277                            TensorArray** output_tensor_array) override {
    278     string container;
    279     string tensor_array_name;
    280     if (ctx->input_dtype(0) != DT_RESOURCE) {
    281       TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &tensor_array_name));
    282       if (container != "_tensor_arrays") {
    283         return errors::InvalidArgument(
    284             "Input container should be '_tensor_arrays',  but received '",
    285             container, "'");
    286       }
    287     } else {
    288       container = "_tensor_arrays";
    289       auto resource = ctx->input(0).flat<ResourceHandle>()(0);
    290       if (StringPiece(resource.name()).substr(0, container.size()) !=
    291           container) {
    292         return errors::InvalidArgument("Wrong input container. ",
    293                                        resource.name());
    294       }
    295       tensor_array_name =
    296           StringPiece(resource.name()).substr(container.size()).ToString();
    297     }
    298 
    299     auto output_handle = tensor_array_output_handle->flat<string>();
    300     output_handle(0) = "_tensor_array_grads";
    301     output_handle(1) = strings::StrCat(tensor_array_name, "@", source_);
    302 
    303     TensorArray* tensor_array;
    304     TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(),
    305                                   strings::StrCat(container, tensor_array_name),
    306                                   &tensor_array));
    307     core::ScopedUnref unref(tensor_array);
    308 
    309     // Once gradients are being calculated, the forward TensorArray
    310     // may no longer be resized by new Writes.
    311     tensor_array->DisableDynamicSize();
    312 
    313     int32 array_size = 0;
    314     int32 marked_size = 0;
    315     TF_RETURN_IF_ERROR(tensor_array->Size(&array_size));
    316     TF_RETURN_IF_ERROR(tensor_array->MarkedSize(&marked_size));
    317 
    318     if (array_size < 0) {
    319       return errors::InvalidArgument("ArraySize should be >= 0.");
    320     }
    321     if (!tensor_array->GradientsAllowed()) {
    322       return errors::InvalidArgument(
    323           "Unable to create a gradients TensorArray for ", tensor_array_name,
    324           ".  Perhaps you used the multiple_writes_aggregate flag on a "
    325           "previous write?  Gradient calculation is impossible when multiple "
    326           "writes are performed to the same index.");
    327     }
    328 
    329     const auto key = strings::StrCat(output_handle(0), output_handle(1));
    330     auto creator = [this, key, tensor_array, array_size, marked_size,
    331                     tensor_array_output_handle,
    332                     output_handle](TensorArray** ret) -> Status {
    333       *ret = new TensorArray(
    334           key, tensor_array->ElemType(), *tensor_array_output_handle,
    335           array_size, tensor_array->ElemShape(),
    336           tensor_array->HasIdenticalElementShapes(), false /* dynamic_size */,
    337           true /* multiple_writes_aggregate */, true /* is_grad */,
    338           marked_size /* marked_size */, true /* close_after_read */);
    339       return (*ret)->CopyShapesFrom(tensor_array);
    340     };
    341 
    342     Status s = rm->LookupOrCreate<TensorArray>(
    343         ctx->step_container()->name(), key, output_tensor_array, creator);
    344     (*output_tensor_array)->Unref();
    345 
    346     return s;
    347   }
    348 
    349  private:
    350   // The gradient source for creating the given
    351   // gradient TensorArray.  This should be unique to each gradients
    352   // call.  Typical values look like "gradients", "gradients_1", ...
    353   string source_;
    354 
    355   TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGradOp);
    356 };
    357 
    358 REGISTER_KERNEL_BUILDER(Name("TensorArrayGrad").Device(DEVICE_CPU),
    359                         TensorArrayGradOp);
    360 REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV2").Device(DEVICE_CPU),
    361                         TensorArrayGradOp);
    362 REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV3").Device(DEVICE_CPU),
    363                         TensorArrayGradOp);
    364 
    365 REGISTER_KERNEL_BUILDER(Name("TensorArrayGrad")
    366                             .Device(DEVICE_GPU)
    367                             .HostMemory("handle")
    368                             .HostMemory("grad_handle"),
    369                         TensorArrayGradOp);
    370 REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV2")
    371                             .Device(DEVICE_GPU)
    372                             .HostMemory("handle")
    373                             .HostMemory("grad_handle"),
    374                         TensorArrayGradOp);
    375 REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV3")
    376                             .Device(DEVICE_GPU)
    377                             .HostMemory("handle")
    378                             .HostMemory("grad_handle"),
    379                         TensorArrayGradOp);
    380 
    381 // WRITE **********************************************************************
    382 
    383 template <typename Device, typename T>
    384 class TensorArrayWriteOp : public OpKernel {
    385  public:
    386   explicit TensorArrayWriteOp(OpKernelConstruction* context)
    387       : OpKernel(context) {}
    388 
    389   void Compute(OpKernelContext* ctx) override {
    390     OP_REQUIRES_OK(ctx, SetupFlowControlInputs(ctx, true));
    391 
    392     const Tensor* tensor_index;
    393     const Tensor* tensor_value;
    394     OP_REQUIRES_OK(ctx, ctx->input("index", &tensor_index));
    395     OP_REQUIRES_OK(ctx, ctx->input("value", &tensor_value));
    396 
    397     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_index->shape()),
    398                 errors::InvalidArgument(
    399                     "TensorArray index must be scalar, but had shape: ",
    400                     tensor_index->shape().DebugString()));
    401 
    402     TensorArray* tensor_array = nullptr;
    403     OP_REQUIRES_OK(ctx, GetTensorArray(ctx, &tensor_array));
    404     core::ScopedUnref unref(tensor_array);
    405     const int32 index = tensor_index->scalar<int32>()();
    406     OP_REQUIRES(
    407         ctx, tensor_value->dtype() == tensor_array->ElemType(),
    408         errors::InvalidArgument("TensorArray dtype is ",
    409                                 DataTypeString(tensor_array->ElemType()),
    410                                 " but Op is trying to write dtype ",
    411                                 DataTypeString(tensor_value->dtype()), "."));
    412     PersistentTensor persistent_tensor(*tensor_value);
    413     Status s = tensor_array->WriteOrAggregate<Device, T>(ctx, index,
    414                                                          &persistent_tensor);
    415     OP_REQUIRES_OK(ctx, s);
    416   }
    417 };
    418 
    419 #define REGISTER_WRITE(type)                                                   \
    420   REGISTER_KERNEL_BUILDER(                                                     \
    421       Name("TensorArrayWrite").Device(DEVICE_CPU).TypeConstraint<type>("T"),   \
    422       TensorArrayWriteOp<CPUDevice, type>);                                    \
    423   REGISTER_KERNEL_BUILDER(                                                     \
    424       Name("TensorArrayWriteV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    425       TensorArrayWriteOp<CPUDevice, type>);                                    \
    426   REGISTER_KERNEL_BUILDER(                                                     \
    427       Name("TensorArrayWriteV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    428       TensorArrayWriteOp<CPUDevice, type>);
    429 
    430 TF_CALL_ALL_TYPES(REGISTER_WRITE);
    431 
    432 #undef REGISTER_WRITE
    433 
    434 #if GOOGLE_CUDA
    435 
    436 #define REGISTER_GPU(type)                                      \
    437   REGISTER_KERNEL_BUILDER(Name("TensorArrayWrite")              \
    438                               .Device(DEVICE_GPU)               \
    439                               .TypeConstraint<type>("T")        \
    440                               .HostMemory("handle")             \
    441                               .HostMemory("index"),             \
    442                           TensorArrayWriteOp<GPUDevice, type>); \
    443   REGISTER_KERNEL_BUILDER(Name("TensorArrayWriteV2")            \
    444                               .Device(DEVICE_GPU)               \
    445                               .TypeConstraint<type>("T")        \
    446                               .HostMemory("handle")             \
    447                               .HostMemory("index"),             \
    448                           TensorArrayWriteOp<GPUDevice, type>); \
    449   REGISTER_KERNEL_BUILDER(Name("TensorArrayWriteV3")            \
    450                               .Device(DEVICE_GPU)               \
    451                               .TypeConstraint<type>("T")        \
    452                               .HostMemory("handle")             \
    453                               .HostMemory("index"),             \
    454                           TensorArrayWriteOp<GPUDevice, type>);
    455 
    456 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
    457 TF_CALL_complex64(REGISTER_GPU);
    458 TF_CALL_complex128(REGISTER_GPU);
    459 REGISTER_GPU(bfloat16);
    460 #undef REGISTER_GPU
    461 
    462 #endif  // GOOGLE_CUDA
    463 
    464 // READ ***********************************************************************
    465 
    466 template <typename Device, typename T>
    467 class TensorArrayReadOp : public OpKernel {
    468  public:
    469   explicit TensorArrayReadOp(OpKernelConstruction* context)
    470       : OpKernel(context) {
    471     OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
    472   }
    473 
    474   void Compute(OpKernelContext* ctx) override {
    475     OP_REQUIRES_OK(ctx, SetupFlowControlInputs(ctx, false));
    476 
    477     const Tensor* tensor_index;
    478     OP_REQUIRES_OK(ctx, ctx->input("index", &tensor_index));
    479 
    480     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_index->shape()),
    481                 errors::InvalidArgument(
    482                     "TensorArray index must be scalar, but had shape: ",
    483                     tensor_index->shape().DebugString()));
    484 
    485     TensorArray* tensor_array = nullptr;
    486     OP_REQUIRES_OK(ctx, GetTensorArray(ctx, &tensor_array));
    487     core::ScopedUnref unref(tensor_array);
    488 
    489     const int32 index = tensor_index->scalar<int32>()();
    490     OP_REQUIRES(
    491         ctx, dtype_ == tensor_array->ElemType(),
    492         errors::InvalidArgument(
    493             "TensorArray dtype is ", DataTypeString(tensor_array->ElemType()),
    494             " but Op requested dtype ", DataTypeString(dtype_), "."));
    495     PersistentTensor value;
    496     Status s = tensor_array->Read<Device, T>(ctx, index, &value);
    497     OP_REQUIRES_OK(ctx, s);
    498     ctx->set_output(0, *value.AccessTensor(ctx));
    499   }
    500 
    501  private:
    502   DataType dtype_;
    503 };
    504 
    505 #define REGISTER_READ(type)                                    \
    506   REGISTER_KERNEL_BUILDER(Name("TensorArrayRead")              \
    507                               .Device(DEVICE_CPU)              \
    508                               .TypeConstraint<type>("dtype"),  \
    509                           TensorArrayReadOp<CPUDevice, type>); \
    510   REGISTER_KERNEL_BUILDER(Name("TensorArrayReadV2")            \
    511                               .Device(DEVICE_CPU)              \
    512                               .TypeConstraint<type>("dtype"),  \
    513                           TensorArrayReadOp<CPUDevice, type>); \
    514   REGISTER_KERNEL_BUILDER(Name("TensorArrayReadV3")            \
    515                               .Device(DEVICE_CPU)              \
    516                               .TypeConstraint<type>("dtype"),  \
    517                           TensorArrayReadOp<CPUDevice, type>);
    518 
    519 TF_CALL_ALL_TYPES(REGISTER_READ)
    520 
    521 #undef REGISTER_READ
    522 
    523 #if GOOGLE_CUDA
    524 
    525 #define REGISTER_GPU(type)                                     \
    526   REGISTER_KERNEL_BUILDER(Name("TensorArrayRead")              \
    527                               .Device(DEVICE_GPU)              \
    528                               .TypeConstraint<type>("dtype")   \
    529                               .HostMemory("handle")            \
    530                               .HostMemory("index"),            \
    531                           TensorArrayReadOp<GPUDevice, type>); \
    532   REGISTER_KERNEL_BUILDER(Name("TensorArrayReadV2")            \
    533                               .Device(DEVICE_GPU)              \
    534                               .TypeConstraint<type>("dtype")   \
    535                               .HostMemory("handle")            \
    536                               .HostMemory("index"),            \
    537                           TensorArrayReadOp<GPUDevice, type>); \
    538   REGISTER_KERNEL_BUILDER(Name("TensorArrayReadV3")            \
    539                               .Device(DEVICE_GPU)              \
    540                               .TypeConstraint<type>("dtype")   \
    541                               .HostMemory("handle")            \
    542                               .HostMemory("index"),            \
    543                           TensorArrayReadOp<GPUDevice, type>);
    544 
    545 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
    546 TF_CALL_complex64(REGISTER_GPU);
    547 TF_CALL_complex128(REGISTER_GPU);
    548 REGISTER_GPU(bfloat16);
    549 #undef REGISTER_GPU
    550 
    551 #endif  // GOOGLE_CUDA
    552 
    553 // PACK and GATHER ************************************************************
    554 
    555 // Concatenate the elements in a TensorArray.  All elements must be
    556 // defined and have the same shape.
    557 template <typename Device, typename T, bool LEGACY_PACK>
    558 class TensorArrayPackOrGatherOp : public OpKernel {
    559  public:
    560   typedef typename TTypes<T, 2>::ConstMatrix ConstMatrix;
    561   typedef std::vector<std::unique_ptr<ConstMatrix> > ConstMatrixVector;
    562 
    563   explicit TensorArrayPackOrGatherOp(OpKernelConstruction* context)
    564       : OpKernel(context) {
    565     OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
    566     OP_REQUIRES_OK(context, context->GetAttr("element_shape", &element_shape_));
    567   }
    568 
    569   void Compute(OpKernelContext* ctx) override {
    570     OP_REQUIRES_OK(ctx, SetupFlowControlInputs(ctx, false));
    571 
    572     TensorArray* tensor_array = nullptr;
    573     OP_REQUIRES_OK(ctx, GetTensorArray(ctx, &tensor_array));
    574 
    575     core::ScopedUnref unref(tensor_array);
    576     OP_REQUIRES(
    577         ctx, dtype_ == tensor_array->ElemType(),
    578         errors::InvalidArgument(
    579             "TensorArray dtype is ", DataTypeString(tensor_array->ElemType()),
    580             " but Op requested dtype ", DataTypeString(dtype_), "."));
    581 
    582     // Ensure new element shape is compatible with the one stored in the
    583     // TensorArray.
    584     OP_REQUIRES_OK(ctx, tensor_array->SetElemShape(element_shape_));
    585 
    586     int32 num_indices;
    587     std::vector<PersistentTensor> values;
    588     std::vector<int32> indices;
    589     if (LEGACY_PACK) {
    590       OP_REQUIRES_OK(ctx, tensor_array->PackOrConcatSize(&num_indices));
    591       indices.resize(num_indices);
    592       std::iota(indices.begin(), indices.end(), 0);
    593     } else {
    594       const Tensor* tensor_indices;
    595       OP_REQUIRES_OK(ctx, ctx->input("indices", &tensor_indices));
    596       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(tensor_indices->shape()),
    597                   errors::InvalidArgument(
    598                       "Expected indices to be a vector, but received shape: ",
    599                       tensor_indices->shape().DebugString()));
    600       const auto indices_t = tensor_indices->vec<int32>();
    601       num_indices = tensor_indices->NumElements();
    602       indices.resize(num_indices);
    603       std::copy(indices_t.data(), indices_t.data() + num_indices,
    604                 indices.begin());
    605     }
    606 
    607     // If there are no elements to return, return a zero-element Tensor with
    608     // shape [0] + element_shape_
    609     if (num_indices == 0) {
    610       OP_REQUIRES(ctx, element_shape_.IsFullyDefined(),
    611                   errors::Unimplemented(
    612                       "TensorArray has size zero, but element shape ",
    613                       element_shape_.DebugString(),
    614                       " is not fully defined. "
    615                       "Currently only static shapes are supported when packing "
    616                       "zero-size TensorArrays."));
    617       TensorShape empty_shape;
    618       element_shape_.AsTensorShape(&empty_shape);
    619       empty_shape.InsertDim(0, 0);
    620       Tensor* empty_unused;
    621       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, empty_shape, &empty_unused));
    622       return;
    623     }
    624 
    625     // Read all the PersistentTensors into a vector to keep track of
    626     // their memory.
    627     Status s = tensor_array->ReadMany<Device, T>(ctx, indices, &values);
    628     OP_REQUIRES_OK(ctx, s);
    629 
    630     const Tensor* value_0_t = values[0].AccessTensor(ctx);
    631 
    632     OP_REQUIRES(
    633         ctx, element_shape_.IsCompatibleWith(value_0_t->shape()),
    634         errors::InvalidArgument("TensorArray was passed element_shape ",
    635                                 element_shape_.DebugString(),
    636                                 " which does not match the Tensor at index 0: ",
    637                                 value_0_t->shape().DebugString()));
    638 
    639     TensorShape output_shape(value_0_t->shape());
    640     output_shape.InsertDim(0, num_indices);
    641 
    642     Tensor* output_tensor = nullptr;
    643     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_tensor));
    644 
    645     // If output_tensor is empty, there is nothing to concatenate so return it.
    646     if (output_shape.num_elements() == 0) {
    647       return;
    648     }
    649 
    650     ConstMatrixVector input_tensors_flat;
    651     input_tensors_flat.reserve(num_indices);
    652     auto output_flat =
    653         output_tensor->shaped<T, 2>({1, output_shape.num_elements()});
    654 
    655     // Insert the first value
    656     input_tensors_flat.emplace_back(new ConstMatrix(
    657         value_0_t->shaped<T, 2>({1, value_0_t->NumElements()})));
    658 
    659     for (int i = 1; i < num_indices; ++i) {
    660       const Tensor* value_t = values[i].AccessTensor(ctx);
    661       OP_REQUIRES(
    662           ctx, value_0_t->shape() == value_t->shape(),
    663           errors::InvalidArgument(
    664               "TensorArray has inconsistent shapes.  Index 0 has shape: ",
    665               value_0_t->shape().DebugString(), " but index ", i,
    666               " has shape: ", value_t->shape().DebugString()));
    667       input_tensors_flat.emplace_back(
    668           new ConstMatrix(value_t->shaped<T, 2>({1, value_t->NumElements()})));
    669     }
    670 
    671 #if GOOGLE_CUDA
    672     if (std::is_same<Device, GPUDevice>::value) {
    673       ConcatGPU<T>(ctx, input_tensors_flat, output_tensor, &output_flat);
    674       return;
    675     }
    676 #endif  // GOOGLE_CUDA
    677     ConcatCPU<T>(ctx->device(), input_tensors_flat, &output_flat);
    678   }
    679 
    680  private:
    681   DataType dtype_;
    682   PartialTensorShape element_shape_;
    683 };
    684 
    685 #define REGISTER_GATHER_AND_PACK(type)                                      \
    686   REGISTER_KERNEL_BUILDER(                                                  \
    687       Name("TensorArrayPack")                                               \
    688           .Device(DEVICE_CPU)                                               \
    689           .TypeConstraint<type>("dtype"),                                   \
    690       TensorArrayPackOrGatherOp<CPUDevice, type, true /* LEGACY_PACK */>);  \
    691   REGISTER_KERNEL_BUILDER(                                                  \
    692       Name("TensorArrayGather")                                             \
    693           .Device(DEVICE_CPU)                                               \
    694           .TypeConstraint<type>("dtype"),                                   \
    695       TensorArrayPackOrGatherOp<CPUDevice, type, false /* LEGACY_PACK */>); \
    696   REGISTER_KERNEL_BUILDER(                                                  \
    697       Name("TensorArrayGatherV2")                                           \
    698           .Device(DEVICE_CPU)                                               \
    699           .TypeConstraint<type>("dtype"),                                   \
    700       TensorArrayPackOrGatherOp<CPUDevice, type, false /* LEGACY_PACK */>); \
    701   REGISTER_KERNEL_BUILDER(                                                  \
    702       Name("TensorArrayGatherV3")                                           \
    703           .Device(DEVICE_CPU)                                               \
    704           .TypeConstraint<type>("dtype"),                                   \
    705       TensorArrayPackOrGatherOp<CPUDevice, type, false /* LEGACY_PACK */>);
    706 
    707 TF_CALL_POD_STRING_TYPES(REGISTER_GATHER_AND_PACK);
    708 REGISTER_GATHER_AND_PACK(quint8);
    709 REGISTER_GATHER_AND_PACK(qint8);
    710 REGISTER_GATHER_AND_PACK(qint32);
    711 
    712 #undef REGISTER_GATHER_AND_PACK
    713 
    714 #if GOOGLE_CUDA
    715 
    716 #define REGISTER_GPU(type)                                                  \
    717   REGISTER_KERNEL_BUILDER(                                                  \
    718       Name("TensorArrayPack")                                               \
    719           .Device(DEVICE_GPU)                                               \
    720           .TypeConstraint<type>("dtype")                                    \
    721           .HostMemory("handle"),                                            \
    722       TensorArrayPackOrGatherOp<GPUDevice, type, true /* LEGACY_PACK */>);  \
    723   REGISTER_KERNEL_BUILDER(                                                  \
    724       Name("TensorArrayGather")                                             \
    725           .Device(DEVICE_GPU)                                               \
    726           .TypeConstraint<type>("dtype")                                    \
    727           .HostMemory("indices")                                            \
    728           .HostMemory("handle"),                                            \
    729       TensorArrayPackOrGatherOp<GPUDevice, type, false /* LEGACY_PACK */>); \
    730   REGISTER_KERNEL_BUILDER(                                                  \
    731       Name("TensorArrayGatherV2")                                           \
    732           .Device(DEVICE_GPU)                                               \
    733           .TypeConstraint<type>("dtype")                                    \
    734           .HostMemory("indices")                                            \
    735           .HostMemory("handle"),                                            \
    736       TensorArrayPackOrGatherOp<GPUDevice, type, false /* LEGACY_PACK */>); \
    737   REGISTER_KERNEL_BUILDER(                                                  \
    738       Name("TensorArrayGatherV3")                                           \
    739           .Device(DEVICE_GPU)                                               \
    740           .TypeConstraint<type>("dtype")                                    \
    741           .HostMemory("indices")                                            \
    742           .HostMemory("handle"),                                            \
    743       TensorArrayPackOrGatherOp<GPUDevice, type, false /* LEGACY_PACK */>);
    744 
    745 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
    746 TF_CALL_complex64(REGISTER_GPU);
    747 TF_CALL_complex128(REGISTER_GPU);
    748 REGISTER_GPU(bfloat16);
    749 #undef REGISTER_GPU
    750 
    751 // A special GPU kernel for int32.
    752 // TODO(b/25387198): Also enable int32 in device memory. This kernel
    753 // registration requires all int32 inputs and outputs to be in host memory.
    754 REGISTER_KERNEL_BUILDER(
    755     Name("TensorArrayGather")
    756         .Device(DEVICE_GPU)
    757         .TypeConstraint<int32>("dtype")
    758         .HostMemory("indices")
    759         .HostMemory("handle"),
    760     TensorArrayPackOrGatherOp<CPUDevice, int32, false /* LEGACY_PACK */>);
    761 REGISTER_KERNEL_BUILDER(
    762     Name("TensorArrayGatherV2")
    763         .Device(DEVICE_GPU)
    764         .TypeConstraint<int32>("dtype")
    765         .HostMemory("indices")
    766         .HostMemory("handle"),
    767     TensorArrayPackOrGatherOp<CPUDevice, int32, false /* LEGACY_PACK */>);
    768 REGISTER_KERNEL_BUILDER(
    769     Name("TensorArrayGatherV3")
    770         .Device(DEVICE_GPU)
    771         .TypeConstraint<int32>("dtype")
    772         .HostMemory("indices")
    773         .HostMemory("handle"),
    774     TensorArrayPackOrGatherOp<CPUDevice, int32, false /* LEGACY_PACK */>);
    775 
    776 #endif  // GOOGLE_CUDA
    777 
    778 // CONCAT *********************************************************************
    779 
    780 // Concatenate the elements in a TensorArray.  All elements must be
    781 // defined and (excepting the first dimension) have the same shape.
    782 template <typename Device, typename T>
    783 class TensorArrayConcatOp : public OpKernel {
    784  public:
    785   typedef typename TTypes<T, 2>::ConstMatrix ConstMatrix;
    786   typedef std::vector<std::unique_ptr<ConstMatrix> > ConstMatrixVector;
    787 
    788   explicit TensorArrayConcatOp(OpKernelConstruction* context)
    789       : OpKernel(context) {
    790     OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
    791     OP_REQUIRES_OK(context, context->GetAttr("element_shape_except0",
    792                                              &element_shape_except0_));
    793   }
    794 
    795   void Compute(OpKernelContext* ctx) override {
    796     OP_REQUIRES_OK(ctx, SetupFlowControlInputs(ctx, false));
    797 
    798     TensorArray* tensor_array = nullptr;
    799     OP_REQUIRES_OK(ctx, GetTensorArray(ctx, &tensor_array));
    800     core::ScopedUnref unref(tensor_array);
    801     OP_REQUIRES(
    802         ctx, dtype_ == tensor_array->ElemType(),
    803         errors::InvalidArgument(
    804             "TensorArray dtype is ", DataTypeString(tensor_array->ElemType()),
    805             " but Op requested dtype ", DataTypeString(dtype_), "."));
    806 
    807     int32 array_size;
    808     OP_REQUIRES_OK(ctx, tensor_array->PackOrConcatSize(&array_size));
    809 
    810     // If there are no elements, return a zero-element Tensor with
    811     // shape [0] + element_shape_except0_
    812     if (array_size == 0) {
    813       OP_REQUIRES(
    814           ctx, element_shape_except0_.IsFullyDefined(),
    815           errors::Unimplemented(
    816               "TensorArray has size zero, but element_shape_except0 ",
    817               element_shape_except0_.DebugString(),
    818               " is not fully defined. "
    819               "Currently only static shapes are supported when concatenating "
    820               "zero-size TensorArrays."));
    821       TensorShape empty_shape;
    822       element_shape_except0_.AsTensorShape(&empty_shape);
    823       empty_shape.InsertDim(0, 0);
    824       Tensor* empty_unused;
    825       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, empty_shape, &empty_unused));
    826       OP_REQUIRES_OK(ctx, ctx->allocate_output(1, {0}, &empty_unused));
    827       return;
    828     }
    829 
    830     // Read all the PersistentTensors into a vector to keep track of
    831     // their memory.
    832     std::vector<PersistentTensor> values;
    833     std::vector<int32> indices(array_size);
    834     std::iota(indices.begin(), indices.end(), 0);
    835     Status s = tensor_array->ReadMany<Device, T>(ctx, indices, &values);
    836     OP_REQUIRES_OK(ctx, s);
    837 
    838     std::vector<const Tensor*> value_tensors;
    839     value_tensors.resize(values.size());
    840 
    841     Tensor* lengths_tensor = nullptr;
    842     OP_REQUIRES_OK(ctx, ctx->allocate_output(
    843                             1, TensorShape({static_cast<int64>(values.size())}),
    844                             &lengths_tensor));
    845     auto lengths_tensor_t = lengths_tensor->vec<int64>();
    846 
    847     TensorShape output_shape;
    848     TensorShape output_shape_except0;
    849     for (std::size_t i = 0; i < values.size(); ++i) {
    850       value_tensors[i] = values[i].AccessTensor(ctx);
    851       TensorShape value_shape_t = value_tensors[i]->shape();
    852 
    853       OP_REQUIRES(
    854           ctx, TensorShapeUtils::IsVectorOrHigher(value_shape_t),
    855           errors::InvalidArgument(
    856               "Concat saw a scalar shape at index ", i,
    857               " but requires at least vectors.  Did you mean to call pack?"));
    858 
    859       lengths_tensor_t(i) = value_shape_t.dim_size(0);
    860 
    861       TensorShape value_shape_t_except0 = value_shape_t;
    862       value_shape_t_except0.RemoveDim(0);
    863       if (i == 0) {
    864         output_shape = value_shape_t;
    865         output_shape_except0 = value_shape_t_except0;
    866         OP_REQUIRES(
    867             ctx, element_shape_except0_.IsCompatibleWith(output_shape_except0),
    868             errors::InvalidArgument(
    869                 "TensorArray was passed element_shape_except0 ",
    870                 element_shape_except0_.DebugString(),
    871                 " but index 0 has (excepting dimension 0) shape: ",
    872                 value_shape_t_except0.DebugString(), " which does not match."));
    873       } else {
    874         OP_REQUIRES(ctx, output_shape_except0 == value_shape_t_except0,
    875                     errors::InvalidArgument(
    876                         "TensorArray has inconsistent shapes.  Index 0 has "
    877                         "(excepting dimension 0) shape: ",
    878                         output_shape_except0.DebugString(), " but index ", i,
    879                         " has (excepting dimension 0) shape: ",
    880                         value_shape_t_except0.DebugString()));
    881         // Store the previous maximum length as the offset for this tensor.
    882         output_shape.set_dim(
    883             0, output_shape.dim_size(0) + value_shape_t.dim_size(0));
    884       }
    885     }
    886 
    887     Tensor* output_tensor = nullptr;
    888     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_tensor));
    889     ConstMatrixVector input_tensors_flat;
    890     input_tensors_flat.reserve(values.size());
    891     for (size_t i = 0; i < values.size(); ++i) {
    892       const Tensor* value_t = value_tensors[i];
    893       if (value_t->NumElements() > 0) {
    894         input_tensors_flat.emplace_back(new ConstMatrix(
    895             value_t->shaped<T, 2>({1, value_t->NumElements()})));
    896       }
    897     }
    898 
    899     if (output_shape.num_elements() > 0) {
    900       auto output_flat =
    901           output_tensor->shaped<T, 2>({1, output_shape.num_elements()});
    902 #if GOOGLE_CUDA
    903       if (std::is_same<Device, GPUDevice>::value) {
    904         ConcatGPU<T>(ctx, input_tensors_flat, output_tensor, &output_flat);
    905         return;
    906       }
    907 #endif  // GOOGLE_CUDA
    908       ConcatCPU<T>(ctx->device(), input_tensors_flat, &output_flat);
    909     }
    910   }
    911 
    912  private:
    913   DataType dtype_;
    914   PartialTensorShape element_shape_except0_;
    915 };
    916 
    917 #define REGISTER_CONCAT(type)                                    \
    918   REGISTER_KERNEL_BUILDER(Name("TensorArrayConcat")              \
    919                               .Device(DEVICE_CPU)                \
    920                               .TypeConstraint<type>("dtype")     \
    921                               .HostMemory("lengths")             \
    922                               .HostMemory("handle"),             \
    923                           TensorArrayConcatOp<CPUDevice, type>); \
    924   REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV2")            \
    925                               .Device(DEVICE_CPU)                \
    926                               .TypeConstraint<type>("dtype")     \
    927                               .HostMemory("lengths")             \
    928                               .HostMemory("handle"),             \
    929                           TensorArrayConcatOp<CPUDevice, type>)  \
    930   REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV3")            \
    931                               .Device(DEVICE_CPU)                \
    932                               .TypeConstraint<type>("dtype")     \
    933                               .HostMemory("lengths")             \
    934                               .HostMemory("handle"),             \
    935                           TensorArrayConcatOp<CPUDevice, type>)
    936 
    937 TF_CALL_POD_STRING_TYPES(REGISTER_CONCAT);
    938 REGISTER_CONCAT(quint8);
    939 REGISTER_CONCAT(qint8);
    940 REGISTER_CONCAT(qint32);
    941 
    942 #undef REGISTER_CONCAT
    943 
    944 #if GOOGLE_CUDA
    945 
    946 #define REGISTER_GPU(type)                                       \
    947   REGISTER_KERNEL_BUILDER(Name("TensorArrayConcat")              \
    948                               .Device(DEVICE_GPU)                \
    949                               .TypeConstraint<type>("dtype")     \
    950                               .HostMemory("lengths")             \
    951                               .HostMemory("handle"),             \
    952                           TensorArrayConcatOp<GPUDevice, type>); \
    953   REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV2")            \
    954                               .Device(DEVICE_GPU)                \
    955                               .TypeConstraint<type>("dtype")     \
    956                               .HostMemory("lengths")             \
    957                               .HostMemory("handle"),             \
    958                           TensorArrayConcatOp<GPUDevice, type>)  \
    959   REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV3")            \
    960                               .Device(DEVICE_GPU)                \
    961                               .TypeConstraint<type>("dtype")     \
    962                               .HostMemory("lengths")             \
    963                               .HostMemory("handle"),             \
    964                           TensorArrayConcatOp<GPUDevice, type>)
    965 
    966 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
    967 TF_CALL_complex64(REGISTER_GPU);
    968 TF_CALL_complex128(REGISTER_GPU);
    969 REGISTER_GPU(bfloat16);
    970 #undef REGISTER_GPU
    971 
    972 // A special GPU kernel for int32.
    973 // TODO(b/25387198): Also enable int32 in device memory. This kernel
    974 // registration requires all int32 inputs and outputs to be in host memory.
    975 REGISTER_KERNEL_BUILDER(Name("TensorArrayConcat")
    976                             .Device(DEVICE_GPU)
    977                             .TypeConstraint<int32>("dtype")
    978                             .HostMemory("lengths")
    979                             .HostMemory("handle"),
    980                         TensorArrayConcatOp<CPUDevice, int32>);
    981 REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV2")
    982                             .Device(DEVICE_GPU)
    983                             .TypeConstraint<int32>("dtype")
    984                             .HostMemory("lengths")
    985                             .HostMemory("handle"),
    986                         TensorArrayConcatOp<CPUDevice, int32>);
    987 REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV3")
    988                             .Device(DEVICE_GPU)
    989                             .TypeConstraint<int32>("dtype")
    990                             .HostMemory("lengths")
    991                             .HostMemory("handle"),
    992                         TensorArrayConcatOp<CPUDevice, int32>);
    993 
    994 #endif  // GOOGLE_CUDA
    995 
    996 // UNPACK and SCATTER *********************************************************
    997 
    998 template <typename Device, typename T, bool LEGACY_UNPACK>
    999 class TensorArrayUnpackOrScatterOp : public OpKernel {
   1000  public:
   1001   explicit TensorArrayUnpackOrScatterOp(OpKernelConstruction* context)
   1002       : OpKernel(context) {}
   1003 
   1004   void Compute(OpKernelContext* ctx) override {
   1005     OP_REQUIRES_OK(ctx, SetupFlowControlInputs(ctx, true));
   1006 
   1007     TensorArray* tensor_array = nullptr;
   1008     OP_REQUIRES_OK(ctx, GetTensorArray(ctx, &tensor_array));
   1009     core::ScopedUnref unref(tensor_array);
   1010     const Tensor* tensor_value;
   1011     OP_REQUIRES_OK(ctx, ctx->input("value", &tensor_value));
   1012     TensorShape element_shape(tensor_value->shape());
   1013 
   1014     OP_REQUIRES(ctx,
   1015                 FastBoundsCheck(element_shape.dim_size(0),
   1016                                 std::numeric_limits<int32>::max()),
   1017                 errors::InvalidArgument("tensor dim0 too large to unpack"));
   1018 
   1019     OP_REQUIRES(
   1020         ctx, tensor_value->dtype() == tensor_array->ElemType(),
   1021         errors::InvalidArgument("TensorArray dtype is ",
   1022                                 DataTypeString(tensor_array->ElemType()),
   1023                                 " but Op is trying to write dtype ",
   1024                                 DataTypeString(tensor_value->dtype()), "."));
   1025     OP_REQUIRES(ctx, element_shape.dims() > 0,
   1026                 errors::InvalidArgument("Input value for unpack must be at "
   1027                                         "least a vector but received shape: ",
   1028                                         element_shape.DebugString()));
   1029     int32 array_size;
   1030     OP_REQUIRES_OK(ctx, tensor_array->Size(&array_size));
   1031 
   1032     int32 max_index;
   1033     int32 num_values;
   1034     std::vector<int32> write_indices;
   1035     if (LEGACY_UNPACK) {
   1036       num_values = element_shape.dim_size(0);
   1037       max_index = num_values - 1;
   1038       write_indices.resize(num_values);
   1039       std::iota(write_indices.begin(), write_indices.end(), 0);
   1040     } else {
   1041       const Tensor* tensor_indices;
   1042       OP_REQUIRES_OK(ctx, ctx->input("indices", &tensor_indices));
   1043       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(tensor_indices->shape()),
   1044                   errors::InvalidArgument(
   1045                       "Expected indices to be a vector, but received shape: ",
   1046                       tensor_indices->shape().DebugString()));
   1047       OP_REQUIRES(ctx,
   1048                   tensor_indices->NumElements() == element_shape.dim_size(0),
   1049                   errors::InvalidArgument(
   1050                       "Expected len(indices) == values.shape[0], but saw: ",
   1051                       tensor_indices->NumElements(), " vs. ",
   1052                       element_shape.dim_size(0)));
   1053       const auto indices_t = tensor_indices->vec<int32>();
   1054       num_values = tensor_indices->NumElements();
   1055       max_index = (num_values == 0)
   1056                       ? -1
   1057                       : *std::max_element(indices_t.data(),
   1058                                           indices_t.data() + num_values);
   1059       write_indices.resize(num_values);
   1060       // Copy into write_indices.
   1061       std::copy(indices_t.data(), indices_t.data() + num_values,
   1062                 write_indices.begin());
   1063     }
   1064 
   1065     bool dynamic_size = tensor_array->HasDynamicSize();
   1066 
   1067     // If dynamic size, we may have to resize the TensorArray to fit.
   1068     if (dynamic_size && array_size < max_index + 1) {
   1069       array_size = static_cast<int32>(max_index + 1);
   1070     }
   1071 
   1072     if (LEGACY_UNPACK) {
   1073       OP_REQUIRES(
   1074           ctx, element_shape.dim_size(0) == array_size,
   1075           errors::InvalidArgument(
   1076               "Input value must have first dimension equal to the array size (",
   1077               element_shape.dim_size(0), " vs. ", array_size, ")"));
   1078     } else {
   1079       OP_REQUIRES(
   1080           ctx, max_index < array_size,
   1081           errors::InvalidArgument("Max scatter index must be < array size (",
   1082                                   max_index, " vs. ", array_size, ")"));
   1083     }
   1084     element_shape.RemoveDim(0);
   1085 
   1086     auto tensor_value_t = tensor_value->shaped<T, 3>(
   1087         {1, num_values, element_shape.num_elements()});
   1088 
   1089     Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0};
   1090     Eigen::DSizes<Eigen::DenseIndex, 3> sizes{1, 1,
   1091                                               element_shape.num_elements()};
   1092 
   1093     std::vector<PersistentTensor> write_values;
   1094     write_values.reserve(num_values);
   1095 
   1096     for (int i = 0; i < num_values; ++i) {
   1097       Tensor* tensor_value_i;
   1098       PersistentTensor persistent_tensor;
   1099       OP_REQUIRES_OK(
   1100           ctx, ctx->allocate_persistent(tensor_array->ElemType(), element_shape,
   1101                                         &persistent_tensor, &tensor_value_i));
   1102       auto tensor_value_i_t =
   1103           tensor_value_i->shaped<T, 3>({1, 1, element_shape.num_elements()});
   1104       indices[1] = i;
   1105 
   1106       if (element_shape.num_elements() > 0) {
   1107         functor::Split<Device, T>()(ctx->eigen_device<Device>(),
   1108                                     tensor_value_i_t, tensor_value_t, indices,
   1109                                     sizes);
   1110       }
   1111 
   1112       write_values.push_back(persistent_tensor);
   1113     }
   1114 
   1115     // Record the pack size of the TensorArray.
   1116     if (LEGACY_UNPACK) {
   1117       OP_REQUIRES_OK(ctx, tensor_array->SetMarkedSize(array_size));
   1118     }
   1119 
   1120     Status s = tensor_array->WriteOrAggregateMany<Device, T>(ctx, write_indices,
   1121                                                              &write_values);
   1122     OP_REQUIRES_OK(ctx, s);
   1123   }
   1124 };
   1125 
   1126 #define REGISTER_SCATTER_AND_UNPACK(type)                                      \
   1127   REGISTER_KERNEL_BUILDER(                                                     \
   1128       Name("TensorArrayUnpack").Device(DEVICE_CPU).TypeConstraint<type>("T"),  \
   1129       TensorArrayUnpackOrScatterOp<CPUDevice, type,                            \
   1130                                    true /* LEGACY_UNPACK */>);                 \
   1131   REGISTER_KERNEL_BUILDER(                                                     \
   1132       Name("TensorArrayScatter").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
   1133       TensorArrayUnpackOrScatterOp<CPUDevice, type,                            \
   1134                                    false /* LEGACY_UNPACK */>);                \
   1135   REGISTER_KERNEL_BUILDER(                                                     \
   1136       Name("TensorArrayScatterV2")                                             \
   1137           .Device(DEVICE_CPU)                                                  \
   1138           .TypeConstraint<type>("T"),                                          \
   1139       TensorArrayUnpackOrScatterOp<CPUDevice, type,                            \
   1140                                    false /* LEGACY_UNPACK */>);                \
   1141   REGISTER_KERNEL_BUILDER(                                                     \
   1142       Name("TensorArrayScatterV3")                                             \
   1143           .Device(DEVICE_CPU)                                                  \
   1144           .TypeConstraint<type>("T"),                                          \
   1145       TensorArrayUnpackOrScatterOp<CPUDevice, type,                            \
   1146                                    false /* LEGACY_UNPACK */>);
   1147 
   1148 TF_CALL_ALL_TYPES(REGISTER_SCATTER_AND_UNPACK);
   1149 #undef REGISTER_SCATTER_AND_UNPACK
   1150 
   1151 #if GOOGLE_CUDA
   1152 
   1153 #define REGISTER_GPU(type)                                      \
   1154   REGISTER_KERNEL_BUILDER(                                      \
   1155       Name("TensorArrayUnpack")                                 \
   1156           .Device(DEVICE_GPU)                                   \
   1157           .TypeConstraint<type>("T")                            \
   1158           .HostMemory("handle"),                                \
   1159       TensorArrayUnpackOrScatterOp<GPUDevice, type,             \
   1160                                    true /* LEGACY_UNPACK */>);  \
   1161   REGISTER_KERNEL_BUILDER(                                      \
   1162       Name("TensorArrayScatter")                                \
   1163           .Device(DEVICE_GPU)                                   \
   1164           .TypeConstraint<type>("T")                            \
   1165           .HostMemory("indices")                                \
   1166           .HostMemory("handle"),                                \
   1167       TensorArrayUnpackOrScatterOp<GPUDevice, type,             \
   1168                                    false /* LEGACY_UNPACK */>); \
   1169   REGISTER_KERNEL_BUILDER(                                      \
   1170       Name("TensorArrayScatterV2")                              \
   1171           .Device(DEVICE_GPU)                                   \
   1172           .TypeConstraint<type>("T")                            \
   1173           .HostMemory("indices")                                \
   1174           .HostMemory("handle"),                                \
   1175       TensorArrayUnpackOrScatterOp<GPUDevice, type,             \
   1176                                    false /* LEGACY_UNPACK */>); \
   1177   REGISTER_KERNEL_BUILDER(                                      \
   1178       Name("TensorArrayScatterV3")                              \
   1179           .Device(DEVICE_GPU)                                   \
   1180           .TypeConstraint<type>("T")                            \
   1181           .HostMemory("indices")                                \
   1182           .HostMemory("handle"),                                \
   1183       TensorArrayUnpackOrScatterOp<GPUDevice, type,             \
   1184                                    false /* LEGACY_UNPACK */>);
   1185 
   1186 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
   1187 TF_CALL_complex64(REGISTER_GPU);
   1188 TF_CALL_complex128(REGISTER_GPU);
   1189 #undef REGISTER_GPU
   1190 
   1191 #endif  // GOOGLE_CUDA
   1192 
   1193 // SPLIT *********************************************************************
   1194 
   1195 template <typename Device, typename T>
   1196 class TensorArraySplitOp : public OpKernel {
   1197  public:
   1198   explicit TensorArraySplitOp(OpKernelConstruction* context)
   1199       : OpKernel(context) {}
   1200 
   1201   void Compute(OpKernelContext* ctx) override {
   1202     OP_REQUIRES_OK(ctx, SetupFlowControlInputs(ctx, true));
   1203 
   1204     TensorArray* tensor_array = nullptr;
   1205     OP_REQUIRES_OK(ctx, GetTensorArray(ctx, &tensor_array));
   1206     core::ScopedUnref unref(tensor_array);
   1207     const Tensor* tensor_value;
   1208     OP_REQUIRES_OK(ctx, ctx->input("value", &tensor_value));
   1209     const Tensor* tensor_lengths;
   1210     OP_REQUIRES_OK(ctx, ctx->input("lengths", &tensor_lengths));
   1211 
   1212     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(tensor_lengths->shape()),
   1213                 errors::InvalidArgument(
   1214                     "Expected lengths to be a vector, received shape: ",
   1215                     tensor_lengths->shape().DebugString()));
   1216     OP_REQUIRES(ctx,
   1217                 FastBoundsCheck(tensor_lengths->NumElements(),
   1218                                 std::numeric_limits<int32>::max()),
   1219                 errors::InvalidArgument(
   1220                     "Expected lengths to have < max int32 entries"));
   1221 
   1222     int32 num_tensors = static_cast<int32>(tensor_lengths->NumElements());
   1223     auto tensor_lengths_t = tensor_lengths->vec<int64>();
   1224     std::vector<int64> cumulative_lengths;
   1225     cumulative_lengths.reserve(num_tensors);
   1226     int64 total_length = 0;
   1227     for (int i = 0; i < num_tensors; ++i) {
   1228       total_length += tensor_lengths_t(i);
   1229       cumulative_lengths.push_back(total_length);
   1230     }
   1231 
   1232     OP_REQUIRES(
   1233         ctx, TensorShapeUtils::IsVectorOrHigher(tensor_value->shape()),
   1234         errors::InvalidArgument(
   1235             "Expected value to be at least a vector, but received shape: ",
   1236             tensor_value->shape().DebugString()));
   1237 
   1238     OP_REQUIRES(
   1239         ctx, total_length == tensor_value->shape().dim_size(0),
   1240         errors::InvalidArgument("Expected sum of lengths to be equal to "
   1241                                 "values.shape[0], but sum of lengths is ",
   1242                                 total_length, " and value's shape is: ",
   1243                                 tensor_value->shape().DebugString()));
   1244     int64 elements_per_row =
   1245         (total_length == 0) ? 0 : (tensor_value->NumElements() / total_length);
   1246 
   1247     int32 array_size;
   1248     OP_REQUIRES_OK(ctx, tensor_array->Size(&array_size));
   1249     bool dynamic_size = tensor_array->HasDynamicSize();
   1250 
   1251     std::vector<TensorShape> element_shapes(num_tensors, tensor_value->shape());
   1252     for (int32 i = 0; i < num_tensors; ++i) {
   1253       element_shapes[i].set_dim(0, tensor_lengths_t(i));
   1254     }
   1255 
   1256     // If dynamic size, we may have to resize the TensorArray to fit.
   1257     if (dynamic_size && array_size < num_tensors) {
   1258       array_size = num_tensors;
   1259     }
   1260 
   1261     OP_REQUIRES(
   1262         ctx, array_size == num_tensors,
   1263         errors::InvalidArgument(
   1264             "TensorArray's size is not equal to the size of lengths (",
   1265             array_size, " vs. ", num_tensors, "), and the TensorArray is not ",
   1266             "marked as dynamically resizeable"));
   1267 
   1268     OP_REQUIRES(
   1269         ctx, tensor_value->dtype() == tensor_array->ElemType(),
   1270         errors::InvalidArgument("TensorArray dtype is ",
   1271                                 DataTypeString(tensor_array->ElemType()),
   1272                                 " but Op is trying to write dtype ",
   1273                                 DataTypeString(tensor_value->dtype()), "."));
   1274 
   1275     auto tensor_value_t =
   1276         tensor_value->shaped<T, 3>({1, total_length, elements_per_row});
   1277 
   1278     std::vector<PersistentTensor> write_values;
   1279     write_values.reserve(array_size);
   1280 
   1281     for (int i = 0; i < array_size; ++i) {
   1282       Tensor* tensor_value_i;
   1283       PersistentTensor persistent_tensor;
   1284 
   1285       int64 previous_length = (i == 0) ? 0 : cumulative_lengths[i - 1];
   1286       Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, previous_length, 0};
   1287       Eigen::DSizes<Eigen::DenseIndex, 3> sizes{1, tensor_lengths_t(i),
   1288                                                 elements_per_row};
   1289 
   1290       OP_REQUIRES_OK(ctx, ctx->allocate_persistent(
   1291                               tensor_array->ElemType(), element_shapes[i],
   1292                               &persistent_tensor, &tensor_value_i));
   1293 
   1294       if (tensor_lengths_t(i) > 0) {
   1295         auto tensor_value_i_t = tensor_value_i->shaped<T, 3>(
   1296             {1, tensor_lengths_t(i), elements_per_row});
   1297 
   1298         functor::Split<Device, T>()(ctx->eigen_device<Device>(),
   1299                                     tensor_value_i_t, tensor_value_t, indices,
   1300                                     sizes);
   1301       }
   1302 
   1303       write_values.push_back(persistent_tensor);
   1304     }
   1305 
   1306     // Record the concat size of the TensorArray.
   1307     OP_REQUIRES_OK(ctx, tensor_array->SetMarkedSize(array_size));
   1308 
   1309     std::vector<int32> indices(array_size);
   1310     std::iota(indices.begin(), indices.end(), 0);
   1311 
   1312     Status s = tensor_array->WriteOrAggregateMany<Device, T>(ctx, indices,
   1313                                                              &write_values);
   1314     OP_REQUIRES_OK(ctx, s);
   1315   }
   1316 };
   1317 
   1318 #define REGISTER_SPLIT(type)                                                   \
   1319   REGISTER_KERNEL_BUILDER(                                                     \
   1320       Name("TensorArraySplit").Device(DEVICE_CPU).TypeConstraint<type>("T"),   \
   1321       TensorArraySplitOp<CPUDevice, type>);                                    \
   1322   REGISTER_KERNEL_BUILDER(                                                     \
   1323       Name("TensorArraySplitV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
   1324       TensorArraySplitOp<CPUDevice, type>);                                    \
   1325   REGISTER_KERNEL_BUILDER(                                                     \
   1326       Name("TensorArraySplitV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
   1327       TensorArraySplitOp<CPUDevice, type>);
   1328 
   1329 TF_CALL_ALL_TYPES(REGISTER_SPLIT);
   1330 #undef REGISTER_SPLIT
   1331 
   1332 #if GOOGLE_CUDA
   1333 
   1334 #define REGISTER_GPU(type)                                      \
   1335   REGISTER_KERNEL_BUILDER(Name("TensorArraySplit")              \
   1336                               .Device(DEVICE_GPU)               \
   1337                               .TypeConstraint<type>("T")        \
   1338                               .HostMemory("lengths")            \
   1339                               .HostMemory("handle"),            \
   1340                           TensorArraySplitOp<GPUDevice, type>); \
   1341   REGISTER_KERNEL_BUILDER(Name("TensorArraySplitV2")            \
   1342                               .Device(DEVICE_GPU)               \
   1343                               .TypeConstraint<type>("T")        \
   1344                               .HostMemory("lengths")            \
   1345                               .HostMemory("handle"),            \
   1346                           TensorArraySplitOp<GPUDevice, type>); \
   1347   REGISTER_KERNEL_BUILDER(Name("TensorArraySplitV3")            \
   1348                               .Device(DEVICE_GPU)               \
   1349                               .TypeConstraint<type>("T")        \
   1350                               .HostMemory("lengths")            \
   1351                               .HostMemory("handle"),            \
   1352                           TensorArraySplitOp<GPUDevice, type>);
   1353 
   1354 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
   1355 TF_CALL_complex64(REGISTER_GPU);
   1356 TF_CALL_complex128(REGISTER_GPU);
   1357 #undef REGISTER_GPU
   1358 
   1359 #endif  // GOOGLE_CUDA
   1360 
   1361 // SIZE ***********************************************************************
   1362 
   1363 // Get the size of the TensorArray
   1364 class TensorArraySizeOp : public OpKernel {
   1365  public:
   1366   explicit TensorArraySizeOp(OpKernelConstruction* context)
   1367       : OpKernel(context) {}
   1368 
   1369   void Compute(OpKernelContext* ctx) override {
   1370     TensorArray* tensor_array;
   1371     OP_REQUIRES_OK(ctx, GetTensorArray(ctx, &tensor_array));
   1372     core::ScopedUnref unref(tensor_array);
   1373     Tensor* output = nullptr;
   1374     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
   1375     OP_REQUIRES_OK(ctx, tensor_array->Size(&(output->scalar<int32>()())));
   1376   }
   1377 };
   1378 
   1379 REGISTER_KERNEL_BUILDER(Name("TensorArraySize").Device(DEVICE_CPU),
   1380                         TensorArraySizeOp);
   1381 REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV2").Device(DEVICE_CPU),
   1382                         TensorArraySizeOp);
   1383 REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV3").Device(DEVICE_CPU),
   1384                         TensorArraySizeOp);
   1385 
   1386 REGISTER_KERNEL_BUILDER(Name("TensorArraySize")
   1387                             .Device(DEVICE_GPU)
   1388                             .HostMemory("handle")
   1389                             .HostMemory("size"),
   1390                         TensorArraySizeOp);
   1391 REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV2")
   1392                             .Device(DEVICE_GPU)
   1393                             .HostMemory("handle")
   1394                             .HostMemory("size"),
   1395                         TensorArraySizeOp);
   1396 REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV3")
   1397                             .Device(DEVICE_GPU)
   1398                             .HostMemory("handle")
   1399                             .HostMemory("size"),
   1400                         TensorArraySizeOp);
   1401 
   1402 // CLOSE
   1403 // **********************************************************************
   1404 
   1405 // Delete the TensorArray from its resource container.  This enables
   1406 // the user to close and release the resource in the middle of a step/run.
   1407 // TODO(ebrevdo): decide whether closing the grad op should happen
   1408 // here or on the python side.
   1409 class TensorArrayCloseOp : public OpKernel {
   1410  public:
   1411   explicit TensorArrayCloseOp(OpKernelConstruction* context)
   1412       : OpKernel(context) {}
   1413 
   1414   void Compute(OpKernelContext* ctx) override {
   1415     TensorArray* tensor_array;
   1416     OP_REQUIRES_OK(ctx, GetTensorArray(ctx, &tensor_array));
   1417     core::ScopedUnref unref(tensor_array);
   1418     // Instead of deleting this TA from the ResourceManager, we just
   1419     // clear it away and mark it as closed.  The remaining memory
   1420     // consumed store its mutex and handle Tensor.  This will be
   1421     // cleared out at the end of the step anyway, so it's fine to keep
   1422     // it around until the end of the step.  Further calls to the
   1423     // TensorArray will fail because TensorArray checks internally to
   1424     // see if it is closed or not.
   1425     tensor_array->ClearAndMarkClosed();
   1426   }
   1427 };
   1428 
   1429 REGISTER_KERNEL_BUILDER(Name("TensorArrayClose").Device(DEVICE_CPU),
   1430                         TensorArrayCloseOp);
   1431 REGISTER_KERNEL_BUILDER(Name("TensorArrayCloseV2").Device(DEVICE_CPU),
   1432                         TensorArrayCloseOp);
   1433 REGISTER_KERNEL_BUILDER(Name("TensorArrayCloseV3").Device(DEVICE_CPU),
   1434                         TensorArrayCloseOp);
   1435 
   1436 REGISTER_KERNEL_BUILDER(
   1437     Name("TensorArrayClose").Device(DEVICE_GPU).HostMemory("handle"),
   1438     TensorArrayCloseOp);
   1439 REGISTER_KERNEL_BUILDER(
   1440     Name("TensorArrayCloseV2").Device(DEVICE_GPU).HostMemory("handle"),
   1441     TensorArrayCloseOp);
   1442 REGISTER_KERNEL_BUILDER(
   1443     Name("TensorArrayCloseV3").Device(DEVICE_GPU).HostMemory("handle"),
   1444     TensorArrayCloseOp);
   1445 
   1446 }  // namespace tensorflow
   1447