1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA TensorArray operators. 17 18 #include <limits> 19 #include <vector> 20 21 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" 22 #include "tensorflow/compiler/tf2xla/shape_util.h" 23 #include "tensorflow/compiler/tf2xla/type_util.h" 24 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 25 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 27 #include "tensorflow/compiler/tf2xla/xla_resource.h" 28 #include "tensorflow/compiler/xla/literal_util.h" 29 #include "tensorflow/core/framework/op_kernel.h" 30 #include "tensorflow/core/framework/partial_tensor_shape.h" 31 #include "tensorflow/core/framework/register_types.h" 32 #include "tensorflow/core/framework/tensor.h" 33 #include "tensorflow/core/framework/tensor_types.h" 34 #include "tensorflow/core/framework/types.h" 35 #include "tensorflow/core/kernels/bounds_check.h" 36 #include "tensorflow/core/kernels/concat_lib.h" 37 #include "tensorflow/core/lib/core/status.h" 38 #include "tensorflow/core/platform/types.h" 39 40 namespace tensorflow { 41 namespace { 42 43 // Since the element shape is not always provided to the TensorArrayV3 operator, 44 // we must support lazily initialization of the TensorArray at the time of the 45 // first write. 46 // If a TensorArray `resource` has not been initialized, constructs storage for 47 // the TensorArray with elements of `elem_shape`. For both initialized and 48 // uninitialized TensorArrays, checks that the tensor has a type compatible with 49 // 'dtype' and shape compatible with 'elem_shape'. 50 Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, 51 XlaResource* resource, DataType dtype, 52 const TensorShape& elem_shape) { 53 if (resource->kind() != XlaResource::kTensorArray) { 54 return errors::InvalidArgument("Unexpected non-TensorArray resource"); 55 } 56 57 if (resource->type() != dtype) { 58 return errors::InvalidArgument( 59 "TensorArray dtype is ", DataTypeString(resource->type()), 60 " but op has dtype ", DataTypeString(dtype), "."); 61 } 62 63 TF_RET_CHECK(resource->tensor_array_size() >= 0) 64 << resource->name() << " size " << resource->tensor_array_size(); 65 66 if (!resource->initialized()) { 67 xla::ComputationDataHandle zero = 68 XlaHelpers::Zero(builder, resource->type()); 69 70 TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); 71 TF_RETURN_IF_ERROR(resource->SetZeroValue(builder)); 72 } else { 73 // Checks the elem_shape matches the TensorArray shape. 74 auto shape_or_status = builder->GetShape(resource->value()); 75 if (!shape_or_status.ok()) { 76 return shape_or_status.status(); 77 } 78 TensorShape shape; 79 TF_RETURN_IF_ERROR( 80 XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); 81 82 TensorShape ta_shape; 83 ta_shape.AddDim(resource->tensor_array_size()); 84 ta_shape.AppendShape(elem_shape); 85 if (ta_shape != shape) { 86 return errors::InvalidArgument( 87 "Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ", 88 shape.DebugString()); 89 } 90 } 91 return Status::OK(); 92 } 93 94 // Checks that the TensorArray 'resource' has been initialized, and has type 95 // 'dtype'. Sets 'shape' to the shape 96 Status CheckTensorArrayIsInitialized(const string& op_name, 97 const XlaResource* resource, 98 DataType dtype) { 99 if (resource->kind() != XlaResource::kTensorArray) { 100 return errors::InvalidArgument( 101 "Unexpected non-TensorArray resource passed to ", op_name); 102 } 103 if (!resource->initialized()) { 104 return errors::InvalidArgument("Uninitialized TensorArray passed to ", 105 op_name); 106 } 107 if (resource->type() != dtype) { 108 return errors::InvalidArgument( 109 "TensorArray dtype is ", DataTypeString(resource->type()), 110 " but op has dtype ", DataTypeString(dtype), "."); 111 } 112 113 return Status::OK(); 114 } 115 116 Status GetTensorArrayShape(const XlaResource* resource, 117 xla::ComputationBuilder* builder, 118 TensorShape* shape) { 119 *shape = resource->shape(); 120 shape->InsertDim(0, resource->tensor_array_size()); 121 return Status::OK(); 122 } 123 124 // Like ComputationBuilder::DynamicUpdateSlice, but adds 'update' to the 125 // relevant slice of 'operand'. 126 xla::ComputationDataHandle DynamicAddSlice( 127 xla::ComputationBuilder* builder, const xla::ComputationDataHandle& operand, 128 const xla::ComputationDataHandle& update, 129 const gtl::ArraySlice<int64>& update_dims, 130 const xla::ComputationDataHandle& start_indices) { 131 xla::ComputationDataHandle current = 132 builder->DynamicSlice(operand, start_indices, update_dims); 133 xla::ComputationDataHandle sum = builder->Add(current, update); 134 return builder->DynamicUpdateSlice(operand, sum, start_indices); 135 } 136 137 class TensorArrayOp : public XlaOpKernel { 138 public: 139 explicit TensorArrayOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 140 OP_REQUIRES_OK(ctx, ctx->GetAttr("element_shape", &element_shape_)); 141 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); 142 bool dynamic_size; 143 OP_REQUIRES_OK(ctx, ctx->GetAttr("dynamic_size", &dynamic_size)); 144 OP_REQUIRES( 145 ctx, !dynamic_size, 146 errors::Unimplemented( 147 "TensorArrays with dynamic size are not supported by XLA.")); 148 149 OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_array_name", &tensor_array_name_)); 150 } 151 152 void Compile(XlaOpKernelContext* ctx) override { 153 int64 size; 154 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &size)); 155 OP_REQUIRES(ctx, size >= 0, 156 errors::InvalidArgument("TensorArray size must be >= 0")); 157 158 xla::ComputationBuilder* b = ctx->builder(); 159 160 // Initializes the TensorArray value if we know the element shape. 161 // Otherwise, defer initialization to the first write. 162 xla::ComputationDataHandle value; 163 TensorShape shape; 164 if (element_shape_.IsFullyDefined()) { 165 CHECK(element_shape_.AsTensorShape(&shape)); 166 TensorShape ta_shape; 167 ta_shape.AddDim(size); 168 ta_shape.AppendShape(shape); 169 xla::ComputationDataHandle zero = XlaHelpers::Zero(b, dtype_); 170 value = b->Broadcast(zero, ta_shape.dim_sizes()); 171 } 172 173 XlaContext& xc = XlaContext::Get(ctx); 174 XlaResource* var; 175 string name = strings::StrCat("TensorArray: ", tensor_array_name_); 176 OP_REQUIRES_OK( 177 ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), 178 dtype_, shape, value, /*tensor_array_size=*/size, 179 /*tensor_array_gradients=*/{}, &var)); 180 ctx->SetResourceOutput(0, var); 181 182 Tensor flow(DT_FLOAT, TensorShape({})); 183 flow.scalar<float>()() = 0.0f; 184 ctx->SetConstantOutput(1, flow); 185 } 186 187 private: 188 PartialTensorShape element_shape_; 189 DataType dtype_; 190 string tensor_array_name_; 191 192 TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp); 193 }; 194 195 REGISTER_XLA_OP(Name("TensorArrayV3").CompileTimeConstInput("size"), 196 TensorArrayOp); 197 198 class TensorArrayWriteOp : public XlaOpKernel { 199 public: 200 explicit TensorArrayWriteOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 201 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 202 } 203 204 void Compile(XlaOpKernelContext* ctx) override { 205 xla::ComputationBuilder* b = ctx->builder(); 206 207 TensorShape elem_shape = ctx->InputShape(2); 208 209 // Initializes the TensorArray, if the element shape was not known at 210 // construction time. 211 XlaResource* resource; 212 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); 213 OP_REQUIRES_OK(ctx, 214 MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); 215 216 xla::ComputationDataHandle ta = resource->value(); 217 xla::ComputationDataHandle index = ctx->Input(1); 218 xla::ComputationDataHandle value = ctx->Input(2); 219 xla::ComputationDataHandle flow = ctx->Input(3); 220 221 // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. 222 auto start_indices = 223 b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0), 224 xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); 225 226 TensorShape slice_shape = elem_shape; 227 slice_shape.InsertDim(0, 1LL); 228 auto update = b->Reshape(value, slice_shape.dim_sizes()); 229 230 xla::ComputationDataHandle written = 231 DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); 232 233 OP_REQUIRES_OK(ctx, resource->SetValue(written)); 234 ctx->SetOutput(0, flow); 235 } 236 237 private: 238 DataType dtype_; 239 240 TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayWriteOp); 241 }; 242 243 REGISTER_XLA_OP(Name("TensorArrayWriteV3"), TensorArrayWriteOp); 244 245 class TensorArrayReadOp : public XlaOpKernel { 246 public: 247 explicit TensorArrayReadOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 248 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); 249 } 250 251 void Compile(XlaOpKernelContext* ctx) override { 252 xla::ComputationBuilder* b = ctx->builder(); 253 254 XlaResource* resource; 255 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); 256 257 OP_REQUIRES_OK(ctx, 258 CheckTensorArrayIsInitialized(name(), resource, dtype_)); 259 TensorShape ta_shape; 260 OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); 261 262 xla::ComputationDataHandle ta = resource->value(); 263 xla::ComputationDataHandle index = ctx->Input(1); 264 265 // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. 266 auto start_indices = 267 b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0), 268 xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}})); 269 270 auto slice_shape = ta_shape.dim_sizes(); 271 slice_shape[0] = 1LL; 272 273 xla::ComputationDataHandle read = 274 b->DynamicSlice(ta, start_indices, slice_shape); 275 276 // Remove the leading '1' dimension. 277 std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end()); 278 ctx->SetOutput(0, b->Reshape(read, value_shape)); 279 } 280 281 private: 282 DataType dtype_; 283 284 TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayReadOp); 285 }; 286 287 REGISTER_XLA_OP(Name("TensorArrayReadV3"), TensorArrayReadOp); 288 289 class TensorArrayGatherOp : public XlaOpKernel { 290 public: 291 explicit TensorArrayGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 292 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); 293 } 294 295 void Compile(XlaOpKernelContext* ctx) override { 296 xla::ComputationBuilder* b = ctx->builder(); 297 298 XlaResource* resource; 299 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); 300 301 OP_REQUIRES_OK(ctx, 302 CheckTensorArrayIsInitialized(name(), resource, dtype_)); 303 TensorShape ta_shape; 304 OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); 305 306 const TensorShape indices_shape = ctx->InputShape(1); 307 OP_REQUIRES(ctx, indices_shape.dims() == 1, 308 errors::InvalidArgument("indices must be rank 1")); 309 auto indices = ctx->Input(1); 310 DataType index_type = ctx->input_type(1); 311 312 xla::ComputationDataHandle ta = resource->value(); 313 314 // Look for the case where the gather takes a simple slice from the 315 // tensor array (0, 1, 2, 3, 4, ..., N) 316 std::vector<int64> const_indices; 317 Status status = ctx->ConstantInputAsIntVector(1, &const_indices); 318 if (status.ok()) { 319 bool gather_is_dense_slice = true; 320 for (auto i = 0; i < const_indices.size(); i++) { 321 if (const_indices[i] != i) { 322 gather_is_dense_slice = false; 323 break; 324 } 325 } 326 327 if (gather_is_dense_slice) { 328 std::vector<int64> begin(ta_shape.dims(), 0); 329 std::vector<int64> strides(ta_shape.dims(), 1); 330 std::vector<int64> end(ta_shape.dims(), 1); 331 end[0] = const_indices.size(); 332 for (auto i = 1; i < ta_shape.dims(); i++) { 333 end[i] = ta_shape.dim_size(i); 334 } 335 ctx->SetOutput(0, b->Slice(ta, begin, end, strides)); 336 return; 337 } 338 } 339 340 xla::ComputationDataHandle gather; 341 OP_REQUIRES_OK( 342 ctx, 343 XlaGather(ta, ta_shape, indices, indices_shape, /*axis=*/0, 344 /*indices_are_nd=*/false, dtype_, index_type, b, &gather)); 345 ctx->SetOutput(0, gather); 346 } 347 348 private: 349 DataType dtype_; 350 351 TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGatherOp); 352 }; 353 354 REGISTER_XLA_OP(Name("TensorArrayGatherV3"), TensorArrayGatherOp); 355 356 class TensorArrayScatterOp : public XlaOpKernel { 357 public: 358 explicit TensorArrayScatterOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 359 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 360 } 361 362 void Compile(XlaOpKernelContext* ctx) override { 363 xla::ComputationBuilder* b = ctx->builder(); 364 365 const TensorShape value_shape = ctx->InputShape(2); 366 367 XlaResource* resource; 368 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); 369 TensorShape elem_shape = value_shape; 370 elem_shape.RemoveDim(0); 371 OP_REQUIRES_OK(ctx, 372 MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); 373 374 const TensorShape indices_shape = ctx->InputShape(1); 375 OP_REQUIRES(ctx, indices_shape.dims() >= 1, 376 errors::InvalidArgument("indices must be rank 1")); 377 const int num_indices = indices_shape.dim_size(0); 378 const xla::ComputationDataHandle indices = ctx->Input(1); 379 380 xla::ComputationDataHandle ta = resource->value(); 381 const xla::ComputationDataHandle value = ctx->Input(2); 382 const xla::ComputationDataHandle flow = ctx->Input(3); 383 384 // Look for the case where the scatter is for each sub-tensor in order. The 385 // tensor array implementation allows for this to be a straight addition. 386 bool scatter_all_elements_in_order = false; 387 std::vector<int64> const_indices; 388 Status status = ctx->ConstantInputAsIntVector(1, &const_indices); 389 if (status.ok() && num_indices == value_shape.dim_size(0)) { 390 scatter_all_elements_in_order = true; 391 for (auto i = 0; i < num_indices; i++) { 392 if (const_indices[i] != i) { 393 scatter_all_elements_in_order = false; 394 break; 395 } 396 } 397 } 398 399 if (scatter_all_elements_in_order) { 400 ta = b->Add(ta, value); 401 } else { 402 auto slice_dims = value_shape.dim_sizes(); 403 slice_dims[0] = 1LL; 404 405 std::vector<int64> value_starts(value_shape.dims(), 0); 406 auto value_ends = value_shape.dim_sizes(); 407 408 std::vector<int64> value_strides(value_shape.dims(), 1); 409 410 // For every (index, value) pair, update the corresponding TensorArray 411 // storage. 412 for (int i = 0; i < num_indices; ++i) { 413 // Slice out part of the value. 414 value_starts[0] = i; 415 value_ends[0] = i + 1; 416 auto slice = b->Slice(value, value_starts, value_ends, value_strides); 417 418 // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. 419 auto index = b->Slice(indices, {i}, {i + 1}, {1}); 420 auto start_indices = 421 b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0), 422 xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); 423 ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); 424 } 425 } 426 427 OP_REQUIRES_OK(ctx, resource->SetValue(ta)); 428 ctx->SetOutput(0, flow); 429 } 430 431 private: 432 DataType dtype_; 433 434 TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayScatterOp); 435 }; 436 437 REGISTER_XLA_OP(Name("TensorArrayScatterV3"), TensorArrayScatterOp); 438 439 class TensorArrayConcatOp : public XlaOpKernel { 440 public: 441 explicit TensorArrayConcatOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 442 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); 443 } 444 445 void Compile(XlaOpKernelContext* ctx) override { 446 xla::ComputationBuilder* b = ctx->builder(); 447 448 XlaResource* resource; 449 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); 450 451 OP_REQUIRES_OK(ctx, 452 CheckTensorArrayIsInitialized(name(), resource, dtype_)); 453 TensorShape ta_shape; 454 OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); 455 456 xla::ComputationDataHandle ta = resource->value(); 457 458 auto ta_dims = ta_shape.dim_sizes(); 459 std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end()); 460 shape[0] *= ta_shape.dim_size(0); 461 ctx->SetOutput(0, b->Reshape(ta, shape)); 462 463 Tensor lengths(DT_INT64, {ta_dims[0]}); 464 auto lengths_vec = lengths.vec<int64>(); 465 for (int i = 0; i < ta_dims[0]; ++i) { 466 lengths_vec(i) = ta_dims[1]; 467 } 468 ctx->SetConstantOutput(1, lengths); 469 } 470 471 private: 472 DataType dtype_; 473 474 TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayConcatOp); 475 }; 476 477 REGISTER_XLA_OP(Name("TensorArrayConcatV3"), TensorArrayConcatOp); 478 479 class TensorArraySplitOp : public XlaOpKernel { 480 public: 481 explicit TensorArraySplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 482 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 483 } 484 485 void Compile(XlaOpKernelContext* ctx) override { 486 std::vector<int64> lengths; 487 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &lengths)); 488 489 int64 length = 0; 490 if (!lengths.empty()) { 491 length = lengths[0]; 492 for (int i = 1; i < lengths.size(); ++i) { 493 OP_REQUIRES(ctx, lengths[i] == length, 494 errors::InvalidArgument("lengths must be equal: ", length, 495 " vs. ", lengths[i])); 496 } 497 } 498 499 TensorShape value_shape = ctx->InputShape(1); 500 OP_REQUIRES(ctx, value_shape.dims() >= 1, 501 errors::InvalidArgument("value must have rank >= 1, got ", 502 value_shape.DebugString())); 503 TensorShape elem_shape = value_shape; 504 elem_shape.set_dim(0, length); 505 506 xla::ComputationBuilder* b = ctx->builder(); 507 XlaResource* resource; 508 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); 509 OP_REQUIRES_OK(ctx, 510 MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); 511 xla::ComputationDataHandle ta = resource->value(); 512 513 TensorShape ta_shape; 514 ta_shape.AddDim(resource->tensor_array_size()); 515 ta_shape.AppendShape(elem_shape); 516 517 OP_REQUIRES( 518 ctx, lengths.size() == resource->tensor_array_size(), 519 errors::InvalidArgument( 520 "TensorArray's size is not equal to the size of lengths (", 521 lengths.size(), " vs. ", resource->tensor_array_size(), ")")); 522 523 const xla::ComputationDataHandle value = ctx->Input(1); 524 const xla::ComputationDataHandle flow = ctx->Input(3); 525 526 OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(), 527 errors::InvalidArgument("mismatched element count ", 528 value_shape.DebugString(), " vs. ", 529 ta_shape.DebugString())); 530 531 OP_REQUIRES_OK(ctx, resource->SetValue(b->Add( 532 ta, b->Reshape(value, ta_shape.dim_sizes())))); 533 534 ctx->SetOutput(0, flow); 535 } 536 537 private: 538 DataType dtype_; 539 540 TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySplitOp); 541 }; 542 543 REGISTER_XLA_OP(Name("TensorArraySplitV3").CompileTimeConstInput("lengths"), 544 TensorArraySplitOp); 545 546 class TensorArraySizeOp : public XlaOpKernel { 547 public: 548 explicit TensorArraySizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 549 550 void Compile(XlaOpKernelContext* ctx) override { 551 XlaResource* var; 552 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var)); 553 Tensor size_tensor(DT_INT32, {}); 554 size_tensor.scalar<int32>()() = 555 static_cast<int32>(var->tensor_array_size()); 556 ctx->SetConstantOutput(0, size_tensor); 557 } 558 559 private: 560 TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySizeOp); 561 }; 562 563 REGISTER_XLA_OP(Name("TensorArraySizeV3"), TensorArraySizeOp); 564 565 class TensorArrayGradOp : public XlaOpKernel { 566 public: 567 explicit TensorArrayGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 568 OP_REQUIRES_OK(ctx, ctx->GetAttr("source", &source_)); 569 } 570 571 void Compile(XlaOpKernelContext* ctx) override { 572 xla::ComputationBuilder* b = ctx->builder(); 573 574 XlaResource* resource; 575 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); 576 577 OP_REQUIRES_OK( 578 ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type())); 579 TensorShape ta_shape; 580 OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); 581 582 // Finds or looks up the corresponding gradient TensorArray, which stores 583 // gradients computed during backpropagation. 584 XlaResource* gradient; 585 OP_REQUIRES_OK( 586 ctx, resource->GetOrCreateTensorArrayGradient(source_, b, &gradient)); 587 588 ctx->SetResourceOutput(0, gradient); 589 ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); 590 } 591 592 private: 593 string source_; 594 595 TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGradOp); 596 }; 597 598 REGISTER_XLA_OP(Name("TensorArrayGradV3"), TensorArrayGradOp); 599 600 class TensorArrayCloseOp : public XlaOpKernel { 601 public: 602 explicit TensorArrayCloseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 603 604 void Compile(XlaOpKernelContext* ctx) override { 605 // Do nothing; XLA handles resource management. 606 } 607 608 private: 609 TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayCloseOp); 610 }; 611 612 REGISTER_XLA_OP(Name("TensorArrayCloseV3"), TensorArrayCloseOp); 613 614 } // anonymous namespace 615 } // namespace tensorflow 616