1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Control Flow Operations. 16 17 See the @{$python/control_flow_ops} guide. 18 19 @@identity 20 @@identity_n 21 @@tuple 22 @@group 23 @@no_op 24 @@count_up_to 25 @@cond 26 @@smart_cond 27 @@case 28 @@while_loop 29 @@logical_and 30 @@logical_not 31 @@logical_or 32 @@logical_xor 33 @@equal 34 @@not_equal 35 @@less 36 @@less_equal 37 @@greater 38 @@greater_equal 39 @@where 40 @@is_finite 41 @@is_inf 42 @@is_nan 43 @@verify_tensor_all_finite 44 @@check_numerics 45 @@add_check_numerics_ops 46 @@Assert 47 @@Print 48 """ 49 # pylint: disable=g-bad-name 50 from __future__ import absolute_import 51 from __future__ import division 52 from __future__ import print_function 53 54 import abc 55 import collections 56 import functools 57 58 import six 59 60 from tensorflow.core.framework import attr_value_pb2 61 from tensorflow.core.protobuf import control_flow_pb2 62 from tensorflow.python.eager import context 63 from tensorflow.python.framework import constant_op 64 from tensorflow.python.framework import dtypes 65 from tensorflow.python.framework import errors 66 from tensorflow.python.framework import ops 67 from tensorflow.python.framework import sparse_tensor 68 from tensorflow.python.framework import tensor_shape 69 from tensorflow.python.framework import tensor_util 70 from tensorflow.python.ops import array_ops 71 from tensorflow.python.ops import control_flow_util as util 72 from tensorflow.python.ops import gen_array_ops 73 from tensorflow.python.ops import gen_control_flow_ops 74 from tensorflow.python.ops import gen_data_flow_ops 75 from tensorflow.python.ops import gen_logging_ops 76 from tensorflow.python.ops import math_ops 77 from tensorflow.python.ops import tensor_array_ops 78 # go/tf-wildcard-import 79 # pylint: disable=wildcard-import,undefined-variable 80 from tensorflow.python.ops.gen_control_flow_ops import * 81 # pylint: enable=wildcard-import 82 from tensorflow.python.platform import tf_logging as logging 83 from tensorflow.python.util import compat 84 from tensorflow.python.util import deprecation 85 from tensorflow.python.util import nest 86 from tensorflow.python.util import tf_should_use 87 from tensorflow.python.util.tf_export import tf_export 88 89 # We override the 'tuple' for a control flow op, so we keep python's 90 # existing 'tuple' for later use in this module. 91 _basetuple = tuple 92 93 94 def _summarize_eager(tensor, summarize=None): 95 """Returns a summarized string representation of eager `tensor`. 96 97 Args: 98 tensor: EagerTensor to summarize 99 summarize: Include these many first elements of `array` 100 """ 101 # reshape((-1,)) is the fastest way to get a flat array view 102 if tensor._rank(): # pylint: disable=protected-access 103 flat = tensor.numpy().reshape((-1,)) 104 lst = [str(x) for x in flat[:summarize]] 105 if len(lst) < flat.size: 106 lst.append("...") 107 else: 108 # tensor.numpy() returns a scalar for zero dimensional arrays 109 if summarize != 0: 110 lst = [str(tensor.numpy())] 111 else: 112 lst = [] 113 114 return ", ".join(lst) 115 116 117 # pylint: disable=protected-access 118 119 120 # Assert and Print are special symbols in python, so we must 121 # use an upper-case version of them. 122 @tf_export("Assert") 123 @tf_should_use.should_use_result 124 def Assert(condition, data, summarize=None, name=None): 125 """Asserts that the given condition is true. 126 127 If `condition` evaluates to false, print the list of tensors in `data`. 128 `summarize` determines how many entries of the tensors to print. 129 130 NOTE: In graph mode, to ensure that Assert executes, one usually attaches 131 a dependency: 132 133 ```python 134 # Ensure maximum element of x is smaller or equal to 1 135 assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x]) 136 with tf.control_dependencies([assert_op]): 137 ... code using x ... 138 ``` 139 140 Args: 141 condition: The condition to evaluate. 142 data: The tensors to print out when condition is false. 143 summarize: Print this many entries of each tensor. 144 name: A name for this operation (optional). 145 146 Returns: 147 assert_op: An `Operation` that, when executed, raises a 148 `tf.errors.InvalidArgumentError` if `condition` is not true. 149 @compatibility{eager} returns None. 150 151 Raises: 152 @compatibility{eager} `tf.errors.InvalidArgumentError` if `condition` 153 is not true 154 """ 155 if context.in_eager_mode(): 156 if not condition: 157 xs = ops.convert_n_to_tensor(data) 158 data_str = [_summarize_eager(x, summarize) for x in xs] 159 raise errors.InvalidArgumentError( 160 node_def=None, 161 op=None, 162 message="Expected '%s' to be true. Summarized data: %s" % 163 (condition, "\n".join(data_str))) 164 return 165 166 with ops.name_scope(name, "Assert", [condition, data]) as name: 167 xs = ops.convert_n_to_tensor(data) 168 if all([x.dtype in {dtypes.string, dtypes.int32} for x in xs]): 169 # As a simple heuristic, we assume that string and int32 are 170 # on host to avoid the need to use cond. If it is not case, 171 # we will pay the price copying the tensor to host memory. 172 return gen_logging_ops._assert(condition, data, summarize, name="Assert") 173 else: 174 condition = ops.convert_to_tensor(condition, name="Condition") 175 176 def true_assert(): 177 return gen_logging_ops._assert( 178 condition, data, summarize, name="Assert") 179 180 guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard") 181 return guarded_assert.op 182 183 184 def _Identity(data, name=None): 185 """Return a tensor with the same shape and contents as the input tensor. 186 187 Args: 188 data: A Tensor. 189 name: A name for this operation (optional). 190 191 Returns: 192 A Tensor with the same type and value as the input Tensor. 193 """ 194 data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True) 195 if isinstance(data, ops.Tensor): 196 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 197 return gen_array_ops._ref_identity(data, name=name) 198 else: 199 return array_ops.identity(data, name=name) 200 else: 201 if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)): 202 raise TypeError("Type %s not supported" % type(data)) 203 values = _Identity(data.values, name=name) 204 indices = array_ops.identity(data.indices, name="indices") 205 if isinstance(data, ops.IndexedSlices): 206 dense_shape = data.dense_shape 207 if dense_shape is not None: 208 dense_shape = array_ops.identity(dense_shape, name="dense_shape") 209 return ops.IndexedSlices(values, indices, dense_shape) 210 else: 211 dense_shape = array_ops.identity(data.dense_shape, name="dense_shape") 212 return sparse_tensor.SparseTensor(indices, values, dense_shape) 213 214 215 def _NextIteration(data, name=None): 216 data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True) 217 if isinstance(data, ops.Tensor): 218 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 219 return ref_next_iteration(data, name=name) 220 else: 221 return next_iteration(data, name=name) 222 else: 223 if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)): 224 raise TypeError("Type %s not supported" % type(data)) 225 values = _NextIteration(data.values, name=name) 226 indices = next_iteration(data.indices, name="indices") 227 if isinstance(data, ops.IndexedSlices): 228 dense_shape = data.dense_shape 229 if dense_shape is not None: 230 dense_shape = next_iteration(dense_shape, name="dense_shape") 231 return ops.IndexedSlices(values, indices, dense_shape) 232 else: 233 dense_shape = next_iteration(data.dense_shape, name="dense_shape") 234 return sparse_tensor.SparseTensor(indices, values, dense_shape) 235 236 237 def _Enter(data, 238 frame_name, 239 is_constant=False, 240 parallel_iterations=10, 241 use_ref=True, 242 use_input_shape=True, 243 name=None): 244 """Creates or finds a child frame, and makes `data` available to it. 245 246 The unique `frame_name` is used by the `Executor` to identify frames. If 247 `is_constant` is true, `data` is a constant in the child frame; otherwise 248 it may be changed in the child frame. At most `parallel_iterations` 249 iterations are run in parallel in the child frame. 250 251 Args: 252 data: The tensor to be made available to the child frame. 253 frame_name: The name of the child frame. 254 is_constant: If true, the output is constant within the child frame. 255 parallel_iterations: The number of iterations allowed to run in parallel. 256 use_ref: If true, use ref_enter if data is of ref type. 257 name: A name for this operation (optional). 258 259 Returns: 260 The same tensor as `data`. 261 """ 262 data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True) 263 if isinstance(data, ops.Tensor): 264 if data.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access 265 result = gen_control_flow_ops._ref_enter( 266 data, frame_name, is_constant, parallel_iterations, name=name) 267 else: 268 result = gen_control_flow_ops._enter( 269 data, frame_name, is_constant, parallel_iterations, name=name) 270 if use_input_shape: 271 result.set_shape(data.get_shape()) 272 return result 273 else: 274 if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)): 275 raise TypeError("Type %s not supported" % type(data)) 276 values = _Enter( 277 data.values, 278 frame_name, 279 is_constant, 280 parallel_iterations=parallel_iterations, 281 use_input_shape=use_input_shape, 282 name=name) 283 indices = gen_control_flow_ops._enter( 284 data.indices, 285 frame_name, 286 is_constant, 287 parallel_iterations, 288 name="indices") 289 if use_input_shape: 290 indices.set_shape(data.indices.get_shape()) 291 if isinstance(data, ops.IndexedSlices): 292 dense_shape = data.dense_shape 293 if dense_shape is not None: 294 dense_shape = gen_control_flow_ops._enter( 295 dense_shape, 296 frame_name, 297 is_constant, 298 parallel_iterations, 299 name="dense_shape") 300 if use_input_shape: 301 dense_shape.set_shape(data.dense_shape.get_shape()) 302 return ops.IndexedSlices(values, indices, dense_shape) 303 else: 304 dense_shape = gen_control_flow_ops._enter( 305 data.dense_shape, 306 frame_name, 307 is_constant, 308 parallel_iterations, 309 name="dense_shape") 310 if use_input_shape: 311 dense_shape.set_shape(data.dense_shape.get_shape()) 312 return sparse_tensor.SparseTensor(indices, values, dense_shape) 313 314 315 def exit(data, name=None): # pylint: disable=redefined-builtin 316 """Exits the current frame to its parent frame. 317 318 Exit makes its input `data` available to the parent frame. 319 320 Args: 321 data: The tensor to be made available to the parent frame. 322 name: A name for this operation (optional). 323 324 Returns: 325 The same tensor as `data`. 326 """ 327 data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True) 328 if isinstance(data, ops.Tensor): 329 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 330 return gen_control_flow_ops._ref_exit(data, name) 331 else: 332 return gen_control_flow_ops._exit(data, name) 333 else: 334 if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)): 335 raise TypeError("Type %s not supported" % type(data)) 336 values = exit(data.values, name=name) 337 indices = gen_control_flow_ops._exit(data.indices, name="indices") 338 if isinstance(data, ops.IndexedSlices): 339 dense_shape = data.dense_shape 340 if dense_shape is not None: 341 dense_shape = gen_control_flow_ops._exit(dense_shape, name) 342 return ops.IndexedSlices(values, indices, dense_shape) 343 else: 344 dense_shape = gen_control_flow_ops._exit(data.dense_shape, name) 345 return sparse_tensor.SparseTensor(indices, values, dense_shape) 346 347 348 def switch(data, pred, dtype=None, name=None): 349 """Forwards `data` to an output determined by `pred`. 350 351 If `pred` is false, the `data` input is forwarded to the first output. 352 Otherwise, the data goes to the second output. 353 354 This op handles `Tensor`s and `IndexedSlices`. 355 356 Args: 357 data: The tensor to be forwarded to the appropriate output. 358 pred: A scalar that specifies which output port will receive data. 359 dtype: Optional element type for the returned tensor. If missing, 360 the type is inferred from the type of `value`. 361 name: A name for this operation (optional). 362 363 Returns: 364 `(output_false, output_true)`: If `pred` is true, data will be forwarded 365 to `output_true`, otherwise it goes to `output_false`. 366 """ 367 with ops.name_scope(name, "Switch", [data, pred]) as name: 368 data = ops.internal_convert_to_tensor_or_indexed_slices( 369 data, dtype=dtype, name="data", as_ref=True) 370 pred = ops.convert_to_tensor(pred, name="pred") 371 if isinstance(data, ops.Tensor): 372 return gen_control_flow_ops._switch(data, pred, name=name) 373 else: 374 if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)): 375 raise TypeError("Type %s not supported" % type(data)) 376 val, ind = data.values, data.indices 377 val_f, val_t = gen_control_flow_ops._switch(val, pred, name=name) 378 ind_f, ind_t = gen_control_flow_ops._switch(ind, pred, name="indices") 379 if isinstance(data, ops.IndexedSlices): 380 dense_shape = data.dense_shape 381 if dense_shape is not None: 382 dense_shape_f, dense_shape_t = gen_control_flow_ops._switch( 383 dense_shape, pred, name="dense_shape") 384 else: 385 dense_shape_f, dense_shape_t = None, None 386 return (ops.IndexedSlices(val_f, ind_f, dense_shape_f), 387 ops.IndexedSlices(val_t, ind_t, dense_shape_t)) 388 else: 389 dense_shape = data.dense_shape 390 dense_shape_f, dense_shape_t = gen_control_flow_ops._switch( 391 data.dense_shape, pred, name="dense_shape") 392 return (sparse_tensor.SparseTensor(ind_f, val_f, dense_shape_f), 393 sparse_tensor.SparseTensor(ind_t, val_t, dense_shape_t)) 394 395 396 def _SwitchRefOrTensor(data, pred, name="Switch"): 397 """Forwards `data` to an output determined by `pred`. 398 399 If `pred` is false, the `data` input is forwarded to the first output. 400 Otherwise, the data goes to the second output. 401 402 This op handles `Tensor`s and `IndexedSlices`. 403 404 Args: 405 data: The tensor to be forwarded to the appropriate output. 406 pred: A scalar that specifies which output port will receive data. 407 name: A name for this operation (optional). 408 409 Returns: 410 `(output_false, output_true)`: If `pred` is true, data will be forwarded to 411 `output_true`, otherwise it goes to `output_false`. 412 413 Raises: 414 TypeError: if data is not a Tensor or IndexedSlices 415 """ 416 data = ops.convert_to_tensor_or_indexed_slices(data, name="data") 417 # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below 418 # addresses the following scenario. 419 # 420 # Assume you execute Optimizer.apply_gradients() in a branch of a cond(). 421 # 422 # 1. The update op is created inside a `with ops.colocate(var):` block 423 # 424 # 2. Some tensor `data` is captured and a switch is created in a 425 # `with ops.colocate_with(data):` block. 426 # 427 # with ops.colocate_with(var): 428 # with ops.colocate_with(data): 429 # op = ... 430 # 431 # var and data may be pinned to different devices, so we want to ops 432 # created within ops.colocate_with(data) to ignore the existing stack. 433 with ops.colocate_with(data, ignore_existing=True): 434 if isinstance(data, ops.Tensor): 435 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 436 return ref_switch(data, pred, name=name) 437 return switch(data, pred, name=name) 438 439 440 def merge(inputs, name=None): 441 """Returns the value of an available element of `inputs`. 442 443 This op tests each of the tensors in `inputs` in turn to determine if any of 444 them is available. If it finds an available tensor, it returns it and its 445 index in `inputs`. 446 447 It is an error if more than one tensor in `inputs` is available. If no tensor 448 in `inputs` is available, the returned tensor and index are not set. 449 450 This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of 451 `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices 452 before merging. 453 454 Args: 455 inputs: The input tensors, at most one of which is available. 456 name: A name for this operation (optional). 457 458 Returns: 459 A tuple containing the chosen input tensor and its index in `inputs`. 460 461 Raises: 462 ValueError: If any of the inputs is None, or inputs are IndexedSlices and 463 some but not all have a dense_shape property. 464 """ 465 if any([inp is None for inp in inputs]): 466 raise ValueError("At least one of the merge inputs is None: %s" % inputs) 467 with ops.name_scope(name, "Merge", inputs) as name: 468 inputs = [ 469 ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref=True) 470 for inp in inputs 471 ] 472 if all([isinstance(v, ops.Tensor) for v in inputs]): 473 if all([v.dtype._is_ref_dtype for v in inputs]): # pylint: disable=protected-access 474 return gen_control_flow_ops._ref_merge(inputs, name) 475 else: 476 return gen_control_flow_ops._merge(inputs, name) 477 elif all([isinstance(v, sparse_tensor.SparseTensor) for v in inputs]): 478 # Only handle the case when all inputs are SparseTensor. 479 values, _ = merge([inp.values for inp in inputs], name=name) 480 indices, chosen_index = gen_control_flow_ops._merge( 481 [inp.indices for inp in inputs], name="indices") 482 dense_shape, _ = gen_control_flow_ops._merge( 483 [inp.dense_shape for inp in inputs], name="dense_shape") 484 return (sparse_tensor.SparseTensor(indices, values, dense_shape), 485 chosen_index) 486 else: 487 # For now convert all the inputs as IndexedSlices. 488 inputs = math_ops._as_indexed_slices_list(inputs, optimize=False) 489 values, _ = merge([inp.values for inp in inputs], name=name) 490 indices, chosen_index = gen_control_flow_ops._merge( 491 [inp.indices for inp in inputs], name="indices") 492 if any(inp.dense_shape is not None for inp in inputs): 493 if any(inp.dense_shape is None for inp in inputs): 494 raise ValueError("Either all merged IndexedSlices must have a " 495 "dense_shape, or none must have a dense_shape.") 496 dense_shape, _ = gen_control_flow_ops._merge( 497 [inp.dense_shape for inp in inputs], name="dense_shape") 498 else: 499 dense_shape = None 500 return ops.IndexedSlices(values, indices, dense_shape), chosen_index 501 502 503 # pylint: enable=protected-access 504 505 506 def _convert_tensorarray_to_flow(tensor_or_tensor_array): 507 if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray): 508 return tensor_or_tensor_array.flow 509 else: 510 return tensor_or_tensor_array 511 512 513 def _make_tensor_array(ta, t_or_flow): 514 # pylint: disable=protected-access 515 new_ta = tensor_array_ops.TensorArray( 516 dtype=ta.dtype, 517 handle=ta.handle, 518 flow=t_or_flow, 519 infer_shape=ta._infer_shape, 520 colocate_with_first_write_call=ta._colocate_with_first_write_call) 521 new_ta._colocate_with = ta._colocate_with 522 new_ta._element_shape = ta._element_shape 523 # pylint: enable=protected-access 524 return new_ta 525 526 527 def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows): 528 if len(tensors_or_tensorarrays) != len(tensors_or_flows): 529 raise ValueError( 530 "Lengths of original Tensor list and new list do not match: %d vs. %d" % 531 (len(tensors_or_tensorarrays), len(tensors_or_flows))) 532 return [ 533 _make_tensor_array(ta, t_or_flow) 534 if isinstance(ta, tensor_array_ops.TensorArray) else t_or_flow 535 for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows) 536 ] 537 538 539 def _ShapeLessThanOrEqual(shape1, shape2): 540 if shape2.dims is None: 541 return True 542 if shape1.ndims != shape2.ndims: 543 return False 544 for dim1, dim2 in zip(shape1.dims, shape2.dims): 545 if dim2.value is not None and dim1.value != dim2.value: 546 return False 547 return True 548 549 550 def _SetShapeInvariants(input_vars, enter_vars, shapes): 551 """Set the shapes of the tensors in `enter_vars` to `shapes`. 552 553 Args: 554 input_vars: A list of tensors that are inputs to `enter_vars`. 555 enter_vars: A list of tensors whose shapes will be set. 556 shapes: A (possibly nested) list of shapes. 557 558 Raises: 559 ValueError: If any tensor in `enter_vars` has a less specific shape 560 than its corresponding shape in `shapes`. 561 """ 562 if shapes is None: 563 return 564 flat_shapes = nest.flatten(shapes) 565 if not all([isinstance(s, tensor_shape.TensorShape) for s in flat_shapes]): 566 raise ValueError("`shapes` must be a (possibly nested) list of shapes.") 567 # Check that the shapes of the inputs are less than the shape invariants, 568 # and set the shapes of `enter_vars` to the shape invariants. 569 for inp, var, shape in zip(input_vars, enter_vars, flat_shapes): 570 if isinstance(var, ops.Tensor): 571 if not _ShapeLessThanOrEqual(inp.get_shape(), shape): 572 raise ValueError( 573 "The shape invariant specified for %s is not compatible with " 574 "the initial shape of the loop variable. It enters the loop " 575 "with shape %s, but the specified shape invariant is %s." % 576 (inp.name, inp.get_shape(), shape)) 577 var.set_shape(shape) 578 else: 579 if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)): 580 raise TypeError("Type %s not supported" % type(var)) 581 if isinstance(var, ops.IndexedSlices): 582 if not _ShapeLessThanOrEqual(inp.values.get_shape(), shape): 583 raise ValueError( 584 "The shape invariant specified for %s is not compatible with " 585 "the initial shape of the values tensor of this IndexedSlices. " 586 "It enters the loop with shape %s, but the specified shape " 587 "invariant is %s." % (inp.values.name, inp.values.get_shape(), 588 shape)) 589 var.values.set_shape(shape) 590 var.indices.set_shape(tensor_shape.TensorShape([shape[0]])) 591 if var.dense_shape is not None: 592 var.dense_shape.set_shape(tensor_shape.TensorShape([shape.ndims])) 593 else: 594 if not _ShapeLessThanOrEqual(inp.dense_shape.get_shape(), shape): 595 raise ValueError( 596 "The shape invariant specified for %s is not compatible with " 597 "the initial shape of the shape tensor of this SparseTensor. " 598 "It enters the loop with shape %s, but the specified shape " 599 "invariant is %s." % (inp.dense_shape.name, 600 inp.dense_shape.get_shape(), shape)) 601 var.values.set_shape(tensor_shape.TensorShape([None])) 602 var.indices.set_shape(tensor_shape.TensorShape([None, shape.ndims])) 603 var.dense_shape.set_shape(shape) 604 605 606 def _EnforceShapeInvariant(merge_var, next_var): 607 """Check if the shapes of the loops variables are invariants. 608 609 Args: 610 merge_vars: The list of tensors representing the initial values of the 611 loop variables. 612 next_vars: The list of tensors representing the values of the loop 613 variables after one loop iteration. 614 615 Raises: 616 ValueError: If any tensor in `merge_vars` has a more specific shape than 617 its correspnding tensor in `next_var`. 618 """ 619 if isinstance(merge_var, ops.Tensor): 620 m_shape = merge_var.get_shape() 621 n_shape = next_var.get_shape() 622 if not _ShapeLessThanOrEqual(n_shape, m_shape): 623 # TODO(skyewm): get original loop input that caused the shape error and 624 # report its name instead of the merge node's. 625 raise ValueError( 626 "The shape for %s is not an invariant for the loop. It enters " 627 "the loop with shape %s, but has shape %s after one iteration. " 628 "Provide shape invariants using either the `shape_invariants` " 629 "argument of tf.while_loop or set_shape() on the loop variables." % 630 (merge_var.name, m_shape, n_shape)) 631 else: 632 if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)): 633 raise TypeError("Type %s not supported" % type(var)) 634 if isinstance(var, ops.IndexedSlices): 635 m_values_shape = merge_var.values.get_shape() 636 m_indices_shape = merge_var.indices.get_shape() 637 m_shape_shape = tensor_shape.TensorShape(None) 638 if merge_var.dense_shape is not None: 639 m_shape_shape = merge_var.dense_shape.get_shape() 640 n_values_shape = next_var.values.get_shape() 641 n_indices_shape = next_var.indices.get_shape() 642 n_shape_shape = tensor_shape.TensorShape(None) 643 if next_var.dense_shape is not None: 644 n_shape_shape = next_var.dense_shape.get_shape() 645 if (not _ShapeLessThanOrEqual(n_values_shape, m_values_shape) or 646 not _ShapeLessThanOrEqual(n_indices_shape, m_indices_shape)): 647 if not _ShapeLessThanOrEqual(n_values_shape, m_values_shape): 648 raise ValueError( 649 "The shape for %s is not an invariant for the loop. It enters " 650 "the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) " 651 "after one iteration. Provide shape invariants using either the " 652 "`shape_invariants` argument of tf.while_loop or set_shape() " 653 "on the loop variables." % 654 (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape, 655 n_values_shape, n_indices_shape, n_shape_shape)) 656 else: 657 m_values_shape = merge_var.values.get_shape() 658 m_indices_shape = merge_var.indices.get_shape() 659 m_shape_shape = merge_var.dense_shape.get_shape() 660 n_values_shape = next_var.values.get_shape() 661 n_indices_shape = next_var.indices.get_shape() 662 n_shape_shape = next_var.dense_shape.get_shape() 663 if (not _ShapeLessThanOrEqual(n_values_shape, m_values_shape) or 664 not _ShapeLessThanOrEqual(n_indices_shape, m_indices_shape) or 665 not _ShapeLessThanOrEqual(n_shape_shape, m_shape_shape)): 666 raise ValueError( 667 "The shape for %s is not an invariant for the loop. It enters " 668 "the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) " 669 "after one iteration. Provide shape invariants using either " 670 "the `shape_invariants` argument of tf.while_loop or set_shape() " 671 "on the loop variables." % 672 (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape, 673 n_values_shape, n_indices_shape, n_shape_shape)) 674 675 676 def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True): 677 """Add NextIteration and back edge from v to m.""" 678 if isinstance(m, ops.Tensor): 679 v = ops.convert_to_tensor(v) 680 v = _NextIteration(v) 681 if enforce_shape_invariant: 682 # Make sure the shapes of loop outputs are correct. We do this before 683 # calling _update_input, which will raise a less-helpful error message if 684 # the types don't match. 685 # TODO(skyewm): call this for other cases below (needs testing) 686 _EnforceShapeInvariant(m, v) 687 m.op._update_input(1, v) # pylint: disable=protected-access 688 elif isinstance(m, ops.IndexedSlices): 689 # pylint: disable=protected-access 690 v = math_ops._as_indexed_slices(v, optimize=False) 691 v = _NextIteration(v) 692 m.values.op._update_input(1, v.values) 693 m.indices.op._update_input(1, v.indices) 694 # pylint: enable=protected-access 695 if m.dense_shape is not None: 696 if v.dense_shape is None: 697 raise ValueError("Must have dense shape: %s" % v.name) 698 m.dense_shape.op._update_input(1, v.dense_shape) 699 elif isinstance(m, sparse_tensor.SparseTensor): 700 if not isinstance(v, sparse_tensor.SparseTensor): 701 raise ValueError("Must be a sparse tensor: %s" % v.name) 702 v = _NextIteration(v) 703 # pylint: disable=protected-access 704 m.values.op._update_input(1, v.values) 705 m.indices.op._update_input(1, v.indices) 706 m.dense_shape.op._update_input(1, v.dense_shape) 707 # pylint: enable=protected-access 708 else: 709 raise TypeError("Type %s not supported" % type(m)) 710 return v 711 712 713 def GetMaxSizeFromNestedMaximumIterations(value, while_ctxt): 714 """Calculate a max_size for use by stack ops inside an XLA while_loop. 715 716 Args: 717 value: The value inside the while_loop forward context. Used for printing 718 error messages. 719 while_ctxt: The forward context inside which value resides. This does 720 not always match the value's immediate context, as `value` may be 721 inside e.g. a cond context inside the while_loop. 722 723 Returns: 724 A tensor containing the `max_size` to feed to a Stack initializer. 725 726 Raises: 727 ValueError: If `value` is nested inside a `while_loop` that either 728 lacks a `maximum_iterations` parameter, or the `maximum_iterations` 729 parameter: 730 731 - is inside a `while_loop` that is a parent of the calling context, and 732 - cannot be evaluated at graph build time to a constant. 733 """ 734 value_name = value.name 735 # curr_ctxt is the context that tf.gradients was called in. 736 curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access 737 738 curr_ctxt_name = curr_ctxt.name if curr_ctxt is not None else "" 739 max_size = constant_op.constant(1) 740 741 # Loop through all containing while contexts between value and the 742 # current context, multiplying together each context's 743 # max_iterations to get the maximum stack size. 744 while while_ctxt not in (None, curr_ctxt): 745 max_iter = while_ctxt.maximum_iterations 746 if max_iter is None: 747 raise ValueError( 748 "Cannot create a gradient accumulator for tensor '%s' inside " 749 "XLA while_loop because maximum_iterations was not passed to " 750 "the tf.while_loop call ('%s')." % (value_name, while_ctxt.name)) 751 752 # pylint: disable=protected-access 753 max_iter_ctxt = max_iter.op._get_control_flow_context() 754 # pylint: enable=protected-access 755 756 # If max_iter_ctxt (non-strictly) contains curr_ctxt, then it's OK to use. 757 if util.IsContainingContext(curr_ctxt, max_iter_ctxt): 758 max_size *= max_iter 759 else: 760 # We cannot use max_iter because it's defined in a nested while 761 # or cond context, so will fail if we try to use it as input to 762 # any ops in curr_ctxt (e.g. max_size or the final accumulator 763 # stack). Attempt to get a constant value out to use instead. 764 const_max_iter = tensor_util.constant_value(max_iter) 765 if const_max_iter is None: 766 raise ValueError( 767 "Cannot create a gradient accumulator for tensor '%s' inside XLA " 768 "while_loop. maximum_iterations tensor '%s' for while_loop context " 769 "'%s' must be statically known (e.g. a constant value or known " 770 "shape dimension), or be defined at or outside the while loop " 771 "context '%s' (currently defined in '%s')." % 772 (value_name, max_iter.name, while_ctxt.name, curr_ctxt_name, 773 max_iter_ctxt.name)) 774 max_size *= const_max_iter 775 776 # Find the next outer WhileContext (or stop if we reach the 777 # tf.gradient's context). 778 while_ctxt = util.GetContainingWhileContext( 779 while_ctxt.outer_context, stop_ctxt=curr_ctxt) 780 781 return max_size 782 783 784 class GradLoopState(object): 785 """The state used for constructing the gradient graph for a while loop. 786 787 We create a GradLoopState for each while loop in forward and its 788 corresponding while loop in backprop. This gives us access to both 789 the forward and the backprop WhileContexts. 790 791 During the construction of gradient graph, any time when we detect 792 a forward value that is needed for backprop, we create a history 793 accumulator and add it to `history_map`. Any time when we backprop 794 a loop switch op (in _SwitchGrad), we add the grad merge op in 795 `switch_map`. 796 """ 797 798 def __init__(self, forward_ctxt, outer_grad_state): 799 # The grad loop state for the outer while loop. 800 self._outer_grad_state = None 801 802 # The while loop context for forward. 803 self._forward_context = None 804 805 # The loop counter added by AddForwardLoopCounter. It is the value 806 # of the loop counter for the next iteration. 807 self._forward_index = None 808 809 # A sync op for forward. 810 self._forward_sync = None 811 812 # The while loop context for backprop. 813 self._grad_context = None 814 815 # The loop counter added by AddBackpropLoopCounter. It is the value 816 # of the loop counter for the current iteration. 817 self._grad_index = None 818 819 # A sync op for backprop. 820 self._grad_sync = None 821 822 # Information needed by backprop. 823 self._history_map = {} 824 self._switch_map = {} 825 self._unused_exits = [] 826 self._deferred_exits = [] 827 self._forward_loop_exits = list(forward_ctxt.loop_exits) 828 self._pending_exits_count = len(forward_ctxt.loop_exits) 829 830 self._outer_grad_state = outer_grad_state 831 if outer_grad_state: 832 outer_forward_ctxt = outer_grad_state.forward_context 833 else: 834 outer_forward_ctxt = forward_ctxt.outer_context 835 836 # Add the forward loop counter. 837 if outer_forward_ctxt: 838 outer_forward_ctxt.Enter() 839 cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) 840 if outer_forward_ctxt: 841 outer_forward_ctxt.Exit() 842 self._forward_context = forward_ctxt 843 self._forward_index = forward_index 844 845 # Add the backprop WhileContext, and the backprop loop counter. 846 if outer_grad_state: 847 # This is a nested loop. Remember the iteration counts for each 848 # execution of this inner loop. 849 outer_forward_ctxt.AddName(cnt.name) 850 history_cnt = outer_grad_state.AddForwardAccumulator(cnt) 851 852 outer_grad_ctxt = outer_grad_state.grad_context 853 outer_grad_ctxt.Enter() 854 self._grad_context = WhileContext( 855 maximum_iterations=forward_ctxt.maximum_iterations, 856 parallel_iterations=forward_ctxt.parallel_iterations, 857 back_prop=forward_ctxt.back_prop, 858 swap_memory=forward_ctxt.swap_memory, 859 name=forward_ctxt.name, 860 grad_state=self) 861 real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt) 862 self._grad_index = self._grad_context.AddBackpropLoopCounter( 863 real_cnt, outer_grad_state) 864 outer_grad_ctxt.Exit() 865 else: 866 if outer_forward_ctxt: 867 outer_forward_ctxt.Enter() 868 self._grad_context = WhileContext( 869 maximum_iterations=forward_ctxt.maximum_iterations, 870 parallel_iterations=forward_ctxt.parallel_iterations, 871 back_prop=forward_ctxt.back_prop, 872 swap_memory=forward_ctxt.swap_memory, 873 name=forward_ctxt.name, 874 grad_state=self) 875 self._grad_index = self._grad_context.AddBackpropLoopCounter( 876 cnt, outer_grad_state) 877 if outer_forward_ctxt: 878 outer_forward_ctxt.Exit() 879 880 @property 881 def outer_grad_state(self): 882 """The grad loop state for outer loop.""" 883 return self._outer_grad_state 884 885 @property 886 def forward_context(self): 887 """The while loop context for forward.""" 888 return self._forward_context 889 890 @property 891 def forward_index(self): 892 """The loop index of forward loop.""" 893 return self._forward_index 894 895 @property 896 def forward_sync(self): 897 """A control trigger node for synchronization in the forward loop. 898 899 One main use is to keep the push ops of a stack executed in the 900 iteration order. 901 """ 902 if self._forward_sync is None: 903 with ops.control_dependencies(None): 904 self._forward_sync = control_trigger(name="f_sync") 905 self._forward_sync._set_control_flow_context(self._forward_context) 906 self._forward_index.op._add_control_input(self._forward_sync) 907 return self._forward_sync 908 909 @property 910 def grad_context(self): 911 """The corresponding WhileContext for gradient.""" 912 return self._grad_context 913 914 @property 915 def grad_index(self): 916 """The loop index of backprop loop.""" 917 return self._grad_index 918 919 @property 920 def grad_sync(self): 921 """A control trigger node for synchronization in the grad loop. 922 923 One main use is to keep the pop ops of a stack executed in the 924 iteration order. 925 """ 926 if self._grad_sync is None: 927 with ops.control_dependencies(None): 928 self._grad_sync = control_trigger(name="b_sync") 929 self._grad_sync._set_control_flow_context(self._grad_context) 930 self._grad_index.op._add_control_input(self._grad_sync) 931 if self._grad_context.outer_context: 932 self._grad_context.outer_context.AddInnerOp(self._grad_sync) 933 return self._grad_sync 934 935 @property 936 def history_map(self): 937 """The map that records all the tensors needed for backprop.""" 938 return self._history_map 939 940 @property 941 def switch_map(self): 942 """The map that records all the Switch ops for the while loop.""" 943 return self._switch_map 944 945 @property 946 def unused_exits(self): 947 """The list of "unused" exits.""" 948 return self._unused_exits 949 950 @property 951 def deferred_exits(self): 952 """The list of "deferred" exits.""" 953 return self._deferred_exits 954 955 @property 956 def forward_loop_exits(self): 957 """The list of exits of the forward loop.""" 958 return self._forward_loop_exits 959 960 @property 961 def pending_exits_count(self): 962 """The number of exits we expect to see but haven't.""" 963 return self._pending_exits_count 964 965 @pending_exits_count.setter 966 def pending_exits_count(self, cnt): 967 """Set the pending count to cnt.""" 968 self._pending_exits_count = cnt 969 970 def AddForwardAccumulator(self, value, dead_branch=False): 971 """Add an accumulator for each forward tensor that is needed in backprop. 972 973 This is added to the forward loop at the first time when a tensor 974 in the forward loop is used by backprop gradient computation loop. 975 We create an accumulator that accumulates the value of tensor at each 976 iteration. Called in the control flow context where gradients() is called. 977 978 The pseudocode is: 979 ``` 980 acc = stack(); 981 while (_pivot) { 982 acc = stack_push(acc, value); 983 } 984 ``` 985 986 We make sure that the stack push op in one iteration is executed before 987 next iteration. This is achieved by adding a control edge from 988 `forward_index.op.inputs[0].op` to the push op, and another control 989 edge from the push op to either `forward_index.op` or `forward_sync`. 990 991 Args: 992 value: The source tensor in forward that is to be accumulated. 993 dead_branch: True iff the tensor is on a dead branch of a cond. 994 995 Returns: 996 The stack that contains the accumulated history of the tensor. 997 998 Raises: 999 TypeError: For internal errors involving the value condition context. 1000 ValueError: If `value` is inside a XLA scope and a valid max size 1001 for the stack can't be found. 1002 """ 1003 # curr_ctxt is the context that tf.gradients was called in. 1004 curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access 1005 with ops.control_dependencies(None): 1006 if curr_ctxt: 1007 curr_ctxt.Enter() 1008 with ops.colocate_with(value): 1009 # We only need to pass maximum_iterations to the stack if 1010 # we're inside an XLA context. 1011 if not util.IsInXLAContext(value.op): 1012 max_size = constant_op.constant(-1, dtypes.int32) 1013 else: 1014 max_size = GetMaxSizeFromNestedMaximumIterations( 1015 value, self.forward_context) 1016 # pylint: disable=protected-access 1017 acc = gen_data_flow_ops._stack_v2( 1018 max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") 1019 # pylint: enable=protected-access 1020 if curr_ctxt: 1021 curr_ctxt.Exit() 1022 1023 # Make acc available in the forward context. 1024 enter_acc = self.forward_context.AddValue(acc) 1025 1026 # Add the stack_push op in the context of value.op. 1027 swap_enabled = self.forward_context.swap_memory 1028 value_ctxt = util.GetOutputContext(value.op) 1029 if value_ctxt == self.forward_context: 1030 # value is not nested in the forward context. 1031 self.forward_context.Enter() 1032 # pylint: disable=protected-access 1033 push = gen_data_flow_ops._stack_push_v2( 1034 enter_acc, value, swap_memory=swap_enabled) 1035 # pylint: enable=protected-access 1036 self.forward_context.Exit() 1037 # Protect stack push and order it before forward_index. 1038 self.forward_index.op._add_control_input(push.op) 1039 else: 1040 # value is in a cond context within the forward context. 1041 if not isinstance(value_ctxt, CondContext): 1042 raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) 1043 if dead_branch: 1044 # The special case for creating a zero tensor for a dead 1045 # branch of a switch. See ControlFlowState.ZerosLike(). 1046 value_ctxt.outer_context.Enter() 1047 # pylint: disable=protected-access 1048 push = gen_data_flow_ops._stack_push_v2( 1049 enter_acc, value, swap_memory=swap_enabled) 1050 # pylint: enable=protected-access 1051 value_ctxt.outer_context.Exit() 1052 push.op._set_control_flow_context(value_ctxt) 1053 else: 1054 value_ctxt.Enter() 1055 # pylint: disable=protected-access 1056 push = gen_data_flow_ops._stack_push_v2( 1057 enter_acc, value, swap_memory=swap_enabled) 1058 # pylint: enable=protected-access 1059 value_ctxt.Exit() 1060 # Protect stack push and order it before forward_sync. 1061 self.forward_sync._add_control_input(push.op) 1062 # Order stack push after the successor of forward_index 1063 add_op = self.forward_index.op.inputs[0].op 1064 push.op._add_control_input(add_op) 1065 return acc 1066 1067 def AddBackpropAccumulatedValue(self, history_value, value, 1068 dead_branch=False): 1069 """Add the getter for an accumulated value in the grad context. 1070 1071 This is added to the backprop loop. Called in the grad context to 1072 get the value of an accumulated value. The stack pop op must be guarded 1073 by the pred of the controlling cond. 1074 1075 Args: 1076 history_value: The history (a stack) of a value. 1077 value: The value that is pushed onto the stack. 1078 dead_branch: True iff the tensor is on a dead branch of a cond. 1079 1080 Returns: 1081 The current value (the top of the stack). 1082 """ 1083 history_ctxt = history_value.op._get_control_flow_context() 1084 # Find the cond context that controls history_value if any. 1085 cond_ctxt = None 1086 value_ctxt = value.op._get_control_flow_context() 1087 while value_ctxt and value_ctxt != history_ctxt: 1088 if isinstance(value_ctxt, CondContext): 1089 cond_ctxt = value_ctxt 1090 break 1091 value_ctxt = value_ctxt.outer_context 1092 with ops.control_dependencies(None): 1093 self.grad_context.Enter() 1094 if cond_ctxt: 1095 # Guard stack pop with a switch if it is controlled by a cond. 1096 grad_state = self 1097 pred = None 1098 while pred is None and grad_state: 1099 pred = grad_state.history_map.get(cond_ctxt.pred.name) 1100 grad_state = grad_state.outer_grad_state 1101 if pred is None: 1102 pred = cond_ctxt.pred 1103 branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch 1104 history_value = _SwitchRefOrTensor(history_value, pred)[branch] 1105 # pylint: disable=protected-access 1106 pop = gen_data_flow_ops._stack_pop_v2(history_value, 1107 value.dtype.base_dtype) 1108 # pylint: enable=protected-access 1109 pop.set_shape(value.get_shape()) 1110 self.grad_context.Exit() 1111 parallel_iterations = self.grad_context.parallel_iterations 1112 if parallel_iterations > 1: 1113 # All pops are ordered after pivot_for_body and before grad_sync. 1114 self.grad_sync._add_control_input(pop.op) 1115 return pop 1116 1117 def GetRealValue(self, value): 1118 """Get the real value of `value`. 1119 1120 If backprop "uses" a value produced by forward inference, an accumulator 1121 is added in the forward loop to accumulate its values. We use the 1122 accumulated value. This method must be called in the grad loop context. 1123 `value` must be in forward and needed for backprop. 1124 1125 Args: 1126 value: A tensor to be captured. 1127 1128 Returns: 1129 The same tensor obtained from the saved history. 1130 """ 1131 assert value.op.type not in ["Variable", "VariableV2"] 1132 real_value = self._history_map.get(value.name) 1133 if real_value is None: 1134 cur_value = value 1135 cur_grad_state = self 1136 while True: 1137 enter_op = util.GetLoopConstantEnter(cur_value) 1138 if enter_op: 1139 # Special case: cur_value comes from a constant Enter node. 1140 cur_value = enter_op.inputs[0] 1141 cur_grad_state = cur_grad_state.outer_grad_state 1142 if cur_grad_state is None: 1143 # We are now outside all nested loops for this gradient(), 1144 # so `value` is a loop invariant and there is no need to 1145 # save the history of value. Just make cur_value to enter 1146 # the right control flow context. 1147 real_value = self._grad_context.AddValue(cur_value) 1148 break 1149 elif constant_op.is_constant(cur_value): 1150 # If the value to be forwarded is a constant, clone the constant in 1151 # the gradient loop rather than using a stack. 1152 # TODO(phawkins): consider hoisting the constant out of the loop 1153 # instead. 1154 real_value = constant_op.constant( 1155 tensor_util.constant_value(cur_value), dtype=cur_value.dtype) 1156 break 1157 else: 1158 # Record the history of this value in forward_ctxt. 1159 self._grad_context.Exit() 1160 history_value = cur_grad_state.AddForwardAccumulator(cur_value) 1161 self._grad_context.Enter() 1162 break 1163 1164 if real_value is None: 1165 # Add the stack pop op in the grad context. 1166 real_value = cur_grad_state.AddBackpropAccumulatedValue( 1167 history_value, cur_value) 1168 if cur_grad_state != self: 1169 real_value = self._grad_context.AddValue(real_value) 1170 self._history_map[value.name] = real_value 1171 return real_value 1172 1173 1174 def _GetWhileContext(op): 1175 """Get the WhileContext to which this op belongs.""" 1176 ctxt = op._get_control_flow_context() 1177 if ctxt: 1178 ctxt = ctxt.GetWhileContext() 1179 return ctxt 1180 1181 1182 class ControlFlowState(object): 1183 """Maintain the mapping from the loops to their grad states.""" 1184 1185 def __init__(self): 1186 self._map = {} # maps forward loop context to GradLoopState 1187 1188 def GetGradState(self, op, before): 1189 """Return the grad state for this op if it's in a forward loop context.""" 1190 if before and util.IsLoopExit(op): 1191 forward_ctxt = op._get_control_flow_context() 1192 forward_ctxt = forward_ctxt.outer_context 1193 if forward_ctxt: 1194 forward_ctxt = forward_ctxt.GetWhileContext() 1195 else: 1196 forward_ctxt = _GetWhileContext(op) 1197 if forward_ctxt: 1198 return self._map.get(forward_ctxt) 1199 return None 1200 1201 def ProcessUnusedLoopExits(self, pending_count, to_ops_set): 1202 """Process all the "unused" loop exits. 1203 1204 The "unused" exits of the loops are added to `unused_exits`. An exit is 1205 unused if its pending_count is 0. If there is an exit with real gradient, 1206 all these deferred exits will enter the backprop loop with zero gradient. 1207 Otherwise, they will enter the backprop loop with None. As an example, 1208 people often write: 1209 1210 ```python 1211 v1, _ = tf.while_loop(p, b, [x1, x2]) 1212 result = gradients(v1, x1) 1213 ``` 1214 1215 The exit node for x2 is not included by the betweenness analysis. But we 1216 need to backprop x2 if x2 is involved in computing v1. 1217 1218 Args: 1219 pending_count: The number of backprop inputs for every op. 1220 to_ops_set: The set of ops for ys in gradients(ys, xs) 1221 1222 Returns: 1223 The set of unused loop exits that we know at this point we need 1224 to backprop. 1225 """ 1226 loop_exits = [] 1227 for _, grad_state in self._map.items(): 1228 # pylint: disable=protected-access 1229 for y in grad_state.forward_loop_exits: 1230 if pending_count[y.op._id] == 0: 1231 grad_state.pending_exits_count -= 1 1232 if y.op._id not in to_ops_set: 1233 grad_state.unused_exits.append(y) 1234 if grad_state.pending_exits_count == 0: 1235 loop_exits.extend(grad_state.unused_exits) 1236 # Need to include Enters in backprop for higher-order gradients. 1237 for y in grad_state.forward_context.loop_enters: 1238 if pending_count[y.op._id] == 0: 1239 pending_count[y.op._id] = 1 1240 # pylint: enable=protected-access 1241 return loop_exits 1242 1243 def EnterGradWhileContext(self, op, before): 1244 """Enter the WhileContext for gradient computation.""" 1245 grad_state = self.GetGradState(op, before) 1246 if grad_state: 1247 grad_state.grad_context.Enter() 1248 1249 def ExitGradWhileContext(self, op, before): 1250 """Exit the WhileContext for gradient computation.""" 1251 grad_state = self.GetGradState(op, before) 1252 if grad_state: 1253 grad_state.grad_context.Exit() 1254 1255 def AddWhileContext(self, op, between_op_list, between_ops): 1256 """Add the grad state for the while loop that op belongs to. 1257 1258 Note that op is an Exit, and this method must be called in 1259 the control flow context where gradients() is called. 1260 1261 Note that this method modifies `between_op_list` and `between_ops`. 1262 """ 1263 forward_ctxt = _GetWhileContext(op) 1264 grad_state = self._map.get(forward_ctxt) 1265 if grad_state is None: 1266 # This is a new while loop so create a grad state for it. 1267 outer_forward_ctxt = forward_ctxt.outer_context 1268 if outer_forward_ctxt: 1269 outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() 1270 outer_grad_state = None 1271 if outer_forward_ctxt: 1272 outer_grad_state = self._map.get(outer_forward_ctxt) 1273 grad_state = GradLoopState(forward_ctxt, outer_grad_state) 1274 self._map[forward_ctxt] = grad_state 1275 1276 # We need to include all exits of a loop for backprop. 1277 for loop_exit in grad_state.forward_loop_exits: 1278 if not between_ops[loop_exit.op._id]: 1279 between_ops[loop_exit.op._id] = True 1280 between_op_list.append(loop_exit.op) 1281 1282 def ZerosLikeForExit(self, val): 1283 """Create zeros_like gradient for a loop exit. 1284 1285 If the result of a loop variable is not used but is involved in 1286 computing the result of some needed loop variable, we create a 1287 zero-valued tensor that is fed as gradient for the Exit node of that 1288 loop variable. Note that val.op is an Exit, and this method must be 1289 called in the control flow context where gradients() is called. 1290 1291 Args: 1292 val: The output tensor of an Exit op. 1293 1294 Returns: 1295 A zero tensor of the same shape of val. 1296 """ 1297 val_shape = val.get_shape() 1298 forward_ctxt = val.op._get_control_flow_context() 1299 outer_forward_ctxt = forward_ctxt.outer_context 1300 if outer_forward_ctxt: 1301 outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() 1302 outer_grad_state = None 1303 if outer_forward_ctxt: 1304 outer_grad_state = self._map.get(outer_forward_ctxt) 1305 if outer_grad_state: 1306 # This is a nested loop. 1307 if val_shape.is_fully_defined(): 1308 # If the shape is known statically, just create a zero tensor 1309 # with the right shape in the right context. 1310 outer_grad_state.grad_context.Enter() 1311 result = array_ops.zeros(val_shape.dims, val.dtype) 1312 outer_grad_state.grad_context.Exit() 1313 else: 1314 # Only the shape of value is needed for backprop. 1315 forward_ctxt.outer_context.Enter() 1316 shape = array_ops.shape_internal(val, optimize=False) 1317 forward_ctxt.outer_context.Exit() 1318 # Save the shape to a stack. 1319 history_shape = outer_grad_state.AddForwardAccumulator(shape) 1320 # Get the shape back from the stack. 1321 outer_grad_ctxt = outer_grad_state.grad_context 1322 outer_grad_ctxt.Enter() 1323 real_shape = outer_grad_state.AddBackpropAccumulatedValue( 1324 history_shape, shape) 1325 result = array_ops.zeros(real_shape, val.dtype) 1326 outer_grad_ctxt.Exit() 1327 else: 1328 # This is not a nested loop. 1329 if val_shape.is_fully_defined(): 1330 # If the shape is known statically, just create a zero tensor 1331 # with the right shape. 1332 result = array_ops.zeros(val_shape.dims, val.dtype) 1333 else: 1334 result = array_ops.zeros_like(val, optimize=False) 1335 return result 1336 1337 def ZerosLike(self, op, index): 1338 """Create zeros_like for the specified output of an op. 1339 1340 If op is in a while loop that is part of gradients(), this method 1341 must be called in its grad loop context. 1342 1343 Args: 1344 op: A tensorflow operation. 1345 index: the index for a specific output of the op. 1346 1347 Returns: 1348 A zero tensor of the same shape of op.outputs[index]. 1349 """ 1350 if util.IsLoopSwitch(op): 1351 return None 1352 dead_branch = util.IsSwitch(op) 1353 forward_ctxt = _GetWhileContext(op) 1354 grad_state = self._map.get(forward_ctxt) 1355 if grad_state is None: 1356 # op is not in a while loop that is part of gradients(). 1357 return ZerosLikeOutsideLoop(op, index) 1358 op_ctxt = op._get_control_flow_context() 1359 val = ops.convert_to_tensor(op.outputs[index], name="tensor") 1360 shape = val.get_shape() 1361 if shape.is_fully_defined(): 1362 # If the shape is known statically, just create a zero tensor with 1363 # the right shape in the grad loop context. 1364 result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype) 1365 if dead_branch: 1366 # op is a cond switch. Guard the zero tensor with a switch. 1367 pred = grad_state.history_map.get(op_ctxt.pred.name) 1368 branch = op_ctxt.branch 1369 result = _SwitchRefOrTensor(result, pred)[1 - branch] 1370 else: 1371 # Unknown shape so keep a history of the shape at runtime. 1372 if dead_branch: 1373 # Need to add a special switch to guard the value. 1374 pred = op_ctxt.pred 1375 branch = op_ctxt.branch 1376 op_ctxt.outer_context.Enter() 1377 val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch] 1378 zeros_shape = array_ops.shape_internal(val, optimize=False) 1379 op_ctxt.outer_context.Exit() 1380 val.op._set_control_flow_context(op_ctxt) 1381 zeros_shape.op._set_control_flow_context(op_ctxt) 1382 else: 1383 op_ctxt.Enter() 1384 zeros_shape = array_ops.shape_internal(val, optimize=False) 1385 op_ctxt.Exit() 1386 1387 # Add forward accumulator for shape. 1388 grad_state.grad_context.Exit() 1389 history_zeros_shape = grad_state.AddForwardAccumulator( 1390 zeros_shape, dead_branch=dead_branch) 1391 grad_state.grad_context.Enter() 1392 1393 # Create a zero tensor with the right shape. 1394 shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape, 1395 zeros_shape, dead_branch) 1396 result = array_ops.zeros(shape, val.dtype) 1397 return result 1398 1399 def PostProcessing(self): 1400 """Perform postprocessing at the end of gradients(). 1401 1402 We have created the gradient graph at this point. So this function 1403 can be used to perform any postprocessing on the gradient graph. 1404 We currently perform the following postprocessing: 1405 1. Patch the gradient graph if the output of a loop variable 1406 doesn't depend on its input. 1407 """ 1408 for _, grad_state in self._map.items(): 1409 for _, b_merge in grad_state.switch_map.items(): 1410 if b_merge.op.inputs[0] == b_merge.op.inputs[1]: 1411 # The value of this loop variable at iteration i+1 doesn't 1412 # depend on its value at iteration i. So use zeros as the 1413 # gradients for all iterations > 0. 1414 dtype = b_merge.op.inputs[0].dtype 1415 shape = b_merge.op.inputs[0].get_shape() 1416 # pylint: disable=protected-access 1417 if shape.is_fully_defined(): 1418 grad_state.grad_context.Enter() 1419 # Create a zeros and use it for iterations > 0. 1420 grad_val = constant_op.constant(0, dtype=dtype, shape=shape) 1421 next_grad_val = _NextIteration(grad_val) 1422 grad_state.grad_context.Exit() 1423 else: 1424 # Create a zeros in the outer grad context. 1425 outer_grad_ctxt = grad_state.grad_context.outer_context 1426 if outer_grad_ctxt: 1427 outer_grad_ctxt.Enter() 1428 enter_grad_op = b_merge.op.inputs[0].op 1429 enter_grad = enter_grad_op.inputs[0] 1430 grad_shape = array_ops.shape_internal(enter_grad, optimize=False) 1431 grad_val = array_ops.zeros(grad_shape) 1432 if outer_grad_ctxt: 1433 outer_grad_ctxt.Exit() 1434 # Use the zeros for iterations > 0. 1435 grad_state.grad_context.Enter() 1436 next_grad_val = _NextIteration(grad_val) 1437 grad_state.grad_context.Exit() 1438 b_merge.op._update_input(1, next_grad_val) 1439 # pylint: enable=protected-access 1440 1441 1442 def MaybeCreateControlFlowState(between_op_list, between_ops, 1443 colocate_gradients_with_ops): 1444 """Create the state for all the while loops involved in one gradients(). 1445 1446 We create a ControlFlowState when there are while loops involved in 1447 gradients(). In gradients(), control flow logic is only invoked when 1448 the ControlFlowState is not None. 1449 1450 Note that this method modifies `between_op_list` and `between_ops`. 1451 """ 1452 loop_state = None 1453 for op in between_op_list: 1454 if util.IsLoopExit(op): 1455 if loop_state is None: 1456 loop_state = ControlFlowState() 1457 if colocate_gradients_with_ops: 1458 with ops.colocate_with(op): 1459 loop_state.AddWhileContext(op, between_op_list, between_ops) 1460 else: 1461 loop_state.AddWhileContext(op, between_op_list, between_ops) 1462 return loop_state 1463 1464 1465 def ZerosLikeOutsideLoop(op, index): 1466 """Create zeros_like for the specified output of an op.""" 1467 val = op.outputs[index] 1468 if not util.IsSwitch(op): 1469 return array_ops.zeros_like(val, optimize=False) 1470 else: 1471 op_ctxt = op._get_control_flow_context() 1472 if op_ctxt: 1473 # We are in a cond context. Use a switch to create zeros only when needed. 1474 pred = op_ctxt.pred 1475 branch = op_ctxt.branch 1476 switch_val = switch(op.inputs[0], pred)[1 - branch] 1477 zeros_shape = array_ops.shape_internal(switch_val, optimize=False) 1478 return array_ops.zeros(zeros_shape, dtype=val.dtype) 1479 else: 1480 return array_ops.zeros_like(val, optimize=False) 1481 1482 1483 class ControlFlowContext(object): 1484 """The base class for control flow context. 1485 1486 The usage pattern is a sequence of (Enter, Exit) followed by a final 1487 ExitResult. 1488 1489 We maintain the following state for control flow contexts during graph 1490 construction: 1491 1. graph has _control_flow_context: the current context used to 1492 construct new nodes. Changed by ctxt.Enter() and ctxt.Exit() 1493 2. op has _control_flow_context: the context to which the op belongs. 1494 Set at the time the op is created. Immutable. 1495 3. A ControlFlowContext has _outer_context: the context in which this 1496 context is created. Set at the time a context is created. Immutable. 1497 4. A ControlFlowContext has _context_stack. 1498 Pushed and popped by ctxt.Enter() and ctxt.Exit() 1499 """ 1500 1501 def __init__(self, values_def=None, import_scope=None): 1502 self._nested_contexts = [] 1503 self._outer_context = ops.get_default_graph()._get_control_flow_context() 1504 if self._outer_context: 1505 self._outer_context._nested_contexts.append(self) # pylint: disable=protected-access 1506 self._context_stack = [] 1507 if values_def: 1508 self._init_values_from_proto(values_def, import_scope=import_scope) 1509 else: 1510 # Values that have been already seen in this context. 1511 self._values = set() 1512 # Values referenced by but external to this context. 1513 self._external_values = {} 1514 1515 def _init_values_from_proto(self, values_def, import_scope=None): 1516 """Initializes values and external_values from `ValuesDef` protocol buffer. 1517 1518 Args: 1519 values_def: `ValuesDef` protocol buffer. 1520 import_scope: Optional `string`. Name scope to add. 1521 """ 1522 assert isinstance(values_def, control_flow_pb2.ValuesDef) 1523 self._values = set( 1524 ops.prepend_name_scope(value, import_scope) 1525 for value in values_def.values) 1526 g = ops.get_default_graph() 1527 self._external_values = {} 1528 for k, v in values_def.external_values.items(): 1529 k = ops.prepend_name_scope(k, import_scope) 1530 self._external_values[k] = g.as_graph_element( 1531 ops.prepend_name_scope(v, import_scope)) 1532 op_names = set([ 1533 op.split(":")[0] 1534 for op in self._values - set(self._external_values.keys()) 1535 ]) 1536 for op in op_names: 1537 # pylint: disable=protected-access 1538 g.as_graph_element(op)._set_control_flow_context(self) 1539 # pylint: enable=protected-access 1540 1541 @property 1542 def name(self): 1543 return self._name 1544 1545 @property 1546 def outer_context(self): 1547 """Return the context containing this context.""" 1548 return self._outer_context 1549 1550 @property 1551 def grad_state(self): 1552 raise NotImplementedError("Abstract method") 1553 1554 @property 1555 def back_prop(self): 1556 raise NotImplementedError("Abstract method") 1557 1558 @abc.abstractmethod 1559 def to_control_flow_context_def(self, context_def, export_scope=None): 1560 """Serializes this into `context_def`. 1561 1562 Args: 1563 context_def: a `ControlFlowContextDef` protocol buffer. 1564 export_scope: Optional `string`. Name scope to remove. 1565 """ 1566 raise NotImplementedError("Abstract method") 1567 1568 def _to_values_def(self, export_scope=None): 1569 """Converts the values to a `ValuesDef` protocol buffer. 1570 1571 Args: 1572 export_scope: Optional `string`. Name scope to remove. 1573 1574 Returns: 1575 A `ValuesDef` protocol buffer. 1576 """ 1577 values_def = control_flow_pb2.ValuesDef() 1578 values_def.values.extend( 1579 [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)]) 1580 for k, v in self._external_values.items(): 1581 k = ops.strip_name_scope(k, export_scope) 1582 values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope) 1583 return values_def 1584 1585 def AddName(self, name): 1586 self._values.add(name) 1587 1588 # pylint: disable=protected-access 1589 def Enter(self): 1590 """Enter this control flow context.""" 1591 graph = ops.get_default_graph() 1592 self._context_stack.append(graph._get_control_flow_context()) 1593 graph._set_control_flow_context(self) 1594 1595 def Exit(self): 1596 """Exit this control flow context.""" 1597 graph = ops.get_default_graph() 1598 last_context = self._context_stack.pop() 1599 graph._set_control_flow_context(last_context) 1600 1601 def ExitResult(self, result): 1602 """Make a list of tensors available in the outer context.""" 1603 if self._outer_context: 1604 nest.map_structure(lambda x: self._outer_context.AddName(x.name), result) 1605 1606 def GetWhileContext(self): 1607 """Return the while context containing this context.""" 1608 if self._outer_context: 1609 return self._outer_context.GetWhileContext() 1610 return None 1611 1612 def _IsInOuterContext(self, op): 1613 op_ctxt = util.GetOutputContext(op) 1614 outer_ctxt = self.outer_context 1615 while outer_ctxt != op_ctxt: 1616 if outer_ctxt is None: 1617 return False 1618 outer_ctxt = outer_ctxt.outer_context 1619 return True 1620 1621 def _RemoveExternalControlEdges(self, op): 1622 """Remove any external control dependency on this op.""" 1623 while_ctxt = self.GetWhileContext() 1624 # A control input of `op` is internal if it is in the same while 1625 # loop context as the enclosing while loop context of self. 1626 if while_ctxt is None: 1627 internal_control_inputs = op.control_inputs 1628 else: 1629 internal_control_inputs = [] 1630 for x in op.control_inputs: 1631 ctxt = util.GetOutputContext(x) 1632 if ctxt is not None and ctxt.GetWhileContext() == while_ctxt: 1633 internal_control_inputs.append(x) 1634 external_control_inputs = [] 1635 if len(internal_control_inputs) != len(op.control_inputs): 1636 external_control_inputs = list(set(op.control_inputs) 1637 - set(internal_control_inputs)) 1638 op._remove_all_control_inputs() 1639 op._add_control_inputs(internal_control_inputs) 1640 return internal_control_inputs, external_control_inputs 1641 1642 # pylint: enable=protected-access 1643 1644 def AddInnerOp(self, op): 1645 """Notifies a scope about an operator added to an inner scope.""" 1646 if self._outer_context: 1647 self._outer_context.AddInnerOp(op) 1648 1649 def GetControlPivot(self): 1650 """Returns the pivot node for this context, or None.""" 1651 return None 1652 1653 def IsWhileContext(self): 1654 return False 1655 1656 def IsCondContext(self): 1657 return False 1658 1659 def IsXLAContext(self): 1660 return False 1661 1662 def __str__(self): 1663 return self.name 1664 1665 1666 class CondContext(ControlFlowContext): 1667 """The context for the conditional construct.""" 1668 1669 def __init__(self, 1670 pred=None, 1671 pivot=None, 1672 branch=None, 1673 name="cond_text", 1674 context_def=None, 1675 import_scope=None): 1676 """Creates a `CondContext`. 1677 1678 Args: 1679 pred: The `boolean` tensor for the conditional predicate. 1680 pivot: The predicate tensor in this branch. 1681 branch: 0 or 1 representing this branch. 1682 name: Name of the `CondContext` python object. 1683 context_def: Optional `ContextDef` protocol buffer to initialize the 1684 `CondContext` object from. 1685 import_scope: Optional `string`. Name scope to add. Only used when 1686 initialing from protocol buffer. 1687 """ 1688 self._name = ops.get_default_graph().unique_name(name) 1689 1690 if context_def: 1691 self._init_from_proto(context_def, import_scope=import_scope) 1692 else: 1693 # Initializes the default fields. 1694 ControlFlowContext.__init__(self) 1695 self._pred = pred # The boolean tensor for the cond predicate 1696 self._pivot = pivot # The predicate tensor in this branch 1697 self._branch = branch # 0 or 1 representing this branch 1698 1699 # Values considered to have been already seen in this context. 1700 self._values.add(pred.name) 1701 self._values.add(pivot.name) 1702 1703 def _init_from_proto(self, context_def, import_scope=None): 1704 """Creates a new `CondContext` from protocol buffer. 1705 1706 Args: 1707 context_def: `CondContextDef` protocol buffer. 1708 import_scope: Optional `string`. Name scope to add. 1709 """ 1710 assert isinstance(context_def, control_flow_pb2.CondContextDef) 1711 # Create from context_def. 1712 g = ops.get_default_graph() 1713 self._name = ops.prepend_name_scope(context_def.context_name, import_scope) 1714 self._pred = g.as_graph_element( 1715 ops.prepend_name_scope(context_def.pred_name, import_scope)) 1716 self._pivot = g.as_graph_element( 1717 ops.prepend_name_scope(context_def.pivot_name, import_scope)) 1718 self._branch = context_def.branch 1719 super(CondContext, self).__init__( 1720 values_def=context_def.values_def, import_scope=import_scope) 1721 1722 @property 1723 def pred(self): 1724 return self._pred 1725 1726 @property 1727 def pivot(self): 1728 return self._pivot 1729 1730 @property 1731 def branch(self): 1732 return self._branch 1733 1734 @property 1735 def grad_state(self): 1736 if self.GetWhileContext(): 1737 return self.GetWhileContext().grad_state 1738 return None 1739 1740 @property 1741 def back_prop(self): 1742 if self.GetWhileContext(): 1743 self.GetWhileContext().back_prop 1744 return False 1745 1746 def GetControlPivot(self): 1747 return self._pivot 1748 1749 def to_proto(self, export_scope=None): 1750 """Converts a `CondContext` to a `CondContextDef` protocol buffer. 1751 1752 Args: 1753 export_scope: Optional `string`. Name scope to remove. 1754 1755 Returns: 1756 A `CondContextDef` protocol buffer. 1757 """ 1758 if (export_scope is None or self.name.startswith(export_scope)): 1759 context_def = control_flow_pb2.CondContextDef() 1760 context_def.context_name = ops.strip_name_scope(self.name, export_scope) 1761 context_def.pred_name = ops.strip_name_scope(self._pred.name, 1762 export_scope) 1763 context_def.pivot_name = ops.strip_name_scope(self._pivot.name, 1764 export_scope) 1765 context_def.branch = self._branch 1766 context_def.values_def.MergeFrom(super(CondContext, self)._to_values_def( 1767 export_scope)) 1768 # TODO(b/72868227): enable this once the corresponding control_flow.proto 1769 # changes have been checked in (they aren't checked in and this is 1770 # disabled for now to ensure forwards compatibility). 1771 if False: # pylint: disable=using-constant-test 1772 for nested in self._nested_contexts: 1773 nested_def = context_def.nested_contexts.add() 1774 nested.to_control_flow_context_def(nested_def) 1775 1776 return context_def 1777 else: 1778 return None 1779 1780 @staticmethod 1781 def from_proto(context_def, import_scope=None): 1782 """Returns a `CondContext` object created from `context_def`.""" 1783 ret = CondContext(context_def=context_def, 1784 import_scope=import_scope) 1785 1786 # TODO(b/72868227): remove "if hasattr(...)" once the corresponding 1787 # control_flow.proto changes have been checked in (they aren't checked in 1788 # and this is here for now to ensure forwards compatibility). 1789 if hasattr(context_def, "nested_contexts"): 1790 ret.Enter() 1791 for nested_def in context_def.nested_contexts: 1792 from_control_flow_context_def(nested_def) 1793 ret.Exit() 1794 return ret 1795 1796 def to_control_flow_context_def(self, context_def, export_scope=None): 1797 context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) 1798 1799 def AddValue(self, val): 1800 """Add `val` to the current context and its outer context recursively.""" 1801 if val.name in self._values: 1802 # Use the real value if it comes from outer context. This is needed in 1803 # particular for nested conds. 1804 result = self._external_values.get(val.name) 1805 result = val if result is None else result 1806 else: 1807 result = val 1808 self._values.add(val.name) 1809 if self._outer_context: 1810 result = self._outer_context.AddValue(val) 1811 self._values.add(result.name) 1812 with ops.control_dependencies(None): 1813 result = _SwitchRefOrTensor(result, self._pred)[self._branch] 1814 if self._outer_context: 1815 self._outer_context.AddInnerOp(result.op) 1816 1817 result.op.graph.prevent_fetching(result.op) 1818 # pylint: disable=protected-access 1819 result.op._set_control_flow_context(self) 1820 # pylint: enable=protected-access 1821 1822 self._values.add(result.name) 1823 self._external_values[val.name] = result 1824 return result 1825 1826 def AddOp(self, op): 1827 self._AddOpInternal(op) 1828 1829 def _AddOpInternal(self, op): 1830 """Add `op` to the current context.""" 1831 if not op.inputs: 1832 # Remove any external control dependency on this op 1833 self._RemoveExternalControlEdges(op) 1834 # pylint: disable=protected-access 1835 op._add_control_input(self._pivot.op) 1836 # pylint: enable=protected-access 1837 for x in op.outputs: 1838 self._values.add(x.name) 1839 else: 1840 for index in range(len(op.inputs)): 1841 x = op.inputs[index] 1842 real_x = self.AddValue(x) 1843 if real_x != x: 1844 # pylint: disable=protected-access 1845 op._update_input(index, real_x) 1846 # pylint: enable=protected-access 1847 # Remove any external control dependency on this op. 1848 self._RemoveExternalControlEdges(op) 1849 for x in op.outputs: 1850 self._values.add(x.name) 1851 # pylint: disable=protected-access 1852 if op.graph._is_function(op.type) or op.type == "SymbolicGradient": 1853 op._add_control_input(self._pivot.op) 1854 # pylint: enable=protected-access 1855 1856 if self._outer_context or not util.IsLoopExit(op): 1857 op.graph.prevent_fetching(op) 1858 1859 if self._outer_context: 1860 self._outer_context.AddInnerOp(op) 1861 1862 def _ProcessOutputTensor(self, val): 1863 """Process an output tensor of a conditional branch.""" 1864 real_val = val 1865 if val.name not in self._values: 1866 # Handle the special case of lambda: x 1867 self._values.add(val.name) 1868 if self._outer_context: 1869 real_val = self._outer_context.AddValue(val) 1870 self._values.add(real_val.name) 1871 real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch] 1872 self._external_values[val.name] = real_val 1873 else: 1874 external_val = self._external_values.get(val.name) 1875 if external_val is not None: 1876 real_val = external_val 1877 return real_val 1878 1879 def _BuildCondTensor(self, v): 1880 if isinstance(v, ops.Operation): 1881 # Use pivot as the proxy for this op. 1882 return with_dependencies([v], self._pivot) 1883 elif isinstance(v, (ops.IndexedSlices, sparse_tensor.SparseTensor)): 1884 values = self._ProcessOutputTensor(v.values) 1885 indices = self._ProcessOutputTensor(v.indices) 1886 if isinstance(v, ops.IndexedSlices): 1887 dense_shape = v.dense_shape 1888 if dense_shape is not None: 1889 dense_shape = self._ProcessOutputTensor(dense_shape) 1890 return ops.IndexedSlices(values, indices, dense_shape) 1891 else: 1892 dense_shape = self._ProcessOutputTensor(v.dense_shape) 1893 return sparse_tensor.SparseTensor(indices, values, dense_shape) 1894 else: 1895 v = nest.map_structure(_convert_tensorarray_to_flow, v) 1896 return self._ProcessOutputTensor(ops.convert_to_tensor(v)) 1897 1898 def BuildCondBranch(self, fn): 1899 """Add the subgraph defined by fn() to the graph.""" 1900 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1901 original_result = fn() 1902 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1903 if len(post_summaries) > len(pre_summaries): 1904 new_summaries = post_summaries[len(pre_summaries):] 1905 summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1906 summary_ref[:] = pre_summaries 1907 with ops.control_dependencies(new_summaries): 1908 if original_result is None: 1909 return no_op(), None 1910 else: 1911 original_result = nest.map_structure(array_ops.identity, 1912 original_result) 1913 if original_result is None: 1914 return None, None 1915 1916 result = nest.map_structure(self._BuildCondTensor, original_result) 1917 if not isinstance(result, (list, _basetuple)): 1918 result = [result] 1919 return original_result, result 1920 1921 def IsCondContext(self): 1922 return True 1923 1924 1925 def _UnpackIfSingleton(res): 1926 if isinstance(res, (list, _basetuple)) and len(res) == 1: 1927 return res[0] 1928 else: 1929 return res 1930 1931 1932 # pylint: disable=redefined-outer-name 1933 # pylint: disable=g-doc-args 1934 @tf_export("cond") 1935 @deprecation.deprecated_args( 1936 None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.", 1937 "fn1", "fn2") 1938 def cond(pred, 1939 true_fn=None, 1940 false_fn=None, 1941 strict=False, 1942 name=None, 1943 fn1=None, 1944 fn2=None): 1945 """Return `true_fn()` if the predicate `pred` is true else `false_fn()`. 1946 1947 `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and 1948 `false_fn` must have the same non-zero number and type of outputs. 1949 1950 Note that the conditional execution applies only to the operations defined in 1951 `true_fn` and `false_fn`. Consider the following simple program: 1952 1953 ```python 1954 z = tf.multiply(a, b) 1955 result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y)) 1956 ``` 1957 1958 If `x < y`, the `tf.add` operation will be executed and `tf.square` 1959 operation will not be executed. Since `z` is needed for at least one 1960 branch of the `cond`, the `tf.multiply` operation is always executed, 1961 unconditionally. 1962 Although this behavior is consistent with the dataflow model of TensorFlow, 1963 it has occasionally surprised some users who expected a lazier semantics. 1964 1965 Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the 1966 call to `cond`, and not at all during `Session.run()`). `cond` 1967 stitches together the graph fragments created during the `true_fn` and 1968 `false_fn` calls with some additional graph nodes to ensure that the right 1969 branch gets executed depending on the value of `pred`. 1970 1971 `tf.cond` supports nested structures as implemented in 1972 `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the 1973 same (possibly nested) value structure of lists, tuples, and/or named tuples. 1974 Singleton lists and tuples form the only exceptions to this: when returned by 1975 `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. 1976 This behavior is disabled by passing `strict=True`. 1977 1978 Args: 1979 pred: A scalar determining whether to return the result of `true_fn` or 1980 `false_fn`. 1981 true_fn: The callable to be performed if pred is true. 1982 false_fn: The callable to be performed if pred is false. 1983 strict: A boolean that enables/disables 'strict' mode; see above. 1984 name: Optional name prefix for the returned tensors. 1985 1986 Returns: 1987 Tensors returned by the call to either `true_fn` or `false_fn`. If the 1988 callables return a singleton list, the element is extracted from the list. 1989 1990 Raises: 1991 TypeError: if `true_fn` or `false_fn` is not callable. 1992 ValueError: if `true_fn` and `false_fn` do not return the same number of 1993 tensors, or return tensors of different types. 1994 1995 Example: 1996 1997 ```python 1998 x = tf.constant(2) 1999 y = tf.constant(5) 2000 def f1(): return tf.multiply(x, 17) 2001 def f2(): return tf.add(y, 23) 2002 r = tf.cond(tf.less(x, y), f1, f2) 2003 # r is set to f1(). 2004 # Operations in f2 (e.g., tf.add) are not executed. 2005 ``` 2006 2007 """ 2008 # We needed to make true_fn/false_fn keyword arguments for 2009 # backwards-compatibility. This check exists so that we can convert back to 2010 # having them be positional arguments. 2011 # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after 2012 # `fn1` and `fn2` are deleted. 2013 if fn1 is not None: 2014 if true_fn is not None: 2015 raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.") 2016 true_fn = fn1 2017 elif true_fn is None: 2018 raise TypeError("cond(): true_fn argument required") 2019 if fn2 is not None: 2020 if false_fn is not None: 2021 raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.") 2022 false_fn = fn2 2023 elif false_fn is None: 2024 raise TypeError("cond(): false_fn argument required") 2025 2026 if not callable(true_fn): 2027 raise TypeError("true_fn must be callable.") 2028 if not callable(false_fn): 2029 raise TypeError("false_fn must be callable.") 2030 2031 with ops.name_scope(name, "cond", [pred]): 2032 if context.in_eager_mode(): 2033 if pred: 2034 return _UnpackIfSingleton(true_fn()) 2035 return _UnpackIfSingleton(false_fn()) 2036 2037 # Add the Switch to the graph. 2038 if isinstance(pred, bool): 2039 raise TypeError("pred must not be a Python bool") 2040 p_2, p_1 = switch(pred, pred) 2041 pivot_1 = array_ops.identity(p_1, name="switch_t") 2042 pivot_2 = array_ops.identity(p_2, name="switch_f") 2043 pred = array_ops.identity(pred, name="pred_id") 2044 # Disable the fetching of tensors that are only on one branch of cond. 2045 for tensor in [p_1, p_2, pivot_1, pivot_2, pred]: 2046 tensor.op.graph.prevent_fetching(tensor.op) 2047 2048 # Build the graph for the true branch in a new context. 2049 context_t = CondContext(pred, pivot_1, branch=1) 2050 context_t.Enter() 2051 orig_res_t, res_t = context_t.BuildCondBranch(true_fn) 2052 if orig_res_t is None: 2053 raise ValueError("true_fn must have a return value.") 2054 context_t.ExitResult(res_t) 2055 context_t.Exit() 2056 2057 # Build the graph for the false branch in a new context. 2058 context_f = CondContext(pred, pivot_2, branch=0) 2059 context_f.Enter() 2060 orig_res_f, res_f = context_f.BuildCondBranch(false_fn) 2061 if orig_res_f is None: 2062 raise ValueError("false_fn must have a return value.") 2063 context_f.ExitResult(res_f) 2064 context_f.Exit() 2065 2066 if not strict: 2067 orig_res_t = _UnpackIfSingleton(orig_res_t) 2068 orig_res_f = _UnpackIfSingleton(orig_res_f) 2069 2070 # Check that the return values of the two branches have the same structure. 2071 try: 2072 nest.assert_same_structure(orig_res_t, orig_res_f) 2073 except TypeError as e: 2074 raise TypeError( 2075 "Incompatible return types of true_fn and false_fn: {}".format(e)) 2076 except ValueError as e: 2077 raise ValueError( 2078 "Incompatible return values of true_fn and false_fn: {}".format(e)) 2079 2080 # Add the final merge to the graph. 2081 if not res_t: 2082 raise ValueError("true_fn and false_fn must return at least one result.") 2083 2084 res_t_flat = nest.flatten(res_t) 2085 res_f_flat = nest.flatten(res_f) 2086 2087 for x, y in zip(res_t_flat, res_f_flat): 2088 assert ((isinstance(x, ops.IndexedSlices) and 2089 isinstance(y, ops.IndexedSlices)) or 2090 (isinstance(x, sparse_tensor.SparseTensor) and 2091 isinstance(y, sparse_tensor.SparseTensor)) or 2092 (isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor))) 2093 val_x = x if isinstance(x, ops.Tensor) else x.values 2094 val_y = y if isinstance(y, ops.Tensor) else y.values 2095 if val_x.dtype.base_dtype != val_y.dtype.base_dtype: 2096 raise ValueError( 2097 "Outputs of true_fn and false_fn must have the same type: %s, %s" % 2098 (val_x.dtype.name, val_y.dtype.name)) 2099 2100 merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)] 2101 merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges) 2102 2103 # Only add non-nested conds to the collection. Any nested control flow will 2104 # be encapsulated in the root context. 2105 assert context_t.outer_context == context_f.outer_context 2106 # TODO(b/72868227): remove "if True..." once the corresponding 2107 # control_flow.proto changes have been checked in (they aren't checked in 2108 # and this is disabled for now to ensure forwards compatibility). 2109 if True or context_t.outer_context is None: 2110 ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t) 2111 ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f) 2112 2113 merges = nest.pack_sequence_as(structure=orig_res_t, flat_sequence=merges) 2114 2115 # Singleton lists and tuples are automatically unpacked if strict == False. 2116 if not strict: 2117 merges = _UnpackIfSingleton(merges) 2118 return merges 2119 2120 2121 # pylint: enable=g-doc-args 2122 # pylint: enable=redefined-outer-name 2123 2124 2125 def smart_cond(pred, true_fn=None, false_fn=None, name=None): 2126 """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. 2127 2128 If `pred` is a bool or has a constant value, we return either `true_fn()` 2129 or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. 2130 2131 Arguments: 2132 pred: A scalar determining whether to return the result of `true_fn` or 2133 `false_fn`. 2134 true_fn: The callable to be performed if pred is true. 2135 false_fn: The callable to be performed if pred is false. 2136 name: Optional name prefix when using `tf.cond`. 2137 2138 Returns: 2139 Tensors returned by the call to either `true_fn` or `false_fn`. 2140 2141 Raises: 2142 TypeError: If `true_fn` or `false_fn` is not callable. 2143 """ 2144 if not callable(true_fn): 2145 raise TypeError('`true_fn` must be callable.') 2146 if not callable(false_fn): 2147 raise TypeError('`false_fn` must be callable.') 2148 2149 pred_value = smart_constant_value(pred) 2150 if pred_value is not None: 2151 if pred_value: 2152 return true_fn() 2153 else: 2154 return false_fn() 2155 else: 2156 return cond(pred, true_fn=true_fn, false_fn=false_fn, name=name) 2157 2158 2159 def smart_constant_value(pred): 2160 """Return the bool value for `pred`, or None if `pred` had a dynamic value. 2161 2162 Arguments: 2163 pred: A scalar, either a Python bool or tensor. 2164 2165 Returns: 2166 True or False if `pred` has a constant boolean value, None otherwise. 2167 2168 Raises: 2169 TypeError: If `pred` is not a Tensor or bool. 2170 """ 2171 if isinstance(pred, bool): 2172 pred_value = pred 2173 elif isinstance(pred, ops.Tensor): 2174 pred_value = tensor_util.constant_value(pred) 2175 else: 2176 raise TypeError('`pred` must be a Tensor or a Python bool.') 2177 return pred_value 2178 2179 2180 def _resource_safe_shape(t): 2181 """Returns the shape of t or the variable it points to.""" 2182 if t.dtype == dtypes.resource: 2183 while t.op.inputs: 2184 t = t.op.inputs[0] 2185 return tensor_shape.TensorShape(t.op.get_attr("shape")) 2186 return array_ops.shape_internal(t, optimize=False) 2187 2188 2189 # TODO(yuanbyu): Consider having a unified notion of context for 2190 # not only conditionals and loops but also control dependency and 2191 # subgraphs. 2192 class WhileContext(ControlFlowContext): 2193 """The context for the loop construct.""" 2194 2195 def __init__(self, 2196 maximum_iterations=None, 2197 parallel_iterations=10, 2198 back_prop=True, 2199 swap_memory=False, 2200 name="while_context", 2201 grad_state=None, 2202 context_def=None, 2203 import_scope=None): 2204 """"Creates a `WhileContext`. 2205 2206 Args: 2207 maximum_iterations: Optional upper bound on number of loop iterations. 2208 parallel_iterations: The number of iterations allowed to run in parallel. 2209 back_prop: Whether backprop is enabled for this while loop. 2210 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 2211 name: Optional name prefix for the returned tensors. 2212 grad_state: The gradient loop state. 2213 context_def: Optional `WhileContextDef` protocol buffer to initialize 2214 the `Whilecontext` python object from. 2215 import_scope: Optional `string`. Name scope to add. Only used when 2216 initialing from protocol buffer. 2217 """ 2218 if context_def: 2219 self._init_from_proto(context_def, import_scope=import_scope) 2220 else: 2221 ControlFlowContext.__init__(self) 2222 self._init_from_args(maximum_iterations, parallel_iterations, back_prop, 2223 swap_memory, name) 2224 # The gradient loop state. 2225 self._grad_state = grad_state 2226 2227 def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop, 2228 swap_memory, name): 2229 """Creates a new `WhileContext` from arguments. 2230 2231 Args: 2232 maximum_iterations: Optional upper bound on number of loop iterations. 2233 parallel_iterations: The number of iterations allowed to run in parallel. 2234 back_prop: Whether backprop is enabled for this while loop. 2235 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 2236 name: Optional name prefix for the returned tensors. 2237 2238 Raises: 2239 ValueError: If `parallel_iterations` has invalid value. 2240 """ 2241 if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0): 2242 raise ValueError("`parallel_iterations` must be a positive integer: " 2243 "%s" % parallel_iterations) 2244 self._name = ops.get_default_graph().unique_name(name) 2245 self._maximum_iterations = maximum_iterations 2246 self._parallel_iterations = parallel_iterations 2247 self._back_prop = back_prop 2248 self._swap_memory = swap_memory 2249 # We use this node to control constants created by the pred lambda. 2250 self._pivot_for_pred = None 2251 # We use this node to control constants created by the body lambda. 2252 self._pivot_for_body = None 2253 # The boolean tensor for loop termination condition. Used in code 2254 # generation for gradient computation 2255 self._pivot = None 2256 # The list of exit tensors for loop variables. 2257 self._loop_exits = [] 2258 # The list of enter tensors for loop variables. 2259 self._loop_enters = [] 2260 2261 def _init_from_proto(self, context_def, import_scope=None): 2262 """Creates a new `WhileContext` from protocol buffer. 2263 2264 Args: 2265 context_def: `WhileContextDef` protocol buffer. 2266 import_scope: Optional `string`. Name scope to add. 2267 """ 2268 assert isinstance(context_def, control_flow_pb2.WhileContextDef) 2269 # Create from context_def. 2270 g = ops.get_default_graph() 2271 self._name = ops.prepend_name_scope(context_def.context_name, import_scope) 2272 if context_def.maximum_iterations_name: 2273 self._maximum_iterations = g.as_graph_element( 2274 ops.prepend_name_scope(context_def.maximum_iterations_name, 2275 import_scope)) 2276 else: 2277 self._maximum_iterations = None 2278 self._parallel_iterations = context_def.parallel_iterations 2279 self._back_prop = context_def.back_prop 2280 self._swap_memory = context_def.swap_memory 2281 self._pivot_for_pred = g.as_graph_element( 2282 ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope)) 2283 # We use this node to control constants created by the body lambda. 2284 self._pivot_for_body = g.as_graph_element( 2285 ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope)) 2286 # The boolean tensor for loop termination condition. Used in code 2287 # generation for gradient computation. 2288 self._pivot = g.as_graph_element( 2289 ops.prepend_name_scope(context_def.pivot_name, import_scope)) 2290 # The list of exit tensors for loop variables. 2291 self._loop_exits = [ 2292 g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) 2293 for exit_name in context_def.loop_exit_names 2294 ] 2295 # The list of enter tensors for loop variables. 2296 self._loop_enters = [ 2297 g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) 2298 for enter_name in context_def.loop_enter_names 2299 ] 2300 super(WhileContext, self).__init__( 2301 values_def=context_def.values_def, import_scope=import_scope) 2302 2303 # import_scope causes self.name to be different from the original serialized 2304 # context's name. Rewrite "frame_name" attrs with the new name. 2305 if import_scope: 2306 for tensor_name in self._values: 2307 op = g.as_graph_element(tensor_name).op 2308 if util.IsLoopEnter(op): 2309 # pylint: disable=protected-access 2310 op._set_attr("frame_name", 2311 attr_value_pb2.AttrValue(s=compat.as_bytes(self.name))) 2312 # pylint: enable=protected-access 2313 2314 @property 2315 def maximum_iterations(self): 2316 """The maximum number of iterations that will be executed.""" 2317 return self._maximum_iterations 2318 2319 @property 2320 def parallel_iterations(self): 2321 """The number of iterations allowed to run in parallel.""" 2322 return self._parallel_iterations 2323 2324 @property 2325 def back_prop(self): 2326 """True iff backprop is enabled for this while loop.""" 2327 return self._back_prop 2328 2329 @property 2330 def swap_memory(self): 2331 """True iff GPU-CPU memory swap is enabled for this while loop.""" 2332 return self._swap_memory 2333 2334 @property 2335 def pivot(self): 2336 """The boolean tensor representing the loop termination condition.""" 2337 return self._pivot 2338 2339 @property 2340 def loop_enters(self): 2341 """The list of enter tensors for loop variables.""" 2342 return self._loop_enters 2343 2344 @property 2345 def loop_exits(self): 2346 """The list of exit tensors for loop variables.""" 2347 return self._loop_exits 2348 2349 @property 2350 def grad_state(self): 2351 """The gradient loop state.""" 2352 return self._grad_state 2353 2354 def to_proto(self, export_scope=None): 2355 """Converts a `WhileContext` to a `WhileContextDef` protocol buffer. 2356 2357 Args: 2358 export_scope: Optional `string`. Name scope to remove. 2359 2360 Returns: 2361 A `WhileContextDef` protocol buffer. 2362 """ 2363 if (export_scope is None or self.name.startswith(export_scope)): 2364 context_def = control_flow_pb2.WhileContextDef() 2365 context_def.context_name = ops.strip_name_scope(self.name, export_scope) 2366 context_def.parallel_iterations = self._parallel_iterations 2367 if self._maximum_iterations is not None: 2368 context_def.maximum_iterations_name = ops.strip_name_scope( 2369 self._maximum_iterations.name, export_scope) 2370 context_def.back_prop = self._back_prop 2371 context_def.swap_memory = self._swap_memory 2372 context_def.pivot_for_pred_name = ops.strip_name_scope( 2373 self._pivot_for_pred.name, export_scope) 2374 context_def.pivot_for_body_name = ops.strip_name_scope( 2375 self._pivot_for_body.name, export_scope) 2376 context_def.pivot_name = ops.strip_name_scope(self._pivot.name, 2377 export_scope) 2378 context_def.loop_exit_names.extend([ 2379 ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits 2380 ]) 2381 context_def.loop_enter_names.extend([ 2382 ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters 2383 ]) 2384 context_def.values_def.MergeFrom( 2385 super(WhileContext, self)._to_values_def( 2386 export_scope=export_scope)) 2387 # TODO(b/72868227): remove "if True..." once the corresponding 2388 # control_flow.proto changes have been checked in (they aren't checked in 2389 # and this is disabled for now to ensure forwards compatibility). 2390 if False: # pylint: disable=using-constant-test 2391 for nested in self._nested_contexts: 2392 nested_def = context_def.nested_contexts.add() 2393 nested.to_control_flow_context_def(nested_def) 2394 2395 return context_def 2396 else: 2397 return None 2398 2399 def to_control_flow_context_def(self, context_def, export_scope=None): 2400 context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) 2401 2402 @staticmethod 2403 def from_proto(context_def, import_scope=None): 2404 """Returns a `WhileContext` object created from `context_def`. 2405 2406 Args: 2407 context_def: A `WhileContextDef` protocol buffer. 2408 import_scope: Optional `string`. Name scope to add. 2409 2410 Returns: 2411 A `WhileContext` Python object. 2412 """ 2413 ret = WhileContext(context_def=context_def, 2414 import_scope=import_scope) 2415 # TODO(b/72868227): remove "if hasattr(...)" once the corresponding 2416 # control_flow.proto changes have been checked in (they aren't checked in 2417 # and this is disabled for now to ensure forwards compatibility). 2418 if hasattr(context_def, "nested_contexts"): 2419 ret.Enter() 2420 for nested_def in context_def.nested_contexts: 2421 from_control_flow_context_def(nested_def, import_scope=import_scope) 2422 ret.Exit() 2423 return ret 2424 2425 def GetWhileContext(self): 2426 return self 2427 2428 def GetControlPivot(self): 2429 if self._pivot_for_body is not None: 2430 return self._pivot_for_body 2431 return self._pivot_for_pred 2432 2433 def AddValue(self, val): 2434 """Add `val` to the current context and its outer context recursively.""" 2435 result = val 2436 if val.name not in self._values: 2437 self._values.add(val.name) 2438 2439 # If we are in a grad context and val is from its forward context, 2440 # use GetRealValue(), which adds the logic to save the history of 2441 # val in forward. 2442 grad_ctxt = ops.get_default_graph()._get_control_flow_context() 2443 if grad_ctxt: 2444 grad_ctxt = grad_ctxt.GetWhileContext() 2445 if grad_ctxt.grad_state: 2446 forward_ctxt = _GetWhileContext(val.op) 2447 if util.IsLoopExit(val.op): 2448 forward_ctxt = forward_ctxt.outer_context 2449 if forward_ctxt: 2450 forward_ctxt = forward_ctxt.GetWhileContext() 2451 if forward_ctxt == grad_ctxt.grad_state.forward_context: 2452 real_val = grad_ctxt.grad_state.GetRealValue(val) 2453 self._external_values[val.name] = real_val 2454 return real_val 2455 2456 if self._outer_context is not None: 2457 result = self._outer_context.AddValue(val) 2458 # Create an Enter to make `result` known to this loop context. 2459 with ops.control_dependencies(None): 2460 enter = _Enter( 2461 result, 2462 self._name, 2463 is_constant=True, 2464 parallel_iterations=self._parallel_iterations) 2465 enter.graph.prevent_feeding(enter) 2466 if self._outer_context: 2467 self._outer_context.AddInnerOp(enter.op) 2468 # Fix the control inputs and control flow context of these enter ops. 2469 self._FixControlInputsAndContext([enter]) 2470 2471 # Add `enter` in this context. 2472 self._values.add(enter.name) 2473 self._external_values[val.name] = enter 2474 result = enter 2475 else: 2476 actual_val = self._external_values.get(val.name) 2477 if actual_val is not None: 2478 result = actual_val 2479 return result 2480 2481 def AddOp(self, op): 2482 """Add `op` to the current context.""" 2483 # For a reduction op, if op is in a grad context and its input is from 2484 # its forward context, moving op to the forward context means we would 2485 # store the tensor after the reduction as opposed to the tensor before 2486 # reduction, and therefore could significantly reduce memory consumption. 2487 # For now, we do this only for a few ops. 2488 if op.type in {"Shape", "Size", "Rank"}: 2489 grad_ctxt = ops.get_default_graph()._get_control_flow_context() 2490 if grad_ctxt: 2491 grad_ctxt = grad_ctxt.GetWhileContext() 2492 if grad_ctxt.grad_state: 2493 op_input_forward_ctxt = _GetWhileContext(op.inputs[0].op) 2494 if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context: 2495 op_input_ctxt = op.inputs[0].op._get_control_flow_context() 2496 op._set_control_flow_context(op_input_ctxt) 2497 op_input_ctxt._AddOpInternal(op) 2498 return 2499 self._AddOpInternal(op) 2500 2501 def _AddOpInternal(self, op): 2502 """Add `op` to the current context. 2503 2504 We move any external control dependencies of the op to the loop pivot, to 2505 ensure they get executed. 2506 """ 2507 if not op.inputs: 2508 # Remove any external control dependency on this op 2509 control_inputs, external_inputs = self._RemoveExternalControlEdges(op) 2510 # Add a control edge from the control pivot to this op. 2511 if not control_inputs: 2512 # pylint: disable=protected-access 2513 op._add_control_input(self.GetControlPivot().op) 2514 # pylint: enable=protected-access 2515 for x in op.outputs: 2516 self._values.add(x.name) 2517 else: 2518 for index in range(len(op.inputs)): 2519 x = op.inputs[index] 2520 real_x = self.AddValue(x) 2521 if real_x != x: 2522 op._update_input(index, real_x) # pylint: disable=protected-access 2523 # Remove any external control dependency on this op. 2524 _, external_inputs = self._RemoveExternalControlEdges(op) 2525 # Add a control dependency to prevent loop invariants from 2526 # enabling ops that should not be executed. 2527 self._MaybeAddControlDependency(op) 2528 for x in op.outputs: 2529 self._values.add(x.name) 2530 if external_inputs: 2531 # Use an identity to pull control inputs as data inputs. Note that we 2532 # ignore ops which don't have outputs. TODO(apassos): fix that 2533 with ops.control_dependencies(None): 2534 self.Enter() 2535 external_inputs = [array_ops.identity(x.outputs[0]).op 2536 for x in external_inputs if x.outputs] 2537 self.Exit() 2538 op._add_control_inputs(external_inputs) # pylint: disable=protected-access 2539 if self._outer_context or not util.IsLoopExit(op): 2540 op.graph.prevent_fetching(op) 2541 for x in op.outputs: 2542 op.graph.prevent_feeding(x) 2543 2544 if self._outer_context: 2545 self._outer_context.AddInnerOp(op) 2546 2547 def _MaybeAddControlDependency(self, op): 2548 """Add a control input to the op if it only depends on loop invariants.""" 2549 2550 def _IsOpFree(op): 2551 """Determines if `op` needs a control dependency.""" 2552 if op.control_inputs: 2553 return False 2554 # pylint: disable=protected-access 2555 if op.graph._is_function(op.type) or op.type == "SymbolicGradient": 2556 return True 2557 # pylint: enable=protected-access 2558 for x in op.inputs: 2559 if not util.IsLoopConstantEnter(x.op): 2560 return False 2561 return True 2562 2563 if _IsOpFree(op): 2564 # pylint: disable=protected-access 2565 op._add_control_input(self.GetControlPivot().op) 2566 # pylint: enable=protected-access 2567 2568 def AddForwardLoopCounter(self, outer_grad_state): 2569 """Adds a loop that counts the number of iterations. 2570 2571 This is added to the forward loop at the time when we start to 2572 create the loop for backprop gradient computation. Called in 2573 the outer context of this forward context. 2574 2575 The pseudocode is: 2576 `n = 0; while (_pivot) { n++; }` 2577 2578 Note that a control dependency is added to `n` to ensure the correct 2579 execution order of stack push ops. 2580 2581 Args: 2582 outer_grad_state: The outer grad state. None if not nested. 2583 2584 Returns: 2585 The number of iterations taken by the forward loop and the loop index. 2586 """ 2587 n = constant_op.constant(0, name="f_count") 2588 if outer_grad_state is not None: 2589 # Force the stack pushes of i-th execution of an inner loop to be ordered 2590 # before the pushes of (i+1)-th execution of the same inner loop. 2591 outer_add_op = outer_grad_state.forward_index.op.inputs[0].op 2592 n.op._add_control_input(outer_add_op) # pylint: disable=protected-access 2593 2594 self.Enter() 2595 self.AddName(n.name) 2596 enter_n = _Enter( 2597 n, 2598 self._name, 2599 is_constant=False, 2600 parallel_iterations=self._parallel_iterations, 2601 name="f_count") 2602 self.loop_enters.append(enter_n) 2603 2604 merge_n = merge([enter_n, enter_n])[0] 2605 switch_n = switch(merge_n, self._pivot) 2606 2607 index = math_ops.add(switch_n[1], 1) 2608 next_n = _NextIteration(index) 2609 merge_n.op._update_input(1, next_n) 2610 2611 total_iterations = exit(switch_n[0], name="f_count") 2612 self.loop_exits.append(total_iterations) 2613 self.ExitResult([total_iterations]) 2614 self.Exit() 2615 return total_iterations, next_n 2616 2617 def AddBackpropLoopCounter(self, count, outer_grad_state): 2618 """Add the backprop loop that controls the iterations. 2619 2620 This is added to the backprop loop. It is used to control the loop 2621 termination of the backprop loop. Called in the outer context of 2622 this grad context. 2623 2624 The pseudocode is: 2625 `n = count; while (n >= 1) { n--; }` 2626 2627 Note that a control dependency is added to `final_zero` to ensure the 2628 correct execution order of stack pop ops. 2629 2630 Args: 2631 count: The number of iterations for backprop. 2632 outer_grad_state: The outer grad state. None if not nested. 2633 2634 Returns: 2635 The loop index. 2636 """ 2637 one = constant_op.constant(1, name="b_count") 2638 2639 self.Enter() 2640 self.AddName(count.name) 2641 enter_count = _Enter( 2642 count, 2643 self._name, 2644 is_constant=False, 2645 parallel_iterations=self._parallel_iterations, 2646 name="b_count") 2647 self.loop_enters.append(enter_count) 2648 2649 merge_count = merge([enter_count, enter_count])[0] 2650 self._pivot_for_pred = merge_count 2651 2652 pred = math_ops.greater_equal(merge_count, one) 2653 self._pivot = loop_cond(pred, name="b_count") 2654 switch_count = switch(merge_count, self._pivot) 2655 2656 index = math_ops.subtract(switch_count[1], one) 2657 self._pivot_for_body = index 2658 next_count = _NextIteration(index) 2659 merge_count.op._update_input(1, next_count) 2660 2661 final_zero = exit(switch_count[0], name="b_count") 2662 self.loop_exits.append(final_zero) 2663 if outer_grad_state is not None: 2664 # Force the stack pops of i-th execution of an inner loop to be ordered 2665 # before the pops of (i+1)-th execution of the same inner loop. 2666 # pylint: disable=protected-access 2667 outer_grad_state.grad_sync._add_control_input(final_zero.op) 2668 # pylint: enable=protected-access 2669 2670 self.ExitResult([final_zero]) 2671 self.Exit() 2672 return next_count 2673 2674 def AddBackpropAccumulator(self, op, grad): 2675 """Add an accumulation loop for every loop invariant. 2676 2677 This is added to the backprop loop. It is used to accumulate partial 2678 gradients within each loop iteration. Called when in the gradient while 2679 context. 2680 2681 The pseudocode is: 2682 ``` 2683 acc = 0.0; 2684 while (_pivot) { 2685 acc += grad; 2686 } 2687 ``` 2688 2689 Args: 2690 op: The Enter op for a loop invariant. 2691 grad: The partial gradient of an iteration for a loop invariant. 2692 2693 Returns: 2694 The gradient for a loop invariant. 2695 """ 2696 self.Exit() 2697 # Create a zeros tensor with the right shape for acc. If we don't 2698 # know the full shape statically, we will have to get the shape 2699 # dynamically from the forward inference. Getting the shape right 2700 # for the zeros is only needed for the base case when the loop exits 2701 # without running any iterations. 2702 shape = grad.get_shape() 2703 if shape.is_fully_defined(): 2704 if self.outer_context: 2705 self.outer_context.Enter() 2706 acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc") 2707 if self.outer_context: 2708 self.outer_context.Exit() 2709 else: 2710 value = op.inputs[0] 2711 if (isinstance(self.outer_context, WhileContext) and 2712 self.outer_context.grad_state is not None): 2713 # We are in a nested while loop. 2714 forward_ctxt = self.grad_state.forward_context 2715 forward_ctxt.outer_context.Enter() 2716 zeros_shape = array_ops.shape_internal(value, optimize=False) 2717 forward_ctxt.outer_context.Exit() 2718 outer_grad_state = self.grad_state.outer_grad_state 2719 history_zeros_shape = outer_grad_state.AddForwardAccumulator( 2720 zeros_shape) 2721 self.outer_context.Enter() 2722 real_shape = outer_grad_state.AddBackpropAccumulatedValue( 2723 history_zeros_shape, zeros_shape) 2724 acc = array_ops.zeros(real_shape, grad.dtype) 2725 self.outer_context.Exit() 2726 else: 2727 if self.outer_context: 2728 self.outer_context.Enter() 2729 zeros_shape = array_ops.shape_internal(value, optimize=False) 2730 acc = array_ops.zeros(zeros_shape, grad.dtype) 2731 if self.outer_context: 2732 self.outer_context.Exit() 2733 2734 self.Enter() 2735 self.AddName(acc.name) 2736 enter_acc = _Enter( 2737 acc, 2738 self._name, 2739 is_constant=False, 2740 parallel_iterations=self._parallel_iterations, 2741 name="b_acc") 2742 self.loop_enters.append(enter_acc) 2743 2744 merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0] 2745 switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot) 2746 2747 add_acc = math_ops.add(switch_acc_true, grad) 2748 next_acc = _NextIteration(add_acc) 2749 merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access 2750 2751 result_acc = exit(switch_acc_false, name="b_acc") 2752 self.loop_exits.append(result_acc) 2753 self.ExitResult([result_acc]) 2754 return result_acc 2755 2756 def AddBackpropIndexedSlicesAccumulator(self, op, grad): 2757 """This is used for accumulating gradients that are IndexedSlices. 2758 2759 This is essentially the equivalent of AddBackpropAccumulator but optimized 2760 for things like updating embeddings from within a while loop. 2761 2762 Args: 2763 op: The Enter op for a loop invariant. 2764 grad: The partial gradients represented as an IndexedSlices. 2765 2766 Returns: 2767 The accumulated IndexedSlices gradient of the loop invariant. 2768 """ 2769 values = grad.values 2770 indices = grad.indices 2771 dense_shape = grad.dense_shape 2772 2773 self.Exit() 2774 if self.outer_context: 2775 self.outer_context.Enter() 2776 if values.get_shape().is_fully_defined(): 2777 values_shape = tensor_shape.TensorShape( 2778 [tensor_shape.Dimension(1)] + values.get_shape().dims[1:]) 2779 if self.outer_context: 2780 self.outer_context.Enter() 2781 values_acc = constant_op.constant( 2782 0, values.dtype, shape=values_shape, name="b_acc") 2783 if self.outer_context: 2784 self.outer_context.Exit() 2785 else: 2786 values_shape = _resource_safe_shape(op.inputs[0])[1:] 2787 values_shape = array_ops.concat([[1], values_shape], 0) 2788 values_acc = array_ops.zeros(values_shape, dtype=values.dtype) 2789 indices_acc = constant_op.constant([0], indices.dtype) 2790 shape_acc = None 2791 if dense_shape is not None: 2792 if dense_shape.get_shape().is_fully_defined(): 2793 if self.outer_context: 2794 self.outer_context.Enter() 2795 shape_acc = constant_op.constant( 2796 0, dense_shape.dtype, shape=dense_shape.get_shape()) 2797 if self.outer_context: 2798 self.outer_context.Exit() 2799 else: 2800 shape_acc = array_ops.zeros_like( 2801 array_ops.shape_internal(op.inputs[0], optimize=False), 2802 optimize=False) 2803 2804 if self.outer_context: 2805 self.outer_context.Exit() 2806 2807 self.Enter() 2808 self.AddName(values_acc.name) 2809 self.AddName(indices_acc.name) 2810 init_acc = [indices_acc, values_acc] 2811 if shape_acc is not None: 2812 self.AddName(shape_acc.name) 2813 init_acc.append(shape_acc) 2814 2815 # Set use_input_shape=False since the accumulator tensors will grow in 2816 # size. If use_input_shape=True, the _update_input call below will result in 2817 # incompatible shapes. 2818 enter_acc = [ 2819 _Enter( 2820 x, 2821 self._name, 2822 is_constant=False, 2823 parallel_iterations=self._parallel_iterations, 2824 use_input_shape=False, 2825 name="b_acc") for x in init_acc 2826 ] 2827 # Manually set appropriate partial shapes. 2828 enter_acc[0].set_shape([None]) 2829 if values_acc.shape.dims is not None: 2830 enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:]) 2831 self.loop_enters.extend(enter_acc) 2832 2833 merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc] 2834 switch_acc = [switch(x, self._pivot) for x in merge_acc] 2835 2836 # The actual accumulation. 2837 acc_indexed_slices = [ 2838 array_ops.concat([xa[1], xv], 0) 2839 for xa, xv in zip(switch_acc[:2], [indices, values]) 2840 ] 2841 if shape_acc is not None: 2842 # For the shape we just keep the maximum 2843 acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1])) 2844 2845 next_acc = [_NextIteration(x) for x in acc_indexed_slices] 2846 for xm, xn in zip(merge_acc, next_acc): 2847 xm.op._update_input(1, xn) # pylint: disable=protected-access 2848 2849 exit_acc = [exit(x[0], name="b_acc") for x in switch_acc] 2850 self.loop_exits.extend(exit_acc) 2851 2852 self.ExitResult(exit_acc) 2853 return ops.IndexedSlices( 2854 indices=exit_acc[0], 2855 values=exit_acc[1], 2856 dense_shape=exit_acc[2] if shape_acc is not None else None) 2857 2858 def _InitializeValues(self, values): 2859 """Makes the values known to this context.""" 2860 self._values = set() 2861 for x in values: 2862 if isinstance(x, ops.Tensor): 2863 self._values.add(x.name) 2864 else: 2865 self._values.add(x.values.name) 2866 self._values.add(x.indices.name) 2867 if isinstance(x, ops.IndexedSlices): 2868 dense_shape = x.dense_shape 2869 elif isinstance(x, sparse_tensor.SparseTensor): 2870 dense_shape = x.dense_shape 2871 else: 2872 raise TypeError("Type %s not supported" % type(x)) 2873 if dense_shape is not None: 2874 self._values.add(dense_shape.name) 2875 2876 def _BuildLoop(self, pred, body, original_loop_vars, loop_vars, 2877 shape_invariants): 2878 """Core: Add the loop termination condition and body to the graph.""" 2879 flat_loop_vars = nest.flatten(original_loop_vars) 2880 2881 # Let the context know the loop variables so the loop variables 2882 # would be added in the outer contexts properly. 2883 self._InitializeValues(loop_vars) 2884 real_vars = loop_vars 2885 if self._outer_context: 2886 real_vars = [self._outer_context.AddValue(x) for x in loop_vars] 2887 with ops.control_dependencies(None): 2888 enter_vars = [ 2889 _Enter( 2890 x, 2891 self._name, 2892 is_constant=False, 2893 parallel_iterations=self._parallel_iterations, 2894 use_input_shape=(shape_invariants is None)) for x in real_vars 2895 ] 2896 for x in enter_vars: 2897 x.graph.prevent_feeding(x) 2898 if self._outer_context: 2899 self._outer_context.AddInnerOp(x.op) 2900 2901 # Finds the closest enclosing non-None control pivot. 2902 outer_context = self._outer_context 2903 control_pivot = None 2904 while outer_context is not None and control_pivot is None: 2905 control_pivot = outer_context.GetControlPivot() 2906 # pylint: disable=protected-access 2907 outer_context = outer_context._outer_context 2908 # pylint: enable=protected-access 2909 2910 if control_pivot is not None: 2911 for var in enter_vars: 2912 if util.IsLoopConstantEnter(var.op.inputs[0].op): 2913 # pylint: disable=protected-access 2914 var.op._add_control_input(control_pivot.op) 2915 # pylint: enable=protected-access 2916 _SetShapeInvariants(real_vars, enter_vars, shape_invariants) 2917 2918 # Fix the control inputs and control flow context of these enter ops. 2919 self._FixControlInputsAndContext(enter_vars) 2920 self._InitializeValues(enter_vars) 2921 self._loop_enters = enter_vars 2922 2923 merge_vars = [merge([x, x])[0] for x in enter_vars] 2924 self._pivot_for_pred = merge_vars[0] 2925 2926 # Build the graph for pred. 2927 merge_vars_with_tensor_arrays = ( 2928 _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars)) 2929 packed_vars = nest.pack_sequence_as( 2930 structure=original_loop_vars, 2931 flat_sequence=merge_vars_with_tensor_arrays) 2932 c = ops.convert_to_tensor(pred(*packed_vars)) 2933 self._pivot = loop_cond(c, name="LoopCond") 2934 switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars] 2935 2936 # Build the graph for body. 2937 vars_for_body = [_Identity(x[1]) for x in switch_vars] 2938 self._pivot_for_body = vars_for_body[0] 2939 # Convert TensorArray flow variables inside the context back into 2940 # their associated TensorArrays for calling the body. 2941 vars_for_body_with_tensor_arrays = ( 2942 _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body)) 2943 packed_vars_for_body = nest.pack_sequence_as( 2944 structure=original_loop_vars, 2945 flat_sequence=vars_for_body_with_tensor_arrays) 2946 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2947 body_result = body(*packed_vars_for_body) 2948 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2949 if not nest.is_sequence(body_result): 2950 body_result = [body_result] 2951 if len(post_summaries) > len(pre_summaries): 2952 new_summaries = post_summaries[len(pre_summaries):] 2953 summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2954 summary_ref[:] = pre_summaries 2955 with ops.control_dependencies(new_summaries): 2956 2957 def map_fn(x): 2958 # TODO(apassos) figure out how to trigger with tensor arrays as well 2959 if isinstance(x, tensor_array_ops.TensorArray): 2960 return x 2961 return array_ops.identity(x) 2962 2963 body_result = nest.map_structure(map_fn, body_result) 2964 2965 # Compare the structure types of input and output of body. 2966 # For backwards compatibility, the first layer is forced to a list 2967 # during this comparison, because inputs are typically lists and 2968 # outputs of the body are typically tuples. 2969 nest.assert_same_structure(list(packed_vars_for_body), list(body_result)) 2970 2971 # Store body_result to keep track of TensorArrays returned by body 2972 original_body_result = body_result 2973 # Convert TensorArrays returned by body into their flow variables 2974 result = nest.map_structure(_convert_tensorarray_to_flow, 2975 nest.flatten(body_result)) 2976 result = ops.convert_n_to_tensor_or_indexed_slices(result) 2977 2978 # Add NextIteration and the back edges to complete the loop. 2979 if len(merge_vars) != len(result): 2980 raise ValueError("Number of inputs and outputs of body must match " 2981 "loop_vars: %d, %d" % (len(merge_vars), len(result))) 2982 next_vars = [] 2983 for m, v in zip(merge_vars, result): 2984 next_vars.append(_AddNextAndBackEdge(m, v)) 2985 2986 # Add the exit ops. 2987 exit_vars = [exit(x[0]) for x in switch_vars] 2988 self._loop_exits = exit_vars 2989 2990 # Exit the loop. 2991 self.ExitResult(exit_vars) 2992 2993 return original_body_result, exit_vars 2994 2995 def BuildLoop(self, pred, body, loop_vars, shape_invariants): 2996 """Add the loop termination condition and body to the graph.""" 2997 2998 # Keep original_loop_vars to identify which are TensorArrays 2999 original_loop_vars = loop_vars 3000 # Convert TensorArrays to their flow variables 3001 loop_vars = nest.map_structure(_convert_tensorarray_to_flow, 3002 nest.flatten(loop_vars)) 3003 loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars) 3004 try: 3005 self.Enter() 3006 original_body_result, exit_vars = self._BuildLoop( 3007 pred, body, original_loop_vars, loop_vars, shape_invariants) 3008 finally: 3009 self.Exit() 3010 3011 flat_result = nest.flatten(original_body_result) 3012 # Convert TensorArray flow variables outside the context back into 3013 # their associated TensorArrays for returning to caller. 3014 exit_vars_with_tensor_arrays = ( 3015 _convert_flows_to_tensorarrays(flat_result, exit_vars)) 3016 packed_exit_vars = nest.pack_sequence_as( 3017 structure=original_body_result, 3018 flat_sequence=exit_vars_with_tensor_arrays) 3019 return (packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars) 3020 3021 def _FixControlInputsAndContext(self, enters): 3022 graph = ops.get_default_graph() 3023 # pylint: disable=protected-access 3024 for e in enters: 3025 if isinstance(e, ops.Tensor): 3026 xs = [e] 3027 else: 3028 if not isinstance(e, (ops.IndexedSlices, sparse_tensor.SparseTensor)): 3029 raise TypeError("Type %s not supported" % type(e)) 3030 xs = [e.values, e.indices] 3031 shape = e.dense_shape 3032 if shape is not None: 3033 xs.append(shape) 3034 for x in xs: 3035 inp_op = x.op.inputs[0].op 3036 control_inputs = graph._control_dependencies_for_inputs([inp_op]) 3037 outer_control_inputs = [ 3038 op for op in control_inputs if self._IsInOuterContext(op) 3039 ] 3040 x.op._set_control_flow_context(self) 3041 x.op._add_control_inputs(outer_control_inputs) 3042 graph._record_op_seen_by_control_dependencies(x.op) 3043 # pylint: enable=protected-access 3044 3045 def IsWhileContext(self): 3046 return True 3047 3048 3049 # pylint: disable=redefined-outer-name 3050 @tf_export("while_loop") 3051 def while_loop(cond, 3052 body, 3053 loop_vars, 3054 shape_invariants=None, 3055 parallel_iterations=10, 3056 back_prop=True, 3057 swap_memory=False, 3058 name=None, 3059 maximum_iterations=None): 3060 """Repeat `body` while the condition `cond` is true. 3061 3062 `cond` is a callable returning a boolean scalar tensor. `body` is a callable 3063 returning a (possibly nested) tuple, namedtuple or list of tensors of the same 3064 arity (length and structure) and types as `loop_vars`. `loop_vars` is a 3065 (possibly nested) tuple, namedtuple or list of tensors that is passed to both 3066 `cond` and `body`. `cond` and `body` both take as many arguments as there are 3067 `loop_vars`. 3068 3069 In addition to regular Tensors or IndexedSlices, the body may accept and 3070 return TensorArray objects. The flows of the TensorArray objects will 3071 be appropriately forwarded between loops and during gradient calculations. 3072 3073 Note that `while_loop` calls `cond` and `body` *exactly once* (inside the 3074 call to `while_loop`, and not at all during `Session.run()`). `while_loop` 3075 stitches together the graph fragments created during the `cond` and `body` 3076 calls with some additional graph nodes to create the graph flow that 3077 repeats `body` until `cond` returns false. 3078 3079 For correctness, `tf.while_loop()` strictly enforces shape invariants for 3080 the loop variables. A shape invariant is a (possibly partial) shape that 3081 is unchanged across the iterations of the loop. An error will be raised 3082 if the shape of a loop variable after an iteration is determined to be more 3083 general than or incompatible with its shape invariant. For example, a shape 3084 of [11, None] is more general than a shape of [11, 17], and [11, 21] is not 3085 compatible with [11, 17]. By default (if the argument `shape_invariants` is 3086 not specified), it is assumed that the initial shape of each tensor in 3087 `loop_vars` is the same in every iteration. The `shape_invariants` argument 3088 allows the caller to specify a less specific shape invariant for each loop 3089 variable, which is needed if the shape varies between iterations. The 3090 @{tf.Tensor.set_shape} 3091 function may also be used in the `body` function to indicate that 3092 the output loop variable has a particular shape. The shape invariant for 3093 SparseTensor and IndexedSlices are treated specially as follows: 3094 3095 a) If a loop variable is a SparseTensor, the shape invariant must be 3096 TensorShape([r]) where r is the rank of the dense tensor represented 3097 by the sparse tensor. It means the shapes of the three tensors of the 3098 SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here 3099 is the shape of the SparseTensor.dense_shape property. It must be the shape of 3100 a vector. 3101 3102 b) If a loop variable is an IndexedSlices, the shape invariant must be 3103 a shape invariant of the values tensor of the IndexedSlices. It means 3104 the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]], 3105 [shape.ndims]). 3106 3107 `while_loop` implements non-strict semantics, enabling multiple iterations 3108 to run in parallel. The maximum number of parallel iterations can be 3109 controlled by `parallel_iterations`, which gives users some control over 3110 memory consumption and execution order. For correct programs, `while_loop` 3111 should return the same result for any parallel_iterations > 0. 3112 3113 For training, TensorFlow stores the tensors that are produced in the 3114 forward inference and are needed in back propagation. These tensors are a 3115 main source of memory consumption and often cause OOM errors when training 3116 on GPUs. When the flag swap_memory is true, we swap out these tensors from 3117 GPU to CPU. This for example allows us to train RNN models with very long 3118 sequences and large batches. 3119 3120 Args: 3121 cond: A callable that represents the termination condition of the loop. 3122 body: A callable that represents the loop body. 3123 loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array, 3124 `Tensor`, and `TensorArray` objects. 3125 shape_invariants: The shape invariants for the loop variables. 3126 parallel_iterations: The number of iterations allowed to run in parallel. 3127 It must be a positive integer. 3128 back_prop: Whether backprop is enabled for this while loop. 3129 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 3130 name: Optional name prefix for the returned tensors. 3131 maximum_iterations: Optional maximum number of iterations of the while loop 3132 to run. If provided, the `cond` output is AND-ed with an additional 3133 condition ensuring the number of iterations executed is no greater than 3134 `maximum_iterations`. 3135 3136 Returns: 3137 The output tensors for the loop variables after the loop. When the length 3138 of `loop_vars` is 1 this is a Tensor, TensorArray or IndexedSlice and when 3139 the length of `loop_vars` is greater than 1 it returns a list. 3140 3141 Raises: 3142 TypeError: if `cond` or `body` is not callable. 3143 ValueError: if `loop_vars` is empty. 3144 3145 Example: 3146 3147 ```python 3148 i = tf.constant(0) 3149 c = lambda i: tf.less(i, 10) 3150 b = lambda i: tf.add(i, 1) 3151 r = tf.while_loop(c, b, [i]) 3152 ``` 3153 3154 Example with nesting and a namedtuple: 3155 3156 ```python 3157 import collections 3158 Pair = collections.namedtuple('Pair', 'j, k') 3159 ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2))) 3160 c = lambda i, p: i < 10 3161 b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k))) 3162 ijk_final = tf.while_loop(c, b, ijk_0) 3163 ``` 3164 3165 Example using shape_invariants: 3166 3167 ```python 3168 i0 = tf.constant(0) 3169 m0 = tf.ones([2, 2]) 3170 c = lambda i, m: i < 10 3171 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] 3172 tf.while_loop( 3173 c, b, loop_vars=[i0, m0], 3174 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) 3175 ``` 3176 3177 Example which demonstrates non-strict semantics: In the following 3178 example, the final value of the counter `i` does not depend on `x`. So 3179 the `while_loop` can increment the counter parallel to updates of `x`. 3180 However, because the loop counter at one loop iteration depends 3181 on the value at the previous iteration, the loop counter itself cannot 3182 be incremented in parallel. Hence if we just want the final value of the 3183 counter (which we print on the line `print(sess.run(i))`), then 3184 `x` will never be incremented, but the counter will be updated on a 3185 single thread. Conversely, if we want the value of the output (which we 3186 print on the line `print(sess.run(out).shape)`), then the counter may be 3187 incremented on its own thread, while `x` can be incremented in 3188 parallel on a separate thread. In the extreme case, it is conceivable 3189 that the thread incrementing the counter runs until completion before 3190 `x` is incremented even a single time. The only thing that can never 3191 happen is that the thread updating `x` can never get ahead of the 3192 counter thread because the thread incrementing `x` depends on the value 3193 of the counter. 3194 ```python 3195 import tensorflow as tf 3196 3197 n = 10000 3198 x = tf.constant(list(range(n))) 3199 c = lambda i, x: i < n 3200 b = lambda i, x: (tf.Print(i + 1, [i]), tf.Print(x + 1, [i], "x:")) 3201 i, out = tf.while_loop(c, b, (0, x)) 3202 with tf.Session() as sess: 3203 print(sess.run(i)) # prints [0] ... [9999] 3204 3205 # The following line may increment the counter and x in parallel. 3206 # The counter thread may get ahead of the other thread, but not the 3207 # other way around. So you may see things like 3208 # [9996] x:[9987] 3209 # meaning that the counter thread is on iteration 9996, 3210 # while the other thread is on iteration 9987 3211 print(sess.run(out).shape) 3212 ``` 3213 3214 """ 3215 with ops.name_scope(name, "while", loop_vars): 3216 if not loop_vars: 3217 raise ValueError("No loop variables provided") 3218 if not callable(cond): 3219 raise TypeError("cond must be callable.") 3220 if not callable(body): 3221 raise TypeError("body must be callable.") 3222 if parallel_iterations < 1: 3223 raise TypeError("parallel_iterations must be a positive integer.") 3224 3225 if maximum_iterations is not None: 3226 maximum_iterations = ops.convert_to_tensor( 3227 maximum_iterations, name="maximum_iterations") 3228 if maximum_iterations.shape.ndims != 0: 3229 raise ValueError("maximum_iterations must be a scalar, saw shape: %s" % 3230 maximum_iterations.shape) 3231 3232 counter = constant_op.constant( 3233 0, dtype=maximum_iterations.dtype, name="iteration_counter") 3234 orig_cond = cond 3235 orig_body = body 3236 if len(loop_vars) == 1: 3237 loop_vars = (counter, loop_vars[0]) 3238 cond = lambda i, lv: ( # pylint: disable=g-long-lambda 3239 math_ops.logical_and(i < maximum_iterations, orig_cond(lv))) 3240 body = lambda i, lv: (i + 1, orig_body(lv)) 3241 else: 3242 loop_vars = (counter, loop_vars) 3243 cond = lambda i, lv: ( # pylint: disable=g-long-lambda 3244 math_ops.logical_and(i < maximum_iterations, orig_cond(*lv))) 3245 body = lambda i, lv: (i + 1, orig_body(*lv)) 3246 3247 if context.in_eager_mode(): 3248 while cond(*loop_vars): 3249 loop_vars = body(*loop_vars) 3250 if maximum_iterations is not None: 3251 return loop_vars[1] 3252 else: 3253 return loop_vars 3254 3255 if shape_invariants is not None: 3256 if maximum_iterations is not None: 3257 shape_invariants = (tensor_shape.TensorShape([]), shape_invariants) 3258 nest.assert_same_structure(loop_vars, shape_invariants) 3259 3260 loop_context = WhileContext( 3261 maximum_iterations=maximum_iterations, 3262 parallel_iterations=parallel_iterations, 3263 back_prop=back_prop, 3264 swap_memory=swap_memory) 3265 # Only add non-nested loops to the collection. Any nested control flow will 3266 # be encapsulated in the root context. 3267 # TODO(b/72868227): enable condition once the corresponding 3268 # control_flow.proto changes have been checked in (they aren't checked in 3269 # and this is disabled for now to ensure forwards compatibility). 3270 if True or loop_context.outer_context is None: 3271 ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context) 3272 result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants) 3273 if maximum_iterations is not None: 3274 return result[1] 3275 else: 3276 return result 3277 3278 3279 # pylint: enable=redefined-outer-name 3280 3281 3282 def _AsTensorList(x, p): 3283 """Return x as a list of Tensors or IndexedSlices. 3284 3285 For entries of `x` that are Operations, this returns an Identity of `p` 3286 with a dependency on the operation. 3287 3288 Args: 3289 x: A Tensor/IndexedSlices/Operation or a list or tuple of them. 3290 p: A Tensor to return for entries in `x` that are Operations. 3291 3292 Returns: 3293 A list of Tensors or IndexedSlices. 3294 """ 3295 if not isinstance(x, (list, _basetuple)): 3296 x = [x] 3297 3298 l = [] 3299 for v in x: 3300 if isinstance(v, ops.Operation): 3301 v = with_dependencies([v], p) 3302 v = ops.convert_to_tensor_or_indexed_slices(v) 3303 if isinstance(v, ops.Tensor): 3304 l.append(array_ops.identity(v)) 3305 else: 3306 l.append( 3307 ops.IndexedSlices( 3308 array_ops.identity(v.values), array_ops.identity(v.indices))) 3309 return l 3310 3311 3312 def _CheckResults(a, b): 3313 assert len(a) == len(b), ( 3314 "Values returned by a() and b() must have the same length.") 3315 for x, y in zip(a, b): 3316 assert x.dtype == y.dtype, ( 3317 "Values returned by a() [%s] and b() [%s] must have " 3318 "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name)) 3319 3320 3321 def with_dependencies(dependencies, output_tensor, name=None): 3322 """Produces the content of `output_tensor` only after `dependencies`. 3323 3324 In some cases, a user may want the output of an operation to be 3325 consumed externally only after some other dependencies have run 3326 first. This function ensures returns `output_tensor`, but only after all 3327 operations in `dependencies` have run. Note that this means that there is 3328 no guarantee that `output_tensor` will be evaluated after any `dependencies` 3329 have run. 3330 3331 See also @{tf.tuple$tuple} and @{tf.group$group}. 3332 3333 Args: 3334 dependencies: Iterable of operations to run before this op finishes. 3335 output_tensor: A `Tensor` or `IndexedSlices` that will be returned. 3336 name: (Optional) A name for this operation. 3337 3338 Returns: 3339 Same as `output_tensor`. 3340 3341 Raises: 3342 TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`. 3343 """ 3344 if context.in_eager_mode(): 3345 return output_tensor 3346 with ops.name_scope(name, "control_dependency", 3347 list(dependencies) + [output_tensor]) as name: 3348 with ops.colocate_with(output_tensor): 3349 with ops.control_dependencies(dependencies): 3350 output_tensor = ops.convert_to_tensor_or_indexed_slices(output_tensor) 3351 if isinstance(output_tensor, ops.Tensor): 3352 return _Identity(output_tensor, name=name) 3353 else: 3354 return ops.IndexedSlices( 3355 _Identity(output_tensor.values, name=name), output_tensor.indices, 3356 output_tensor.dense_shape) 3357 3358 3359 def _GroupControlDeps(dev, deps, name=None): 3360 with ops.control_dependencies(deps): 3361 if dev is None: 3362 return no_op(name=name) 3363 else: 3364 with ops.device(dev): 3365 return no_op(name=name) 3366 3367 3368 # TODO(touts): Accept "inputs" as a list. 3369 @tf_export("group") 3370 def group(*inputs, **kwargs): 3371 """Create an op that groups multiple operations. 3372 3373 When this op finishes, all ops in `inputs` have finished. This op has no 3374 output. 3375 3376 See also @{tf.tuple$tuple} and 3377 @{tf.control_dependencies$control_dependencies}. 3378 3379 Args: 3380 *inputs: Zero or more tensors to group. 3381 name: A name for this operation (optional). 3382 3383 Returns: 3384 An Operation that executes all its inputs. 3385 3386 Raises: 3387 ValueError: If an unknown keyword argument is provided. 3388 """ 3389 if context.in_eager_mode(): 3390 return None 3391 name = kwargs.pop("name", None) 3392 if kwargs: 3393 raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys())) 3394 with ops.name_scope(name, "group_deps", inputs) as name: 3395 # Grouping no inputs means do nothing 3396 if not inputs: 3397 return no_op(name=name) 3398 3399 # Sorts *inputs according to their devices. 3400 ops_on_device = {} # device -> operations specified on the device. 3401 for inp in nest.flatten(inputs): 3402 if not hasattr(inp, "device"): 3403 raise TypeError("Expected tf.group() expected Tensor arguments not " 3404 "'%s' with type '%s'" % (inp, type(inp))) 3405 if not hasattr(inp, "device"): 3406 if isinstance(inp, list): 3407 raise TypeError("To call tf.group() with a list, use " 3408 "tf.group(*[...]) not tf.group([...]).") 3409 raise TypeError("Expected tf.group() expected Tensor arguments not " 3410 "'%s' with type '%s'" % (inp, type(inp))) 3411 dev = inp.device 3412 if dev in ops_on_device: 3413 ops_on_device[dev].append(inp) 3414 else: 3415 ops_on_device[dev] = [inp] 3416 if len(ops_on_device) == 1: 3417 # 1-level tree. The root node is the returned NoOp node. 3418 (dev, deps), = ops_on_device.items() 3419 return _GroupControlDeps(dev, deps, name=name) 3420 3421 # 2-level tree. The root node is the returned NoOp node. 3422 # deps contains 1 NoOp node for each device. 3423 deps = [] 3424 3425 def device_key(dev): 3426 """A sort key that allows None to be compared to strings.""" 3427 return "" if dev is None else dev 3428 3429 for dev in sorted(six.iterkeys(ops_on_device), key=device_key): 3430 deps.append(_GroupControlDeps(dev, ops_on_device[dev])) 3431 3432 with ops.control_dependencies(deps): 3433 return no_op(name=name) 3434 3435 3436 @tf_export("tuple") 3437 def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined-builtin 3438 """Group tensors together. 3439 3440 This creates a tuple of tensors with the same values as the `tensors` 3441 argument, except that the value of each tensor is only returned after the 3442 values of all tensors have been computed. 3443 3444 `control_inputs` contains additional ops that have to finish before this op 3445 finishes, but whose outputs are not returned. 3446 3447 This can be used as a "join" mechanism for parallel computations: all the 3448 argument tensors can be computed in parallel, but the values of any tensor 3449 returned by `tuple` are only available after all the parallel computations 3450 are done. 3451 3452 See also @{tf.group$group} and 3453 @{tf.control_dependencies$control_dependencies}. 3454 3455 Args: 3456 tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`. 3457 name: (optional) A name to use as a `name_scope` for the operation. 3458 control_inputs: List of additional ops to finish before returning. 3459 3460 Returns: 3461 Same as `tensors`. 3462 3463 Raises: 3464 ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`. 3465 TypeError: If `control_inputs` is not a list of `Operation` or `Tensor` 3466 objects. 3467 3468 """ 3469 if context.in_eager_mode(): 3470 return tensors 3471 with ops.name_scope(name, "tuple", tensors) as name: 3472 gating_ops = [t.op for t in tensors if t is not None] 3473 if control_inputs: 3474 for c in control_inputs: 3475 if isinstance(c, ops.Tensor): 3476 c = c.op 3477 elif not isinstance(c, ops.Operation): 3478 raise TypeError("Control input must be Operation or Tensor: %s" % c) 3479 gating_ops.append(c) 3480 # Note that in order to ensure ordering in the pbtxt, we must take care to 3481 # ensure the order here. 3482 gating_ops = sorted(set(gating_ops), key=lambda op: op._id) # Uniquify ops. 3483 if not gating_ops: 3484 raise ValueError("Must have at least one Tensor: %s" % tensors) 3485 gate = group(*gating_ops) 3486 tpl = [] 3487 for t in tensors: 3488 if t is not None: 3489 tpl.append(with_dependencies([gate], t)) 3490 else: 3491 tpl.append(None) 3492 return tpl 3493 3494 3495 def _assert_at_most_n_true(predicates, n, msg): 3496 """Returns an Assert op that checks that at most n predicates are True. 3497 3498 Args: 3499 predicates: list of bool scalar tensors. 3500 n: maximum number of true predicates allowed. 3501 msg: Error message. 3502 """ 3503 preds_c = array_ops.stack(predicates, name="preds_c") 3504 num_true_conditions = math_ops.reduce_sum( 3505 math_ops.cast(preds_c, dtypes.int32), name="num_true_conds") 3506 condition = math_ops.less_equal(num_true_conditions, 3507 constant_op.constant(n, name="n_true_conds")) 3508 preds_names = ", ".join(getattr(p, "name", "?") for p in predicates) 3509 error_msg = [ 3510 "%s: more than %d conditions (%s) evaluated as True:" % 3511 (msg, n, preds_names), preds_c 3512 ] 3513 return Assert(condition, data=error_msg, summarize=len(predicates)) 3514 3515 3516 def _case_create_default_action(predicates, actions): 3517 """Creates default action for a list of actions and their predicates. 3518 3519 It uses the input actions to select an arbitrary as default and makes sure 3520 that corresponding predicates have valid values. 3521 3522 Args: 3523 predicates: a list of bool scalar tensors 3524 actions: a list of callable objects which return tensors. 3525 3526 Returns: 3527 a callable 3528 """ 3529 k = len(predicates) - 1 # could pick any 3530 predicate, action = predicates[k], actions[k] 3531 other_predicates, other_actions = predicates[:k], actions[:k] 3532 3533 def default_action(): 3534 others_msg = ("Implementation error: " 3535 "selected default action #%d was called, but some of other " 3536 "predicates are True: " % k) 3537 default_msg = ("Input error: " 3538 "None of conditions evaluated as True:", 3539 array_ops.stack(predicates, name="preds_c")) 3540 with ops.control_dependencies([ 3541 _assert_at_most_n_true(other_predicates, n=0, msg=others_msg), 3542 Assert(predicate, data=default_msg) 3543 ]): 3544 return action() 3545 3546 return default_action, other_predicates, other_actions 3547 3548 3549 def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name): 3550 """Verifies input arguments for the case function. 3551 3552 Args: 3553 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a 3554 callable which returns a list of tensors. 3555 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3556 name: A name for the case operation. 3557 3558 Raises: 3559 TypeError: If `pred_fn_pairs` is not a list/dictionary. 3560 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3561 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3562 callable. 3563 3564 Returns: 3565 a tuple <list of scalar bool tensors, list of callables>. 3566 """ 3567 if not isinstance(pred_fn_pairs, (list, _basetuple, dict)): 3568 raise TypeError("fns must be a list, tuple, or dict") 3569 3570 if isinstance(pred_fn_pairs, collections.OrderedDict): 3571 pred_fn_pairs = pred_fn_pairs.items() 3572 elif isinstance(pred_fn_pairs, dict): 3573 pred_fn_pairs = sorted(pred_fn_pairs.items(), key=lambda item: item[0].name) 3574 if not exclusive: 3575 logging.warn("%s: An unordered dictionary of predicate/fn pairs was " 3576 "provided, but exclusive=False. The order of conditional " 3577 "tests is deterministic but not guaranteed.", name) 3578 for pred_fn_pair in pred_fn_pairs: 3579 if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2: 3580 raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple") 3581 pred, fn = pred_fn_pair 3582 if pred.dtype != dtypes.bool: 3583 raise TypeError("pred must be of type bool: %s", pred.name) 3584 if not callable(fn): 3585 raise TypeError("fn for pred %s must be callable." % pred.name) 3586 predicates, actions = zip(*pred_fn_pairs) 3587 return predicates, actions 3588 3589 3590 @tf_export("case") 3591 def case(pred_fn_pairs, 3592 default=None, 3593 exclusive=False, 3594 strict=False, 3595 name="case"): 3596 """Create a case operation. 3597 3598 The `pred_fn_pairs` parameter is a dict or list of pairs of size N. 3599 Each pair contains a boolean scalar tensor and a python callable that 3600 creates the tensors to be returned if the boolean evaluates to True. 3601 `default` is a callable generating a list of tensors. All the callables 3602 in `pred_fn_pairs` as well as `default` (if provided) should return the same 3603 number and types of tensors. 3604 3605 If `exclusive==True`, all predicates are evaluated, and an exception is 3606 thrown if more than one of the predicates evaluates to `True`. 3607 If `exclusive==False`, execution stops at the first predicate which 3608 evaluates to True, and the tensors generated by the corresponding function 3609 are returned immediately. If none of the predicates evaluate to True, this 3610 operation returns the tensors generated by `default`. 3611 3612 `tf.case` supports nested structures as implemented in 3613 `tensorflow.python.util.nest`. All of the callables must return the same 3614 (possibly nested) value structure of lists, tuples, and/or named tuples. 3615 Singleton lists and tuples form the only exceptions to this: when returned by 3616 a callable, they are implicitly unpacked to single values. This 3617 behavior is disabled by passing `strict=True`. 3618 3619 If an unordered dictionary is used for `pred_fn_pairs`, the order of the 3620 conditional tests is not guaranteed. However, the order is guaranteed to be 3621 deterministic, so that variables created in conditional branches are created 3622 in fixed order across runs. 3623 3624 **Example 1:** 3625 3626 Pseudocode: 3627 3628 ``` 3629 if (x < y) return 17; 3630 else return 23; 3631 ``` 3632 3633 Expressions: 3634 3635 ```python 3636 f1 = lambda: tf.constant(17) 3637 f2 = lambda: tf.constant(23) 3638 r = case([(tf.less(x, y), f1)], default=f2) 3639 ``` 3640 3641 **Example 2:** 3642 3643 Pseudocode: 3644 3645 ``` 3646 if (x < y && x > z) raise OpError("Only one predicate may evaluate true"); 3647 if (x < y) return 17; 3648 else if (x > z) return 23; 3649 else return -1; 3650 ``` 3651 3652 Expressions: 3653 3654 ```python 3655 def f1(): return tf.constant(17) 3656 def f2(): return tf.constant(23) 3657 def f3(): return tf.constant(-1) 3658 r = case({tf.less(x, y): f1, tf.greater(x, z): f2}, 3659 default=f3, exclusive=True) 3660 ``` 3661 3662 Args: 3663 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a 3664 callable which returns a list of tensors. 3665 default: Optional callable that returns a list of tensors. 3666 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3667 strict: A boolean that enables/disables 'strict' mode; see above. 3668 name: A name for this operation (optional). 3669 3670 Returns: 3671 The tensors returned by the first pair whose predicate evaluated to True, or 3672 those returned by `default` if none does. 3673 3674 Raises: 3675 TypeError: If `pred_fn_pairs` is not a list/dictionary. 3676 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3677 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3678 callable. 3679 """ 3680 predicates, actions = _case_verify_and_canonicalize_args( 3681 pred_fn_pairs, exclusive, name) 3682 with ops.name_scope(name, "case", [predicates]): 3683 if default is None: 3684 default, predicates, actions = _case_create_default_action( 3685 predicates, actions) 3686 fn = default 3687 # To eval conditions in direct order we create nested conditions in reverse: 3688 # cond(c[0], true_fn=.., false_fn=cond(c[1], ...)) 3689 for predicate, action in reversed(list(zip(predicates, actions))): 3690 fn = functools.partial( 3691 cond, predicate, true_fn=action, false_fn=fn, strict=strict) 3692 if exclusive: 3693 with ops.control_dependencies([ 3694 _assert_at_most_n_true( 3695 predicates, n=1, msg="Input error: exclusive=True") 3696 ]): 3697 return fn() 3698 else: 3699 return fn() 3700 3701 3702 class XLAControlFlowContext(ControlFlowContext): 3703 """Base class for XLA and TPU control flow contexts.""" 3704 3705 def __init__(self): 3706 super(XLAControlFlowContext, self).__init__() 3707 self._name = "XLAControlFlowContext" 3708 3709 def IsXLAContext(self): 3710 return True 3711 3712 def AddOp(self, _): 3713 pass 3714 3715 def AddValue(self, x): 3716 return x 3717 3718 3719 def from_control_flow_context_def(context_def, import_scope=None): 3720 """Deserializes `context_def` into the appropriate ControlFlowContext. 3721 3722 Args: 3723 context_def: ControlFlowContextDef proto 3724 import_scope: Optional `string`. Name scope to add. 3725 3726 Returns: 3727 A ControlFlowContext subclass 3728 """ 3729 if context_def.HasField("cond_ctxt"): 3730 return CondContext.from_proto(context_def.cond_ctxt, 3731 import_scope=import_scope) 3732 if context_def.HasField("while_ctxt"): 3733 return WhileContext.from_proto(context_def.while_ctxt, 3734 import_scope=import_scope) 3735 raise NotImplementedError("Unknown ControlFlowContextDef field: %s" 3736 % context_def.WhichOneof("ctxt")) 3737 3738 3739 ops.register_proto_function( 3740 ops.GraphKeys.COND_CONTEXT, 3741 proto_type=control_flow_pb2.CondContextDef, 3742 to_proto=CondContext.to_proto, 3743 from_proto=CondContext.from_proto) 3744 3745 ops.register_proto_function( 3746 ops.GraphKeys.WHILE_CONTEXT, 3747 proto_type=control_flow_pb2.WhileContextDef, 3748 to_proto=WhileContext.to_proto, 3749 from_proto=WhileContext.from_proto) 3750