Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // XLA TensorArray operators.
     17 
     18 #include <limits>
     19 #include <vector>
     20 
     21 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
     22 #include "tensorflow/compiler/tf2xla/shape_util.h"
     23 #include "tensorflow/compiler/tf2xla/type_util.h"
     24 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     25 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     27 #include "tensorflow/compiler/tf2xla/xla_resource.h"
     28 #include "tensorflow/compiler/xla/literal_util.h"
     29 #include "tensorflow/core/framework/op_kernel.h"
     30 #include "tensorflow/core/framework/partial_tensor_shape.h"
     31 #include "tensorflow/core/framework/register_types.h"
     32 #include "tensorflow/core/framework/tensor.h"
     33 #include "tensorflow/core/framework/tensor_types.h"
     34 #include "tensorflow/core/framework/types.h"
     35 #include "tensorflow/core/kernels/bounds_check.h"
     36 #include "tensorflow/core/kernels/concat_lib.h"
     37 #include "tensorflow/core/lib/core/status.h"
     38 #include "tensorflow/core/platform/types.h"
     39 
     40 namespace tensorflow {
     41 namespace {
     42 
     43 // Since the element shape is not always provided to the TensorArrayV3 operator,
     44 // we must support lazily initialization of the TensorArray at the time of the
     45 // first write.
     46 // If a TensorArray `resource` has not been initialized, constructs storage for
     47 // the TensorArray with elements of `elem_shape`. For both initialized and
     48 // uninitialized TensorArrays, checks that the tensor has a type compatible with
     49 // 'dtype' and shape compatible with 'elem_shape'.
     50 Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
     51                                   XlaResource* resource, DataType dtype,
     52                                   const TensorShape& elem_shape) {
     53   if (resource->kind() != XlaResource::kTensorArray) {
     54     return errors::InvalidArgument("Unexpected non-TensorArray resource");
     55   }
     56 
     57   if (resource->type() != dtype) {
     58     return errors::InvalidArgument(
     59         "TensorArray dtype is ", DataTypeString(resource->type()),
     60         " but op has dtype ", DataTypeString(dtype), ".");
     61   }
     62 
     63   TF_RET_CHECK(resource->tensor_array_size() >= 0)
     64       << resource->name() << " size " << resource->tensor_array_size();
     65 
     66   if (!resource->initialized()) {
     67     xla::ComputationDataHandle zero =
     68         XlaHelpers::Zero(builder, resource->type());
     69 
     70     TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
     71     TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
     72   } else {
     73     // Checks the elem_shape matches the TensorArray shape.
     74     auto shape_or_status = builder->GetShape(resource->value());
     75     if (!shape_or_status.ok()) {
     76       return shape_or_status.status();
     77     }
     78     TensorShape shape;
     79     TF_RETURN_IF_ERROR(
     80         XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
     81 
     82     TensorShape ta_shape;
     83     ta_shape.AddDim(resource->tensor_array_size());
     84     ta_shape.AppendShape(elem_shape);
     85     if (ta_shape != shape) {
     86       return errors::InvalidArgument(
     87           "Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ",
     88           shape.DebugString());
     89     }
     90   }
     91   return Status::OK();
     92 }
     93 
     94 // Checks that the TensorArray 'resource' has been initialized, and has type
     95 // 'dtype'. Sets 'shape' to the shape
     96 Status CheckTensorArrayIsInitialized(const string& op_name,
     97                                      const XlaResource* resource,
     98                                      DataType dtype) {
     99   if (resource->kind() != XlaResource::kTensorArray) {
    100     return errors::InvalidArgument(
    101         "Unexpected non-TensorArray resource passed to ", op_name);
    102   }
    103   if (!resource->initialized()) {
    104     return errors::InvalidArgument("Uninitialized TensorArray passed to ",
    105                                    op_name);
    106   }
    107   if (resource->type() != dtype) {
    108     return errors::InvalidArgument(
    109         "TensorArray dtype is ", DataTypeString(resource->type()),
    110         " but op has dtype ", DataTypeString(dtype), ".");
    111   }
    112 
    113   return Status::OK();
    114 }
    115 
    116 Status GetTensorArrayShape(const XlaResource* resource,
    117                            xla::ComputationBuilder* builder,
    118                            TensorShape* shape) {
    119   *shape = resource->shape();
    120   shape->InsertDim(0, resource->tensor_array_size());
    121   return Status::OK();
    122 }
    123 
    124 // Like ComputationBuilder::DynamicUpdateSlice, but adds 'update' to the
    125 // relevant slice of 'operand'.
    126 xla::ComputationDataHandle DynamicAddSlice(
    127     xla::ComputationBuilder* builder, const xla::ComputationDataHandle& operand,
    128     const xla::ComputationDataHandle& update,
    129     const gtl::ArraySlice<int64>& update_dims,
    130     const xla::ComputationDataHandle& start_indices) {
    131   xla::ComputationDataHandle current =
    132       builder->DynamicSlice(operand, start_indices, update_dims);
    133   xla::ComputationDataHandle sum = builder->Add(current, update);
    134   return builder->DynamicUpdateSlice(operand, sum, start_indices);
    135 }
    136 
    137 class TensorArrayOp : public XlaOpKernel {
    138  public:
    139   explicit TensorArrayOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    140     OP_REQUIRES_OK(ctx, ctx->GetAttr("element_shape", &element_shape_));
    141     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
    142     bool dynamic_size;
    143     OP_REQUIRES_OK(ctx, ctx->GetAttr("dynamic_size", &dynamic_size));
    144     OP_REQUIRES(
    145         ctx, !dynamic_size,
    146         errors::Unimplemented(
    147             "TensorArrays with dynamic size are not supported by XLA."));
    148 
    149     OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_array_name", &tensor_array_name_));
    150   }
    151 
    152   void Compile(XlaOpKernelContext* ctx) override {
    153     int64 size;
    154     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &size));
    155     OP_REQUIRES(ctx, size >= 0,
    156                 errors::InvalidArgument("TensorArray size must be >= 0"));
    157 
    158     xla::ComputationBuilder* b = ctx->builder();
    159 
    160     // Initializes the TensorArray value if we know the element shape.
    161     // Otherwise, defer initialization to the first write.
    162     xla::ComputationDataHandle value;
    163     TensorShape shape;
    164     if (element_shape_.IsFullyDefined()) {
    165       CHECK(element_shape_.AsTensorShape(&shape));
    166       TensorShape ta_shape;
    167       ta_shape.AddDim(size);
    168       ta_shape.AppendShape(shape);
    169       xla::ComputationDataHandle zero = XlaHelpers::Zero(b, dtype_);
    170       value = b->Broadcast(zero, ta_shape.dim_sizes());
    171     }
    172 
    173     XlaContext& xc = XlaContext::Get(ctx);
    174     XlaResource* var;
    175     string name = strings::StrCat("TensorArray: ", tensor_array_name_);
    176     OP_REQUIRES_OK(
    177         ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
    178                                dtype_, shape, value, /*tensor_array_size=*/size,
    179                                /*tensor_array_gradients=*/{}, &var));
    180     ctx->SetResourceOutput(0, var);
    181 
    182     Tensor flow(DT_FLOAT, TensorShape({}));
    183     flow.scalar<float>()() = 0.0f;
    184     ctx->SetConstantOutput(1, flow);
    185   }
    186 
    187  private:
    188   PartialTensorShape element_shape_;
    189   DataType dtype_;
    190   string tensor_array_name_;
    191 
    192   TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp);
    193 };
    194 
    195 REGISTER_XLA_OP(Name("TensorArrayV3").CompileTimeConstInput("size"),
    196                 TensorArrayOp);
    197 
    198 class TensorArrayWriteOp : public XlaOpKernel {
    199  public:
    200   explicit TensorArrayWriteOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    201     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
    202   }
    203 
    204   void Compile(XlaOpKernelContext* ctx) override {
    205     xla::ComputationBuilder* b = ctx->builder();
    206 
    207     TensorShape elem_shape = ctx->InputShape(2);
    208 
    209     // Initializes the TensorArray, if the element shape was not known at
    210     // construction time.
    211     XlaResource* resource;
    212     OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
    213     OP_REQUIRES_OK(ctx,
    214                    MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
    215 
    216     xla::ComputationDataHandle ta = resource->value();
    217     xla::ComputationDataHandle index = ctx->Input(1);
    218     xla::ComputationDataHandle value = ctx->Input(2);
    219     xla::ComputationDataHandle flow = ctx->Input(3);
    220 
    221     // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
    222     auto start_indices =
    223         b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
    224                xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
    225 
    226     TensorShape slice_shape = elem_shape;
    227     slice_shape.InsertDim(0, 1LL);
    228     auto update = b->Reshape(value, slice_shape.dim_sizes());
    229 
    230     xla::ComputationDataHandle written =
    231         DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
    232 
    233     OP_REQUIRES_OK(ctx, resource->SetValue(written));
    234     ctx->SetOutput(0, flow);
    235   }
    236 
    237  private:
    238   DataType dtype_;
    239 
    240   TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayWriteOp);
    241 };
    242 
    243 REGISTER_XLA_OP(Name("TensorArrayWriteV3"), TensorArrayWriteOp);
    244 
    245 class TensorArrayReadOp : public XlaOpKernel {
    246  public:
    247   explicit TensorArrayReadOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    248     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
    249   }
    250 
    251   void Compile(XlaOpKernelContext* ctx) override {
    252     xla::ComputationBuilder* b = ctx->builder();
    253 
    254     XlaResource* resource;
    255     OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
    256 
    257     OP_REQUIRES_OK(ctx,
    258                    CheckTensorArrayIsInitialized(name(), resource, dtype_));
    259     TensorShape ta_shape;
    260     OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
    261 
    262     xla::ComputationDataHandle ta = resource->value();
    263     xla::ComputationDataHandle index = ctx->Input(1);
    264 
    265     // start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
    266     auto start_indices =
    267         b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
    268                xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}}));
    269 
    270     auto slice_shape = ta_shape.dim_sizes();
    271     slice_shape[0] = 1LL;
    272 
    273     xla::ComputationDataHandle read =
    274         b->DynamicSlice(ta, start_indices, slice_shape);
    275 
    276     // Remove the leading '1' dimension.
    277     std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
    278     ctx->SetOutput(0, b->Reshape(read, value_shape));
    279   }
    280 
    281  private:
    282   DataType dtype_;
    283 
    284   TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayReadOp);
    285 };
    286 
    287 REGISTER_XLA_OP(Name("TensorArrayReadV3"), TensorArrayReadOp);
    288 
    289 class TensorArrayGatherOp : public XlaOpKernel {
    290  public:
    291   explicit TensorArrayGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    292     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
    293   }
    294 
    295   void Compile(XlaOpKernelContext* ctx) override {
    296     xla::ComputationBuilder* b = ctx->builder();
    297 
    298     XlaResource* resource;
    299     OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
    300 
    301     OP_REQUIRES_OK(ctx,
    302                    CheckTensorArrayIsInitialized(name(), resource, dtype_));
    303     TensorShape ta_shape;
    304     OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
    305 
    306     const TensorShape indices_shape = ctx->InputShape(1);
    307     OP_REQUIRES(ctx, indices_shape.dims() == 1,
    308                 errors::InvalidArgument("indices must be rank 1"));
    309     auto indices = ctx->Input(1);
    310     DataType index_type = ctx->input_type(1);
    311 
    312     xla::ComputationDataHandle ta = resource->value();
    313 
    314     // Look for the case where the gather takes a simple slice from the
    315     // tensor array (0, 1, 2, 3, 4, ..., N)
    316     std::vector<int64> const_indices;
    317     Status status = ctx->ConstantInputAsIntVector(1, &const_indices);
    318     if (status.ok()) {
    319       bool gather_is_dense_slice = true;
    320       for (auto i = 0; i < const_indices.size(); i++) {
    321         if (const_indices[i] != i) {
    322           gather_is_dense_slice = false;
    323           break;
    324         }
    325       }
    326 
    327       if (gather_is_dense_slice) {
    328         std::vector<int64> begin(ta_shape.dims(), 0);
    329         std::vector<int64> strides(ta_shape.dims(), 1);
    330         std::vector<int64> end(ta_shape.dims(), 1);
    331         end[0] = const_indices.size();
    332         for (auto i = 1; i < ta_shape.dims(); i++) {
    333           end[i] = ta_shape.dim_size(i);
    334         }
    335         ctx->SetOutput(0, b->Slice(ta, begin, end, strides));
    336         return;
    337       }
    338     }
    339 
    340     xla::ComputationDataHandle gather;
    341     OP_REQUIRES_OK(
    342         ctx,
    343         XlaGather(ta, ta_shape, indices, indices_shape, /*axis=*/0,
    344                   /*indices_are_nd=*/false, dtype_, index_type, b, &gather));
    345     ctx->SetOutput(0, gather);
    346   }
    347 
    348  private:
    349   DataType dtype_;
    350 
    351   TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGatherOp);
    352 };
    353 
    354 REGISTER_XLA_OP(Name("TensorArrayGatherV3"), TensorArrayGatherOp);
    355 
    356 class TensorArrayScatterOp : public XlaOpKernel {
    357  public:
    358   explicit TensorArrayScatterOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    359     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
    360   }
    361 
    362   void Compile(XlaOpKernelContext* ctx) override {
    363     xla::ComputationBuilder* b = ctx->builder();
    364 
    365     const TensorShape value_shape = ctx->InputShape(2);
    366 
    367     XlaResource* resource;
    368     OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
    369     TensorShape elem_shape = value_shape;
    370     elem_shape.RemoveDim(0);
    371     OP_REQUIRES_OK(ctx,
    372                    MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
    373 
    374     const TensorShape indices_shape = ctx->InputShape(1);
    375     OP_REQUIRES(ctx, indices_shape.dims() >= 1,
    376                 errors::InvalidArgument("indices must be rank 1"));
    377     const int num_indices = indices_shape.dim_size(0);
    378     const xla::ComputationDataHandle indices = ctx->Input(1);
    379 
    380     xla::ComputationDataHandle ta = resource->value();
    381     const xla::ComputationDataHandle value = ctx->Input(2);
    382     const xla::ComputationDataHandle flow = ctx->Input(3);
    383 
    384     // Look for the case where the scatter is for each sub-tensor in order. The
    385     // tensor array implementation allows for this to be a straight addition.
    386     bool scatter_all_elements_in_order = false;
    387     std::vector<int64> const_indices;
    388     Status status = ctx->ConstantInputAsIntVector(1, &const_indices);
    389     if (status.ok() && num_indices == value_shape.dim_size(0)) {
    390       scatter_all_elements_in_order = true;
    391       for (auto i = 0; i < num_indices; i++) {
    392         if (const_indices[i] != i) {
    393           scatter_all_elements_in_order = false;
    394           break;
    395         }
    396       }
    397     }
    398 
    399     if (scatter_all_elements_in_order) {
    400       ta = b->Add(ta, value);
    401     } else {
    402       auto slice_dims = value_shape.dim_sizes();
    403       slice_dims[0] = 1LL;
    404 
    405       std::vector<int64> value_starts(value_shape.dims(), 0);
    406       auto value_ends = value_shape.dim_sizes();
    407 
    408       std::vector<int64> value_strides(value_shape.dims(), 1);
    409 
    410       // For every (index, value) pair, update the corresponding TensorArray
    411       // storage.
    412       for (int i = 0; i < num_indices; ++i) {
    413         // Slice out part of the value.
    414         value_starts[0] = i;
    415         value_ends[0] = i + 1;
    416         auto slice = b->Slice(value, value_starts, value_ends, value_strides);
    417 
    418         // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
    419         auto index = b->Slice(indices, {i}, {i + 1}, {1});
    420         auto start_indices =
    421             b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
    422                    xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
    423         ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
    424       }
    425     }
    426 
    427     OP_REQUIRES_OK(ctx, resource->SetValue(ta));
    428     ctx->SetOutput(0, flow);
    429   }
    430 
    431  private:
    432   DataType dtype_;
    433 
    434   TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayScatterOp);
    435 };
    436 
    437 REGISTER_XLA_OP(Name("TensorArrayScatterV3"), TensorArrayScatterOp);
    438 
    439 class TensorArrayConcatOp : public XlaOpKernel {
    440  public:
    441   explicit TensorArrayConcatOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    442     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
    443   }
    444 
    445   void Compile(XlaOpKernelContext* ctx) override {
    446     xla::ComputationBuilder* b = ctx->builder();
    447 
    448     XlaResource* resource;
    449     OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
    450 
    451     OP_REQUIRES_OK(ctx,
    452                    CheckTensorArrayIsInitialized(name(), resource, dtype_));
    453     TensorShape ta_shape;
    454     OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
    455 
    456     xla::ComputationDataHandle ta = resource->value();
    457 
    458     auto ta_dims = ta_shape.dim_sizes();
    459     std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end());
    460     shape[0] *= ta_shape.dim_size(0);
    461     ctx->SetOutput(0, b->Reshape(ta, shape));
    462 
    463     Tensor lengths(DT_INT64, {ta_dims[0]});
    464     auto lengths_vec = lengths.vec<int64>();
    465     for (int i = 0; i < ta_dims[0]; ++i) {
    466       lengths_vec(i) = ta_dims[1];
    467     }
    468     ctx->SetConstantOutput(1, lengths);
    469   }
    470 
    471  private:
    472   DataType dtype_;
    473 
    474   TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayConcatOp);
    475 };
    476 
    477 REGISTER_XLA_OP(Name("TensorArrayConcatV3"), TensorArrayConcatOp);
    478 
    479 class TensorArraySplitOp : public XlaOpKernel {
    480  public:
    481   explicit TensorArraySplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    482     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
    483   }
    484 
    485   void Compile(XlaOpKernelContext* ctx) override {
    486     std::vector<int64> lengths;
    487     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &lengths));
    488 
    489     int64 length = 0;
    490     if (!lengths.empty()) {
    491       length = lengths[0];
    492       for (int i = 1; i < lengths.size(); ++i) {
    493         OP_REQUIRES(ctx, lengths[i] == length,
    494                     errors::InvalidArgument("lengths must be equal: ", length,
    495                                             " vs. ", lengths[i]));
    496       }
    497     }
    498 
    499     TensorShape value_shape = ctx->InputShape(1);
    500     OP_REQUIRES(ctx, value_shape.dims() >= 1,
    501                 errors::InvalidArgument("value must have rank >= 1, got ",
    502                                         value_shape.DebugString()));
    503     TensorShape elem_shape = value_shape;
    504     elem_shape.set_dim(0, length);
    505 
    506     xla::ComputationBuilder* b = ctx->builder();
    507     XlaResource* resource;
    508     OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
    509     OP_REQUIRES_OK(ctx,
    510                    MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
    511     xla::ComputationDataHandle ta = resource->value();
    512 
    513     TensorShape ta_shape;
    514     ta_shape.AddDim(resource->tensor_array_size());
    515     ta_shape.AppendShape(elem_shape);
    516 
    517     OP_REQUIRES(
    518         ctx, lengths.size() == resource->tensor_array_size(),
    519         errors::InvalidArgument(
    520             "TensorArray's size is not equal to the size of lengths (",
    521             lengths.size(), " vs. ", resource->tensor_array_size(), ")"));
    522 
    523     const xla::ComputationDataHandle value = ctx->Input(1);
    524     const xla::ComputationDataHandle flow = ctx->Input(3);
    525 
    526     OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(),
    527                 errors::InvalidArgument("mismatched element count ",
    528                                         value_shape.DebugString(), " vs. ",
    529                                         ta_shape.DebugString()));
    530 
    531     OP_REQUIRES_OK(ctx, resource->SetValue(b->Add(
    532                             ta, b->Reshape(value, ta_shape.dim_sizes()))));
    533 
    534     ctx->SetOutput(0, flow);
    535   }
    536 
    537  private:
    538   DataType dtype_;
    539 
    540   TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySplitOp);
    541 };
    542 
    543 REGISTER_XLA_OP(Name("TensorArraySplitV3").CompileTimeConstInput("lengths"),
    544                 TensorArraySplitOp);
    545 
    546 class TensorArraySizeOp : public XlaOpKernel {
    547  public:
    548   explicit TensorArraySizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
    549 
    550   void Compile(XlaOpKernelContext* ctx) override {
    551     XlaResource* var;
    552     OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var));
    553     Tensor size_tensor(DT_INT32, {});
    554     size_tensor.scalar<int32>()() =
    555         static_cast<int32>(var->tensor_array_size());
    556     ctx->SetConstantOutput(0, size_tensor);
    557   }
    558 
    559  private:
    560   TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySizeOp);
    561 };
    562 
    563 REGISTER_XLA_OP(Name("TensorArraySizeV3"), TensorArraySizeOp);
    564 
    565 class TensorArrayGradOp : public XlaOpKernel {
    566  public:
    567   explicit TensorArrayGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    568     OP_REQUIRES_OK(ctx, ctx->GetAttr("source", &source_));
    569   }
    570 
    571   void Compile(XlaOpKernelContext* ctx) override {
    572     xla::ComputationBuilder* b = ctx->builder();
    573 
    574     XlaResource* resource;
    575     OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
    576 
    577     OP_REQUIRES_OK(
    578         ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type()));
    579     TensorShape ta_shape;
    580     OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
    581 
    582     // Finds or looks up the corresponding gradient TensorArray, which stores
    583     // gradients computed during backpropagation.
    584     XlaResource* gradient;
    585     OP_REQUIRES_OK(
    586         ctx, resource->GetOrCreateTensorArrayGradient(source_, b, &gradient));
    587 
    588     ctx->SetResourceOutput(0, gradient);
    589     ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
    590   }
    591 
    592  private:
    593   string source_;
    594 
    595   TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGradOp);
    596 };
    597 
    598 REGISTER_XLA_OP(Name("TensorArrayGradV3"), TensorArrayGradOp);
    599 
    600 class TensorArrayCloseOp : public XlaOpKernel {
    601  public:
    602   explicit TensorArrayCloseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
    603 
    604   void Compile(XlaOpKernelContext* ctx) override {
    605     // Do nothing; XLA handles resource management.
    606   }
    607 
    608  private:
    609   TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayCloseOp);
    610 };
    611 
    612 REGISTER_XLA_OP(Name("TensorArrayCloseV3"), TensorArrayCloseOp);
    613 
    614 }  // anonymous namespace
    615 }  // namespace tensorflow
    616