1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """TensorArray: a dynamically sized array of Tensors. 16 17 @@TensorArray 18 """ 19 # Mixture of pep8 and non-pep8 names, so disable pylint bad-name 20 # pylint: disable=g-bad-name 21 from __future__ import absolute_import 22 from __future__ import division 23 from __future__ import print_function 24 25 import contextlib 26 27 from tensorflow.python.eager import context 28 from tensorflow.python.framework import constant_op 29 from tensorflow.python.framework import dtypes 30 from tensorflow.python.framework import errors_impl 31 from tensorflow.python.framework import ops 32 from tensorflow.python.framework import tensor_shape 33 from tensorflow.python.framework import tensor_util 34 from tensorflow.python.ops import array_ops 35 from tensorflow.python.ops import gen_data_flow_ops 36 from tensorflow.python.ops import math_ops 37 from tensorflow.python.util import tf_should_use 38 from tensorflow.python.util.tf_export import tf_export 39 40 41 # _GraphTensorArray accesses many of the hidden generated ops, but is in 42 # fact built to wrap these methods. 43 # pylint: disable=protected-access 44 class _GraphTensorArray(object): 45 """Graph-mode implementation of TensorArray. 46 """ 47 48 def __init__(self, 49 dtype, 50 size=None, 51 dynamic_size=None, 52 clear_after_read=None, 53 tensor_array_name=None, 54 handle=None, 55 flow=None, 56 infer_shape=True, 57 element_shape=None, 58 colocate_with_first_write_call=True, 59 name=None): 60 """Constructs a graph mode TensorArray. 61 62 Args: 63 dtype: (required) data type of the TensorArray. 64 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 65 Required if handle is not provided. 66 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 67 can grow the TensorArray past its initial size. Default: False. 68 clear_after_read: Boolean (optional, default: True). If True, clear 69 TensorArray values after reading them. This disables read-many 70 semantics, but allows early release of memory. 71 tensor_array_name: (optional) Python string: the name of the TensorArray. 72 This is used when creating the TensorArray handle. If this value is 73 set, handle should be None. 74 handle: (optional) A `Tensor` handle to an existing TensorArray. If this 75 is set, tensor_array_name should be None. Only supported in graph mode. 76 flow: (optional) A float `Tensor` scalar coming from an existing 77 `TensorArray.flow`. Only supported in graph mode. 78 infer_shape: (optional, default: True) If True, shape inference 79 is enabled. In this case, all elements must have the same shape. 80 element_shape: (optional, default: None) A `TensorShape` object specifying 81 the shape constraints of each of the elements of the TensorArray. 82 Need not be fully defined. 83 colocate_with_first_write_call: If `True`, the TensorArray will be 84 colocated on the same device as the Tensor used on its first write 85 (write operations include `write`, `unstack`, and `split`). If `False`, 86 the TensorArray will be placed on the device determined by the 87 device context available during its initialization. 88 name: A name for the operation (optional). 89 90 Raises: 91 ValueError: if both handle and tensor_array_name are provided. 92 TypeError: if handle is provided but is not a Tensor. 93 """ 94 if handle is not None and tensor_array_name: 95 raise ValueError( 96 "Cannot construct with both handle and tensor_array_name") 97 if handle is not None and not isinstance(handle, ops.Tensor): 98 raise TypeError("Handle must be a Tensor") 99 if handle is None and size is None: 100 raise ValueError("Size must be provided if handle is not provided") 101 if handle is not None and size is not None: 102 raise ValueError("Cannot provide both a handle and size " 103 "at the same time") 104 if handle is not None and element_shape is not None: 105 raise ValueError("Cannot provide both a handle and element_shape " 106 "at the same time") 107 if handle is not None and dynamic_size is not None: 108 raise ValueError("Cannot provide both a handle and dynamic_size " 109 "at the same time") 110 if handle is not None and clear_after_read is not None: 111 raise ValueError("Cannot provide both a handle and clear_after_read " 112 "at the same time") 113 114 if clear_after_read is None: 115 clear_after_read = True 116 dynamic_size = dynamic_size or False 117 118 self._dtype = dtype 119 120 # Used to keep track of what tensors the TensorArray should be 121 # colocated with. We choose to colocate the TensorArray with the 122 # first tensor written to it. 123 self._colocate_with_first_write_call = colocate_with_first_write_call 124 if colocate_with_first_write_call: 125 self._colocate_with = [] 126 else: 127 self._colocate_with = None 128 129 # Record the current static shape for the array elements. The element 130 # shape is defined either by `element_shape` or the shape of the tensor 131 # of the first write. If `infer_shape` is true, all writes checks for 132 # shape equality. 133 if element_shape is None: 134 self._infer_shape = infer_shape 135 self._element_shape = [] 136 else: 137 self._infer_shape = True 138 self._element_shape = [tensor_shape.TensorShape(element_shape)] 139 with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope: 140 if handle is not None: 141 self._handle = handle 142 if flow is None: 143 raise ValueError("flow must not be None if handle is not None.") 144 self._flow = flow 145 else: 146 # Construct the TensorArray with an empty device. The first 147 # write into the TensorArray from a Tensor with a set device 148 # will retroactively set the device value of this op. 149 def create(): 150 """Create the TensorArray op.""" 151 return gen_data_flow_ops._tensor_array_v3( 152 dtype=dtype, 153 size=size, 154 element_shape=element_shape, 155 identical_element_shapes=infer_shape, 156 dynamic_size=dynamic_size, 157 clear_after_read=clear_after_read, 158 tensor_array_name=tensor_array_name, 159 name=scope) 160 if colocate_with_first_write_call: 161 with ops.device(None), ops.colocate_with(None, ignore_existing=True): 162 self._handle, self._flow = create() 163 else: 164 self._handle, self._flow = create() 165 166 @property 167 def flow(self): 168 return self._flow 169 170 @property 171 def dtype(self): 172 return self._dtype 173 174 @property 175 def handle(self): 176 return self._handle 177 178 def _merge_element_shape(self, shape): 179 """Changes the element shape of the array given a shape to merge with. 180 181 Args: 182 shape: A `TensorShape` object to merge with. 183 184 Raises: 185 ValueError: if the provided shape is incompatible with the current 186 element shape of the `TensorArray`. 187 """ 188 189 if self._element_shape: 190 if not shape.is_compatible_with(self._element_shape[0]): 191 raise ValueError( 192 "Inconsistent shapes: saw %s but expected %s " 193 "(and infer_shape=True)" % (shape, self._element_shape[0])) 194 self._element_shape[0] = self._element_shape[0].merge_with(shape) 195 else: 196 self._element_shape.append(shape) 197 198 @contextlib.contextmanager 199 def _maybe_colocate_with(self, value): 200 """Colocate operations with an internal colocation group or `value`. 201 202 Args: 203 value: `Tensor`, the tensor to try to colocate with. 204 205 Yields: 206 Does not yield anything, but the new context is a colocation context. 207 208 If no internal colocation group is set, colocate with `value` and set 209 the internal colocation group to be value. 210 """ 211 if not self._colocate_with_first_write_call: 212 yield 213 else: 214 if not self._colocate_with: 215 self._colocate_with.append(value) 216 with ops.colocate_with(self._colocate_with[0]): 217 yield 218 219 def identity(self): 220 """See TensorArray.""" 221 flow = array_ops.identity(self._flow) 222 ta = TensorArray( 223 dtype=self._dtype, handle=self._handle, flow=flow, 224 infer_shape=self._infer_shape, 225 colocate_with_first_write_call=self._colocate_with_first_write_call) 226 ta._element_shape = self._element_shape 227 ta._colocate_with = self._colocate_with 228 return ta 229 230 def grad(self, source, flow=None, name=None): 231 """See TensorArray.""" 232 # tensor_array_grad requires a flow input when forward 233 # TensorArrays are dynamically sized. This forces the creation 234 # of the grad TensorArray only once the final forward array's size 235 # is fixed. 236 if flow is None: 237 flow = self.flow 238 with ops.name_scope(name, "TensorArrayGrad", [self._handle]): 239 with ops.colocate_with(self._handle): 240 g_handle, unused_flow = gen_data_flow_ops._tensor_array_grad_v3( 241 handle=self._handle, source=source, flow_in=flow, name=name) 242 with ops.control_dependencies([g_handle]): 243 flow = array_ops.identity(flow, name="gradient_flow") 244 g = TensorArray( 245 dtype=self._dtype, 246 handle=g_handle, 247 flow=flow, 248 infer_shape=self._infer_shape, 249 colocate_with_first_write_call=False) 250 g._element_shape = self._element_shape 251 return g 252 253 def read(self, index, name=None): 254 """See TensorArray.""" 255 value = gen_data_flow_ops._tensor_array_read_v3( 256 handle=self._handle, 257 index=index, 258 flow_in=self._flow, 259 dtype=self._dtype, 260 name=name) 261 if self._element_shape: 262 value.set_shape(self._element_shape[0].dims) 263 return value 264 265 @tf_should_use.should_use_result 266 def write(self, index, value, name=None): 267 """See TensorArray.""" 268 with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]): 269 value = ops.convert_to_tensor(value, name="value") 270 if self._infer_shape: 271 self._merge_element_shape(value.shape) 272 with self._maybe_colocate_with(value): 273 flow_out = gen_data_flow_ops._tensor_array_write_v3( 274 handle=self._handle, 275 index=index, 276 value=value, 277 flow_in=self._flow, 278 name=name) 279 ta = TensorArray( 280 dtype=self._dtype, handle=self._handle, flow=flow_out, 281 colocate_with_first_write_call=self._colocate_with_first_write_call) 282 ta._infer_shape = self._infer_shape 283 ta._element_shape = self._element_shape 284 ta._colocate_with = self._colocate_with 285 return ta 286 287 def stack(self, name=None): 288 """See TensorArray.""" 289 with ops.colocate_with(self._handle): 290 with ops.name_scope(name, "TensorArrayStack", [self._handle]): 291 return self.gather(math_ops.range(0, self.size()), name=name) 292 293 def gather(self, indices, name=None): 294 """See TensorArray.""" 295 if self._element_shape: 296 element_shape = self._element_shape[0] 297 else: 298 element_shape = tensor_shape.TensorShape(None) 299 value = gen_data_flow_ops._tensor_array_gather_v3( 300 handle=self._handle, 301 indices=indices, 302 flow_in=self._flow, 303 dtype=self._dtype, 304 name=name, 305 element_shape=element_shape) 306 if self._element_shape and self._element_shape[0].dims is not None: 307 value.set_shape([None] + self._element_shape[0].dims) 308 return value 309 310 def concat(self, name=None): 311 """See TensorArray.""" 312 if self._element_shape and self._element_shape[0].dims is not None: 313 element_shape_except0 = ( 314 tensor_shape.TensorShape(self._element_shape[0].dims[1:])) 315 else: 316 element_shape_except0 = tensor_shape.TensorShape(None) 317 value, _ = gen_data_flow_ops._tensor_array_concat_v3( 318 handle=self._handle, 319 flow_in=self._flow, 320 dtype=self._dtype, 321 name=name, 322 element_shape_except0=element_shape_except0) 323 if self._element_shape and self._element_shape[0].dims is not None: 324 value.set_shape([None] + self._element_shape[0].dims[1:]) 325 return value 326 327 @tf_should_use.should_use_result 328 def unstack(self, value, name=None): 329 """See TensorArray.""" 330 with ops.name_scope(name, "TensorArrayUnstack", [self._handle, value]): 331 num_elements = array_ops.shape(value)[0] 332 return self.scatter( 333 indices=math_ops.range(0, num_elements), value=value, name=name) 334 335 @tf_should_use.should_use_result 336 def scatter(self, indices, value, name=None): 337 """See TensorArray.""" 338 with ops.name_scope(name, "TensorArrayScatter", 339 [self._handle, value, indices]): 340 value = ops.convert_to_tensor(value, name="value") 341 if self._infer_shape and context.in_graph_mode(): 342 self._merge_element_shape(value.shape[1:]) 343 with self._maybe_colocate_with(value): 344 flow_out = gen_data_flow_ops._tensor_array_scatter_v3( 345 handle=self._handle, 346 indices=indices, 347 value=value, 348 flow_in=self._flow, 349 name=name) 350 ta = TensorArray( 351 dtype=self._dtype, handle=self._handle, flow=flow_out, 352 colocate_with_first_write_call=self._colocate_with_first_write_call) 353 ta._infer_shape = self._infer_shape 354 ta._element_shape = self._element_shape 355 ta._colocate_with = self._colocate_with 356 return ta 357 358 @tf_should_use.should_use_result 359 def split(self, value, lengths, name=None): 360 """See TensorArray.""" 361 with ops.name_scope(name, "TensorArraySplit", 362 [self._handle, value, lengths]): 363 value = ops.convert_to_tensor(value, name="value") 364 with self._maybe_colocate_with(value): 365 lengths_64 = math_ops.to_int64(lengths) 366 if self._infer_shape and context.in_graph_mode(): 367 clengths = tensor_util.constant_value(lengths_64) 368 if value.shape.dims is not None: 369 if clengths is not None and clengths.max() == clengths.min(): 370 self._merge_element_shape( 371 tensor_shape.TensorShape([clengths[0]]).concatenate( 372 value.shape[1:])) 373 flow_out = gen_data_flow_ops._tensor_array_split_v3( 374 handle=self._handle, 375 value=value, 376 lengths=lengths_64, 377 flow_in=self._flow, 378 name=name) 379 ta = TensorArray( 380 dtype=self._dtype, handle=self._handle, flow=flow_out, 381 colocate_with_first_write_call=self._colocate_with_first_write_call) 382 ta._infer_shape = self._infer_shape 383 ta._element_shape = self._element_shape 384 ta._colocate_with = self._colocate_with 385 return ta 386 387 def size(self, name=None): 388 """See TensorArray.""" 389 return gen_data_flow_ops._tensor_array_size_v3( 390 handle=self._handle, flow_in=self.flow, name=name) 391 392 @tf_should_use.should_use_result 393 def close(self, name=None): 394 """See TensorArray.""" 395 return gen_data_flow_ops._tensor_array_close_v3( 396 handle=self._handle, name=name) 397 398 # pylint: enable=protected-access 399 400 401 # pylint: disable=protected-access 402 def _eager_write_no_copy(ta, index, value): 403 """Writes value into an _EagerTensorArray without creating a new TensorArray. 404 405 Args: 406 ta: _EagerTensorArray into which to write value. 407 index: 0-D. int32 scalar with the index to write to. 408 value: N-D. Tensor of type `dtype`. The Tensor to write to this index. 409 410 Raises: 411 errors_impl.AlreadyExistsError: attempting to overwrite an entry. 412 errors_impl.InvalidArgumentError: value dtype does not match `ta`'s dtype. 413 errors_impl.OutOfRangeError: `index` is out of bounds. 414 ValueError: shape of `value` is not consistent with inferred shape. 415 """ 416 417 if isinstance(index, ops.EagerTensor): 418 index = index.numpy() 419 420 if index < 0: 421 raise errors_impl.OutOfRangeError( 422 None, None, 423 "Writing to negative indices (index %d) is not allowed." % index) 424 425 tensor_array = ta._tensor_array 426 size = len(tensor_array) 427 if index >= size: 428 if not ta._dynamic_size: 429 raise errors_impl.OutOfRangeError( 430 None, None, 431 "Tried to write to index %d but array is not resizeable and size " 432 "is: %d" % (index, size)) 433 tensor_array.extend([None for _ in range(index - size + 1)]) 434 435 if not isinstance(value, ops.EagerTensor): 436 value = constant_op.constant(value) 437 438 if ta._infer_shape: 439 if ta._element_shape is None: 440 ta._element_shape = value.shape 441 elif ta._element_shape != value.shape: 442 raise ValueError("Incompatible shape for value (%s), expected (%s)" % 443 (value.shape.as_list(), ta._element_shape.as_list())) 444 445 if ta._dtype != value.dtype: 446 raise errors_impl.InvalidArgumentError( 447 None, None, 448 "TensorArray dtype is %s but Op is trying to write dtype %s" % 449 (ta._dtype.name, value.dtype.name)) 450 451 if ta._tensor_array[index] is not None: 452 raise errors_impl.AlreadyExistsError( 453 None, None, 454 "Could not write to TensorArray index %d because it has already been " 455 "written to." % index) 456 457 tensor_array[index] = value 458 459 # pylint: enable=protected-access 460 461 462 class _EagerTensorArray(object): 463 """Eager-mode implementation of TensorArray. 464 """ 465 466 def __init__(self, 467 dtype, 468 size=None, 469 dynamic_size=None, 470 clear_after_read=None, 471 tensor_array_name=None, 472 handle=None, 473 flow=None, 474 infer_shape=True, 475 element_shape=None, 476 colocate_with_first_write_call=True, 477 name=None): 478 """Constructs an Eager mode TensorArray. 479 480 Args: 481 dtype: (required) data type of the TensorArray. 482 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 483 Required if handle is not provided. 484 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 485 can grow the TensorArray past its initial size. Default: False. 486 clear_after_read: Boolean (optional, default: True). If True, clear 487 TensorArray values after reading them. This disables read-many 488 semantics, but allows early release of memory. 489 tensor_array_name: unused. 490 handle: unsupported. 491 flow: unsupported. 492 infer_shape: used for error checking, same semantics as TensorArray. 493 element_shape: used for error checking, same semantics as TensorArray. 494 colocate_with_first_write_call: unsupported. 495 name: unsupported. 496 497 Raises: 498 ValueError: handle or flow are supplied, or if size is not supplied. 499 """ 500 501 del (flow, tensor_array_name, name) # not meaningful in Eager 502 503 if handle is not None: 504 raise ValueError("TensorArray handles are not supported in Eager mode.") 505 if size is None: 506 raise ValueError("Size must be declared for TensorArrays in Eager mode.") 507 508 # These attributes are not meaningful in Eager, but some library functions 509 # (e.g., those in control_flow_ops.py) access them to create new tensor 510 # arrays; as such, we define them for the sake of compatibility. 511 self._handle = None 512 # we assign a dummy value to _flow in case other code assumes it to be 513 # a Tensor 514 self._flow = constant_op.constant(0, dtype=dtypes.int32) 515 self._infer_shape = infer_shape 516 self._element_shape = element_shape 517 self._colocate_with_first_write_call = colocate_with_first_write_call 518 519 self._dtype = dtype 520 self._dynamic_size = dynamic_size or False 521 self._clear_after_read = ( 522 True if clear_after_read is None else clear_after_read) 523 self._previously_read_indices = [] 524 525 if isinstance(size, ops.EagerTensor): 526 size = size.numpy() 527 self._tensor_array = [None for _ in range(size)] 528 529 @property 530 def flow(self): 531 """Flows are not meaningful in Eager; this exists for compatibility.""" 532 return self._flow 533 534 @property 535 def dtype(self): 536 return self._dtype 537 538 @property 539 def handle(self): 540 """Handles are not meaningful in Eager; this exists for compatibility.""" 541 return self._handle 542 543 def _identity_without_array(self): 544 """Returns a new TensorArray with the same properties as this Eager one. 545 546 NB: Does not set the underlying _tensor_array attribute. 547 """ 548 ta = TensorArray( 549 dtype=self._dtype, 550 size=len(self._tensor_array), 551 dynamic_size=self._dynamic_size, 552 clear_after_read=self._clear_after_read, 553 handle=self._handle, 554 flow=self._flow, 555 infer_shape=self._infer_shape, 556 element_shape=self._element_shape, 557 colocate_with_first_write_call=self._colocate_with_first_write_call) 558 ta._implementation._previously_read_indices = self._previously_read_indices # pylint: disable=protected-access 559 return ta 560 561 def identity(self): 562 """See TensorArray.""" 563 ta = self._identity_without_array() 564 ta._implementation._tensor_array = [t for t in self._tensor_array] # pylint: disable=protected-access 565 return ta 566 567 def grad(self, source, flow=None, name=None): 568 raise NotImplementedError( 569 "TensorArray.grad is not supported in Eager mode; Eager's gradient " 570 "implementation does not use/need this function to compute gradients " 571 "of operations that use TensorArrays.") 572 573 def read(self, index, name=None): 574 """See TensorArray.""" 575 del name # not meaningful in Eager mode 576 577 if isinstance(index, ops.EagerTensor): 578 index = index.numpy() 579 580 if index < 0: 581 raise errors_impl.OutOfRangeError( 582 None, None, 583 "Reading from negative indices (index %d) is not allowed." % index) 584 585 if index >= len(self._tensor_array): 586 raise errors_impl.OutOfRangeError( 587 None, None, "Tried to read from index %d but array size is: %d" % 588 (index, len(self._tensor_array))) 589 590 tensor = self._tensor_array[index] 591 if tensor is None: 592 if index in self._previously_read_indices: 593 raise errors_impl.InvalidArgumentError( 594 None, None, 595 "Could not read index %d twice because it was cleared after " 596 "a previous read (perhaps try setting clear_after_read = false?)" % 597 index) 598 else: 599 tensor = self._maybe_zero(index) 600 601 if self._clear_after_read: 602 self._tensor_array[index] = None 603 self._previously_read_indices.append(index) 604 return tensor 605 606 def write(self, index, value, name=None): 607 """See TensorArray.""" 608 del name # not meaningful in Eager mode 609 ta = self.identity() 610 _eager_write_no_copy(ta._implementation, index, value) # pylint: disable=protected-access 611 return ta 612 613 def _maybe_zero(self, ix): 614 val = self._tensor_array[ix] 615 if val is None: 616 val = self._tensor_array[ix] = array_ops.zeros( 617 shape=self._element_shape, dtype=self._dtype) 618 return val 619 620 def stack(self, name=None): 621 """See TensorArray.""" 622 if self._tensor_array: 623 for ix in range(len(self._tensor_array)): 624 self._maybe_zero(ix) 625 return array_ops.stack(self._tensor_array, name=name) 626 627 def gather(self, indices, name=None): 628 """See TensorArray.""" 629 del name # not meaningful in Eager mode 630 return array_ops.stack([self._maybe_zero(i) for i in indices.numpy()]) 631 632 def concat(self, name=None): 633 """See TensorArray.""" 634 try: 635 return array_ops.concat( 636 [self._maybe_zero(ix) for ix in range(len(self._tensor_array))], 637 0, name=name) 638 except errors_impl.OpError: 639 # Reproduce a subset of the error-handling for graph-mode TensorArrays. 640 shapes = [t.shape for t in self._tensor_array] 641 ndims = [s.ndims for s in shapes] 642 if 0 in ndims: 643 idx = ndims.index(0) 644 raise errors_impl.InvalidArgumentError( 645 None, None, "Concat saw a scalar shape at index %d but requires " 646 "at least vectors." % idx) 647 else: 648 raise 649 650 def unstack(self, value, name=None): 651 """See TensorArray.""" 652 tensors = array_ops.unstack(value, name=name) 653 if len(tensors) > len(self._tensor_array) and not self._dynamic_size: 654 raise ValueError( 655 "Cannot unstack %d tensors into a TensorArray of static size %d" % 656 (len(tensors), len(self._tensor_array))) 657 ta = self._identity_without_array() 658 ta._implementation._tensor_array = tensors # pylint: disable=protected-access 659 return ta 660 661 def scatter(self, indices, value, name=None): 662 """See TensorArray.""" 663 del name # unused in Eager 664 ta = self.identity() 665 for index, val in zip(indices.numpy(), array_ops.unstack(value)): 666 _eager_write_no_copy(ta._implementation, index, val) # pylint: disable=protected-access 667 return ta 668 669 def split(self, value, lengths, name=None): 670 """See TensorArray.""" 671 # error checking to match graph-mode errors 672 value = constant_op.constant(value) 673 lengths = constant_op.constant(lengths) 674 sum_lengths = math_ops.reduce_sum(lengths) 675 if lengths.shape.ndims != 1: 676 raise errors_impl.InvalidArgumentError( 677 None, None, "Expected lengths to be a vector, received shape: %s" % 678 lengths.shape.as_list()) 679 elif value.shape.ndims == 0: 680 raise errors_impl.InvalidArgumentError( 681 None, None, "Expected value to be at least a vector, " 682 "but received shape: %s" % value.shape.as_list()) 683 elif sum_lengths.numpy() != value.shape.as_list()[0]: 684 raise errors_impl.InvalidArgumentError( 685 None, None, "Expected sum of lengths to be equal to " 686 "values.shape[0], but sum of lengths is %d and " 687 "value's shape is: %s " % (sum_lengths.numpy(), 688 value.shape.as_list())) 689 elif not self._dynamic_size and lengths.shape[0] != len(self._tensor_array): 690 raise errors_impl.InvalidArgumentError( 691 None, None, "TensorArray's size is not equal to the size of " 692 "lengths (%d vs. %d), and the TensorArray is not marked as " 693 "dynamically resizeable" % (len(self._tensor_array), 694 lengths.shape[0])) 695 else: 696 ta = self._identity_without_array() 697 tensor_array = array_ops.split(value, lengths, name=name) 698 ta._implementation._tensor_array = tensor_array # pylint: disable=protected-access 699 return ta 700 701 def size(self, name=None): 702 """See TensorArray.""" 703 del name # not meaningful in Eager mode 704 return constant_op.constant(len(self._tensor_array)) 705 706 def close(self, name=None): 707 del name # not meaningful in Eager mode 708 del self._tensor_array[:] 709 return 710 711 712 # TensorArray is designed to hide an underlying implementation object 713 # and as such accesses many of that object's hidden fields. 714 # pylint: disable=protected-access 715 @tf_export("TensorArray") 716 class TensorArray(object): 717 """Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays. 718 719 This class is meant to be used with dynamic iteration primitives such as 720 `while_loop` and `map_fn`. It supports gradient back-propagation via special 721 "flow" control flow dependencies. 722 """ 723 724 def __init__(self, 725 dtype, 726 size=None, 727 dynamic_size=None, 728 clear_after_read=None, 729 tensor_array_name=None, 730 handle=None, 731 flow=None, 732 infer_shape=True, 733 element_shape=None, 734 colocate_with_first_write_call=True, 735 name=None): 736 """Construct a new TensorArray or wrap an existing TensorArray handle. 737 738 A note about the parameter `name`: 739 740 The name of the `TensorArray` (even if passed in) is uniquified: each time 741 a new `TensorArray` is created at runtime it is assigned its own name for 742 the duration of the run. This avoids name collisions if a `TensorArray` 743 is created within a `while_loop`. 744 745 Args: 746 dtype: (required) data type of the TensorArray. 747 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 748 Required if handle is not provided. 749 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 750 can grow the TensorArray past its initial size. Default: False. 751 clear_after_read: Boolean (optional, default: True). If True, clear 752 TensorArray values after reading them. This disables read-many 753 semantics, but allows early release of memory. 754 tensor_array_name: (optional) Python string: the name of the TensorArray. 755 This is used when creating the TensorArray handle. If this value is 756 set, handle should be None. 757 handle: (optional) A `Tensor` handle to an existing TensorArray. If this 758 is set, tensor_array_name should be None. Only supported in graph mode. 759 flow: (optional) A float `Tensor` scalar coming from an existing 760 `TensorArray.flow`. Only supported in graph mode. 761 infer_shape: (optional, default: True) If True, shape inference 762 is enabled. In this case, all elements must have the same shape. 763 element_shape: (optional, default: None) A `TensorShape` object specifying 764 the shape constraints of each of the elements of the TensorArray. 765 Need not be fully defined. 766 colocate_with_first_write_call: If `True`, the TensorArray will be 767 colocated on the same device as the Tensor used on its first write 768 (write operations include `write`, `unstack`, and `split`). If `False`, 769 the TensorArray will be placed on the device determined by the 770 device context available during its initialization. 771 name: A name for the operation (optional). 772 773 Raises: 774 ValueError: if both handle and tensor_array_name are provided. 775 TypeError: if handle is provided but is not a Tensor. 776 """ 777 if context.in_graph_mode(): 778 implementation = _GraphTensorArray 779 else: 780 implementation = _EagerTensorArray 781 782 self._implementation = implementation( 783 dtype, 784 size=size, 785 dynamic_size=dynamic_size, 786 clear_after_read=clear_after_read, 787 tensor_array_name=tensor_array_name, 788 handle=handle, 789 flow=flow, 790 infer_shape=infer_shape, 791 element_shape=element_shape, 792 colocate_with_first_write_call=colocate_with_first_write_call, 793 name=name) 794 795 @property 796 def flow(self): 797 """The flow `Tensor` forcing ops leading to this TensorArray state.""" 798 return self._implementation._flow 799 800 @property 801 def dtype(self): 802 """The data type of this TensorArray.""" 803 return self._implementation._dtype 804 805 @property 806 def handle(self): 807 """The reference to the TensorArray.""" 808 return self._implementation._handle 809 810 @property 811 def _infer_shape(self): 812 return self._implementation._infer_shape 813 814 @_infer_shape.setter 815 def _infer_shape(self, infer_shape): 816 self._implementation._infer_shape = infer_shape 817 818 @property 819 def _element_shape(self): 820 return self._implementation._element_shape 821 822 @_element_shape.setter 823 def _element_shape(self, element_shape): 824 self._implementation._element_shape = element_shape 825 826 @property 827 def _colocate_with_first_write_call(self): 828 return self._implementation._colocate_with_first_write_call 829 830 @property 831 def _colocate_with(self): 832 return self._implementation._colocate_with 833 834 @_colocate_with.setter 835 def _colocate_with(self, colocate_with): 836 self._implementation._colocate_with = colocate_with 837 838 def identity(self): 839 """Returns a TensorArray with the same content and properties. 840 841 Returns: 842 A new TensorArray object with flow that ensures the control dependencies 843 from the contexts will become control dependencies for writes, reads, etc. 844 Use this object all for subsequent operations. 845 """ 846 return self._implementation.identity() 847 848 def grad(self, source, flow=None, name=None): 849 return self._implementation.grad(source, flow=flow, name=name) 850 851 def read(self, index, name=None): 852 """Read the value at location `index` in the TensorArray. 853 854 Args: 855 index: 0-D. int32 tensor with the index to read from. 856 name: A name for the operation (optional). 857 858 Returns: 859 The tensor at index `index`. 860 """ 861 return self._implementation.read(index, name=name) 862 863 @tf_should_use.should_use_result 864 def write(self, index, value, name=None): 865 """Write `value` into index `index` of the TensorArray. 866 867 Args: 868 index: 0-D. int32 scalar with the index to write to. 869 value: N-D. Tensor of type `dtype`. The Tensor to write to this index. 870 name: A name for the operation (optional). 871 872 Returns: 873 A new TensorArray object with flow that ensures the write occurs. 874 Use this object all for subsequent operations. 875 876 Raises: 877 ValueError: if there are more writers than specified. 878 """ 879 return self._implementation.write(index, value, name=name) 880 881 def stack(self, name=None): 882 """Return the values in the TensorArray as a stacked `Tensor`. 883 884 All of the values must have been written and their shapes must all match. 885 If input shapes have rank-`R`, then output shape will have rank-`(R+1)`. 886 887 Args: 888 name: A name for the operation (optional). 889 890 Returns: 891 All the tensors in the TensorArray stacked into one tensor. 892 """ 893 return self._implementation.stack(name=name) 894 895 def gather(self, indices, name=None): 896 """Return selected values in the TensorArray as a packed `Tensor`. 897 898 All of selected values must have been written and their shapes 899 must all match. 900 901 Args: 902 indices: A `1-D` `Tensor` taking values in `[0, max_value)`. If 903 the `TensorArray` is not dynamic, `max_value=size()`. 904 name: A name for the operation (optional). 905 906 Returns: 907 The tensors in the `TensorArray` selected by `indices`, packed into one 908 tensor. 909 """ 910 return self._implementation.gather(indices, name=name) 911 912 def concat(self, name=None): 913 """Return the values in the TensorArray as a concatenated `Tensor`. 914 915 All of the values must have been written, their ranks must match, and 916 and their shapes must all match for all dimensions except the first. 917 918 Args: 919 name: A name for the operation (optional). 920 921 Returns: 922 All the tensors in the TensorArray concatenated into one tensor. 923 """ 924 return self._implementation.concat(name=name) 925 926 @tf_should_use.should_use_result 927 def unstack(self, value, name=None): 928 """Unstack the values of a `Tensor` in the TensorArray. 929 930 If input value shapes have rank-`R`, then the output TensorArray will 931 contain elements whose shapes are rank-`(R-1)`. 932 933 Args: 934 value: (N+1)-D. Tensor of type `dtype`. The Tensor to unstack. 935 name: A name for the operation (optional). 936 937 Returns: 938 A new TensorArray object with flow that ensures the unstack occurs. 939 Use this object all for subsequent operations. 940 941 Raises: 942 ValueError: if the shape inference fails. 943 """ 944 return self._implementation.unstack(value, name=name) 945 946 @tf_should_use.should_use_result 947 def scatter(self, indices, value, name=None): 948 """Scatter the values of a `Tensor` in specific indices of a `TensorArray`. 949 950 Args: 951 indices: A `1-D` `Tensor` taking values in `[0, max_value)`. If 952 the `TensorArray` is not dynamic, `max_value=size()`. 953 value: (N+1)-D. Tensor of type `dtype`. The Tensor to unpack. 954 name: A name for the operation (optional). 955 956 Returns: 957 A new TensorArray object with flow that ensures the scatter occurs. 958 Use this object all for subsequent operations. 959 960 Raises: 961 ValueError: if the shape inference fails. 962 """ 963 return self._implementation.scatter(indices, value, name=name) 964 965 @tf_should_use.should_use_result 966 def split(self, value, lengths, name=None): 967 """Split the values of a `Tensor` into the TensorArray. 968 969 Args: 970 value: (N+1)-D. Tensor of type `dtype`. The Tensor to split. 971 lengths: 1-D. int32 vector with the lengths to use when splitting 972 `value` along its first dimension. 973 name: A name for the operation (optional). 974 975 Returns: 976 A new TensorArray object with flow that ensures the split occurs. 977 Use this object all for subsequent operations. 978 979 Raises: 980 ValueError: if the shape inference fails. 981 """ 982 return self._implementation.split(value, lengths, name=name) 983 984 def size(self, name=None): 985 """Return the size of the TensorArray.""" 986 return self._implementation.size(name=name) 987 988 @tf_should_use.should_use_result 989 def close(self, name=None): 990 """Close the current TensorArray.""" 991 return self._implementation.close(name=name) 992 993 # pylint: enable=protected-access 994