1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Classes and functions used to construct graphs.""" 16 # pylint: disable=g-bad-name 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import collections 22 import copy 23 import linecache 24 import os 25 import re 26 import sys 27 import threading 28 29 import numpy as np 30 import six 31 from six.moves import xrange # pylint: disable=redefined-builtin 32 33 from tensorflow.core.framework import attr_value_pb2 34 from tensorflow.core.framework import function_pb2 35 from tensorflow.core.framework import graph_pb2 36 from tensorflow.core.framework import node_def_pb2 37 from tensorflow.core.framework import op_def_pb2 38 from tensorflow.core.framework import versions_pb2 39 from tensorflow.core.protobuf import config_pb2 40 from tensorflow.python import pywrap_tensorflow as c_api 41 from tensorflow.python.eager import context 42 from tensorflow.python.eager import core 43 from tensorflow.python.eager import tape 44 from tensorflow.python.framework import c_api_util 45 from tensorflow.python.framework import device as pydev 46 from tensorflow.python.framework import dtypes 47 from tensorflow.python.framework import errors 48 from tensorflow.python.framework import op_def_registry 49 from tensorflow.python.framework import registry 50 from tensorflow.python.framework import tensor_shape 51 from tensorflow.python.framework import versions 52 from tensorflow.python.ops import control_flow_util 53 from tensorflow.python.platform import app 54 from tensorflow.python.platform import tf_logging as logging 55 from tensorflow.python.util import compat 56 from tensorflow.python.util import decorator_utils 57 from tensorflow.python.util import tf_contextlib 58 from tensorflow.python.util.tf_export import tf_export 59 60 61 # Temporary global switch determining if we should enable the work-in-progress 62 # calls to the C API. Currently disabled by default but can be manually enabled 63 # in code or via the environment variable. This will be removed once all 64 # functionality is supported and there's no performance penalty with it enabled. 65 _USE_C_API = os.getenv("TF_C_API_GRAPH_CONSTRUCTION", "0") is not "0" 66 67 68 def tensor_id(tensor): 69 """Returns a unique identifier for this Tensor.""" 70 return tensor._id # pylint: disable=protected-access 71 72 73 class _NullContextmanager(object): 74 75 def __enter__(self): 76 pass 77 78 def __exit__(self, type_arg, value_arg, traceback_arg): 79 return False # False values do not suppress exceptions 80 81 82 def _override_helper(clazz_object, operator, func): 83 """Overrides (string) operator on Tensors to call func. 84 85 Args: 86 clazz_object: the class to override for; either Tensor or SparseTensor. 87 operator: the string name of the operator to override. 88 func: the function that replaces the overridden operator. 89 90 Raises: 91 ValueError: If operator has already been overwritten, 92 or if operator is not allowed to be overwritten. 93 """ 94 existing = getattr(clazz_object, operator, None) 95 if existing is not None: 96 # Check to see if this is a default method-wrapper or slot wrapper which 97 # will be true for the comparison operators. 98 if not isinstance(existing, type(object.__lt__)): 99 raise ValueError("operator %s cannot be overwritten again on class %s." % 100 (operator, clazz_object)) 101 if operator not in Tensor.OVERLOADABLE_OPERATORS: 102 raise ValueError("Overriding %s is disallowed" % operator) 103 setattr(clazz_object, operator, func) 104 105 106 def _as_graph_element(obj): 107 """Convert `obj` to a graph element if possible, otherwise return `None`. 108 109 Args: 110 obj: Object to convert. 111 112 Returns: 113 The result of `obj._as_graph_element()` if that method is available; 114 otherwise `None`. 115 """ 116 conv_fn = getattr(obj, "_as_graph_element", None) 117 if conv_fn and callable(conv_fn): 118 return conv_fn() 119 return None 120 121 122 _TENSOR_LIKE_TYPES = tuple() 123 124 125 def is_dense_tensor_like(t): 126 """EXPERIMENTAL: Returns true if `t` implements the tensor interface. 127 128 See `register_dense_tensor_like_type()` for the current definition of a 129 "tensor-like type". 130 131 Args: 132 t: An object. 133 134 Returns: 135 True iff `t` is an instance of one of the registered "tensor-like" types. 136 """ 137 return isinstance(t, _TENSOR_LIKE_TYPES) 138 139 140 def register_dense_tensor_like_type(tensor_type): 141 """EXPERIMENTAL: Registers `tensor_type` as implementing the tensor interface. 142 143 A "tensor-like type" can represent a single dense tensor, and implements 144 the `name` and `dtype` properties. 145 146 Args: 147 tensor_type: A type implementing the tensor interface. 148 149 Raises: 150 TypeError: If `tensor_type` does not implement the tensor interface. 151 """ 152 try: 153 if not isinstance(tensor_type.name, property): 154 raise TypeError("Type %s does not define a `name` property" % 155 tensor_type.__name__) 156 except AttributeError: 157 raise TypeError("Type %s does not define a `name` property" % 158 tensor_type.__name__) 159 try: 160 if not isinstance(tensor_type.dtype, property): 161 raise TypeError("Type %s does not define a `dtype` property" % 162 tensor_type.__name__) 163 except AttributeError: 164 raise TypeError("Type %s does not define a `dtype` property" % 165 tensor_type.__name__) 166 # We expect this list to be small, so choose quadratic complexity 167 # for registration, so that we have a tuple that can be used for 168 # more efficient `isinstance` checks later. 169 global _TENSOR_LIKE_TYPES 170 _TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type]) 171 172 173 def uid(): 174 """A unique (within this program execution) integer.""" 175 return c_api.TFE_Py_UID() 176 177 178 def numpy_text(tensor, is_repr=False): 179 """Human readable representation of a tensor's numpy value.""" 180 if tensor.dtype.is_numpy_compatible: 181 text = repr(tensor.numpy()) if is_repr else str(tensor.numpy()) 182 else: 183 text = "<unprintable>" 184 if "\n" in text: 185 text = "\n" + text 186 return text 187 188 189 # NOTE(ebrevdo): Do not subclass this. If you do, I will break you on purpose. 190 class _TensorLike(object): 191 """Internal cls for grouping Tensor, SparseTensor, ..., for is_instance.""" 192 pass 193 194 195 @tf_export("Tensor") 196 class Tensor(_TensorLike): 197 """Represents one of the outputs of an `Operation`. 198 199 A `Tensor` is a symbolic handle to one of the outputs of an 200 `Operation`. It does not hold the values of that operation's output, 201 but instead provides a means of computing those values in a 202 TensorFlow @{tf.Session}. 203 204 This class has two primary purposes: 205 206 1. A `Tensor` can be passed as an input to another `Operation`. 207 This builds a dataflow connection between operations, which 208 enables TensorFlow to execute an entire `Graph` that represents a 209 large, multi-step computation. 210 211 2. After the graph has been launched in a session, the value of the 212 `Tensor` can be computed by passing it to 213 @{tf.Session.run}. 214 `t.eval()` is a shortcut for calling 215 `tf.get_default_session().run(t)`. 216 217 In the following example, `c`, `d`, and `e` are symbolic `Tensor` 218 objects, whereas `result` is a numpy array that stores a concrete 219 value: 220 221 ```python 222 # Build a dataflow graph. 223 c = tf.constant([[1.0, 2.0], [3.0, 4.0]]) 224 d = tf.constant([[1.0, 1.0], [0.0, 1.0]]) 225 e = tf.matmul(c, d) 226 227 # Construct a `Session` to execute the graph. 228 sess = tf.Session() 229 230 # Execute the graph and store the value that `e` represents in `result`. 231 result = sess.run(e) 232 ``` 233 """ 234 235 # List of Python operators that we allow to override. 236 OVERLOADABLE_OPERATORS = { 237 # Binary. 238 "__add__", 239 "__radd__", 240 "__sub__", 241 "__rsub__", 242 "__mul__", 243 "__rmul__", 244 "__div__", 245 "__rdiv__", 246 "__truediv__", 247 "__rtruediv__", 248 "__floordiv__", 249 "__rfloordiv__", 250 "__mod__", 251 "__rmod__", 252 "__lt__", 253 "__le__", 254 "__gt__", 255 "__ge__", 256 "__and__", 257 "__rand__", 258 "__or__", 259 "__ror__", 260 "__xor__", 261 "__rxor__", 262 "__getitem__", 263 "__pow__", 264 "__rpow__", 265 # Unary. 266 "__invert__", 267 "__neg__", 268 "__abs__", 269 "__matmul__", 270 "__rmatmul__" 271 } 272 273 def __init__(self, op, value_index, dtype): 274 """Creates a new `Tensor`. 275 276 Args: 277 op: An `Operation`. `Operation` that computes this tensor. 278 value_index: An `int`. Index of the operation's endpoint that produces 279 this tensor. 280 dtype: A `DType`. Type of elements stored in this tensor. 281 282 Raises: 283 TypeError: If the op is not an `Operation`. 284 """ 285 if not isinstance(op, Operation): 286 raise TypeError("op needs to be an Operation: %s" % op) 287 self._op = op 288 self._value_index = value_index 289 self._dtype = dtypes.as_dtype(dtype) 290 self._shape_val = tensor_shape.unknown_shape() 291 # List of operations that use this Tensor as input. We maintain this list 292 # to easily navigate a computation graph. 293 self._consumers = [] 294 295 # Attributes used for C++ shape inference. Not inspected, only forwarded. 296 # If set, will be a HandleData object from cpp_shape_inference.proto. 297 self._handle_data = None 298 self._id = uid() 299 300 @property 301 def op(self): 302 """The `Operation` that produces this tensor as an output.""" 303 return self._op 304 305 @property 306 def dtype(self): 307 """The `DType` of elements in this tensor.""" 308 return self._dtype 309 310 @property 311 def graph(self): 312 """The `Graph` that contains this tensor.""" 313 return self._op.graph 314 315 @property 316 def name(self): 317 """The string name of this tensor.""" 318 if not self._op.name: 319 raise ValueError("Operation was not named: %s" % self._op) 320 return "%s:%d" % (self._op.name, self._value_index) 321 322 @property 323 def device(self): 324 """The name of the device on which this tensor will be produced, or None.""" 325 return self._op.device 326 327 @property 328 def shape(self): 329 """Returns the `TensorShape` that represents the shape of this tensor. 330 331 The shape is computed using shape inference functions that are 332 registered in the Op for each `Operation`. See 333 @{tf.TensorShape} 334 for more details of what a shape represents. 335 336 The inferred shape of a tensor is used to provide shape 337 information without having to launch the graph in a session. This 338 can be used for debugging, and providing early error messages. For 339 example: 340 341 ```python 342 c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 343 344 print(c.shape) 345 ==> TensorShape([Dimension(2), Dimension(3)]) 346 347 d = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]]) 348 349 print(d.shape) 350 ==> TensorShape([Dimension(4), Dimension(2)]) 351 352 # Raises a ValueError, because `c` and `d` do not have compatible 353 # inner dimensions. 354 e = tf.matmul(c, d) 355 356 f = tf.matmul(c, d, transpose_a=True, transpose_b=True) 357 358 print(f.shape) 359 ==> TensorShape([Dimension(3), Dimension(4)]) 360 ``` 361 362 In some cases, the inferred shape may have unknown dimensions. If 363 the caller has additional information about the values of these 364 dimensions, `Tensor.set_shape()` can be used to augment the 365 inferred shape. 366 367 Returns: 368 A `TensorShape` representing the shape of this tensor. 369 370 """ 371 if _USE_C_API: 372 graph = self._op._graph._c_graph # pylint: disable=protected-access 373 with errors.raise_exception_on_not_ok_status() as status: 374 num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output(), 375 status) 376 if num_dims == -1: 377 dim_list = None 378 else: 379 with errors.raise_exception_on_not_ok_status() as status: 380 dim_list = c_api.TF_GraphGetTensorShape_wrapper( 381 graph, self._as_tf_output(), num_dims, status) 382 dim_list = [None if i == -1 else i for i in dim_list] 383 return tensor_shape.TensorShape(dim_list) 384 return self._shape_val 385 386 @property 387 def _shape(self): 388 logging.warning("Tensor._shape is private, use Tensor.shape " 389 "instead. Tensor._shape will eventually be removed.") 390 return self.shape 391 392 @_shape.setter 393 def _shape(self, value): 394 raise ValueError( 395 "Tensor._shape cannot be assigned, use Tensor.set_shape instead.") 396 397 def __iter__(self): 398 if context.in_graph_mode(): 399 raise TypeError( 400 "`Tensor` objects are not iterable when eager execution is not " 401 "enabled. To iterate over this tensor use `tf.map_fn`.") 402 shape = self._shape_tuple() 403 if shape is None: 404 raise TypeError("Cannot iterate over a tensor with unknown shape.") 405 if not shape: 406 raise TypeError("Cannot iterate over a scalar tensor.") 407 if shape[0] is None: 408 raise TypeError( 409 "Cannot iterate over a tensor with unknown first dimension.") 410 for i in xrange(shape[0]): 411 yield self[i] 412 413 def _shape_as_list(self): 414 if self.shape.ndims is not None: 415 return [dim.value for dim in self.shape.dims] 416 else: 417 return None 418 419 def _shape_tuple(self): 420 shape = self._shape_as_list() 421 if shape is None: 422 return None 423 return tuple(shape) 424 425 def _rank(self): 426 """Integer rank of this Tensor, if known, else None. 427 428 Returns: 429 Integer rank or None 430 """ 431 return self.shape.ndims 432 433 def get_shape(self): 434 """Alias of Tensor.shape.""" 435 return self.shape 436 437 def set_shape(self, shape): 438 """Updates the shape of this tensor. 439 440 This method can be called multiple times, and will merge the given 441 `shape` with the current shape of this tensor. It can be used to 442 provide additional information about the shape of this tensor that 443 cannot be inferred from the graph alone. For example, this can be used 444 to provide additional information about the shapes of images: 445 446 ```python 447 _, image_data = tf.TFRecordReader(...).read(...) 448 image = tf.image.decode_png(image_data, channels=3) 449 450 # The height and width dimensions of `image` are data dependent, and 451 # cannot be computed without executing the op. 452 print(image.shape) 453 ==> TensorShape([Dimension(None), Dimension(None), Dimension(3)]) 454 455 # We know that each image in this dataset is 28 x 28 pixels. 456 image.set_shape([28, 28, 3]) 457 print(image.shape) 458 ==> TensorShape([Dimension(28), Dimension(28), Dimension(3)]) 459 ``` 460 461 Args: 462 shape: A `TensorShape` representing the shape of this tensor, a 463 `TensorShapeProto`, a list, a tuple, or None. 464 465 Raises: 466 ValueError: If `shape` is not compatible with the current shape of 467 this tensor. 468 """ 469 if not _USE_C_API: 470 self._shape_val = self._shape_val.merge_with(shape) 471 return 472 if not isinstance(shape, tensor_shape.TensorShape): 473 shape = tensor_shape.TensorShape(shape) 474 dim_list = [] 475 if shape.dims is None: 476 unknown_shape = True 477 else: 478 unknown_shape = False 479 for dim in shape.dims: 480 if dim.value is None: 481 dim_list.append(-1) 482 else: 483 dim_list.append(dim.value) 484 try: 485 with errors.raise_exception_on_not_ok_status() as status: 486 c_api.TF_GraphSetTensorShape_wrapper( 487 self._op._graph._c_graph, # pylint: disable=protected-access 488 self._as_tf_output(), 489 dim_list, 490 unknown_shape, 491 status) 492 except errors.InvalidArgumentError as e: 493 # Convert to ValueError for backwards compatibility. 494 raise ValueError(str(e)) 495 496 @property 497 def value_index(self): 498 """The index of this tensor in the outputs of its `Operation`.""" 499 return self._value_index 500 501 def consumers(self): 502 """Returns a list of `Operation`s that consume this tensor. 503 504 Returns: 505 A list of `Operation`s. 506 """ 507 if self._op._c_op: # pylint: disable=protected-access 508 consumer_names = c_api.TF_OperationOutputConsumers_wrapper( 509 self._as_tf_output()) 510 # pylint: disable=protected-access 511 return [ 512 self.graph._get_operation_by_name_unsafe(name) 513 for name in consumer_names 514 ] 515 # pylint: enable=protected-access 516 else: 517 return self._consumers 518 519 def _add_consumer(self, consumer): 520 """Add a consumer to this tensor. 521 522 Args: 523 consumer: an Operation. 524 525 Raises: 526 TypeError: if the consumer is not an Operation. 527 """ 528 # pylint: disable=protected-access 529 assert not self._op._c_op, "Tensor._add_consumer doesn't work with C API" 530 # pylint: enable=protected-access 531 if not isinstance(consumer, Operation): 532 raise TypeError("Consumer must be an Operation: %s" % consumer) 533 self._consumers.append(consumer) 534 535 def _as_node_def_input(self): 536 """Return a value to use for the NodeDef "input" attribute. 537 538 The returned string can be used in a NodeDef "input" attribute 539 to indicate that the NodeDef uses this Tensor as input. 540 541 Raises: 542 ValueError: if this Tensor's Operation does not have a name. 543 544 Returns: 545 a string. 546 """ 547 if not self._op.name: 548 raise ValueError("Operation was not named: %s" % self._op) 549 if self._value_index == 0: 550 return self._op.name 551 else: 552 return "%s:%d" % (self._op.name, self._value_index) 553 554 def _as_tf_output(self): 555 # pylint: disable=protected-access 556 assert self.op._c_op 557 return c_api_util.tf_output(self.op._c_op, self.value_index) 558 # pylint: enable=protected-access 559 560 def __str__(self): 561 return "Tensor(\"%s\"%s%s%s)" % ( 562 self.name, (", shape=%s" % self.get_shape()) 563 if self.get_shape().ndims is not None else "", 564 (", dtype=%s" % self._dtype.name) 565 if self._dtype else "", (", device=%s" % self.device) 566 if self.device else "") 567 568 def __repr__(self): 569 return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(), 570 self._dtype.name) 571 572 def __hash__(self): 573 # Necessary to support Python's collection membership operators 574 return id(self) 575 576 def __eq__(self, other): 577 # Necessary to support Python's collection membership operators 578 return id(self) == id(other) 579 580 # NOTE(mrry): This enables the Tensor's overloaded "right" binary 581 # operators to run when the left operand is an ndarray, because it 582 # accords the Tensor class higher priority than an ndarray, or a 583 # numpy matrix. 584 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 585 # mechanism, which allows more control over how Tensors interact 586 # with ndarrays. 587 __array_priority__ = 100 588 589 @staticmethod 590 def _override_operator(operator, func): 591 _override_helper(Tensor, operator, func) 592 593 def __bool__(self): 594 """Dummy method to prevent a tensor from being used as a Python `bool`. 595 596 This overload raises a `TypeError` when the user inadvertently 597 treats a `Tensor` as a boolean (e.g. in an `if` statement). For 598 example: 599 600 ```python 601 if tf.constant(True): # Will raise. 602 # ... 603 604 if tf.constant(5) < tf.constant(7): # Will raise. 605 # ... 606 ``` 607 608 This disallows ambiguities between testing the Python value vs testing the 609 dynamic condition of the `Tensor`. 610 611 Raises: 612 `TypeError`. 613 """ 614 raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. " 615 "Use `if t is not None:` instead of `if t:` to test if a " 616 "tensor is defined, and use TensorFlow ops such as " 617 "tf.cond to execute subgraphs conditioned on the value of " 618 "a tensor.") 619 620 def __nonzero__(self): 621 """Dummy method to prevent a tensor from being used as a Python `bool`. 622 623 This is the Python 2.x counterpart to `__bool__()` above. 624 625 Raises: 626 `TypeError`. 627 """ 628 raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. " 629 "Use `if t is not None:` instead of `if t:` to test if a " 630 "tensor is defined, and use TensorFlow ops such as " 631 "tf.cond to execute subgraphs conditioned on the value of " 632 "a tensor.") 633 634 def eval(self, feed_dict=None, session=None): 635 """Evaluates this tensor in a `Session`. 636 637 Calling this method will execute all preceding operations that 638 produce the inputs needed for the operation that produces this 639 tensor. 640 641 *N.B.* Before invoking `Tensor.eval()`, its graph must have been 642 launched in a session, and either a default session must be 643 available, or `session` must be specified explicitly. 644 645 Args: 646 feed_dict: A dictionary that maps `Tensor` objects to feed values. 647 See @{tf.Session.run} for a 648 description of the valid feed values. 649 session: (Optional.) The `Session` to be used to evaluate this tensor. If 650 none, the default session will be used. 651 652 Returns: 653 A numpy array corresponding to the value of this tensor. 654 655 """ 656 return _eval_using_default_session(self, feed_dict, self.graph, session) 657 658 659 # TODO(agarwal): consider getting rid of this. 660 class _EagerTensorBase(Tensor): 661 """Base class for EagerTensor.""" 662 663 @property 664 def dtype(self): 665 # Note: using the intern table directly here as this is 666 # performance-sensitive in some models. 667 return dtypes._INTERN_TABLE[self._datatype_enum()] # pylint: disable=protected-access 668 669 def numpy(self): 670 """Returns a numpy array or a scalar with the same contents as the Tensor. 671 672 TODO(ashankar,agarwal): Perhaps this should NOT reference the underlying 673 buffer but instead always explicitly copy? Note that currently it may or may 674 not copy based on whether the numpy data is properly aligned or not. 675 676 Returns: 677 A numpy array or a scalar. Numpy array may share memory with the 678 Tensor object. Any changes to one may be reflected in the other. A scalar 679 value is returned when self has rank 0. 680 681 Raises: 682 ValueError: if the type of this Tensor is not representable in numpy. 683 """ 684 if self.dtype == dtypes.resource: 685 raise ValueError("Resource handles are not convertible to numpy.") 686 return self.cpu()._numpy() # pylint: disable=protected-access 687 688 # __int__ and __float__ may copy the tensor to CPU and 689 # only work for scalars; values are cast as per numpy. 690 def __int__(self): 691 return int(self.numpy()) 692 693 def __float__(self): 694 return float(self.numpy()) 695 696 def __array__(self, dtype=None): 697 return np.array(self.numpy(), dtype=dtype) 698 699 def __format__(self, format_spec): 700 return self.numpy().__format__(format_spec) 701 702 def _numpy(self): 703 raise NotImplementedError() 704 705 def __copy__(self): 706 # Eager Tensors are immutable so it's safe to return themselves as a copy. 707 return self 708 709 def __deepcopy__(self, memo): 710 # Eager Tensors are immutable so it's safe to return themselves as a copy. 711 del memo 712 return self 713 714 def _datatype_enum(self): 715 raise NotImplementedError() 716 717 def _shape_tuple(self): 718 """The shape of this Tensor, as a tuple. 719 720 This is more performant than tuple(shape().as_list()) as it avoids 721 two list and one object creation. Marked private for now as from an API 722 perspective, it would be better to have a single performant way of 723 getting a shape rather than exposing shape() and shape_tuple() 724 (and heaven forbid, shape_list() etc. as well!). Punting on that for now, 725 but ideally one would work things out and remove the need for this method. 726 727 Returns: 728 tuple with the shape. 729 """ 730 raise NotImplementedError() 731 732 def _rank(self): 733 """Integer rank of this Tensor. 734 735 Unlike regular Tensors, the rank is always known for EagerTensors. 736 737 This is more performant than len(self._shape_tuple()) 738 739 Returns: 740 Integer rank 741 """ 742 raise NotImplementedError() 743 744 def _copy_to_device(self, context, device): # pylint: disable=redefined-outer-name 745 raise NotImplementedError() 746 747 def __str__(self): 748 return "tf.Tensor(%s, shape=%s, dtype=%s)" % (numpy_text(self), 749 self.shape, 750 self.dtype.name) 751 752 def __repr__(self): 753 return "<tf.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s>" % ( 754 self._id, self.shape, self.dtype.name, numpy_text(self, is_repr=True)) 755 756 @staticmethod 757 def _override_operator(name, func): 758 setattr(_EagerTensorBase, name, func) 759 760 def _copy(self, ctx=None, device_name=None): 761 """Copies tensor to dest device.""" 762 # pylint: disable=protected-access 763 # Creates a new tensor on the dest device. 764 if ctx is None: 765 ctx = context.context() 766 if device_name is None: 767 device_name = ctx.device_name 768 # pylint: disable=protected-access 769 try: 770 new_tensor = self._copy_to_device(context=ctx._handle, device=device_name) 771 except core._NotOkStatusException as e: 772 six.raise_from(core._status_to_exception(e.code, e.message), None) 773 774 # Record the copy on tape and define backprop copy as well. 775 if not context.in_graph_mode(): 776 self_device = self.device 777 def grad_fun(dresult): 778 return [dresult._copy(device_name=self_device)] 779 tape.record_operation("_copy", [new_tensor], [self], grad_fun) 780 return new_tensor 781 # pylint: enable=protected-access 782 783 @property 784 def shape(self): 785 return tensor_shape.TensorShape(self._shape_tuple()) 786 787 def get_shape(self): 788 """Alias of Tensor.shape.""" 789 return self.shape 790 791 def _shape_as_list(self): 792 """The shape of the tensor as a list.""" 793 return list(self._shape_tuple()) 794 795 @property 796 def ndim(self): 797 """Returns the number of Tensor dimensions.""" 798 return self.shape.ndims 799 800 def cpu(self): 801 """A copy of this Tensor with contents backed by host memory.""" 802 return self._copy(context.context(), "CPU:0") 803 804 def gpu(self, gpu_index=0): 805 """A copy of this Tensor with contents backed by memory on the GPU. 806 807 Arguments: 808 gpu_index: Identifies which GPU to place the contents on the returned 809 Tensor in. 810 811 Returns: 812 A GPU-memory backed Tensor object initialized with the same contents 813 as this Tensor. 814 """ 815 return self._copy(context.context(), "GPU:" + str(gpu_index)) 816 817 def __bool__(self): 818 if self._shape_tuple() != (): # pylint: disable=g-explicit-bool-comparison 819 raise ValueError( 820 "Non-scalar tensor %s cannot be converted to boolean." % repr(self)) 821 if self.dtype != dtypes.bool: 822 raise ValueError( 823 "Non-boolean tensor %s cannot be converted to boolean." % repr(self)) 824 return bool(self.cpu().numpy()) 825 826 def __nonzero__(self): 827 return self.__bool__() 828 829 def set_shape(self, shape): 830 if not self.shape.is_compatible_with(shape): 831 raise ValueError( 832 "EagerTensor's shape %s is not compatible with supplied shape %s" % 833 (self.shape, shape)) 834 835 # Methods not supported / implemented for Eager Tensors. 836 @property 837 def op(self): 838 raise AttributeError("op not supported for Eager Tensors.") 839 840 @property 841 def graph(self): 842 raise AttributeError("graph not supported for Eager Tensors.") 843 844 @property 845 def name(self): 846 raise AttributeError("name not supported for Eager Tensors.") 847 848 @property 849 def value_index(self): 850 raise AttributeError("value_index not supported for Eager Tensors.") 851 852 def consumers(self): 853 raise NotImplementedError("consumers not supported for Eager Tensors.") 854 855 def _add_consumer(self, consumer): 856 raise NotImplementedError("_add_consumer not supported for Eager Tensors.") 857 858 def _as_node_def_input(self): 859 raise NotImplementedError( 860 "_as_node_def_input not supported for Eager Tensors.") 861 862 def _as_tf_output(self): 863 raise NotImplementedError("_as_tf_output not supported for Eager Tensors.") 864 865 def eval(self, feed_dict=None, session=None): 866 raise NotImplementedError("eval not supported for Eager Tensors.") 867 868 869 # This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and 870 # registers it with the current module. 871 EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase) 872 873 874 def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False): 875 _ = name, as_ref 876 if dtype and not dtype.is_compatible_with(t.dtype): 877 raise ValueError( 878 "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % 879 (dtype.name, t.dtype.name, str(t))) 880 return t 881 882 883 _tensor_conversion_func_registry = { 884 0: [(Tensor, _TensorTensorConversionFunction)] 885 } 886 _tensor_conversion_func_cache = {} 887 _tensor_conversion_func_lock = threading.Lock() 888 register_dense_tensor_like_type(Tensor) 889 890 891 @tf_export("convert_to_tensor") 892 def convert_to_tensor(value, dtype=None, name=None, preferred_dtype=None): 893 """Converts the given `value` to a `Tensor`. 894 895 This function converts Python objects of various types to `Tensor` 896 objects. It accepts `Tensor` objects, numpy arrays, Python lists, 897 and Python scalars. For example: 898 899 ```python 900 import numpy as np 901 902 def my_func(arg): 903 arg = tf.convert_to_tensor(arg, dtype=tf.float32) 904 return tf.matmul(arg, arg) + arg 905 906 # The following calls are equivalent. 907 value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]])) 908 value_2 = my_func([[1.0, 2.0], [3.0, 4.0]]) 909 value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) 910 ``` 911 912 This function can be useful when composing a new operation in Python 913 (such as `my_func` in the example above). All standard Python op 914 constructors apply this function to each of their Tensor-valued 915 inputs, which allows those ops to accept numpy arrays, Python lists, 916 and scalars in addition to `Tensor` objects. 917 918 Note: This function diverges from default Numpy behavior for `float` and 919 `string` types when `None` is present in a Python list or scalar. Rather 920 than silently converting `None` values, an error will be thrown. 921 922 Args: 923 value: An object whose type has a registered `Tensor` conversion function. 924 dtype: Optional element type for the returned tensor. If missing, the 925 type is inferred from the type of `value`. 926 name: Optional name to use if a new `Tensor` is created. 927 preferred_dtype: Optional element type for the returned tensor, 928 used when dtype is None. In some cases, a caller may not have a 929 dtype in mind when converting to a tensor, so preferred_dtype 930 can be used as a soft preference. If the conversion to 931 `preferred_dtype` is not possible, this argument has no effect. 932 933 Returns: 934 An `Output` based on `value`. 935 936 Raises: 937 TypeError: If no conversion function is registered for `value`. 938 RuntimeError: If a registered conversion function returns an invalid value. 939 940 """ 941 return internal_convert_to_tensor( 942 value=value, 943 dtype=dtype, 944 name=name, 945 preferred_dtype=preferred_dtype, 946 as_ref=False) 947 948 949 def _error_prefix(name): 950 return "" if name is None else "%s: " % name 951 952 953 def internal_convert_to_tensor(value, 954 dtype=None, 955 name=None, 956 as_ref=False, 957 preferred_dtype=None, 958 ctx=None): 959 """Converts the given `value` to an `Tensor`. 960 961 This function converts Python objects of various types to `Tensor` 962 objects. It accepts `Tensor` objects, numpy arrays, Python lists, 963 and Python scalars. For example: 964 965 This function can be useful when composing a new operation in Python 966 All standard Python op constructors apply this function to each of their 967 Tensor-valued inputs, which allows those ops to accept numpy arrays, Python 968 lists, and scalars in addition to `Tensor` objects. 969 970 Args: 971 value: An object whose type has a registered `Tensor` conversion function. 972 dtype: Optional element type for the returned tensor. If missing, the 973 type is inferred from the type of `value`. 974 name: Optional name to use if a new `Tensor` is created. 975 as_ref: True if we want the mutable view of Variables, if applicable. 976 preferred_dtype: Optional element type for the returned tensor, 977 used when dtype is None. In some cases, a caller may not have a 978 dtype in mind when converting to a tensor, so preferred_dtype 979 can be used as a soft preference. If the conversion to 980 `preferred_dtype` is not possible, this argument has no effect. 981 ctx: Optional: The value of context.context(). 982 983 Returns: 984 A `Tensor` based on `value`. 985 986 Raises: 987 TypeError: If no conversion function is registered for `value`. 988 RuntimeError: If a registered conversion function returns an invalid value. 989 990 """ 991 if ctx is None: ctx = context.context() 992 if ctx.in_eager_mode(): 993 # Fast path for EagerTensors that don't need any conversion. 994 if isinstance(value, EagerTensor): 995 # Note that we don't check that value's dtype matches the dtype 996 # argument. We expect that the C runtime will do that checking 997 # when we execute the kernel. 998 return value 999 1000 if dtype is not None: 1001 dtype = dtypes.as_dtype(dtype) 1002 unwrapped_type = type(value) 1003 conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None) 1004 if conversion_func_list is None: 1005 with _tensor_conversion_func_lock: 1006 conversion_func_list = [] 1007 for _, funcs_at_priority in sorted( 1008 _tensor_conversion_func_registry.items()): 1009 for base_type, conversion_func in funcs_at_priority: 1010 if isinstance(value, base_type): 1011 conversion_func_list.append((base_type, conversion_func)) 1012 _tensor_conversion_func_cache[unwrapped_type] = conversion_func_list 1013 1014 for base_type, conversion_func in conversion_func_list: 1015 # If dtype is None but preferred_dtype is not None, we try to 1016 # cast to preferred_dtype first. 1017 ret = None 1018 if dtype is None and preferred_dtype is not None: 1019 try: 1020 ret = conversion_func( 1021 value, dtype=preferred_dtype, name=name, as_ref=as_ref) 1022 except (TypeError, ValueError, errors.UnimplementedError, 1023 errors.InvalidArgumentError): 1024 # Could not coerce the conversion to use the preferred dtype. 1025 ret = None 1026 1027 if ret is not None and ret is not NotImplemented: 1028 if (ret.dtype.base_dtype != 1029 dtypes.as_dtype(preferred_dtype).base_dtype): 1030 raise TypeError("convert_to_tensor did not convert to " 1031 "the preferred dtype: %s vs %s " % 1032 (ret.dtype.base_dtype, 1033 dtypes.as_dtype(preferred_dtype).base_dtype)) 1034 1035 if ret is None: 1036 ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) 1037 1038 if ret is NotImplemented: 1039 continue 1040 1041 if not isinstance(ret, Tensor): 1042 raise RuntimeError( 1043 "%sConversion function %r for type %s returned non-Tensor: %r" % 1044 (_error_prefix(name), conversion_func, base_type, ret)) 1045 if dtype and not dtype.is_compatible_with(ret.dtype): 1046 raise RuntimeError( 1047 "%sConversion function %r for type %s returned incompatible " 1048 "dtype: requested = %s, actual = %s" % 1049 (_error_prefix(name), conversion_func, base_type, dtype.name, 1050 ret.dtype.name)) 1051 return ret 1052 raise TypeError("%sCannot convert %r with type %s to Tensor: " 1053 "no conversion function registered." % 1054 (_error_prefix(name), value, unwrapped_type)) 1055 1056 1057 def internal_convert_n_to_tensor(values, 1058 dtype=None, 1059 name=None, 1060 as_ref=False, 1061 preferred_dtype=None, 1062 ctx=None): 1063 """Converts `values` to a list of `Tensor` objects. 1064 1065 Args: 1066 values: A list of objects that can be consumed by `tf.convert_to_tensor()`. 1067 dtype: (Optional.) The required `DType` of the returned `Tensor` objects. 1068 name: (Optional.) A name prefix to used when a new `Tensor` is 1069 created, in which case element `i` will be given the name `name 1070 + '_' + i`. 1071 as_ref: True if the caller wants the results as ref tensors. 1072 preferred_dtype: Optional element type for the returned tensors, 1073 used when dtype is None. In some cases, a caller may not have a 1074 dtype in mind when converting to a tensor, so preferred_dtype 1075 can be used as a soft preference. If the conversion to 1076 `preferred_dtype` is not possible, this argument has no effect. 1077 ctx: The value of context.context(). 1078 1079 Returns: 1080 A list of `Tensor` and/or `IndexedSlices` objects. 1081 1082 Raises: 1083 TypeError: If no conversion function is registered for an element in 1084 `values`. 1085 RuntimeError: If a registered conversion function returns an invalid 1086 value. 1087 """ 1088 if not isinstance(values, collections.Sequence): 1089 raise TypeError("values must be a list.") 1090 ret = [] 1091 if ctx is None: ctx = context.context() 1092 for i, value in enumerate(values): 1093 n = None if name is None else "%s_%d" % (name, i) 1094 ret.append( 1095 internal_convert_to_tensor( 1096 value, 1097 dtype=dtype, 1098 name=n, 1099 as_ref=as_ref, 1100 preferred_dtype=preferred_dtype, 1101 ctx=ctx)) 1102 return ret 1103 1104 1105 def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None): 1106 """Converts `values` to a list of `Tensor` objects. 1107 1108 Args: 1109 values: A list of objects that can be consumed by `tf.convert_to_tensor()`. 1110 dtype: (Optional.) The required `DType` of the returned `Tensor` objects. 1111 name: (Optional.) A name prefix to used when a new `Tensor` is 1112 created, in which case element `i` will be given the name `name 1113 + '_' + i`. 1114 preferred_dtype: Optional element type for the returned tensors, 1115 used when dtype is None. In some cases, a caller may not have a 1116 dtype in mind when converting to a tensor, so preferred_dtype 1117 can be used as a soft preference. If the conversion to 1118 `preferred_dtype` is not possible, this argument has no effect. 1119 1120 Returns: 1121 A list of `Tensor` and/or `IndexedSlices` objects. 1122 1123 Raises: 1124 TypeError: If no conversion function is registered for an element in 1125 `values`. 1126 RuntimeError: If a registered conversion function returns an invalid 1127 value. 1128 """ 1129 return internal_convert_n_to_tensor( 1130 values=values, 1131 dtype=dtype, 1132 name=name, 1133 preferred_dtype=preferred_dtype, 1134 as_ref=False) 1135 1136 1137 @tf_export("convert_to_tensor_or_indexed_slices") 1138 def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None): 1139 """Converts the given object to a `Tensor` or an `IndexedSlices`. 1140 1141 If `value` is an `IndexedSlices` or `SparseTensor` it is returned 1142 unmodified. Otherwise, it is converted to a `Tensor` using 1143 `convert_to_tensor()`. 1144 1145 Args: 1146 value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed 1147 by `convert_to_tensor()`. 1148 dtype: (Optional.) The required `DType` of the returned `Tensor` or 1149 `IndexedSlices`. 1150 name: (Optional.) A name to use if a new `Tensor` is created. 1151 1152 Returns: 1153 An `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`. 1154 1155 Raises: 1156 ValueError: If `dtype` does not match the element type of `value`. 1157 """ 1158 return internal_convert_to_tensor_or_indexed_slices( 1159 value=value, dtype=dtype, name=name, as_ref=False) 1160 1161 1162 def internal_convert_to_tensor_or_indexed_slices(value, 1163 dtype=None, 1164 name=None, 1165 as_ref=False): 1166 """Converts the given object to an `Tensor` or an `IndexedSlices`. 1167 1168 If `value` is an `IndexedSlices` or `SparseTensor` it is returned 1169 unmodified. Otherwise, it is converted to a `Tensor` using 1170 `convert_to_tensor()`. 1171 1172 Args: 1173 value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed 1174 by `convert_to_tensor()`. 1175 dtype: (Optional.) The required `DType` of the returned `Tensor` or 1176 `IndexedSlices`. 1177 name: (Optional.) A name to use if a new `Tensor` is created. 1178 as_ref: True if the caller wants the results as ref tensors. 1179 1180 Returns: 1181 An `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`. 1182 1183 Raises: 1184 ValueError: If `dtype` does not match the element type of `value`. 1185 """ 1186 if isinstance(value, _TensorLike): 1187 if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype): 1188 raise ValueError( 1189 "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % 1190 (dtypes.as_dtype(dtype).name, value.dtype.name, str(value))) 1191 return value 1192 else: 1193 return internal_convert_to_tensor( 1194 value, dtype=dtype, name=name, as_ref=as_ref) 1195 1196 1197 def internal_convert_n_to_tensor_or_indexed_slices(values, 1198 dtype=None, 1199 name=None, 1200 as_ref=False): 1201 """Converts `values` to a list of `Tensor` or `IndexedSlices` objects. 1202 1203 Any `IndexedSlices` or `SparseTensor` objects in `values` are returned 1204 unmodified. 1205 1206 Args: 1207 values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that 1208 can be consumed by `convert_to_tensor()`. 1209 dtype: (Optional.) The required `DType` of the returned `Tensor` 1210 `IndexedSlices`. 1211 name: (Optional.) A name prefix to used when a new `Tensor` is 1212 created, in which case element `i` will be given the name `name 1213 + '_' + i`. 1214 as_ref: True if the caller wants the results as ref tensors. 1215 1216 Returns: 1217 A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects. 1218 1219 Raises: 1220 TypeError: If no conversion function is registered for an element in 1221 `values`. 1222 RuntimeError: If a registered conversion function returns an invalid 1223 value. 1224 """ 1225 if not isinstance(values, collections.Sequence): 1226 raise TypeError("values must be a list.") 1227 ret = [] 1228 for i, value in enumerate(values): 1229 if value is None: 1230 ret.append(value) 1231 else: 1232 n = None if name is None else "%s_%d" % (name, i) 1233 ret.append( 1234 internal_convert_to_tensor_or_indexed_slices( 1235 value, dtype=dtype, name=n, as_ref=as_ref)) 1236 return ret 1237 1238 1239 def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None): 1240 """Converts `values` to a list of `Output` or `IndexedSlices` objects. 1241 1242 Any `IndexedSlices` or `SparseTensor` objects in `values` are returned 1243 unmodified. 1244 1245 Args: 1246 values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that 1247 can be consumed by `convert_to_tensor()`. 1248 dtype: (Optional.) The required `DType` of the returned `Tensor` 1249 `IndexedSlices`. 1250 name: (Optional.) A name prefix to used when a new `Tensor` is 1251 created, in which case element `i` will be given the name `name 1252 + '_' + i`. 1253 1254 Returns: 1255 A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects. 1256 1257 Raises: 1258 TypeError: If no conversion function is registered for an element in 1259 `values`. 1260 RuntimeError: If a registered conversion function returns an invalid 1261 value. 1262 """ 1263 return internal_convert_n_to_tensor_or_indexed_slices( 1264 values=values, dtype=dtype, name=name, as_ref=False) 1265 1266 1267 # TODO(josh11b): Add ctx argument to conversion_func() signature. 1268 @tf_export("register_tensor_conversion_function") 1269 def register_tensor_conversion_function(base_type, 1270 conversion_func, 1271 priority=100): 1272 """Registers a function for converting objects of `base_type` to `Tensor`. 1273 1274 The conversion function must have the following signature: 1275 1276 ```python 1277 def conversion_func(value, dtype=None, name=None, as_ref=False): 1278 # ... 1279 ``` 1280 1281 It must return a `Tensor` with the given `dtype` if specified. If the 1282 conversion function creates a new `Tensor`, it should use the given 1283 `name` if specified. All exceptions will be propagated to the caller. 1284 1285 The conversion function may return `NotImplemented` for some 1286 inputs. In this case, the conversion process will continue to try 1287 subsequent conversion functions. 1288 1289 If `as_ref` is true, the function must return a `Tensor` reference, 1290 such as a `Variable`. 1291 1292 NOTE: The conversion functions will execute in order of priority, 1293 followed by order of registration. To ensure that a conversion function 1294 `F` runs before another conversion function `G`, ensure that `F` is 1295 registered with a smaller priority than `G`. 1296 1297 Args: 1298 base_type: The base type or tuple of base types for all objects that 1299 `conversion_func` accepts. 1300 conversion_func: A function that converts instances of `base_type` to 1301 `Tensor`. 1302 priority: Optional integer that indicates the priority for applying this 1303 conversion function. Conversion functions with smaller priority values 1304 run earlier than conversion functions with larger priority values. 1305 Defaults to 100. 1306 1307 Raises: 1308 TypeError: If the arguments do not have the appropriate type. 1309 1310 """ 1311 global _tensor_conversion_func_cache 1312 with _tensor_conversion_func_lock: 1313 if not (isinstance(base_type, type) or 1314 (isinstance(base_type, tuple) and 1315 all(isinstance(x, type) for x in base_type))): 1316 raise TypeError("base_type must be a type or a tuple of types.") 1317 if not callable(conversion_func): 1318 raise TypeError("conversion_func must be callable.") 1319 1320 try: 1321 funcs_at_priority = _tensor_conversion_func_registry[priority] 1322 except KeyError: 1323 funcs_at_priority = [] 1324 _tensor_conversion_func_registry[priority] = funcs_at_priority 1325 funcs_at_priority.append((base_type, conversion_func)) 1326 _tensor_conversion_func_cache = {} 1327 1328 1329 @tf_export("IndexedSlices") 1330 class IndexedSlices(_TensorLike): 1331 """A sparse representation of a set of tensor slices at given indices. 1332 1333 This class is a simple wrapper for a pair of `Tensor` objects: 1334 1335 * `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`. 1336 * `indices`: A 1-D integer `Tensor` with shape `[D0]`. 1337 1338 An `IndexedSlices` is typically used to represent a subset of a larger 1339 tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`. 1340 The values in `indices` are the indices in the first dimension of 1341 the slices that have been extracted from the larger tensor. 1342 1343 The dense tensor `dense` represented by an `IndexedSlices` `slices` has 1344 1345 ```python 1346 dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...] 1347 ``` 1348 1349 The `IndexedSlices` class is used principally in the definition of 1350 gradients for operations that have sparse gradients 1351 (e.g. @{tf.gather}). 1352 1353 Contrast this representation with 1354 @{tf.SparseTensor}, 1355 which uses multi-dimensional indices and scalar values. 1356 """ 1357 1358 def __init__(self, values, indices, dense_shape=None): 1359 """Creates an `IndexedSlices`.""" 1360 _get_graph_from_inputs([values, indices, dense_shape]) 1361 self._values = values 1362 self._indices = indices 1363 self._dense_shape = dense_shape 1364 1365 @property 1366 def values(self): 1367 """A `Tensor` containing the values of the slices.""" 1368 return self._values 1369 1370 @property 1371 def indices(self): 1372 """A 1-D `Tensor` containing the indices of the slices.""" 1373 return self._indices 1374 1375 @property 1376 def dense_shape(self): 1377 """A 1-D `Tensor` containing the shape of the corresponding dense tensor.""" 1378 return self._dense_shape 1379 1380 @property 1381 def name(self): 1382 """The name of this `IndexedSlices`.""" 1383 return self.values.name 1384 1385 @property 1386 def device(self): 1387 """The name of the device on which `values` will be produced, or `None`.""" 1388 return self.values.device 1389 1390 @property 1391 def op(self): 1392 """The `Operation` that produces `values` as an output.""" 1393 return self.values.op 1394 1395 @property 1396 def dtype(self): 1397 """The `DType` of elements in this tensor.""" 1398 return self.values.dtype 1399 1400 @property 1401 def graph(self): 1402 """The `Graph` that contains the values, indices, and shape tensors.""" 1403 return self._values.graph 1404 1405 def __str__(self): 1406 return "IndexedSlices(indices=%s, values=%s%s)" % ( 1407 self._indices, self._values, (", dense_shape=%s" % self._dense_shape) 1408 if self._dense_shape is not None else "") 1409 1410 def __neg__(self): 1411 return IndexedSlices(-self.values, self.indices, self.dense_shape) 1412 1413 1414 IndexedSlicesValue = collections.namedtuple( 1415 "IndexedSlicesValue", ["values", "indices", "dense_shape"]) 1416 1417 1418 def _device_string(dev_spec): 1419 if isinstance(dev_spec, pydev.DeviceSpec): 1420 return dev_spec.to_string() 1421 else: 1422 return dev_spec 1423 1424 1425 def _NodeDef(op_type, name, device=None, attrs=None): # pylint: disable=redefined-outer-name 1426 """Create a NodeDef proto. 1427 1428 Args: 1429 op_type: Value for the "op" attribute of the NodeDef proto. 1430 name: Value for the "name" attribute of the NodeDef proto. 1431 device: string, device, or function from NodeDef to string. 1432 Value for the "device" attribute of the NodeDef proto. 1433 attrs: Optional dictionary where the key is the attribute name (a string) 1434 and the value is the respective "attr" attribute of the NodeDef proto (an 1435 AttrValue). 1436 1437 Returns: 1438 A node_def_pb2.NodeDef protocol buffer. 1439 """ 1440 node_def = node_def_pb2.NodeDef() 1441 node_def.op = compat.as_bytes(op_type) 1442 node_def.name = compat.as_bytes(name) 1443 if attrs is not None: 1444 for k, v in six.iteritems(attrs): 1445 node_def.attr[k].CopyFrom(v) 1446 if device is not None: 1447 if callable(device): 1448 node_def.device = device(node_def) 1449 else: 1450 node_def.device = _device_string(device) 1451 return node_def 1452 1453 1454 # Copied from core/framework/node_def_util.cc 1455 # TODO(mrry,josh11b): Consolidate this validation in C++ code. 1456 _VALID_OP_NAME_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*$") 1457 _VALID_SCOPE_NAME_REGEX = re.compile("^[A-Za-z0-9_.\\-/]*$") 1458 1459 1460 def _create_c_op(graph, node_def, inputs, control_inputs): 1461 """Creates a TF_Operation. 1462 1463 Args: 1464 graph: a `Graph`. 1465 node_def: `node_def_pb2.NodeDef` for the operation to create. 1466 inputs: A list of `Tensor`s (corresponding to scalar inputs) and lists of 1467 `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N", 1468 "list(int64)"). The length of the list should be equal to the number of 1469 inputs specified by this operation's op def. 1470 control_inputs: A list of `Operation`s to set as control dependencies. 1471 1472 Returns: 1473 A wrapped TF_Operation*. 1474 """ 1475 # pylint: disable=protected-access 1476 op_desc = c_api.TF_NewOperation(graph._c_graph, 1477 compat.as_str(node_def.op), 1478 compat.as_str(node_def.name)) 1479 # Add inputs 1480 for op_input in inputs: 1481 if isinstance(op_input, (list, tuple)): 1482 c_api.TF_AddInputList(op_desc, [t._as_tf_output() for t in op_input]) 1483 else: 1484 c_api.TF_AddInput(op_desc, op_input._as_tf_output()) 1485 1486 # Add control inputs 1487 for control_input in control_inputs: 1488 c_api.TF_AddControlInput(op_desc, control_input._c_op) 1489 # pylint: enable=protected-access 1490 1491 # Add attrs 1492 for name, attr_value in node_def.attr.items(): 1493 serialized = attr_value.SerializeToString() 1494 # TODO(skyewm): this creates and deletes a new TF_Status for every attr. 1495 # It might be worth creating a convenient way to re-use the same status. 1496 with errors.raise_exception_on_not_ok_status() as status: 1497 c_api.TF_SetAttrValueProto(op_desc, 1498 compat.as_str(name), serialized, status) 1499 1500 try: 1501 with errors.raise_exception_on_not_ok_status() as status: 1502 c_op = c_api.TF_FinishOperation(op_desc, status) 1503 except errors.InvalidArgumentError as e: 1504 # Convert to ValueError for backwards compatibility. 1505 raise ValueError(str(e)) 1506 1507 return c_op 1508 1509 1510 @tf_export("Operation") 1511 class Operation(object): 1512 """Represents a graph node that performs computation on tensors. 1513 1514 An `Operation` is a node in a TensorFlow `Graph` that takes zero or 1515 more `Tensor` objects as input, and produces zero or more `Tensor` 1516 objects as output. Objects of type `Operation` are created by 1517 calling a Python op constructor (such as 1518 @{tf.matmul}) 1519 or @{tf.Graph.create_op}. 1520 1521 For example `c = tf.matmul(a, b)` creates an `Operation` of type 1522 "MatMul" that takes tensors `a` and `b` as input, and produces `c` 1523 as output. 1524 1525 After the graph has been launched in a session, an `Operation` can 1526 be executed by passing it to 1527 @{tf.Session.run}. 1528 `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. 1529 """ 1530 1531 def __init__(self, 1532 node_def, 1533 g, 1534 inputs=None, 1535 output_types=None, 1536 control_inputs=None, 1537 input_types=None, 1538 original_op=None, 1539 op_def=None): 1540 r"""Creates an `Operation`. 1541 1542 NOTE: This constructor validates the name of the `Operation` (passed 1543 as `node_def.name`). Valid `Operation` names match the following 1544 regular expression: 1545 1546 [A-Za-z0-9.][A-Za-z0-9_.\\-/]* 1547 1548 Args: 1549 node_def: `node_def_pb2.NodeDef`. `NodeDef` for the `Operation`. 1550 Used for attributes of `node_def_pb2.NodeDef`, typically `name`, 1551 `op`, and `device`. The `input` attribute is irrelevant here 1552 as it will be computed when generating the model. 1553 g: `Graph`. The parent graph. 1554 inputs: list of `Tensor` objects. The inputs to this `Operation`. 1555 output_types: list of `DType` objects. List of the types of the 1556 `Tensors` computed by this operation. The length of this list indicates 1557 the number of output endpoints of the `Operation`. 1558 control_inputs: list of operations or tensors from which to have a 1559 control dependency. 1560 input_types: List of `DType` objects representing the 1561 types of the tensors accepted by the `Operation`. By default 1562 uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect 1563 reference-typed inputs must specify these explicitly. 1564 original_op: Optional. Used to associate the new `Operation` with an 1565 existing `Operation` (for example, a replica with the op that was 1566 replicated). 1567 op_def: Optional. The `op_def_pb2.OpDef` proto that describes the 1568 op type that this `Operation` represents. 1569 1570 Raises: 1571 TypeError: if control inputs are not Operations or Tensors, 1572 or if `node_def` is not a `NodeDef`, 1573 or if `g` is not a `Graph`, 1574 or if `inputs` are not tensors, 1575 or if `inputs` and `input_types` are incompatible. 1576 ValueError: if the `node_def` name is not valid. 1577 """ 1578 # For internal use only: `node_def` can be set to a TF_Operation to create 1579 # an Operation for that op. This is useful for creating Operations for ops 1580 # indirectly created by C API methods, e.g. the ops created by 1581 # TF_ImportGraphDef. When `node_def` is a TF_Operation, all optional fields 1582 # should be None. 1583 1584 if isinstance(node_def, node_def_pb2.NodeDef): 1585 if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0: 1586 raise ValueError( 1587 "Cannot create a tensor proto whose content is larger than 2GB.") 1588 if not _VALID_OP_NAME_REGEX.match(node_def.name): 1589 raise ValueError("'%s' is not a valid node name" % node_def.name) 1590 c_op = None 1591 elif type(node_def).__name__ == "SwigPyObject": 1592 assert inputs is None 1593 assert output_types is None 1594 assert control_inputs is None 1595 assert input_types is None 1596 assert original_op is None 1597 assert op_def is None 1598 c_op = node_def 1599 else: 1600 raise TypeError("node_def needs to be a NodeDef: %s" % node_def) 1601 1602 if not isinstance(g, Graph): 1603 raise TypeError("g needs to be a Graph: %s" % g) 1604 self._graph = g 1605 1606 if inputs is None: 1607 inputs = [] 1608 elif not isinstance(inputs, list): 1609 raise TypeError("inputs needs to be a list of Tensors: %s" % inputs) 1610 for a in inputs: 1611 if not isinstance(a, Tensor): 1612 raise TypeError("input needs to be a Tensor: %s" % a) 1613 if input_types is None: 1614 input_types = [i.dtype.base_dtype for i in inputs] 1615 else: 1616 if not all( 1617 x.is_compatible_with(i.dtype) 1618 for i, x in zip(inputs, input_types)): 1619 raise TypeError("In op '%s', input types (%s) are not compatible " 1620 "with expected types (%s)" % 1621 (node_def.name, [i.dtype for i in inputs], 1622 input_types)) 1623 1624 # Build the list of control inputs. 1625 control_input_ops = [] 1626 if control_inputs: 1627 for c in control_inputs: 1628 control_op = None 1629 if isinstance(c, Operation): 1630 control_op = c 1631 elif isinstance(c, (Tensor, IndexedSlices)): 1632 control_op = c.op 1633 else: 1634 raise TypeError("Control input must be an Operation, " 1635 "a Tensor, or IndexedSlices: %s" % c) 1636 control_input_ops.append(control_op) 1637 1638 # Don't set private fields with C API enabled to catch users who need to 1639 # switch to public API. 1640 # TODO(skyewm): delete these fields once we remove _USE_C_API 1641 if not self._graph._c_graph: 1642 self._inputs_val = list(inputs) # Defensive copy. 1643 self._input_types_val = input_types 1644 self._control_inputs_val = control_input_ops 1645 self._node_def_val = copy.deepcopy(node_def) 1646 self._op_def_val = op_def 1647 1648 self._id_value = self._graph._next_id() # pylint: disable=protected-access 1649 self._original_op = original_op 1650 self._traceback = self._graph._extract_stack() # pylint: disable=protected-access 1651 self._control_flow_context = self.graph._get_control_flow_context() # pylint: disable=protected-access 1652 1653 # Initialize self._c_op. 1654 if c_op: 1655 # TODO(skyewm): remove this assert when we remove USE_C_API 1656 assert self._graph._c_graph # pylint: disable=protected-access 1657 self._c_op = c_op 1658 elif self._graph._c_graph: # pylint: disable=protected-access 1659 if op_def is None: 1660 op_def = self._graph._get_op_def(node_def.op) 1661 # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs. 1662 # Refactor so we don't have to do this here. 1663 grouped_inputs = self._reconstruct_sequence_inputs( 1664 op_def, inputs, node_def.attr) 1665 self._c_op = _create_c_op(self._graph, node_def, grouped_inputs, 1666 control_input_ops) 1667 else: 1668 self._c_op = None 1669 1670 # Mark that we consume the inputs. This is unnecessary and unsupported with 1671 # the C API enabled, since the C API tracks the tensor consumers instead. 1672 if not self._c_op: 1673 for input_tensor in self._inputs_val: 1674 input_tensor._add_consumer(self) # pylint: disable=protected-access 1675 1676 # Initialize self._outputs. 1677 if self._c_op: 1678 num_outputs = c_api.TF_OperationNumOutputs(self._c_op) 1679 output_types = [ 1680 c_api.TF_OperationOutputType(c_api_util.tf_output(self._c_op, i)) 1681 for i in range(num_outputs)] 1682 assert output_types is not None 1683 elif output_types is None: 1684 output_types = [] 1685 self._output_types_val = output_types 1686 self._outputs = [ 1687 Tensor(self, i, output_type) 1688 for i, output_type in enumerate(output_types) 1689 ] 1690 1691 if not c_op: 1692 self._control_flow_post_processing() 1693 1694 def _control_flow_post_processing(self): 1695 """Add this op to its control flow context. 1696 1697 This may add new ops and change this op's inputs. self.inputs must be 1698 available before calling this method. 1699 """ 1700 for input_tensor in self.inputs: 1701 control_flow_util.CheckInputFromValidContext(self, input_tensor.op) 1702 if self._control_flow_context is not None: 1703 self._control_flow_context.AddOp(self) 1704 self._recompute_node_def() 1705 1706 def _reconstruct_sequence_inputs(self, op_def, inputs, attrs): 1707 """Regroups a flat list of input tensors into scalar and sequence inputs. 1708 1709 Args: 1710 op_def: The `op_def_pb2.OpDef` (for knowing the input types) 1711 inputs: a list of input `Tensor`s to the op. 1712 attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define 1713 how long each sequence is) 1714 1715 Returns: 1716 A list of `Tensor`s (corresponding to scalar inputs) and lists of 1717 `Tensor`s (corresponding to sequence inputs). 1718 """ 1719 grouped_inputs = [] 1720 i = 0 1721 for input_arg in op_def.input_arg: 1722 if input_arg.number_attr: 1723 input_len = attrs[input_arg.number_attr].i 1724 is_sequence = True 1725 elif input_arg.type_list_attr: 1726 input_len = len(attrs[input_arg.type_list_attr].list.type) 1727 is_sequence = True 1728 else: 1729 input_len = 1 1730 is_sequence = False 1731 1732 if is_sequence: 1733 grouped_inputs.append(inputs[i:i + input_len]) 1734 else: 1735 grouped_inputs.append(inputs[i]) 1736 i += input_len 1737 1738 assert i == len(inputs) 1739 return grouped_inputs 1740 1741 def colocation_groups(self): 1742 """Returns the list of colocation groups of the op.""" 1743 default_colocation_group = [ 1744 compat.as_bytes("loc:@%s" % self.name) 1745 ] 1746 try: 1747 class_attr = self.get_attr("_class") 1748 except ValueError: 1749 # This op has no explicit colocation group, so it is itself its 1750 # own root of a colocation group. 1751 return default_colocation_group 1752 1753 attr_groups = [ 1754 class_name for class_name in class_attr 1755 if class_name.startswith(b"loc:@") 1756 ] 1757 1758 # If there are no colocation groups in the explicit _class field, 1759 # return the default colocation group. 1760 return attr_groups if attr_groups else default_colocation_group 1761 1762 def values(self): 1763 """DEPRECATED: Use outputs.""" 1764 return tuple(self.outputs) 1765 1766 def _get_control_flow_context(self): 1767 """Returns the control flow context of this op. 1768 1769 Returns: 1770 A context object. 1771 """ 1772 return self._control_flow_context 1773 1774 def _set_control_flow_context(self, ctx): 1775 """Sets the current control flow context of this op. 1776 1777 Args: 1778 ctx: a context object. 1779 """ 1780 self._control_flow_context = ctx 1781 1782 @property 1783 def name(self): 1784 """The full name of this operation.""" 1785 if self._c_op: 1786 return c_api.TF_OperationName(self._c_op) 1787 else: 1788 return self._node_def_val.name 1789 1790 @property 1791 def _id(self): 1792 """The unique integer id of this operation.""" 1793 return self._id_value 1794 1795 @property 1796 def device(self): 1797 """The name of the device to which this op has been assigned, if any. 1798 1799 Returns: 1800 The string name of the device to which this op has been 1801 assigned, or an empty string if it has not been assigned to a 1802 device. 1803 """ 1804 if self._c_op: 1805 return c_api.TF_OperationDevice(self._c_op) 1806 else: 1807 return self._node_def_val.device 1808 1809 @property 1810 def _output_types(self): 1811 """List this operation's output types. 1812 1813 Returns: 1814 List of the types of the Tensors computed by this operation. 1815 Each element in the list is an integer whose value is one of 1816 the TF_DataType enums defined in c_api.h 1817 The length of this list indicates the number of output endpoints 1818 of the operation. 1819 """ 1820 if self._c_op: 1821 num_outputs = c_api.TF_OperationNumOutputs(self._c_op) 1822 output_types = [ 1823 c_api.TF_OperationOutputType(self._tf_output(i)) 1824 for i in xrange(num_outputs) 1825 ] 1826 # TODO(iga): Remove this assert after converting to C API by default. 1827 # Just being a bit paranoid here. 1828 assert self._output_types_val == output_types 1829 # In all the tests we have output_types that are passed into 1830 # Operation.__init__ are a list of ints (which is illegal according 1831 # to the docstring), but input_types are instances of DType. 1832 # This extra assert is to catch if we ever use DType for output_types. 1833 if output_types: 1834 assert isinstance(output_types[0], int) 1835 return output_types 1836 else: 1837 return self._output_types_val 1838 1839 def _tf_output(self, output_idx): 1840 """Create and return a new TF_Output for output_idx'th output of this op.""" 1841 assert self._c_op 1842 tf_output = c_api.TF_Output() 1843 tf_output.oper = self._c_op 1844 tf_output.index = output_idx 1845 return tf_output 1846 1847 def _tf_input(self, input_idx): 1848 """Create and return a new TF_Input for input_idx'th input of this op.""" 1849 assert self._c_op 1850 tf_input = c_api.TF_Input() 1851 tf_input.oper = self._c_op 1852 tf_input.index = input_idx 1853 return tf_input 1854 1855 def _set_device(self, device): # pylint: disable=redefined-outer-name 1856 """Set the device of this operation. 1857 1858 Args: 1859 device: string or device.. The device to set. 1860 """ 1861 if self._c_op: 1862 c_api.SetRequestedDevice( 1863 self._graph._c_graph, # pylint: disable=protected-access 1864 self._c_op, # pylint: disable=protected-access 1865 compat.as_str(_device_string(device))) 1866 else: 1867 self._node_def_val.device = _device_string(device) 1868 1869 def _add_input(self, tensor, dtype=None): 1870 """Add a new input to this operation. 1871 1872 Args: 1873 tensor: the Tensor to add as an input. 1874 dtype: tf.DType: type of the input; defaults to 1875 the tensor's dtype. 1876 1877 Raises: 1878 TypeError: if tensor is not a Tensor, 1879 or if input tensor type is not convertible to dtype. 1880 ValueError: if the Tensor is from a different graph. 1881 """ 1882 assert not self._c_op, ( 1883 "Operation._add_input doesn't work with C API") 1884 if not isinstance(tensor, Tensor): 1885 raise TypeError("tensor must be a Tensor: %s" % tensor) 1886 _assert_same_graph(self, tensor) 1887 if dtype is None: 1888 dtype = tensor.dtype 1889 else: 1890 dtype = dtypes.as_dtype(dtype) 1891 if not dtype.is_compatible_with(tensor.dtype): 1892 raise TypeError( 1893 "Cannot convert a tensor of type %s to an input of type %s" % 1894 (tensor.dtype.name, dtype.name)) 1895 self._inputs_val.append(tensor) 1896 self._input_types_val.append(dtype) 1897 tensor._add_consumer(self) # pylint: disable=protected-access 1898 self._recompute_node_def() 1899 1900 def _update_input(self, index, tensor): 1901 """Update the input to this operation at the given index. 1902 1903 NOTE: This is for TF internal use only. Please don't use it. 1904 1905 Args: 1906 index: the index of the input to update. 1907 tensor: the Tensor to be used as the input at the given index. 1908 1909 Raises: 1910 TypeError: if tensor is not a Tensor, 1911 or if input tensor type is not convertible to dtype. 1912 ValueError: if the Tensor is from a different graph. 1913 """ 1914 if not isinstance(tensor, Tensor): 1915 raise TypeError("tensor must be a Tensor: %s" % tensor) 1916 _assert_same_graph(self, tensor) 1917 if self._c_op: 1918 with errors.raise_exception_on_not_ok_status() as status: 1919 c_api.UpdateEdge( 1920 self._graph._c_graph, # pylint: disable=protected-access 1921 tensor._as_tf_output(), # pylint: disable=protected-access 1922 self._tf_input(index), 1923 status) 1924 else: 1925 self._inputs_val[index].consumers().remove(self) 1926 self._inputs_val[index] = tensor 1927 self._input_types_val[index] = tensor.dtype 1928 tensor._add_consumer(self) # pylint: disable=protected-access 1929 self._recompute_node_def() 1930 1931 def _add_control_inputs(self, ops): 1932 """Add a list of new control inputs to this operation. 1933 1934 Args: 1935 ops: the list of Operations to add as control input. 1936 1937 Raises: 1938 TypeError: if ops is not a list of Operations. 1939 ValueError: if any op in ops is from a different graph. 1940 """ 1941 if self._c_op: 1942 for op in ops: 1943 if not isinstance(op, Operation): 1944 raise TypeError("op must be an Operation: %s" % op) 1945 c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access 1946 else: 1947 if ops: 1948 for op in ops: 1949 if not isinstance(op, Operation): 1950 raise TypeError("op must be an Operation: %s" % op) 1951 _assert_same_graph(self, op) 1952 self._control_inputs_val.append(op) 1953 self._recompute_node_def() 1954 1955 def _add_control_input(self, op): 1956 """Add a new control input to this operation. 1957 1958 Args: 1959 op: the Operation to add as control input. 1960 1961 Raises: 1962 TypeError: if op is not an Operation. 1963 ValueError: if op is from a different graph. 1964 """ 1965 if self._c_op: 1966 if not isinstance(op, Operation): 1967 raise TypeError("op must be an Operation: %s" % op) 1968 c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access 1969 else: 1970 self._add_control_inputs([op]) 1971 1972 def _remove_all_control_inputs(self): 1973 """Removes any control inputs to this operation.""" 1974 if self._c_op: 1975 c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access 1976 else: 1977 del self.control_inputs[:] 1978 1979 # Methods below are used when building the NodeDef and Graph proto. 1980 def _recompute_node_def(self): 1981 # TODO(skyewm): remove this function when we switch to C API 1982 if self._c_op: return 1983 1984 del self._node_def_val.input[:] 1985 # pylint: disable=protected-access 1986 self._node_def_val.input.extend( 1987 [t._as_node_def_input() for t in self._inputs_val]) 1988 # pylint: enable=protected-access 1989 if self._control_inputs_val: 1990 self._node_def_val.input.extend( 1991 ["^%s" % op.name for op in self._control_inputs_val]) 1992 1993 def __str__(self): 1994 return str(self.node_def) 1995 1996 def __repr__(self): 1997 return "<tf.Operation '%s' type=%s>" % (self.name, self.type) 1998 1999 @property 2000 def outputs(self): 2001 """The list of `Tensor` objects representing the outputs of this op.""" 2002 return self._outputs 2003 2004 # pylint: disable=protected-access 2005 2006 class _InputList(object): 2007 """Immutable input list wrapper.""" 2008 2009 def __init__(self, inputs): 2010 self._inputs = inputs 2011 2012 def __iter__(self): 2013 return iter(self._inputs) 2014 2015 def __len__(self): 2016 return len(self._inputs) 2017 2018 def __bool__(self): 2019 return bool(self._inputs) 2020 2021 # Python 3 wants __bool__, Python 2.7 wants __nonzero__ 2022 __nonzero__ = __bool__ 2023 2024 def __getitem__(self, i): 2025 return self._inputs[i] 2026 2027 # pylint: enable=protected-access 2028 2029 @property 2030 def inputs(self): 2031 """The list of `Tensor` objects representing the data inputs of this op.""" 2032 if self._c_op: 2033 tf_outputs = c_api.GetOperationInputs(self._c_op) 2034 # pylint: disable=protected-access 2035 retval = [ 2036 self.graph._get_tensor_by_tf_output(tf_output) 2037 for tf_output in tf_outputs 2038 ] 2039 # pylint: enable=protected-access 2040 return Operation._InputList(retval) 2041 return Operation._InputList(self._inputs_val) 2042 2043 @property 2044 def _inputs(self): 2045 logging.warning("Operation._inputs is private, use Operation.inputs " 2046 "instead. Operation._inputs will eventually be removed.") 2047 return self.inputs 2048 2049 @_inputs.setter 2050 def _inputs(self, value): 2051 raise ValueError("Cannot assign _inputs") 2052 2053 @property 2054 def _input_types(self): 2055 if self._c_op: 2056 num_inputs = c_api.TF_OperationNumInputs(self._c_op) 2057 input_types = [ 2058 dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i))) 2059 for i in xrange(num_inputs) 2060 ] 2061 return input_types 2062 else: 2063 return self._input_types_val 2064 2065 @_input_types.setter 2066 def _input_types(self, value): 2067 raise ValueError("Cannot assign _input_types") 2068 2069 @property 2070 def control_inputs(self): 2071 """The `Operation` objects on which this op has a control dependency. 2072 2073 Before this op is executed, TensorFlow will ensure that the 2074 operations in `self.control_inputs` have finished executing. This 2075 mechanism can be used to run ops sequentially for performance 2076 reasons, or to ensure that the side effects of an op are observed 2077 in the correct order. 2078 2079 Returns: 2080 A list of `Operation` objects. 2081 2082 """ 2083 if self._c_op: 2084 control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op) 2085 # pylint: disable=protected-access 2086 return [ 2087 self.graph._get_operation_by_name_unsafe( 2088 c_api.TF_OperationName(c_op)) for c_op in control_c_ops 2089 ] 2090 # pylint: enable=protected-access 2091 else: 2092 return self._control_inputs_val 2093 2094 @property 2095 def _control_inputs(self): 2096 logging.warning("Operation._control_inputs is private, use " 2097 "Operation.control_inputs instead. " 2098 "Operation._control_inputs will eventually be removed.") 2099 return self.control_inputs 2100 2101 @_control_inputs.setter 2102 def _control_inputs(self, value): 2103 logging.warning("Operation._control_inputs is private, use " 2104 "Operation.control_inputs instead. " 2105 "Operation._control_inputs will eventually be removed.") 2106 # Copy value because it may be self._control_inputs_val (in particular if 2107 # this is called from self._control_inputs += ...), and we don't want to 2108 # clear value below. 2109 value = copy.copy(value) 2110 self._remove_all_control_inputs() 2111 self._add_control_inputs(value) 2112 2113 @property 2114 def type(self): 2115 """The type of the op (e.g. `"MatMul"`).""" 2116 if self._c_op: 2117 op_type = c_api.TF_OperationOpType(self._c_op) 2118 return op_type 2119 else: 2120 return self._node_def_val.op 2121 2122 @property 2123 def graph(self): 2124 """The `Graph` that contains this operation.""" 2125 return self._graph 2126 2127 @property 2128 def node_def(self): 2129 # pylint: disable=line-too-long 2130 """Returns the `NodeDef` representation of this operation. 2131 2132 Returns: 2133 A 2134 [`NodeDef`](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto) 2135 protocol buffer. 2136 """ 2137 # pylint: enable=line-too-long 2138 if self._c_op: 2139 with c_api_util.tf_buffer() as buf: 2140 with errors.raise_exception_on_not_ok_status() as status: 2141 c_api.TF_OperationToNodeDef(self._c_op, buf, status) 2142 data = c_api.TF_GetBuffer(buf) 2143 node_def = node_def_pb2.NodeDef() 2144 node_def.ParseFromString(compat.as_bytes(data)) 2145 return node_def 2146 else: 2147 return self._node_def_val 2148 2149 @property 2150 def _node_def(self): 2151 logging.warning("Operation._node_def is private, use Operation.node_def " 2152 "instead. Operation._node_def will eventually be removed.") 2153 return self.node_def 2154 2155 @property 2156 def op_def(self): 2157 # pylint: disable=line-too-long 2158 """Returns the `OpDef` proto that represents the type of this op. 2159 2160 Returns: 2161 An 2162 [`OpDef`](https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto) 2163 protocol buffer. 2164 """ 2165 # pylint: enable=line-too-long 2166 if self._c_op: 2167 return self._graph._get_op_def(self.type) 2168 else: 2169 return self._op_def_val 2170 2171 @property 2172 def _op_def(self): 2173 logging.warning("Operation._op_def is private, use Operation.op_def " 2174 "instead. Operation._op_def will eventually be removed.") 2175 return self.op_def 2176 2177 @property 2178 def traceback(self): 2179 """Returns the call stack from when this operation was constructed.""" 2180 return self._graph._convert_stack(self._traceback) # pylint: disable=protected-access 2181 2182 @property 2183 def traceback_with_start_lines(self): 2184 """Same as traceback but includes start line of function definition. 2185 2186 Returns: 2187 A list of 5-tuples (filename, lineno, name, code, func_start_lineno). 2188 """ 2189 return self._graph._convert_stack( # pylint: disable=protected-access 2190 self._traceback, 2191 include_func_start_lineno=True) 2192 2193 def _set_attr(self, attr_name, attr_value): 2194 """Private method used to set an attribute in the node_def.""" 2195 if self._c_op: 2196 buf = c_api.TF_NewBufferFromString( 2197 compat.as_bytes(attr_value.SerializeToString())) 2198 try: 2199 with errors.raise_exception_on_not_ok_status() as status: 2200 # pylint: disable=protected-access 2201 c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf, 2202 status) 2203 # pylint: enable=protected-access 2204 finally: 2205 c_api.TF_DeleteBuffer(buf) 2206 else: 2207 self._node_def_val.attr[attr_name].CopyFrom(attr_value) 2208 2209 def get_attr(self, name): 2210 """Returns the value of the attr of this op with the given `name`. 2211 2212 Args: 2213 name: The name of the attr to fetch. 2214 2215 Returns: 2216 The value of the attr, as a Python object. 2217 2218 Raises: 2219 ValueError: If this op does not have an attr with the given `name`. 2220 """ 2221 fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"] 2222 if self._c_op: 2223 try: 2224 with c_api_util.tf_buffer() as buf: 2225 with errors.raise_exception_on_not_ok_status() as status: 2226 c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf, status) 2227 data = c_api.TF_GetBuffer(buf) 2228 except errors.InvalidArgumentError as e: 2229 # Convert to ValueError for backwards compatibility. 2230 raise ValueError(str(e)) 2231 x = attr_value_pb2.AttrValue() 2232 x.ParseFromString(data) 2233 else: 2234 if name not in self._node_def_val.attr: 2235 raise ValueError( 2236 "No attr named '" + name + "' in " + str(self._node_def_val)) 2237 x = self._node_def_val.attr[name] 2238 2239 # Treat an empty oneof value as an empty list. 2240 if not x.WhichOneof("value"): 2241 return [] 2242 if x.HasField("list"): 2243 for f in fields: 2244 if getattr(x.list, f): 2245 if f == "type": 2246 return [dtypes.as_dtype(x) for x in list(getattr(x.list, f))] 2247 else: 2248 return list(getattr(x.list, f)) 2249 return [] 2250 else: 2251 for f in fields: 2252 if x.HasField(f): 2253 if f == "type": 2254 return dtypes.as_dtype(getattr(x, f)) 2255 else: 2256 return getattr(x, f) 2257 assert False, "Unsupported field type in " + str(x) 2258 2259 def run(self, feed_dict=None, session=None): 2260 """Runs this operation in a `Session`. 2261 2262 Calling this method will execute all preceding operations that 2263 produce the inputs needed for this operation. 2264 2265 *N.B.* Before invoking `Operation.run()`, its graph must have been 2266 launched in a session, and either a default session must be 2267 available, or `session` must be specified explicitly. 2268 2269 Args: 2270 feed_dict: A dictionary that maps `Tensor` objects to feed values. 2271 See @{tf.Session.run} 2272 for a description of the valid feed values. 2273 session: (Optional.) The `Session` to be used to run to this operation. If 2274 none, the default session will be used. 2275 """ 2276 _run_using_default_session(self, feed_dict, self.graph, session) 2277 2278 _gradient_registry = registry.Registry("gradient") 2279 2280 2281 @tf_export("RegisterGradient") 2282 class RegisterGradient(object): 2283 """A decorator for registering the gradient function for an op type. 2284 2285 This decorator is only used when defining a new op type. For an op 2286 with `m` inputs and `n` outputs, the gradient function is a function 2287 that takes the original `Operation` and `n` `Tensor` objects 2288 (representing the gradients with respect to each output of the op), 2289 and returns `m` `Tensor` objects (representing the partial gradients 2290 with respect to each input of the op). 2291 2292 For example, assuming that operations of type `"Sub"` take two 2293 inputs `x` and `y`, and return a single output `x - y`, the 2294 following gradient function would be registered: 2295 2296 ```python 2297 @tf.RegisterGradient("Sub") 2298 def _sub_grad(unused_op, grad): 2299 return grad, tf.negative(grad) 2300 ``` 2301 2302 The decorator argument `op_type` is the string type of an 2303 operation. This corresponds to the `OpDef.name` field for the proto 2304 that defines the operation. 2305 """ 2306 2307 def __init__(self, op_type): 2308 """Creates a new decorator with `op_type` as the Operation type. 2309 2310 Args: 2311 op_type: The string type of an operation. This corresponds to the 2312 `OpDef.name` field for the proto that defines the operation. 2313 """ 2314 if not isinstance(op_type, six.string_types): 2315 raise TypeError("op_type must be a string") 2316 self._op_type = op_type 2317 2318 def __call__(self, f): 2319 """Registers the function `f` as gradient function for `op_type`.""" 2320 _gradient_registry.register(f, self._op_type) 2321 return f 2322 2323 2324 @tf_export("NoGradient", "NotDifferentiable") 2325 def NotDifferentiable(op_type): 2326 """Specifies that ops of type `op_type` is not differentiable. 2327 2328 This function should *not* be used for operations that have a 2329 well-defined gradient that is not yet implemented. 2330 2331 This function is only used when defining a new op type. It may be 2332 used for ops such as `tf.size()` that are not differentiable. For 2333 example: 2334 2335 ```python 2336 tf.NotDifferentiable("Size") 2337 ``` 2338 2339 The gradient computed for 'op_type' will then propagate zeros. 2340 2341 For ops that have a well-defined gradient but are not yet implemented, 2342 no declaration should be made, and an error *must* be thrown if 2343 an attempt to request its gradient is made. 2344 2345 Args: 2346 op_type: The string type of an operation. This corresponds to the 2347 `OpDef.name` field for the proto that defines the operation. 2348 2349 Raises: 2350 TypeError: If `op_type` is not a string. 2351 2352 """ 2353 if not isinstance(op_type, six.string_types): 2354 raise TypeError("op_type must be a string") 2355 _gradient_registry.register(None, op_type) 2356 2357 2358 # Alias for the old name, will be eventually removed. 2359 NoGradient = NotDifferentiable 2360 2361 2362 def get_gradient_function(op): 2363 """Returns the function that computes gradients for "op".""" 2364 if not op.inputs: 2365 return None 2366 try: 2367 op_type = op.get_attr("_gradient_op_type") 2368 except ValueError: 2369 op_type = op.type 2370 return _gradient_registry.lookup(op_type) 2371 2372 2373 _shape_registry = registry.Registry("shape functions") 2374 _default_shape_function_registry = registry.Registry("default shape functions") 2375 2376 # These are set to common_shapes.call_cpp_shape_fn by op generated code 2377 # (generated by python_op_gen.cc). 2378 # It is set outside ops.py to avoid a circular dependency. 2379 _call_cpp_shape_fn = None 2380 _call_cpp_shape_fn_and_require_op = None 2381 2382 2383 def _set_call_cpp_shape_fn(call_cpp_shape_fn): 2384 """Sets default shape fns from passed common_shapes.call_cpp_shape_fn.""" 2385 global _call_cpp_shape_fn, _call_cpp_shape_fn_and_require_op 2386 if _call_cpp_shape_fn: 2387 return # already registered 2388 2389 def call_without_requiring(op): 2390 return call_cpp_shape_fn(op, require_shape_fn=False) 2391 2392 _call_cpp_shape_fn = call_without_requiring 2393 2394 def call_with_requiring(op): 2395 return call_cpp_shape_fn(op, require_shape_fn=True) 2396 2397 _call_cpp_shape_fn_and_require_op = call_with_requiring 2398 2399 2400 class RegisterShape(object): 2401 """No longer used. Was: A decorator for registering a shape function. 2402 2403 Shape functions must now be registered via the SetShapeFn on the 2404 original Op specification in C++. 2405 2406 """ 2407 2408 def __init__(self, op_type): 2409 """Saves the `op_type` as the `Operation` type.""" 2410 if not isinstance(op_type, six.string_types): 2411 raise TypeError("op_type must be a string") 2412 self._op_type = op_type 2413 2414 def __call__(self, f): 2415 """Registers "f" as the shape function for "op_type".""" 2416 if f is None: 2417 assert _call_cpp_shape_fn 2418 2419 # None is a special "weak" value that provides a default shape function, 2420 # and can be overridden by a non-None registration. 2421 try: 2422 _default_shape_function_registry.register(_call_cpp_shape_fn, 2423 self._op_type) 2424 except KeyError: 2425 # Ignore duplicate registrations of the weak value. This can 2426 # occur if the op library input to wrapper generation 2427 # inadvertently links in one or more of the standard op 2428 # libraries. 2429 pass 2430 else: 2431 _shape_registry.register(f, self._op_type) 2432 return f 2433 2434 2435 def _set_shapes_for_outputs_c_api(op): 2436 """set_shapes_for_outputs implementation when C API is enabled.""" 2437 # The C API computes the shapes when the TF_Operation is created. Fetch the 2438 # output shapes from the C object. 2439 for output in op.outputs: 2440 with errors.raise_exception_on_not_ok_status() as status: 2441 # pylint: disable=protected-access 2442 shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper( 2443 op._graph._c_graph, output._as_tf_output(), status) 2444 # pylint: enable=protected-access 2445 if unknown_shape: 2446 output.set_shape(tensor_shape.unknown_shape()) 2447 elif not shape_vector: 2448 output.set_shape(tensor_shape.scalar()) 2449 else: 2450 shape_vector = [None if d == -1 else d for d in shape_vector] 2451 output.set_shape(tensor_shape.TensorShape(shape_vector)) 2452 2453 2454 # TODO(skyewm): remove this when _USE_C_API flag is removed. 2455 def _set_shapes_for_outputs(op): 2456 """set_shapes_for_outputs implementation when C API is disabled.""" 2457 try: 2458 shape_func = _shape_registry.lookup(op.type) 2459 except LookupError: 2460 try: 2461 shape_func = _default_shape_function_registry.lookup(op.type) 2462 except LookupError: 2463 shape_func = _call_cpp_shape_fn_and_require_op 2464 2465 shapes = shape_func(op) 2466 if shapes is None: 2467 raise RuntimeError( 2468 "Shape function for op %s did not return any shapes" % op) 2469 elif isinstance(shapes, dict): 2470 # Returned by call_cpp_shape_fn 2471 shapes_dict = shapes 2472 shapes = shapes_dict["shapes"] 2473 handle_datas = shapes_dict["handle_data"] 2474 for output, handle_data in zip(op.outputs, handle_datas): 2475 # pylint: disable=protected-access 2476 output._handle_data = handle_data 2477 # pylint: enable=protected-access 2478 2479 if len(op.outputs) != len(shapes): 2480 raise RuntimeError( 2481 "Shape function for op %s returned %d shapes but expected %d %s %s" % 2482 (op, len(shapes), len(op.outputs), shape_func.__name__, str(shapes))) 2483 for output, s in zip(op.outputs, shapes): 2484 output.set_shape(s) 2485 2486 2487 def set_shapes_for_outputs(op): 2488 """Set the shapes for op's outputs.""" 2489 if op._c_op: # pylint: disable=protected-access 2490 return _set_shapes_for_outputs_c_api(op) 2491 else: 2492 return _set_shapes_for_outputs(op) 2493 2494 2495 class OpStats(object): 2496 """A holder for statistics about an operator. 2497 2498 This class holds information about the resource requirements for an op, 2499 including the size of its weight parameters on-disk and how many FLOPS it 2500 requires to execute forward inference. 2501 2502 If you define a new operation, you can create a function that will return a 2503 set of information about its usage of the CPU and disk space when serialized. 2504 The function itself takes a Graph object that's been set up so you can call 2505 methods like get_tensor_by_name to help calculate the results, and a NodeDef 2506 argument. 2507 2508 """ 2509 2510 def __init__(self, statistic_type, value=None): 2511 """Sets up the initial placeholders for the statistics.""" 2512 self.statistic_type = statistic_type 2513 self.value = value 2514 2515 @property 2516 def statistic_type(self): 2517 return self._statistic_type 2518 2519 @statistic_type.setter 2520 def statistic_type(self, statistic_type): 2521 self._statistic_type = statistic_type 2522 2523 @property 2524 def value(self): 2525 return self._value 2526 2527 @value.setter 2528 def value(self, value): 2529 self._value = value 2530 2531 def __iadd__(self, other): 2532 if other.statistic_type != self.statistic_type: 2533 raise ValueError("Can't add an OpStat of type %s to one of %s." % 2534 (self.statistic_type, other.statistic_type)) 2535 if self.value is None: 2536 self.value = other.value 2537 elif other.value is not None: 2538 self._value += other.value 2539 return self 2540 2541 2542 _stats_registry = registry.Registry("statistical functions") 2543 2544 2545 class RegisterStatistics(object): 2546 """A decorator for registering the statistics function for an op type. 2547 2548 This decorator can be defined for an op type so that it gives a 2549 report on the resources used by an instance of an operator, in the 2550 form of an OpStats object. 2551 2552 Well-known types of statistics include these so far: 2553 2554 - flops: When running a graph, the bulk of the computation happens doing 2555 numerical calculations like matrix multiplications. This type allows a node 2556 to return how many floating-point operations it takes to complete. The 2557 total number of FLOPs for a graph is a good guide to its expected latency. 2558 2559 You can add your own statistics just by picking a new type string, registering 2560 functions for the ops you care about, and then calling get_stats_for_node_def. 2561 2562 If a statistic for an op is registered multiple times, a KeyError will be 2563 raised. 2564 2565 Since the statistics is counted on a per-op basis. It is not suitable for 2566 model parameters (capacity), which is expected to be counted only once, even 2567 if it is shared by multiple ops. (e.g. RNN) 2568 2569 For example, you can define a new metric called doohickey for a Foo operation 2570 by placing this in your code: 2571 2572 ```python 2573 @ops.RegisterStatistics("Foo", "doohickey") 2574 def _calc_foo_bojangles(unused_graph, unused_node_def): 2575 return ops.OpStats("doohickey", 20) 2576 ``` 2577 2578 Then in client code you can retrieve the value by making this call: 2579 2580 ```python 2581 doohickey = ops.get_stats_for_node_def(graph, node_def, "doohickey") 2582 ``` 2583 2584 If the NodeDef is for an op with a registered doohickey function, you'll get 2585 back the calculated amount in doohickey.value, or None if it's not defined. 2586 2587 """ 2588 2589 def __init__(self, op_type, statistic_type): 2590 """Saves the `op_type` as the `Operation` type.""" 2591 if not isinstance(op_type, six.string_types): 2592 raise TypeError("op_type must be a string.") 2593 if "," in op_type: 2594 raise TypeError("op_type must not contain a comma.") 2595 self._op_type = op_type 2596 if not isinstance(statistic_type, six.string_types): 2597 raise TypeError("statistic_type must be a string.") 2598 if "," in statistic_type: 2599 raise TypeError("statistic_type must not contain a comma.") 2600 self._statistic_type = statistic_type 2601 2602 def __call__(self, f): 2603 """Registers "f" as the statistics function for "op_type".""" 2604 _stats_registry.register(f, self._op_type + "," + self._statistic_type) 2605 return f 2606 2607 2608 def get_stats_for_node_def(graph, node, statistic_type): 2609 """Looks up the node's statistics function in the registry and calls it. 2610 2611 This function takes a Graph object and a NodeDef from a GraphDef, and if 2612 there's an associated statistics method, calls it and returns a result. If no 2613 function has been registered for the particular node type, it returns an empty 2614 statistics object. 2615 2616 Args: 2617 graph: A Graph object that's been set up with the node's graph. 2618 node: A NodeDef describing the operator. 2619 statistic_type: A string identifying the statistic we're interested in. 2620 Returns: 2621 An OpStats object containing information about resource usage. 2622 """ 2623 2624 try: 2625 stats_func = _stats_registry.lookup(node.op + "," + statistic_type) 2626 result = stats_func(graph, node) 2627 except LookupError: 2628 result = OpStats(statistic_type) 2629 return result 2630 2631 2632 def _name_from_scope_name(name): 2633 """Returns the name of an op given the name of its scope. 2634 2635 Args: 2636 name: the name of the scope. 2637 2638 Returns: 2639 the name of the op (equal to scope name minus any trailing slash). 2640 """ 2641 return name[:-1] if (name and name[-1] == "/") else name 2642 2643 2644 @tf_export("Graph") 2645 class Graph(object): 2646 """A TensorFlow computation, represented as a dataflow graph. 2647 2648 A `Graph` contains a set of 2649 @{tf.Operation} objects, 2650 which represent units of computation; and 2651 @{tf.Tensor} objects, which represent 2652 the units of data that flow between operations. 2653 2654 A default `Graph` is always registered, and accessible by calling 2655 @{tf.get_default_graph}. 2656 To add an operation to the default graph, simply call one of the functions 2657 that defines a new `Operation`: 2658 2659 ```python 2660 c = tf.constant(4.0) 2661 assert c.graph is tf.get_default_graph() 2662 ``` 2663 2664 Another typical usage involves the 2665 @{tf.Graph.as_default} 2666 context manager, which overrides the current default graph for the 2667 lifetime of the context: 2668 2669 ```python 2670 g = tf.Graph() 2671 with g.as_default(): 2672 # Define operations and tensors in `g`. 2673 c = tf.constant(30.0) 2674 assert c.graph is g 2675 ``` 2676 2677 Important note: This class *is not* thread-safe for graph construction. All 2678 operations should be created from a single thread, or external 2679 synchronization must be provided. Unless otherwise specified, all methods 2680 are not thread-safe. 2681 2682 A `Graph` instance supports an arbitrary number of "collections" 2683 that are identified by name. For convenience when building a large 2684 graph, collections can store groups of related objects: for 2685 example, the `tf.Variable` uses a collection (named 2686 @{tf.GraphKeys.GLOBAL_VARIABLES}) for 2687 all variables that are created during the construction of a graph. The caller 2688 may define additional collections by specifying a new name. 2689 """ 2690 2691 def __init__(self): 2692 """Creates a new, empty Graph.""" 2693 # Protects the core state that may be accessed by multiple readers. 2694 # Only state that can be returned via public accessors (`as_graph_def()`, 2695 # `get_operations()`, `as_graph_element()`, `get_collection()`, and 2696 # `get_collection_ref()`) is by the lock. Thread-safety is provided on a 2697 # best-effort basis to support buggy programs, and is not guaranteed by the 2698 # public `tf.Graph` API. 2699 # NOTE(mrry): This does not protect the various stacks. A warning will 2700 # be reported if these are used from multiple threads 2701 self._lock = threading.Lock() 2702 self._nodes_by_id = dict() # GUARDED_BY(self._lock) 2703 self._next_id_counter = 0 # GUARDED_BY(self._lock) 2704 self._nodes_by_name = dict() # GUARDED_BY(self._lock) 2705 self._version = 0 # GUARDED_BY(self._lock) 2706 # Current name stack: uniquified names 2707 self._name_stack = "" 2708 # Maps a name used in the graph to the next id to use for that name. 2709 self._names_in_use = {} 2710 # Functions that will be applied to choose a device if none is specified. 2711 self._device_function_stack = [] 2712 # Default original_op applied to new ops. 2713 self._default_original_op = None 2714 # Current control flow context. It could be either CondContext or 2715 # WhileContext defined in ops/control_flow_ops.py 2716 self._control_flow_context = None 2717 # A new node will depend of the union of all of the nodes in the stack. 2718 self._control_dependencies_stack = [] 2719 # Arbitrary collections of objects. 2720 self._collections = {} 2721 # The graph-level random seed 2722 self._seed = None 2723 # A dictionary of attributes that should be applied to all ops. 2724 self._attr_scope_map = {} 2725 # A map from op type to the kernel label that should be used. 2726 self._op_to_kernel_label_map = {} 2727 # A map from op type to an alternative op type that should be used when 2728 # computing gradients. 2729 self._gradient_override_map = {} 2730 # True if the graph is considered "finalized". In that case no 2731 # new operations can be added. 2732 self._finalized = False 2733 # Functions defined in the graph 2734 self._functions = collections.OrderedDict() 2735 # Default GraphDef versions 2736 self._graph_def_versions = versions_pb2.VersionDef( 2737 producer=versions.GRAPH_DEF_VERSION, 2738 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER) 2739 self._building_function = False 2740 # Stack of colocate_with ops 2741 self._colocation_stack = [] 2742 # Set of tensors that are dangerous to feed! 2743 self._unfeedable_tensors = set() 2744 # Set of operations that are dangerous to fetch! 2745 self._unfetchable_ops = set() 2746 # A map of tensor handle placeholder to tensor dtype. 2747 self._handle_feeders = {} 2748 # A map from tensor handle to its read op. 2749 self._handle_readers = {} 2750 # A map from tensor handle to its move op. 2751 self._handle_movers = {} 2752 # A map from tensor handle to its delete op. 2753 self._handle_deleters = {} 2754 # Allow optimizers and other objects to pseudo-uniquely key graphs (this key 2755 # will be shared when defining function graphs, for example, so optimizers 2756 # being called inside function definitions behave as if they were seeing the 2757 # actual outside graph). 2758 self._graph_key = "grap-key-%d/" % (uid(),) 2759 self._container = "" 2760 self._registered_ops = op_def_registry.get_registered_ops() 2761 2762 # TODO(skyewm): fold as much of the above as possible into the C 2763 # implementation 2764 if _USE_C_API or self._use_c_api_hack(): 2765 self._scoped_c_graph = c_api_util.ScopedTFGraph() 2766 else: 2767 self._scoped_c_graph = None 2768 self._variable_creator_stack = [] 2769 2770 # TODO(apassos) remove once the C API is used by default. 2771 def _use_c_api_hack(self): 2772 """Temporary hack; can be overridden to force C API usage.""" 2773 return False 2774 2775 def _convert_stack(self, stack, include_func_start_lineno=False): 2776 """Converts a stack extracted using _extract_stack() to a traceback stack. 2777 2778 Args: 2779 stack: A list of n 5-tuples, 2780 (filename, lineno, name, frame_globals, func_start_lineno). 2781 include_func_start_lineno: True if function start line number should be 2782 included as the 5th entry in return tuples. 2783 2784 Returns: 2785 A list of n 4-tuples or 5-tuples 2786 (filename, lineno, name, code, [optional: func_start_lineno]), where the 2787 code tuple element is calculated from the corresponding elements of the 2788 input tuple. 2789 """ 2790 ret = [] 2791 for (filename, lineno, name, frame_globals, func_start_lineno, 2792 unused_frame_info) in stack: 2793 linecache.checkcache(filename) 2794 line = linecache.getline(filename, lineno, frame_globals) 2795 if line: 2796 line = line.strip() 2797 else: 2798 line = None 2799 if include_func_start_lineno: 2800 ret.append((filename, lineno, name, line, func_start_lineno)) 2801 else: 2802 ret.append((filename, lineno, name, line)) 2803 return ret 2804 2805 # Note: this method is private because the API of tf.Graph() is public and 2806 # frozen, and this functionality is still not ready for public visibility. 2807 @tf_contextlib.contextmanager 2808 def _variable_creator_scope(self, creator): 2809 old = list(self._variable_creator_stack) 2810 self._variable_creator_stack.append(creator) 2811 try: 2812 yield 2813 finally: 2814 self._variable_creator_stack = old 2815 2816 # Note: this method is private because the API of tf.Graph() is public and 2817 # frozen, and this functionality is still not ready for public visibility. 2818 def _get_variable_creator_stack(self): 2819 return list(self._variable_creator_stack) 2820 2821 def _extract_stack(self): 2822 """A lightweight, extensible re-implementation of traceback.extract_stack. 2823 2824 NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for 2825 each stack frame using linecache, which results in an abundance of stat() 2826 calls. This implementation does not retrieve the code, and any consumer 2827 should apply _convert_stack to the result to obtain a traceback that can 2828 be formatted etc. using traceback methods. 2829 2830 Derived classes can implement _extract_frame_info() to add extra information 2831 to the traceback. 2832 2833 Returns: 2834 A list of 6-tuples 2835 (filename, lineno, name, frame_globals, func_start_lineno, custom_info) 2836 corresponding to the call stack of the current thread. 2837 """ 2838 try: 2839 raise ZeroDivisionError 2840 except ZeroDivisionError: 2841 f = sys.exc_info()[2].tb_frame.f_back 2842 ret = [] 2843 while f is not None: 2844 lineno = f.f_lineno 2845 co = f.f_code 2846 filename = co.co_filename 2847 name = co.co_name 2848 frame_globals = f.f_globals 2849 func_start_lineno = co.co_firstlineno 2850 frame_info = self._extract_frame_info(f) 2851 ret.append((filename, lineno, name, frame_globals, func_start_lineno, 2852 frame_info)) 2853 f = f.f_back 2854 ret.reverse() 2855 return ret 2856 2857 def _extract_frame_info(self, frame): # pylint: disable=unused-argument 2858 """Extracts custom information from a frame in an op traceback.""" 2859 return None 2860 2861 def _check_not_finalized(self): 2862 """Check if the graph is finalized. 2863 2864 Raises: 2865 RuntimeError: If the graph finalized. 2866 """ 2867 if self._finalized: 2868 raise RuntimeError("Graph is finalized and cannot be modified.") 2869 2870 def _add_op(self, op): 2871 """Adds 'op' to the graph. 2872 2873 Args: 2874 op: the Operator or Tensor to add. 2875 2876 Raises: 2877 TypeError: if op is not an Operation or Tensor. 2878 ValueError: if the op.name or op._id are already used. 2879 """ 2880 self._check_not_finalized() 2881 if not isinstance(op, (Tensor, Operation)): 2882 raise TypeError("op must be a Tensor or Operation: %s" % op) 2883 with self._lock: 2884 # pylint: disable=protected-access 2885 if op._id in self._nodes_by_id: 2886 raise ValueError("cannot add an op with id %d as it already " 2887 "exists in the graph" % op._id) 2888 if op.name in self._nodes_by_name: 2889 raise ValueError("cannot add op with name %s as that name " 2890 "is already used" % op.name) 2891 self._nodes_by_id[op._id] = op 2892 self._nodes_by_name[op.name] = op 2893 self._version = max(self._version, op._id) 2894 # pylint: enable=protected-access 2895 2896 @property 2897 def _c_graph(self): 2898 if self._scoped_c_graph: 2899 return self._scoped_c_graph.graph 2900 return None 2901 2902 @property 2903 def version(self): 2904 """Returns a version number that increases as ops are added to the graph. 2905 2906 Note that this is unrelated to the 2907 @{tf.Graph.graph_def_versions}. 2908 2909 Returns: 2910 An integer version that increases as ops are added to the graph. 2911 """ 2912 if self._finalized: 2913 return self._version 2914 2915 with self._lock: 2916 return self._version 2917 2918 @property 2919 def graph_def_versions(self): 2920 # pylint: disable=line-too-long 2921 """The GraphDef version information of this graph. 2922 2923 For details on the meaning of each version, see 2924 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto). 2925 2926 Returns: 2927 A `VersionDef`. 2928 """ 2929 # pylint: enable=line-too-long 2930 if self._c_graph: 2931 with c_api_util.tf_buffer() as buf: 2932 with errors.raise_exception_on_not_ok_status() as status: 2933 c_api.TF_GraphVersions(self._c_graph, buf, status) 2934 data = c_api.TF_GetBuffer(buf) 2935 version_def = versions_pb2.VersionDef() 2936 version_def.ParseFromString(compat.as_bytes(data)) 2937 return version_def 2938 else: 2939 return self._graph_def_versions 2940 2941 @property 2942 def seed(self): 2943 """The graph-level random seed of this graph.""" 2944 return self._seed 2945 2946 @seed.setter 2947 def seed(self, seed): 2948 self._seed = seed 2949 2950 @property 2951 def finalized(self): 2952 """True if this graph has been finalized.""" 2953 return self._finalized 2954 2955 def finalize(self): 2956 """Finalizes this graph, making it read-only. 2957 2958 After calling `g.finalize()`, no new operations can be added to 2959 `g`. This method is used to ensure that no operations are added 2960 to a graph when it is shared between multiple threads, for example 2961 when using a @{tf.train.QueueRunner}. 2962 """ 2963 self._finalized = True 2964 2965 def _unsafe_unfinalize(self): 2966 """Opposite of `finalize`. Internal interface. 2967 2968 NOTE: Unfinalizing a graph could have negative impact on performance, 2969 especially in a multi-threaded environment. Unfinalizing a graph 2970 when it is in use by a Session may lead to undefined behavior. Ensure 2971 that all sessions using a graph are closed before calling this method. 2972 """ 2973 self._finalized = False 2974 2975 def _get_control_flow_context(self): 2976 """Returns the current control flow context. 2977 2978 Returns: 2979 A context object. 2980 """ 2981 return self._control_flow_context 2982 2983 def _set_control_flow_context(self, ctx): 2984 """Sets the current control flow context. 2985 2986 Args: 2987 ctx: a context object. 2988 """ 2989 self._control_flow_context = ctx 2990 2991 def _copy_functions_to_graph_def(self, graph_def, starting_bytesize): 2992 """If this graph contains functions, copy them to `graph_def`.""" 2993 bytesize = starting_bytesize 2994 for f in self._functions.values(): 2995 bytesize += f.definition.ByteSize() 2996 if bytesize >= (1 << 31) or bytesize < 0: 2997 raise ValueError("GraphDef cannot be larger than 2GB.") 2998 graph_def.library.function.extend([f.definition]) 2999 if f.grad_func_name: 3000 grad_def = function_pb2.GradientDef() 3001 grad_def.function_name = f.name 3002 grad_def.gradient_func = f.grad_func_name 3003 graph_def.library.gradient.extend([grad_def]) 3004 3005 def _as_graph_def(self, from_version=None, add_shapes=False): 3006 # pylint: disable=line-too-long 3007 """Returns a serialized `GraphDef` representation of this graph. 3008 3009 The serialized `GraphDef` can be imported into another `Graph` 3010 (using @{tf.import_graph_def}) or used with the 3011 [C++ Session API](../../../../api_docs/cc/index.md). 3012 3013 This method is thread-safe. 3014 3015 Args: 3016 from_version: Optional. If this is set, returns a `GraphDef` 3017 containing only the nodes that were added to this graph since 3018 its `version` property had the given value. 3019 add_shapes: If true, adds an "_output_shapes" list attr to each 3020 node with the inferred shapes of each of its outputs. 3021 3022 Returns: 3023 A tuple containing a 3024 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 3025 protocol buffer, and the version of the graph to which that 3026 `GraphDef` corresponds. 3027 3028 Raises: 3029 ValueError: If the `graph_def` would be too large. 3030 3031 """ 3032 # pylint: enable=line-too-long 3033 if _USE_C_API: 3034 with self._lock: 3035 with c_api_util.tf_buffer() as buf: 3036 with errors.raise_exception_on_not_ok_status() as status: 3037 c_api.TF_GraphToGraphDef(self._c_graph, buf, status) 3038 data = c_api.TF_GetBuffer(buf) 3039 graph = graph_pb2.GraphDef() 3040 graph.ParseFromString(compat.as_bytes(data)) 3041 # Strip the experimental library field iff it's empty. 3042 if not graph.library.function: 3043 graph.ClearField("library") 3044 3045 if add_shapes: 3046 for node in graph.node: 3047 op = self._nodes_by_name[node.name] 3048 if op.outputs: 3049 node.attr["_output_shapes"].list.shape.extend( 3050 [output.get_shape().as_proto() for output in op.outputs]) 3051 else: 3052 with self._lock: 3053 graph = graph_pb2.GraphDef() 3054 graph.versions.CopyFrom(self._graph_def_versions) 3055 bytesize = 0 3056 for op_id in sorted(self._nodes_by_id): 3057 op = self._nodes_by_id[op_id] 3058 if from_version is None or op_id > from_version: 3059 graph.node.extend([op.node_def]) 3060 if op.outputs and add_shapes: 3061 assert "_output_shapes" not in graph.node[-1].attr 3062 graph.node[-1].attr["_output_shapes"].list.shape.extend( 3063 [output.get_shape().as_proto() for output in op.outputs]) 3064 bytesize += op.node_def.ByteSize() 3065 if bytesize >= (1 << 31) or bytesize < 0: 3066 raise ValueError("GraphDef cannot be larger than 2GB.") 3067 self._copy_functions_to_graph_def(graph, bytesize) 3068 return graph, self._version 3069 3070 def as_graph_def(self, from_version=None, add_shapes=False): 3071 # pylint: disable=line-too-long 3072 """Returns a serialized `GraphDef` representation of this graph. 3073 3074 The serialized `GraphDef` can be imported into another `Graph` 3075 (using @{tf.import_graph_def}) or used with the 3076 [C++ Session API](../../api_docs/cc/index.md). 3077 3078 This method is thread-safe. 3079 3080 Args: 3081 from_version: Optional. If this is set, returns a `GraphDef` 3082 containing only the nodes that were added to this graph since 3083 its `version` property had the given value. 3084 add_shapes: If true, adds an "_output_shapes" list attr to each 3085 node with the inferred shapes of each of its outputs. 3086 3087 Returns: 3088 A 3089 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 3090 protocol buffer. 3091 3092 Raises: 3093 ValueError: If the `graph_def` would be too large. 3094 """ 3095 # pylint: enable=line-too-long 3096 result, _ = self._as_graph_def(from_version, add_shapes) 3097 return result 3098 3099 def _is_function(self, name): 3100 """Tests whether 'name' is registered in this graph's function library. 3101 3102 Args: 3103 name: string op name. 3104 Returns: 3105 bool indicating whether or not 'name' is registered in function library. 3106 """ 3107 return name in self._functions 3108 3109 def _get_function(self, name): 3110 """Returns the function definition for 'name'. 3111 3112 Args: 3113 name: string function name. 3114 Returns: 3115 The function def proto. 3116 """ 3117 return self._functions.get(name, None) 3118 3119 def _add_function(self, function): 3120 """Adds a function to the graph. 3121 3122 After the function has been added, you can call to the function by 3123 passing the function name in place of an op name to 3124 `Graph.create_op()`. 3125 3126 Args: 3127 function: A `_DefinedFunction` object. 3128 3129 3130 Raises: 3131 ValueError: if another function is defined with the same name. 3132 """ 3133 name = function.name 3134 # Sanity checks on gradient definition. 3135 if (function.grad_func_name is not None) and (function.python_grad_func is 3136 not None): 3137 raise ValueError("Gradient defined twice for function %s" % name) 3138 3139 # Add function to graph 3140 # pylint: disable=protected-access 3141 if self._c_graph: 3142 # Handle functions created without using the C API. TODO(apassos,skyewm) 3143 # remove this when all functions are generated using the C API by default 3144 # as this will be unnecessary. 3145 if not function._c_func: 3146 with errors.raise_exception_on_not_ok_status() as status: 3147 serialized = function.definition.SerializeToString() 3148 function._c_func = c_api.TF_FunctionImportFunctionDef( 3149 serialized, status) 3150 with errors.raise_exception_on_not_ok_status() as status: 3151 gradient = function._grad_func._c_func if function._grad_func else None 3152 c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient, 3153 status) 3154 else: 3155 # If there is already a function with the same name, raise an error 3156 # if bodies are different. Else, do nothing. The C API version above 3157 # has the same behavior. 3158 previous = self._functions.get(name, None) 3159 if previous: 3160 # This check is not ideal as we can have a hash collision with only 3161 # 32 bits in the hash, but the non C API mode is being deprecated. 3162 # Don't bother changing it now. 3163 if previous._hash_str == function._hash_str: 3164 return 3165 else: 3166 raise ValueError("Cannot add function (%s, hash %s) to graph (%s). " 3167 "Another function (%s, hash %s) is already defined " 3168 "with that name (%s)" % ( 3169 function, function._hash_str, self, 3170 previous, previous._hash_str, name)) 3171 # pylint: enable=protected-access 3172 3173 self._functions[name] = function 3174 3175 # Need a new-enough consumer to support the functions we add to the graph. 3176 if self._graph_def_versions.min_consumer < 12: 3177 self._graph_def_versions.min_consumer = 12 3178 3179 @property 3180 def building_function(self): 3181 """Returns True iff this graph represents a function.""" 3182 return self._building_function 3183 3184 # Helper functions to create operations. 3185 def create_op( 3186 self, 3187 op_type, 3188 inputs, 3189 dtypes, # pylint: disable=redefined-outer-name 3190 input_types=None, 3191 name=None, 3192 attrs=None, 3193 op_def=None, 3194 compute_shapes=True, 3195 compute_device=True): 3196 """Creates an `Operation` in this graph. 3197 3198 This is a low-level interface for creating an `Operation`. Most 3199 programs will not call this method directly, and instead use the 3200 Python op constructors, such as `tf.constant()`, which add ops to 3201 the default graph. 3202 3203 Args: 3204 op_type: The `Operation` type to create. This corresponds to the 3205 `OpDef.name` field for the proto that defines the operation. 3206 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 3207 dtypes: A list of `DType` objects that will be the types of the tensors 3208 that the operation produces. 3209 input_types: (Optional.) A list of `DType`s that will be the types of 3210 the tensors that the operation consumes. By default, uses the base 3211 `DType` of each input in `inputs`. Operations that expect 3212 reference-typed inputs must specify `input_types` explicitly. 3213 name: (Optional.) A string name for the operation. If not specified, a 3214 name is generated based on `op_type`. 3215 attrs: (Optional.) A dictionary where the key is the attribute name (a 3216 string) and the value is the respective `attr` attribute of the 3217 `NodeDef` proto that will represent the operation (an `AttrValue` 3218 proto). 3219 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 3220 the operation will have. 3221 compute_shapes: (Optional.) If True, shape inference will be performed 3222 to compute the shapes of the outputs. 3223 compute_device: (Optional.) If True, device functions will be executed 3224 to compute the device property of the Operation. 3225 3226 Raises: 3227 TypeError: if any of the inputs is not a `Tensor`. 3228 ValueError: if colocation conflicts with existing device assignment. 3229 3230 Returns: 3231 An `Operation` object. 3232 3233 """ 3234 self._check_not_finalized() 3235 for idx, a in enumerate(inputs): 3236 if not isinstance(a, Tensor): 3237 raise TypeError("Input #%d is not a tensor: %s" % (idx, a)) 3238 if name is None: 3239 name = op_type 3240 # If a names ends with a '/' it is a "name scope" and we use it as-is, 3241 # after removing the trailing '/'. 3242 if name and name[-1] == "/": 3243 name = _name_from_scope_name(name) 3244 else: 3245 name = self.unique_name(name) 3246 3247 node_def = _NodeDef(op_type, name, device=None, attrs=attrs) 3248 3249 input_ops = set([t.op for t in inputs]) 3250 control_inputs = self._control_dependencies_for_inputs(input_ops) 3251 ret = Operation( 3252 node_def, 3253 self, 3254 inputs=inputs, 3255 output_types=dtypes, 3256 control_inputs=control_inputs, 3257 input_types=input_types, 3258 original_op=self._default_original_op, 3259 op_def=op_def) 3260 self._create_op_helper(ret, compute_shapes=compute_shapes, 3261 compute_device=compute_device) 3262 return ret 3263 3264 def _create_op_from_tf_operation(self, c_op, compute_device=True): 3265 """Creates an `Operation` in this graph from the supplied TF_Operation. 3266 3267 This method is like create_op() except the new Operation is constructed 3268 using `c_op`. The returned Operation will have `c_op` as its _c_op 3269 field. This is used to create Operation objects around TF_Operations created 3270 indirectly by the C API (e.g. by TF_ImportGraphDef, TF_FinishWhile). 3271 3272 This function does not call Operation._control_flow_post_processing or 3273 Graph._control_dependencies_for_inputs (since the inputs may not be 3274 available yet). The caller is responsible for calling these methods. 3275 3276 Args: 3277 c_op: a wrapped TF_Operation 3278 compute_device: (Optional.) If True, device functions will be executed 3279 to compute the device property of the Operation. 3280 3281 Returns: 3282 An `Operation` object. 3283 """ 3284 self._check_not_finalized() 3285 ret = Operation(c_op, self) 3286 assert ret.name not in self._names_in_use 3287 self._names_in_use[ret.name] = 1 3288 self._create_op_helper(ret, compute_device=compute_device) 3289 return ret 3290 3291 def _create_op_helper(self, op, compute_shapes=True, compute_device=True): 3292 """Common logic for creating an op in this graph.""" 3293 # TODO(vrv): Instead of eagerly filling in shape property for every op, only 3294 # populate the shape when requested. 3295 # 3296 # TODO(skyewm): unlike in the original Python implementation, the C API 3297 # always computes shape information (even for function calls, which the 3298 # original Python shape inference code doesn't handle). Deprecate the 3299 # compute_shapes argument. 3300 if op._c_op or compute_shapes: # pylint: disable=protected-access 3301 set_shapes_for_outputs(op) 3302 # TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed. 3303 self._add_op(op) 3304 3305 # Apply any additional attributes requested. Do not overwrite any existing 3306 # attributes. 3307 for key, value in self._attr_scope_map.items(): 3308 try: 3309 op.get_attr(key) 3310 except ValueError: 3311 if callable(value): 3312 value = value(op.node_def) 3313 if not isinstance(value, (type(None), attr_value_pb2.AttrValue)): 3314 raise TypeError( 3315 "Callable for scope map key '%s' must return either None or " 3316 "an AttrValue protocol buffer; but it returned: %s" % (key, 3317 value)) 3318 if value: 3319 op._set_attr(key, value) # pylint: disable=protected-access 3320 3321 # Apply a kernel label if one has been specified for this op type. 3322 try: 3323 kernel_label = self._op_to_kernel_label_map[op.type] 3324 op._set_attr("_kernel", # pylint: disable=protected-access 3325 attr_value_pb2.AttrValue(s=compat.as_bytes(kernel_label))) 3326 except KeyError: 3327 pass 3328 3329 # Apply the overriding op type for gradients if one has been specified for 3330 # this op type. 3331 try: 3332 mapped_op_type = self._gradient_override_map[op.type] 3333 op._set_attr("_gradient_op_type", # pylint: disable=protected-access 3334 attr_value_pb2.AttrValue(s=compat.as_bytes(mapped_op_type))) 3335 except KeyError: 3336 pass 3337 3338 self._record_op_seen_by_control_dependencies(op) 3339 3340 if compute_device: 3341 self._apply_device_functions(op) 3342 3343 if self._colocation_stack: 3344 all_colocation_groups = [] 3345 for colocation_op in self._colocation_stack: 3346 all_colocation_groups.extend(colocation_op.colocation_groups()) 3347 if colocation_op.device: 3348 # Make this device match the device of the colocated op, to provide 3349 # consistency between the device and the colocation property. 3350 if (op.device and pydev.canonical_name(op.device) != 3351 pydev.canonical_name(colocation_op.device)): 3352 logging.warning("Tried to colocate %s with an op %s that had " 3353 "a different device: %s vs %s. " 3354 "Ignoring colocation property.", op.name, 3355 colocation_op.name, op.device, 3356 colocation_op.device) 3357 else: 3358 op._set_device(colocation_op.device) # pylint: disable=protected-access 3359 3360 all_colocation_groups = sorted(set(all_colocation_groups)) 3361 # pylint: disable=protected-access 3362 op._set_attr("_class", attr_value_pb2.AttrValue( 3363 list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups))) 3364 # pylint: enable=protected-access 3365 3366 # Sets "container" attribute if 3367 # (1) self._container is not None 3368 # (2) "is_stateful" is set in OpDef 3369 # (3) "container" attribute is in OpDef 3370 # (4) "container" attribute is None 3371 # TODO(skyewm): remove op.op_def check when _USE_C_API is removed. 3372 if self._container and op.op_def and op.op_def.is_stateful: 3373 try: 3374 container_attr = op.get_attr("container") 3375 except ValueError: 3376 # "container" attribute is not in OpDef 3377 pass 3378 else: 3379 if not container_attr: 3380 op._set_attr("container", attr_value_pb2.AttrValue( # pylint: disable=protected-access 3381 s=compat.as_bytes(self._container))) 3382 3383 def _add_new_tf_operations(self, compute_devices=True): 3384 """Creates `Operations` in this graph for any new TF_Operations. 3385 3386 This is useful for when TF_Operations are indirectly created by the C API 3387 outside of the Operation constructor (e.g. by TF_ImportGraphDef, 3388 TF_FinishWhile). This ensures there are corresponding Operations for all 3389 TF_Operations in the underlying TF_Graph. 3390 3391 Args: 3392 compute_devices: (Optional.) If True, device functions will be executed 3393 to compute the device properties of each new Operation. 3394 3395 Returns: 3396 A list of the new `Operation` objects. 3397 """ 3398 # Create all Operation objects before accessing their inputs since an op may 3399 # be created before its inputs. 3400 new_ops = [ 3401 self._create_op_from_tf_operation(c_op, compute_device=compute_devices) 3402 for c_op in c_api_util.new_tf_operations(self) 3403 ] 3404 3405 for op in new_ops: 3406 new_control_inputs = self._control_dependencies_for_inputs(op.inputs) 3407 # pylint: disable=protected-access 3408 op._add_control_inputs(new_control_inputs) 3409 op._control_flow_post_processing() 3410 # pylint: enable=protected-access 3411 3412 return new_ops 3413 3414 def as_graph_element(self, obj, allow_tensor=True, allow_operation=True): 3415 """Returns the object referred to by `obj`, as an `Operation` or `Tensor`. 3416 3417 This function validates that `obj` represents an element of this 3418 graph, and gives an informative error message if it is not. 3419 3420 This function is the canonical way to get/validate an object of 3421 one of the allowed types from an external argument reference in the 3422 Session API. 3423 3424 This method may be called concurrently from multiple threads. 3425 3426 Args: 3427 obj: A `Tensor`, an `Operation`, or the name of a tensor or operation. 3428 Can also be any object with an `_as_graph_element()` method that returns 3429 a value of one of these types. 3430 allow_tensor: If true, `obj` may refer to a `Tensor`. 3431 allow_operation: If true, `obj` may refer to an `Operation`. 3432 3433 Returns: 3434 The `Tensor` or `Operation` in the Graph corresponding to `obj`. 3435 3436 Raises: 3437 TypeError: If `obj` is not a type we support attempting to convert 3438 to types. 3439 ValueError: If `obj` is of an appropriate type but invalid. For 3440 example, an invalid string. 3441 KeyError: If `obj` is not an object in the graph. 3442 """ 3443 if self._finalized: 3444 return self._as_graph_element_locked(obj, allow_tensor, allow_operation) 3445 3446 with self._lock: 3447 return self._as_graph_element_locked(obj, allow_tensor, allow_operation) 3448 3449 def _as_graph_element_locked(self, obj, allow_tensor, allow_operation): 3450 """See `Graph.as_graph_element()` for details.""" 3451 # The vast majority of this function is figuring 3452 # out what an API user might be doing wrong, so 3453 # that we can give helpful error messages. 3454 # 3455 # Ideally, it would be nice to split it up, but we 3456 # need context to generate nice error messages. 3457 3458 if allow_tensor and allow_operation: 3459 types_str = "Tensor or Operation" 3460 elif allow_tensor: 3461 types_str = "Tensor" 3462 elif allow_operation: 3463 types_str = "Operation" 3464 else: 3465 raise ValueError("allow_tensor and allow_operation can't both be False.") 3466 3467 temp_obj = _as_graph_element(obj) 3468 if temp_obj is not None: 3469 obj = temp_obj 3470 3471 # If obj appears to be a name... 3472 if isinstance(obj, compat.bytes_or_text_types): 3473 name = compat.as_str(obj) 3474 3475 if ":" in name and allow_tensor: 3476 # Looks like a Tensor name and can be a Tensor. 3477 try: 3478 op_name, out_n = name.split(":") 3479 out_n = int(out_n) 3480 except: 3481 raise ValueError("The name %s looks a like a Tensor name, but is " 3482 "not a valid one. Tensor names must be of the " 3483 "form \"<op_name>:<output_index>\"." % repr(name)) 3484 if op_name in self._nodes_by_name: 3485 op = self._nodes_by_name[op_name] 3486 else: 3487 raise KeyError("The name %s refers to a Tensor which does not " 3488 "exist. The operation, %s, does not exist in the " 3489 "graph." % (repr(name), repr(op_name))) 3490 try: 3491 return op.outputs[out_n] 3492 except: 3493 raise KeyError("The name %s refers to a Tensor which does not " 3494 "exist. The operation, %s, exists but only has " 3495 "%s outputs." % (repr(name), repr(op_name), 3496 len(op.outputs))) 3497 3498 elif ":" in name and not allow_tensor: 3499 # Looks like a Tensor name but can't be a Tensor. 3500 raise ValueError("Name %s appears to refer to a Tensor, not a %s." % 3501 (repr(name), types_str)) 3502 3503 elif ":" not in name and allow_operation: 3504 # Looks like an Operation name and can be an Operation. 3505 if name not in self._nodes_by_name: 3506 raise KeyError("The name %s refers to an Operation not in the " 3507 "graph." % repr(name)) 3508 return self._nodes_by_name[name] 3509 3510 elif ":" not in name and not allow_operation: 3511 # Looks like an Operation name but can't be an Operation. 3512 if name in self._nodes_by_name: 3513 # Yep, it's an Operation name 3514 err_msg = ("The name %s refers to an Operation, not a %s." % 3515 (repr(name), types_str)) 3516 else: 3517 err_msg = ("The name %s looks like an (invalid) Operation name, " 3518 "not a %s." % (repr(name), types_str)) 3519 err_msg += (" Tensor names must be of the form " 3520 "\"<op_name>:<output_index>\".") 3521 raise ValueError(err_msg) 3522 3523 elif isinstance(obj, Tensor) and allow_tensor: 3524 # Actually obj is just the object it's referring to. 3525 if obj.graph is not self: 3526 raise ValueError("Tensor %s is not an element of this graph." % obj) 3527 return obj 3528 elif isinstance(obj, Operation) and allow_operation: 3529 # Actually obj is just the object it's referring to. 3530 if obj.graph is not self: 3531 raise ValueError("Operation %s is not an element of this graph." % obj) 3532 return obj 3533 else: 3534 # We give up! 3535 raise TypeError("Can not convert a %s into a %s." % (type(obj).__name__, 3536 types_str)) 3537 3538 def get_operations(self): 3539 """Return the list of operations in the graph. 3540 3541 You can modify the operations in place, but modifications 3542 to the list such as inserts/delete have no effect on the 3543 list of operations known to the graph. 3544 3545 This method may be called concurrently from multiple threads. 3546 3547 Returns: 3548 A list of Operations. 3549 """ 3550 if self._finalized: 3551 return list(self._nodes_by_id.values()) 3552 3553 with self._lock: 3554 return list(self._nodes_by_id.values()) 3555 3556 def get_operation_by_name(self, name): 3557 """Returns the `Operation` with the given `name`. 3558 3559 This method may be called concurrently from multiple threads. 3560 3561 Args: 3562 name: The name of the `Operation` to return. 3563 3564 Returns: 3565 The `Operation` with the given `name`. 3566 3567 Raises: 3568 TypeError: If `name` is not a string. 3569 KeyError: If `name` does not correspond to an operation in this graph. 3570 """ 3571 3572 if not isinstance(name, six.string_types): 3573 raise TypeError("Operation names are strings (or similar), not %s." % 3574 type(name).__name__) 3575 return self.as_graph_element(name, allow_tensor=False, allow_operation=True) 3576 3577 def _get_operation_by_name_unsafe(self, name): 3578 """Returns the `Operation` with the given `name`. 3579 3580 This is a internal unsafe version of get_operation_by_name. It skips many 3581 checks and does not have user friedly error messages but runs considerably 3582 faster. This method may be called concurrently from multiple threads. 3583 3584 Args: 3585 name: The name of the `Operation` to return. 3586 3587 Returns: 3588 The `Operation` with the given `name`. 3589 3590 Raises: 3591 KeyError: If `name` does not correspond to an operation in this graph. 3592 """ 3593 3594 if self._finalized: 3595 return self._nodes_by_name[name] 3596 3597 with self._lock: 3598 return self._nodes_by_name[name] 3599 3600 def _get_operation_by_tf_operation(self, tf_oper): 3601 op_name = c_api.TF_OperationName(tf_oper) 3602 return self._get_operation_by_name_unsafe(op_name) 3603 3604 def get_tensor_by_name(self, name): 3605 """Returns the `Tensor` with the given `name`. 3606 3607 This method may be called concurrently from multiple threads. 3608 3609 Args: 3610 name: The name of the `Tensor` to return. 3611 3612 Returns: 3613 The `Tensor` with the given `name`. 3614 3615 Raises: 3616 TypeError: If `name` is not a string. 3617 KeyError: If `name` does not correspond to a tensor in this graph. 3618 """ 3619 # Names should be strings. 3620 if not isinstance(name, six.string_types): 3621 raise TypeError("Tensor names are strings (or similar), not %s." % 3622 type(name).__name__) 3623 return self.as_graph_element(name, allow_tensor=True, allow_operation=False) 3624 3625 def _get_tensor_by_tf_output(self, tf_output): 3626 """Returns the `Tensor` representing `tf_output`. 3627 3628 Note that there is only one such `Tensor`, i.e. multiple calls to this 3629 function with the same TF_Output value will always return the same `Tensor` 3630 object. 3631 3632 Args: 3633 tf_output: A wrapped `TF_Output` (the C API equivalent of `Tensor`). 3634 3635 Returns: 3636 The `Tensor` that represents `tf_output`. 3637 """ 3638 op = self._get_operation_by_tf_operation(tf_output.oper) 3639 return op.outputs[tf_output.index] 3640 3641 def _next_id(self): 3642 """Id for next Operation instance. Also increments the internal id.""" 3643 self._check_not_finalized() 3644 with self._lock: 3645 self._next_id_counter += 1 3646 return self._next_id_counter 3647 3648 @property 3649 def _last_id(self): 3650 return self._next_id_counter 3651 3652 def _get_op_def(self, type): # pylint: disable=redefined-builtin 3653 """Returns the `OpDef` proto for `type`. `type` is a string.""" 3654 if self._c_graph: 3655 with c_api_util.tf_buffer() as buf: 3656 with errors.raise_exception_on_not_ok_status() as status: 3657 # pylint: disable=protected-access 3658 c_api.TF_GraphGetOpDef(self._c_graph, 3659 compat.as_bytes(type), buf, status) 3660 # pylint: enable=protected-access 3661 data = c_api.TF_GetBuffer(buf) 3662 op_def = op_def_pb2.OpDef() 3663 op_def.ParseFromString(compat.as_bytes(data)) 3664 return op_def 3665 else: 3666 return self._registered_ops[type] 3667 3668 def as_default(self): 3669 """Returns a context manager that makes this `Graph` the default graph. 3670 3671 This method should be used if you want to create multiple graphs 3672 in the same process. For convenience, a global default graph is 3673 provided, and all ops will be added to this graph if you do not 3674 create a new graph explicitly. Use this method with the `with` keyword 3675 to specify that ops created within the scope of a block should be 3676 added to this graph. 3677 3678 The default graph is a property of the current thread. If you 3679 create a new thread, and wish to use the default graph in that 3680 thread, you must explicitly add a `with g.as_default():` in that 3681 thread's function. 3682 3683 The following code examples are equivalent: 3684 3685 ```python 3686 # 1. Using Graph.as_default(): 3687 g = tf.Graph() 3688 with g.as_default(): 3689 c = tf.constant(5.0) 3690 assert c.graph is g 3691 3692 # 2. Constructing and making default: 3693 with tf.Graph().as_default() as g: 3694 c = tf.constant(5.0) 3695 assert c.graph is g 3696 ``` 3697 3698 Returns: 3699 A context manager for using this graph as the default graph. 3700 """ 3701 return _default_graph_stack.get_controller(self) 3702 3703 @property 3704 def collections(self): 3705 """Returns the names of the collections known to this graph.""" 3706 return list(self._collections) 3707 3708 def add_to_collection(self, name, value): 3709 """Stores `value` in the collection with the given `name`. 3710 3711 Note that collections are not sets, so it is possible to add a value to 3712 a collection several times. 3713 3714 Args: 3715 name: The key for the collection. The `GraphKeys` class 3716 contains many standard names for collections. 3717 value: The value to add to the collection. 3718 """ # pylint: disable=g-doc-exception 3719 _assert_collection_is_ok(name) 3720 self._check_not_finalized() 3721 with self._lock: 3722 if name not in self._collections: 3723 self._collections[name] = [value] 3724 else: 3725 self._collections[name].append(value) 3726 3727 def add_to_collections(self, names, value): 3728 """Stores `value` in the collections given by `names`. 3729 3730 Note that collections are not sets, so it is possible to add a value to 3731 a collection several times. This function makes sure that duplicates in 3732 `names` are ignored, but it will not check for pre-existing membership of 3733 `value` in any of the collections in `names`. 3734 3735 `names` can be any iterable, but if `names` is a string, it is treated as a 3736 single collection name. 3737 3738 Args: 3739 names: The keys for the collections to add to. The `GraphKeys` class 3740 contains many standard names for collections. 3741 value: The value to add to the collections. 3742 """ 3743 # Make sure names are unique, but treat strings as a single collection name 3744 names = (names,) if isinstance(names, six.string_types) else set(names) 3745 for name in names: 3746 self.add_to_collection(name, value) 3747 3748 def get_collection_ref(self, name): 3749 """Returns a list of values in the collection with the given `name`. 3750 3751 If the collection exists, this returns the list itself, which can 3752 be modified in place to change the collection. If the collection does 3753 not exist, it is created as an empty list and the list is returned. 3754 3755 This is different from `get_collection()` which always returns a copy of 3756 the collection list if it exists and never creates an empty collection. 3757 3758 Args: 3759 name: The key for the collection. For example, the `GraphKeys` class 3760 contains many standard names for collections. 3761 3762 Returns: 3763 The list of values in the collection with the given `name`, or an empty 3764 list if no value has been added to that collection. 3765 """ # pylint: disable=g-doc-exception 3766 _assert_collection_is_ok(name) 3767 with self._lock: 3768 coll_list = self._collections.get(name, None) 3769 if coll_list is None: 3770 coll_list = [] 3771 self._collections[name] = coll_list 3772 return coll_list 3773 3774 def get_collection(self, name, scope=None): 3775 """Returns a list of values in the collection with the given `name`. 3776 3777 This is different from `get_collection_ref()` which always returns the 3778 actual collection list if it exists in that it returns a new list each time 3779 it is called. 3780 3781 Args: 3782 name: The key for the collection. For example, the `GraphKeys` class 3783 contains many standard names for collections. 3784 scope: (Optional.) A string. If supplied, the resulting list is filtered 3785 to include only items whose `name` attribute matches `scope` using 3786 `re.match`. Items without a `name` attribute are never returned if a 3787 scope is supplied. The choice of `re.match` means that a `scope` without 3788 special tokens filters by prefix. 3789 3790 Returns: 3791 The list of values in the collection with the given `name`, or 3792 an empty list if no value has been added to that collection. The 3793 list contains the values in the order under which they were 3794 collected. 3795 """ # pylint: disable=g-doc-exception 3796 _assert_collection_is_ok(name) 3797 with self._lock: 3798 collection = self._collections.get(name, None) 3799 if collection is None: 3800 return [] 3801 if scope is None: 3802 return list(collection) 3803 else: 3804 c = [] 3805 regex = re.compile(scope) 3806 for item in collection: 3807 if hasattr(item, "name") and regex.match(item.name): 3808 c.append(item) 3809 return c 3810 3811 def get_all_collection_keys(self): 3812 """Returns a list of collections used in this graph.""" 3813 with self._lock: 3814 return [x for x in self._collections if isinstance(x, six.string_types)] 3815 3816 def clear_collection(self, name): 3817 """Clears all values in a collection. 3818 3819 Args: 3820 name: The key for the collection. The `GraphKeys` class contains many 3821 standard names for collections. 3822 """ 3823 self._check_not_finalized() 3824 with self._lock: 3825 if name in self._collections: 3826 del self._collections[name] 3827 3828 @tf_contextlib.contextmanager 3829 def _original_op(self, op): 3830 """Python 'with' handler to help annotate ops with their originator. 3831 3832 An op may have an 'original_op' property that indicates the op on which 3833 it was based. For example a replica op is based on the op that was 3834 replicated and a gradient op is based on the op that was differentiated. 3835 3836 All ops created in the scope of this 'with' handler will have 3837 the given 'op' as their original op. 3838 3839 Args: 3840 op: The Operation that all ops created in this scope will have as their 3841 original op. 3842 3843 Yields: 3844 Nothing. 3845 """ 3846 old_original_op = self._default_original_op 3847 try: 3848 self._default_original_op = op 3849 yield 3850 finally: 3851 self._default_original_op = old_original_op 3852 3853 # pylint: disable=g-doc-return-or-yield,line-too-long 3854 @tf_contextlib.contextmanager 3855 def name_scope(self, name): 3856 r"""Returns a context manager that creates hierarchical names for operations. 3857 3858 A graph maintains a stack of name scopes. A `with name_scope(...):` 3859 statement pushes a new name onto the stack for the lifetime of the context. 3860 3861 The `name` argument will be interpreted as follows: 3862 3863 * A string (not ending with '/') will create a new name scope, in which 3864 `name` is appended to the prefix of all operations created in the 3865 context. If `name` has been used before, it will be made unique by 3866 calling `self.unique_name(name)`. 3867 * A scope previously captured from a `with g.name_scope(...) as 3868 scope:` statement will be treated as an "absolute" name scope, which 3869 makes it possible to re-enter existing scopes. 3870 * A value of `None` or the empty string will reset the current name scope 3871 to the top-level (empty) name scope. 3872 3873 For example: 3874 3875 ```python 3876 with tf.Graph().as_default() as g: 3877 c = tf.constant(5.0, name="c") 3878 assert c.op.name == "c" 3879 c_1 = tf.constant(6.0, name="c") 3880 assert c_1.op.name == "c_1" 3881 3882 # Creates a scope called "nested" 3883 with g.name_scope("nested") as scope: 3884 nested_c = tf.constant(10.0, name="c") 3885 assert nested_c.op.name == "nested/c" 3886 3887 # Creates a nested scope called "inner". 3888 with g.name_scope("inner"): 3889 nested_inner_c = tf.constant(20.0, name="c") 3890 assert nested_inner_c.op.name == "nested/inner/c" 3891 3892 # Create a nested scope called "inner_1". 3893 with g.name_scope("inner"): 3894 nested_inner_1_c = tf.constant(30.0, name="c") 3895 assert nested_inner_1_c.op.name == "nested/inner_1/c" 3896 3897 # Treats `scope` as an absolute name scope, and 3898 # switches to the "nested/" scope. 3899 with g.name_scope(scope): 3900 nested_d = tf.constant(40.0, name="d") 3901 assert nested_d.op.name == "nested/d" 3902 3903 with g.name_scope(""): 3904 e = tf.constant(50.0, name="e") 3905 assert e.op.name == "e" 3906 ``` 3907 3908 The name of the scope itself can be captured by `with 3909 g.name_scope(...) as scope:`, which stores the name of the scope 3910 in the variable `scope`. This value can be used to name an 3911 operation that represents the overall result of executing the ops 3912 in a scope. For example: 3913 3914 ```python 3915 inputs = tf.constant(...) 3916 with g.name_scope('my_layer') as scope: 3917 weights = tf.Variable(..., name="weights") 3918 biases = tf.Variable(..., name="biases") 3919 affine = tf.matmul(inputs, weights) + biases 3920 output = tf.nn.relu(affine, name=scope) 3921 ``` 3922 3923 NOTE: This constructor validates the given `name`. Valid scope 3924 names match one of the following regular expressions: 3925 3926 [A-Za-z0-9.][A-Za-z0-9_.\\-/]* (for scopes at the root) 3927 [A-Za-z0-9_.\\-/]* (for other scopes) 3928 3929 Args: 3930 name: A name for the scope. 3931 3932 Returns: 3933 A context manager that installs `name` as a new name scope. 3934 3935 Raises: 3936 ValueError: If `name` is not a valid scope name, according to the rules 3937 above. 3938 """ 3939 if name: 3940 if isinstance(name, compat.bytes_or_text_types): 3941 name = compat.as_str(name) 3942 3943 if self._name_stack: 3944 # Scopes created in a nested scope may have initial characters 3945 # that are illegal as the initial character of an op name 3946 # (viz. '-', '\', '/', and '_'). 3947 if not _VALID_SCOPE_NAME_REGEX.match(name): 3948 raise ValueError("'%s' is not a valid scope name" % name) 3949 else: 3950 # Scopes created in the root must match the more restrictive 3951 # op name regex, which constrains the initial character. 3952 if not _VALID_OP_NAME_REGEX.match(name): 3953 raise ValueError("'%s' is not a valid scope name" % name) 3954 try: 3955 old_stack = self._name_stack 3956 if not name: # Both for name=None and name="" we re-set to empty scope. 3957 new_stack = None 3958 elif name[-1] == "/": 3959 new_stack = _name_from_scope_name(name) 3960 else: 3961 new_stack = self.unique_name(name) 3962 self._name_stack = new_stack 3963 yield "" if new_stack is None else new_stack + "/" 3964 finally: 3965 self._name_stack = old_stack 3966 3967 # pylint: enable=g-doc-return-or-yield,line-too-long 3968 3969 def unique_name(self, name, mark_as_used=True): 3970 """Return a unique operation name for `name`. 3971 3972 Note: You rarely need to call `unique_name()` directly. Most of 3973 the time you just need to create `with g.name_scope()` blocks to 3974 generate structured names. 3975 3976 `unique_name` is used to generate structured names, separated by 3977 `"/"`, to help identify operations when debugging a graph. 3978 Operation names are displayed in error messages reported by the 3979 TensorFlow runtime, and in various visualization tools such as 3980 TensorBoard. 3981 3982 If `mark_as_used` is set to `True`, which is the default, a new 3983 unique name is created and marked as in use. If it's set to `False`, 3984 the unique name is returned without actually being marked as used. 3985 This is useful when the caller simply wants to know what the name 3986 to be created will be. 3987 3988 Args: 3989 name: The name for an operation. 3990 mark_as_used: Whether to mark this name as being used. 3991 3992 Returns: 3993 A string to be passed to `create_op()` that will be used 3994 to name the operation being created. 3995 """ 3996 if self._name_stack: 3997 name = self._name_stack + "/" + name 3998 i = self._names_in_use.get(name, 0) 3999 # Increment the number for "name". 4000 if mark_as_used: 4001 self._names_in_use[name] = i + 1 4002 if i > 0: 4003 base_name = name 4004 # Make sure the composed name is not already used. 4005 while name in self._names_in_use: 4006 name = "%s_%d" % (base_name, i) 4007 i += 1 4008 # Mark the composed name as used in case someone wants 4009 # to call unique_name("name_1"). 4010 if mark_as_used: 4011 self._names_in_use[name] = 1 4012 return name 4013 4014 def get_name_scope(self): 4015 """Returns the current name scope. 4016 4017 For example: 4018 4019 ```python 4020 with tf.name_scope('scope1'): 4021 with tf.name_scope('scope2'): 4022 print(tf.get_default_graph().get_name_scope()) 4023 ``` 4024 would print the string `scope1/scope2`. 4025 4026 Returns: 4027 A string representing the current name scope. 4028 """ 4029 return self._name_stack 4030 4031 @tf_contextlib.contextmanager 4032 def colocate_with(self, op, ignore_existing=False): 4033 """Returns a context manager that specifies an op to colocate with. 4034 4035 Note: this function is not for public use, only for internal libraries. 4036 4037 For example: 4038 4039 ```python 4040 a = tf.Variable([1.0]) 4041 with g.colocate_with(a): 4042 b = tf.constant(1.0) 4043 c = tf.add(a, b) 4044 ``` 4045 4046 `b` and `c` will always be colocated with `a`, no matter where `a` 4047 is eventually placed. 4048 4049 **NOTE** Using a colocation scope resets any existing device constraints. 4050 4051 If `op` is `None` then `ignore_existing` must be `True` and the new 4052 scope resets all colocation and device constraints. 4053 4054 Args: 4055 op: The op to colocate all created ops with, or `None`. 4056 ignore_existing: If true, only applies colocation of this op within 4057 the context, rather than applying all colocation properties 4058 on the stack. If `op` is `None`, this value must be `True`. 4059 4060 Raises: 4061 ValueError: if op is None but ignore_existing is False. 4062 4063 Yields: 4064 A context manager that specifies the op with which to colocate 4065 newly created ops. 4066 4067 """ 4068 if op is None and not ignore_existing: 4069 raise ValueError("Trying to reset colocation (op is None) but " 4070 "ignore_existing is not True") 4071 4072 if op is not None and not isinstance(op, Operation): 4073 # We always want to colocate with the reference op. 4074 op = internal_convert_to_tensor_or_indexed_slices(op, as_ref=True).op 4075 4076 # By default, colocate_with resets the device function stack, 4077 # since colocate_with is typically used in specific internal 4078 # library functions where colocation is intended to be "stronger" 4079 # than device functions. 4080 # 4081 # In the future, a caller may specify that device_functions win 4082 # over colocation, in which case we can add support. 4083 device_fn_tmp = self._device_function_stack 4084 self._device_function_stack = [] 4085 4086 if ignore_existing: 4087 current_stack = self._colocation_stack 4088 self._colocation_stack = [] 4089 4090 if op is not None: 4091 self._colocation_stack.append(op) 4092 4093 try: 4094 yield 4095 finally: 4096 # Restore device function stack 4097 self._device_function_stack = device_fn_tmp 4098 if op is not None: 4099 self._colocation_stack.pop() 4100 4101 # Reset the colocation stack if requested. 4102 if ignore_existing: 4103 self._colocation_stack = current_stack 4104 4105 @tf_contextlib.contextmanager 4106 def device(self, device_name_or_function): 4107 # pylint: disable=line-too-long 4108 """Returns a context manager that specifies the default device to use. 4109 4110 The `device_name_or_function` argument may either be a device name 4111 string, a device function, or None: 4112 4113 * If it is a device name string, all operations constructed in 4114 this context will be assigned to the device with that name, unless 4115 overridden by a nested `device()` context. 4116 * If it is a function, it will be treated as a function from 4117 Operation objects to device name strings, and invoked each time 4118 a new Operation is created. The Operation will be assigned to 4119 the device with the returned name. 4120 * If it is None, all `device()` invocations from the enclosing context 4121 will be ignored. 4122 4123 For information about the valid syntax of device name strings, see 4124 the documentation in 4125 [`DeviceNameUtils`](https://www.tensorflow.org/code/tensorflow/core/util/device_name_utils.h). 4126 4127 For example: 4128 4129 ```python 4130 with g.device('/device:GPU:0'): 4131 # All operations constructed in this context will be placed 4132 # on GPU 0. 4133 with g.device(None): 4134 # All operations constructed in this context will have no 4135 # assigned device. 4136 4137 # Defines a function from `Operation` to device string. 4138 def matmul_on_gpu(n): 4139 if n.type == "MatMul": 4140 return "/device:GPU:0" 4141 else: 4142 return "/cpu:0" 4143 4144 with g.device(matmul_on_gpu): 4145 # All operations of type "MatMul" constructed in this context 4146 # will be placed on GPU 0; all other operations will be placed 4147 # on CPU 0. 4148 ``` 4149 4150 **N.B.** The device scope may be overridden by op wrappers or 4151 other library code. For example, a variable assignment op 4152 `v.assign()` must be colocated with the `tf.Variable` `v`, and 4153 incompatible device scopes will be ignored. 4154 4155 Args: 4156 device_name_or_function: The device name or function to use in 4157 the context. 4158 4159 Yields: 4160 A context manager that specifies the default device to use for newly 4161 created ops. 4162 4163 """ 4164 # pylint: enable=line-too-long 4165 if (device_name_or_function is not None and 4166 not callable(device_name_or_function)): 4167 device_function = pydev.merge_device(device_name_or_function) 4168 else: 4169 device_function = device_name_or_function 4170 4171 try: 4172 self._device_function_stack.append(device_function) 4173 yield 4174 finally: 4175 self._device_function_stack.pop() 4176 4177 def _apply_device_functions(self, op): 4178 """Applies the current device function stack to the given operation.""" 4179 # Apply any device functions in reverse order, so that the most recently 4180 # pushed function has the first chance to apply a device to the op. 4181 # We apply here because the result can depend on the Operation's 4182 # signature, which is computed in the Operation constructor. 4183 for device_function in reversed(self._device_function_stack): 4184 if device_function is None: 4185 break 4186 op._set_device(device_function(op)) # pylint: disable=protected-access 4187 4188 # pylint: disable=g-doc-return-or-yield 4189 @tf_contextlib.contextmanager 4190 def container(self, container_name): 4191 """Returns a context manager that specifies the resource container to use. 4192 4193 Stateful operations, such as variables and queues, can maintain their 4194 states on devices so that they can be shared by multiple processes. 4195 A resource container is a string name under which these stateful 4196 operations are tracked. These resources can be released or cleared 4197 with `tf.Session.reset()`. 4198 4199 For example: 4200 4201 ```python 4202 with g.container('experiment0'): 4203 # All stateful Operations constructed in this context will be placed 4204 # in resource container "experiment0". 4205 v1 = tf.Variable([1.0]) 4206 v2 = tf.Variable([2.0]) 4207 with g.container("experiment1"): 4208 # All stateful Operations constructed in this context will be 4209 # placed in resource container "experiment1". 4210 v3 = tf.Variable([3.0]) 4211 q1 = tf.FIFOQueue(10, tf.float32) 4212 # All stateful Operations constructed in this context will be 4213 # be created in the "experiment0". 4214 v4 = tf.Variable([4.0]) 4215 q1 = tf.FIFOQueue(20, tf.float32) 4216 with g.container(""): 4217 # All stateful Operations constructed in this context will be 4218 # be placed in the default resource container. 4219 v5 = tf.Variable([5.0]) 4220 q3 = tf.FIFOQueue(30, tf.float32) 4221 4222 # Resets container "experiment0", after which the state of v1, v2, v4, q1 4223 # will become undefined (such as uninitialized). 4224 tf.Session.reset(target, ["experiment0"]) 4225 ``` 4226 4227 Args: 4228 container_name: container name string. 4229 4230 Returns: 4231 A context manager for defining resource containers for stateful ops, 4232 yields the container name. 4233 """ 4234 original_container = self._container 4235 try: 4236 self._container = container_name 4237 yield self._container 4238 finally: 4239 self._container = original_container 4240 4241 # pylint: enable=g-doc-return-or-yield 4242 4243 class _ControlDependenciesController(object): 4244 """Context manager for `control_dependencies()`.""" 4245 4246 def __init__(self, graph, control_inputs): 4247 """Create a new `_ControlDependenciesController`. 4248 4249 A `_ControlDependenciesController` is the context manager for 4250 `with tf.control_dependencies()` blocks. These normally nest, 4251 as described in the documentation for `control_dependencies()`. 4252 4253 The `control_inputs` argument list control dependencies that must be 4254 added to the current set of control dependencies. Because of 4255 uniquification the set can be empty even if the caller passed a list of 4256 ops. The special value `None` indicates that we want to start a new 4257 empty set of control dependencies instead of extending the current set. 4258 4259 In that case we also clear the current control flow context, which is an 4260 additional mechanism to add control dependencies. 4261 4262 Args: 4263 graph: The graph that this controller is managing. 4264 control_inputs: List of ops to use as control inputs in addition 4265 to the current control dependencies. None to indicate that 4266 the dependencies should be cleared. 4267 """ 4268 self._graph = graph 4269 if control_inputs is None: 4270 self._control_inputs_val = [] 4271 self._new_stack = True 4272 else: 4273 self._control_inputs_val = control_inputs 4274 self._new_stack = False 4275 self._seen_nodes = set() 4276 self._old_stack = None 4277 self._old_control_flow_context = None 4278 4279 # pylint: disable=protected-access 4280 4281 def __enter__(self): 4282 if self._new_stack: 4283 # Clear the control_dependencies graph. 4284 self._old_stack = self._graph._control_dependencies_stack 4285 self._graph._control_dependencies_stack = [] 4286 # Clear the control_flow_context too. 4287 self._old_control_flow_context = self._graph._get_control_flow_context() 4288 self._graph._set_control_flow_context(None) 4289 self._graph._push_control_dependencies_controller(self) 4290 4291 def __exit__(self, unused_type, unused_value, unused_traceback): 4292 self._graph._pop_control_dependencies_controller(self) 4293 if self._new_stack: 4294 self._graph._control_dependencies_stack = self._old_stack 4295 self._graph._set_control_flow_context(self._old_control_flow_context) 4296 4297 # pylint: enable=protected-access 4298 4299 @property 4300 def control_inputs(self): 4301 return self._control_inputs_val 4302 4303 def add_op(self, op): 4304 self._seen_nodes.add(op) 4305 4306 def op_in_group(self, op): 4307 return op in self._seen_nodes 4308 4309 def _push_control_dependencies_controller(self, controller): 4310 self._control_dependencies_stack.append(controller) 4311 4312 def _pop_control_dependencies_controller(self, controller): 4313 assert self._control_dependencies_stack[-1] is controller 4314 self._control_dependencies_stack.pop() 4315 4316 def _current_control_dependencies(self): 4317 ret = set() 4318 for controller in self._control_dependencies_stack: 4319 for op in controller.control_inputs: 4320 ret.add(op) 4321 return ret 4322 4323 def _control_dependencies_for_inputs(self, input_ops): 4324 """For an op that takes `input_ops` as inputs, compute control inputs. 4325 4326 The returned control dependencies should yield an execution that 4327 is equivalent to adding all control inputs in 4328 self._control_dependencies_stack to a newly created op. However, 4329 this function attempts to prune the returned control dependencies 4330 by observing that nodes created within the same `with 4331 control_dependencies(...):` block may have data dependencies that make 4332 the explicit approach redundant. 4333 4334 Args: 4335 input_ops: The data input ops for an op to be created. 4336 4337 Returns: 4338 A list of control inputs for the op to be created. 4339 """ 4340 ret = [] 4341 for controller in self._control_dependencies_stack: 4342 # If any of the input_ops already depends on the inputs from controller, 4343 # we say that the new op is dominated (by that input), and we therefore 4344 # do not need to add control dependencies for this controller's inputs. 4345 dominated = False 4346 for op in input_ops: 4347 if controller.op_in_group(op): 4348 dominated = True 4349 break 4350 if not dominated: 4351 # Don't add a control input if we already have a data dependency on i. 4352 # NOTE(mrry): We do not currently track transitive data dependencies, 4353 # so we may add redundant control inputs. 4354 ret.extend([c for c in controller.control_inputs if c not in input_ops]) 4355 return ret 4356 4357 def _record_op_seen_by_control_dependencies(self, op): 4358 """Record that the given op depends on all registered control dependencies. 4359 4360 Args: 4361 op: An Operation. 4362 """ 4363 for controller in self._control_dependencies_stack: 4364 controller.add_op(op) 4365 4366 def control_dependencies(self, control_inputs): 4367 """Returns a context manager that specifies control dependencies. 4368 4369 Use with the `with` keyword to specify that all operations constructed 4370 within the context should have control dependencies on 4371 `control_inputs`. For example: 4372 4373 ```python 4374 with g.control_dependencies([a, b, c]): 4375 # `d` and `e` will only run after `a`, `b`, and `c` have executed. 4376 d = ... 4377 e = ... 4378 ``` 4379 4380 Multiple calls to `control_dependencies()` can be nested, and in 4381 that case a new `Operation` will have control dependencies on the union 4382 of `control_inputs` from all active contexts. 4383 4384 ```python 4385 with g.control_dependencies([a, b]): 4386 # Ops constructed here run after `a` and `b`. 4387 with g.control_dependencies([c, d]): 4388 # Ops constructed here run after `a`, `b`, `c`, and `d`. 4389 ``` 4390 4391 You can pass None to clear the control dependencies: 4392 4393 ```python 4394 with g.control_dependencies([a, b]): 4395 # Ops constructed here run after `a` and `b`. 4396 with g.control_dependencies(None): 4397 # Ops constructed here run normally, not waiting for either `a` or `b`. 4398 with g.control_dependencies([c, d]): 4399 # Ops constructed here run after `c` and `d`, also not waiting 4400 # for either `a` or `b`. 4401 ``` 4402 4403 *N.B.* The control dependencies context applies *only* to ops that 4404 are constructed within the context. Merely using an op or tensor 4405 in the context does not add a control dependency. The following 4406 example illustrates this point: 4407 4408 ```python 4409 # WRONG 4410 def my_func(pred, tensor): 4411 t = tf.matmul(tensor, tensor) 4412 with tf.control_dependencies([pred]): 4413 # The matmul op is created outside the context, so no control 4414 # dependency will be added. 4415 return t 4416 4417 # RIGHT 4418 def my_func(pred, tensor): 4419 with tf.control_dependencies([pred]): 4420 # The matmul op is created in the context, so a control dependency 4421 # will be added. 4422 return tf.matmul(tensor, tensor) 4423 ``` 4424 4425 Args: 4426 control_inputs: A list of `Operation` or `Tensor` objects which 4427 must be executed or computed before running the operations 4428 defined in the context. Can also be `None` to clear the control 4429 dependencies. 4430 4431 Returns: 4432 A context manager that specifies control dependencies for all 4433 operations constructed within the context. 4434 4435 Raises: 4436 TypeError: If `control_inputs` is not a list of `Operation` or 4437 `Tensor` objects. 4438 """ 4439 if control_inputs is None: 4440 return self._ControlDependenciesController(self, None) 4441 # First convert the inputs to ops, and deduplicate them. 4442 # NOTE(mrry): Other than deduplication, we do not currently track direct 4443 # or indirect dependencies between control_inputs, which may result in 4444 # redundant control inputs. 4445 control_ops = [] 4446 current = self._current_control_dependencies() 4447 for c in control_inputs: 4448 if isinstance(c, IndexedSlices): 4449 c = c.op 4450 c = self.as_graph_element(c) 4451 if isinstance(c, Tensor): 4452 c = c.op 4453 elif not isinstance(c, Operation): 4454 raise TypeError("Control input must be Operation or Tensor: %s" % c) 4455 if c not in current: 4456 control_ops.append(c) 4457 current.add(c) 4458 return self._ControlDependenciesController(self, control_ops) 4459 4460 # pylint: disable=g-doc-return-or-yield 4461 @tf_contextlib.contextmanager 4462 def _attr_scope(self, attr_map): 4463 """EXPERIMENTAL: A context manager for setting attributes on operators. 4464 4465 This context manager can be used to add additional 4466 attributes to operators within the scope of the context. 4467 4468 For example: 4469 4470 with ops.Graph().as_default() as g: 4471 f_1 = Foo() # No extra attributes 4472 with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}): 4473 f_2 = Foo() # Additional attribute _a=False 4474 with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}): 4475 f_3 = Foo() # Additional attribute _a=False 4476 with g._attr_scope({"_a": None}): 4477 f_4 = Foo() # No additional attributes. 4478 4479 Args: 4480 attr_map: A dictionary mapping attr name strings to 4481 AttrValue protocol buffers or None. 4482 4483 Returns: 4484 A context manager that sets the kernel label to be used for one or more 4485 ops created in that context. 4486 4487 Raises: 4488 TypeError: If attr_map is not a dictionary mapping 4489 strings to AttrValue protobufs. 4490 """ 4491 if not isinstance(attr_map, dict): 4492 raise TypeError("attr_map must be a dictionary mapping " 4493 "strings to AttrValue protocol buffers") 4494 # The saved_attrs dictionary stores any currently-set labels that 4495 # will be overridden by this context manager. 4496 saved_attrs = {} 4497 # Install the given attribute 4498 for name, attr in attr_map.items(): 4499 if not (isinstance(name, six.string_types) and 4500 (isinstance(attr, (type(None), attr_value_pb2.AttrValue)) or 4501 callable(attr))): 4502 raise TypeError("attr_map must be a dictionary mapping " 4503 "strings to AttrValue protocol buffers or " 4504 "callables that emit AttrValue protocol buffers") 4505 try: 4506 saved_attrs[name] = self._attr_scope_map[name] 4507 except KeyError: 4508 pass 4509 if attr is None: 4510 del self._attr_scope_map[name] 4511 else: 4512 self._attr_scope_map[name] = attr 4513 try: 4514 yield # The code within the context runs here. 4515 finally: 4516 # Remove the attributes set for this context, and restore any saved 4517 # attributes. 4518 for name, attr in attr_map.items(): 4519 try: 4520 self._attr_scope_map[name] = saved_attrs[name] 4521 except KeyError: 4522 del self._attr_scope_map[name] 4523 4524 # pylint: enable=g-doc-return-or-yield 4525 4526 # pylint: disable=g-doc-return-or-yield 4527 @tf_contextlib.contextmanager 4528 def _kernel_label_map(self, op_to_kernel_label_map): 4529 """EXPERIMENTAL: A context manager for setting kernel labels. 4530 4531 This context manager can be used to select particular 4532 implementations of kernels within the scope of the context. 4533 4534 For example: 4535 4536 with ops.Graph().as_default() as g: 4537 f_1 = Foo() # Uses the default registered kernel for the Foo op. 4538 with g.kernel_label_map({"Foo": "v_2"}): 4539 f_2 = Foo() # Uses the registered kernel with label "v_2" 4540 # for the Foo op. 4541 with g.kernel_label_map({"Foo": "v_3"}): 4542 f_3 = Foo() # Uses the registered kernel with label "v_3" 4543 # for the Foo op. 4544 with g.kernel_label_map({"Foo": ""}): 4545 f_4 = Foo() # Uses the default registered kernel 4546 # for the Foo op. 4547 4548 Args: 4549 op_to_kernel_label_map: A dictionary mapping op type strings to 4550 kernel label strings. 4551 4552 Returns: 4553 A context manager that sets the kernel label to be used for one or more 4554 ops created in that context. 4555 4556 Raises: 4557 TypeError: If op_to_kernel_label_map is not a dictionary mapping 4558 strings to strings. 4559 """ 4560 if not isinstance(op_to_kernel_label_map, dict): 4561 raise TypeError("op_to_kernel_label_map must be a dictionary mapping " 4562 "strings to strings") 4563 # The saved_labels dictionary stores any currently-set labels that 4564 # will be overridden by this context manager. 4565 saved_labels = {} 4566 # Install the given label 4567 for op_type, label in op_to_kernel_label_map.items(): 4568 if not (isinstance(op_type, six.string_types) and 4569 isinstance(label, six.string_types)): 4570 raise TypeError("op_to_kernel_label_map must be a dictionary mapping " 4571 "strings to strings") 4572 try: 4573 saved_labels[op_type] = self._op_to_kernel_label_map[op_type] 4574 except KeyError: 4575 pass 4576 self._op_to_kernel_label_map[op_type] = label 4577 try: 4578 yield # The code within the context runs here. 4579 finally: 4580 # Remove the labels set for this context, and restore any saved labels. 4581 for op_type, label in op_to_kernel_label_map.items(): 4582 try: 4583 self._op_to_kernel_label_map[op_type] = saved_labels[op_type] 4584 except KeyError: 4585 del self._op_to_kernel_label_map[op_type] 4586 4587 # pylint: enable=g-doc-return-or-yield 4588 4589 # pylint: disable=g-doc-return-or-yield 4590 @tf_contextlib.contextmanager 4591 def gradient_override_map(self, op_type_map): 4592 """EXPERIMENTAL: A context manager for overriding gradient functions. 4593 4594 This context manager can be used to override the gradient function 4595 that will be used for ops within the scope of the context. 4596 4597 For example: 4598 4599 ```python 4600 @tf.RegisterGradient("CustomSquare") 4601 def _custom_square_grad(op, grad): 4602 # ... 4603 4604 with tf.Graph().as_default() as g: 4605 c = tf.constant(5.0) 4606 s_1 = tf.square(c) # Uses the default gradient for tf.square. 4607 with g.gradient_override_map({"Square": "CustomSquare"}): 4608 s_2 = tf.square(s_2) # Uses _custom_square_grad to compute the 4609 # gradient of s_2. 4610 ``` 4611 4612 Args: 4613 op_type_map: A dictionary mapping op type strings to alternative op 4614 type strings. 4615 4616 Returns: 4617 A context manager that sets the alternative op type to be used for one 4618 or more ops created in that context. 4619 4620 Raises: 4621 TypeError: If `op_type_map` is not a dictionary mapping strings to 4622 strings. 4623 """ 4624 if not isinstance(op_type_map, dict): 4625 raise TypeError("op_type_map must be a dictionary mapping " 4626 "strings to strings") 4627 # The saved_mappings dictionary stores any currently-set mappings that 4628 # will be overridden by this context manager. 4629 saved_mappings = {} 4630 # Install the given label 4631 for op_type, mapped_op_type in op_type_map.items(): 4632 if not (isinstance(op_type, six.string_types) and 4633 isinstance(mapped_op_type, six.string_types)): 4634 raise TypeError("op_type_map must be a dictionary mapping " 4635 "strings to strings") 4636 try: 4637 saved_mappings[op_type] = self._gradient_override_map[op_type] 4638 except KeyError: 4639 pass 4640 self._gradient_override_map[op_type] = mapped_op_type 4641 try: 4642 yield # The code within the context runs here. 4643 finally: 4644 # Remove the labels set for this context, and restore any saved labels. 4645 for op_type, mapped_op_type in op_type_map.items(): 4646 try: 4647 self._gradient_override_map[op_type] = saved_mappings[op_type] 4648 except KeyError: 4649 del self._gradient_override_map[op_type] 4650 4651 # pylint: enable=g-doc-return-or-yield 4652 4653 def prevent_feeding(self, tensor): 4654 """Marks the given `tensor` as unfeedable in this graph.""" 4655 self._unfeedable_tensors.add(tensor) 4656 4657 def is_feedable(self, tensor): 4658 """Returns `True` if and only if `tensor` is feedable.""" 4659 return tensor not in self._unfeedable_tensors 4660 4661 def prevent_fetching(self, op): 4662 """Marks the given `op` as unfetchable in this graph.""" 4663 self._unfetchable_ops.add(op) 4664 4665 def is_fetchable(self, tensor_or_op): 4666 """Returns `True` if and only if `tensor_or_op` is fetchable.""" 4667 if isinstance(tensor_or_op, Tensor): 4668 return tensor_or_op.op not in self._unfetchable_ops 4669 else: 4670 return tensor_or_op not in self._unfetchable_ops 4671 4672 4673 # TODO(agarwal): currently device directives in an outer eager scope will not 4674 # apply to inner graph mode code. Fix that. 4675 4676 4677 @tf_export("device") 4678 def device(device_name_or_function): 4679 """Wrapper for `Graph.device()` using the default graph. 4680 4681 See 4682 @{tf.Graph.device} 4683 for more details. 4684 4685 Args: 4686 device_name_or_function: The device name or function to use in 4687 the context. 4688 4689 Returns: 4690 A context manager that specifies the default device to use for newly 4691 created ops. 4692 4693 Raises: 4694 RuntimeError: If eager execution is enabled and a function is passed in. 4695 """ 4696 if context.in_graph_mode(): 4697 return get_default_graph().device(device_name_or_function) 4698 else: 4699 # TODO(agarwal): support device functions in EAGER mode. 4700 if callable(device_name_or_function): 4701 raise RuntimeError( 4702 "tf.device does not support functions when eager execution " 4703 "is enabled.") 4704 return context.device(device_name_or_function) 4705 4706 4707 @tf_export("container") 4708 def container(container_name): 4709 """Wrapper for `Graph.container()` using the default graph. 4710 4711 Args: 4712 container_name: The container string to use in the context. 4713 4714 Returns: 4715 A context manager that specifies the default container to use for newly 4716 created stateful ops. 4717 """ 4718 return get_default_graph().container(container_name) 4719 4720 4721 @tf_export("colocate_with") 4722 def colocate_with(op, ignore_existing=False): 4723 if context.in_graph_mode(): 4724 return get_default_graph().colocate_with(op, ignore_existing) 4725 else: 4726 if op is not None: 4727 return device(op.device) 4728 else: 4729 return _NullContextmanager() 4730 4731 4732 @tf_export("control_dependencies") 4733 def control_dependencies(control_inputs): 4734 """Wrapper for `Graph.control_dependencies()` using the default graph. 4735 4736 See @{tf.Graph.control_dependencies} 4737 for more details. 4738 4739 Args: 4740 control_inputs: A list of `Operation` or `Tensor` objects which 4741 must be executed or computed before running the operations 4742 defined in the context. Can also be `None` to clear the control 4743 dependencies. 4744 4745 Returns: 4746 A context manager that specifies control dependencies for all 4747 operations constructed within the context. 4748 """ 4749 if context.in_graph_mode(): 4750 return get_default_graph().control_dependencies(control_inputs) 4751 else: 4752 return _NullContextmanager() 4753 4754 4755 class _DefaultStack(threading.local): 4756 """A thread-local stack of objects for providing implicit defaults.""" 4757 4758 def __init__(self): 4759 super(_DefaultStack, self).__init__() 4760 self._enforce_nesting = True 4761 self.stack = [] 4762 4763 def get_default(self): 4764 return self.stack[-1] if len(self.stack) >= 1 else None 4765 4766 def reset(self): 4767 self.stack = [] 4768 4769 def is_cleared(self): 4770 return not self.stack 4771 4772 @property 4773 def enforce_nesting(self): 4774 return self._enforce_nesting 4775 4776 @enforce_nesting.setter 4777 def enforce_nesting(self, value): 4778 self._enforce_nesting = value 4779 4780 @tf_contextlib.contextmanager 4781 def get_controller(self, default): 4782 """A context manager for manipulating a default stack.""" 4783 try: 4784 self.stack.append(default) 4785 yield default 4786 finally: 4787 # stack may be empty if reset() was called 4788 if self.stack: 4789 if self._enforce_nesting: 4790 if self.stack[-1] is not default: 4791 raise AssertionError( 4792 "Nesting violated for default stack of %s objects" % 4793 type(default)) 4794 self.stack.pop() 4795 else: 4796 self.stack.remove(default) 4797 4798 4799 _default_session_stack = _DefaultStack() # pylint: disable=protected-access 4800 4801 4802 def default_session(session): 4803 """Python "with" handler for defining a default session. 4804 4805 This function provides a means of registering a session for handling 4806 Tensor.eval() and Operation.run() calls. It is primarily intended for use 4807 by session.Session, but can be used with any object that implements 4808 the Session.run() interface. 4809 4810 Use with the "with" keyword to specify that Tensor.eval() and Operation.run() 4811 invocations within the scope of a block should be executed by a particular 4812 session. 4813 4814 The default session applies to the current thread only, so it is always 4815 possible to inspect the call stack and determine the scope of a default 4816 session. If you create a new thread, and wish to use the default session 4817 in that thread, you must explicitly add a "with ops.default_session(sess):" 4818 block in that thread's function. 4819 4820 Example: 4821 The following code examples are equivalent: 4822 4823 # 1. Using the Session object directly: 4824 sess = ... 4825 c = tf.constant(5.0) 4826 sess.run(c) 4827 4828 # 2. Using default_session(): 4829 sess = ... 4830 with ops.default_session(sess): 4831 c = tf.constant(5.0) 4832 result = c.eval() 4833 4834 # 3. Overriding default_session(): 4835 sess = ... 4836 with ops.default_session(sess): 4837 c = tf.constant(5.0) 4838 with ops.default_session(...): 4839 c.eval(session=sess) 4840 4841 Args: 4842 session: The session to be installed as the default session. 4843 4844 Returns: 4845 A context manager for the default session. 4846 """ 4847 return _default_session_stack.get_controller(session) 4848 4849 4850 @tf_export("get_default_session") 4851 def get_default_session(): 4852 """Returns the default session for the current thread. 4853 4854 The returned `Session` will be the innermost session on which a 4855 `Session` or `Session.as_default()` context has been entered. 4856 4857 NOTE: The default session is a property of the current thread. If you 4858 create a new thread, and wish to use the default session in that 4859 thread, you must explicitly add a `with sess.as_default():` in that 4860 thread's function. 4861 4862 Returns: 4863 The default `Session` being used in the current thread. 4864 """ 4865 return _default_session_stack.get_default() 4866 4867 4868 def _eval_using_default_session(tensors, feed_dict, graph, session=None): 4869 """Uses the default session to evaluate one or more tensors. 4870 4871 Args: 4872 tensors: A single Tensor, or a list of Tensor objects. 4873 feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, 4874 numpy ndarrays, TensorProtos, or strings. 4875 graph: The graph in which the tensors are defined. 4876 session: (Optional) A different session to use to evaluate "tensors". 4877 4878 Returns: 4879 Either a single numpy ndarray if "tensors" is a single tensor; or a list 4880 of numpy ndarrays that each correspond to the respective element in 4881 "tensors". 4882 4883 Raises: 4884 ValueError: If no default session is available; the default session 4885 does not have "graph" as its graph; or if "session" is specified, 4886 and it does not have "graph" as its graph. 4887 """ 4888 if session is None: 4889 session = get_default_session() 4890 if session is None: 4891 raise ValueError("Cannot evaluate tensor using `eval()`: No default " 4892 "session is registered. Use `with " 4893 "sess.as_default()` or pass an explicit session to " 4894 "`eval(session=sess)`") 4895 if session.graph is not graph: 4896 raise ValueError("Cannot use the default session to evaluate tensor: " 4897 "the tensor's graph is different from the session's " 4898 "graph. Pass an explicit session to " 4899 "`eval(session=sess)`.") 4900 else: 4901 if session.graph is not graph: 4902 raise ValueError("Cannot use the given session to evaluate tensor: " 4903 "the tensor's graph is different from the session's " 4904 "graph.") 4905 return session.run(tensors, feed_dict) 4906 4907 4908 def _run_using_default_session(operation, feed_dict, graph, session=None): 4909 """Uses the default session to run "operation". 4910 4911 Args: 4912 operation: The Operation to be run. 4913 feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, 4914 numpy ndarrays, TensorProtos, or strings. 4915 graph: The graph in which "operation" is defined. 4916 session: (Optional) A different session to use to run "operation". 4917 4918 Raises: 4919 ValueError: If no default session is available; the default session 4920 does not have "graph" as its graph; or if "session" is specified, 4921 and it does not have "graph" as its graph. 4922 """ 4923 if session is None: 4924 session = get_default_session() 4925 if session is None: 4926 raise ValueError("Cannot execute operation using `run()`: No default " 4927 "session is registered. Use `with " 4928 "sess.as_default():` or pass an explicit session to " 4929 "`run(session=sess)`") 4930 if session.graph is not graph: 4931 raise ValueError("Cannot use the default session to execute operation: " 4932 "the operation's graph is different from the " 4933 "session's graph. Pass an explicit session to " 4934 "run(session=sess).") 4935 else: 4936 if session.graph is not graph: 4937 raise ValueError("Cannot use the given session to execute operation: " 4938 "the operation's graph is different from the session's " 4939 "graph.") 4940 session.run(operation, feed_dict) 4941 4942 4943 class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access 4944 """A thread-local stack of objects for providing an implicit default graph.""" 4945 4946 def __init__(self): 4947 super(_DefaultGraphStack, self).__init__() 4948 self._global_default_graph = None 4949 4950 def get_default(self): 4951 """Override that returns a global default if the stack is empty.""" 4952 ret = super(_DefaultGraphStack, self).get_default() 4953 if ret is None: 4954 ret = self._GetGlobalDefaultGraph() 4955 return ret 4956 4957 def _GetGlobalDefaultGraph(self): 4958 if self._global_default_graph is None: 4959 # TODO(mrry): Perhaps log that the default graph is being used, or set 4960 # provide some other feedback to prevent confusion when a mixture of 4961 # the global default graph and an explicit graph are combined in the 4962 # same process. 4963 self._global_default_graph = Graph() 4964 return self._global_default_graph 4965 4966 def reset(self): 4967 super(_DefaultGraphStack, self).reset() 4968 self._global_default_graph = None 4969 4970 @tf_contextlib.contextmanager 4971 def get_controller(self, default): 4972 try: 4973 context.context_stack.push(default.building_function, default.as_default) 4974 with super(_DefaultGraphStack, self).get_controller(default) as g: 4975 yield g 4976 finally: 4977 context.context_stack.pop() 4978 4979 4980 _default_graph_stack = _DefaultGraphStack() 4981 4982 4983 # pylint: disable=g-doc-return-or-yield,line-too-long 4984 @tf_contextlib.contextmanager 4985 def init_scope(): 4986 """A context manager that lifts ops out of control-flow scopes and function-building graphs. 4987 4988 There is often a need to lift variable initialization ops out of control-flow 4989 scopes, function-building graphs, and gradient tapes. Entering an 4990 `init_scope` is a mechanism for satisfying these desiderata. In particular, 4991 entering an `init_scope` has three effects: 4992 4993 (1) All control dependencies are cleared the moment the scope is entered; 4994 this is equivalent to entering the context manager returned from 4995 `control_dependencies(None)`, which has the side-effect of exiting 4996 control-flow scopes like `tf.cond` and `tf.while_loop`. 4997 4998 (2) All operations that are created while the scope is active are lifted 4999 into the lowest context on the `context_stack` that is not building a 5000 graph function. Here, a context is defined as either a graph or an eager 5001 context. Every context switch, i.e., every installation of a graph as 5002 the default graph and every switch into eager mode, is logged in a 5003 thread-local stack called the `context_stack`; the log entry for a 5004 context switch is popped from the stack when the context is exited. 5005 Entering an `init_scope` is equivalent to crawling up the 5006 `context_stack`, finding the first context that is not building a graph 5007 function, and entering it. A caveat is that if graph mode is enabled 5008 but the default graph stack is empty, then entering an `init_scope` 5009 will simply install a fresh graph as the default one. 5010 5011 (3) The gradient tape is paused while the scope is active. 5012 """ 5013 # pylint: enable=g-doc-return-or-yield,line-too-long 5014 5015 in_graph_mode = context.in_graph_mode() 5016 # Retrieve the active name scope: entering an `init_scope` preserves 5017 # the name scope of the current context. 5018 if in_graph_mode: 5019 default_graph = get_default_graph() 5020 scope = default_graph.get_name_scope() 5021 else: 5022 scope = context.context().scope_name 5023 if scope and scope[-1] != '/': 5024 # Names that end with trailing slashes are treated by `name_scope` as 5025 # absolute. 5026 scope = scope + '/' 5027 5028 outer_context = None 5029 if in_graph_mode and not _default_graph_stack.stack: 5030 outer_context = default_graph.as_default 5031 else: 5032 for stack_entry in reversed(context.context_stack.stack): 5033 if not stack_entry.is_building_function: 5034 outer_context = stack_entry.enter_context_fn 5035 break 5036 5037 if outer_context is None: 5038 raise AssertionError("All graphs are building functions, and no " 5039 "eager context was previously active.") 5040 5041 try: 5042 with outer_context(), name_scope(scope), control_dependencies( 5043 None), tape.stop_recording(): 5044 yield 5045 finally: 5046 pass 5047 5048 5049 def enable_eager_execution(config=None, device_policy=None): 5050 """Enables, for the rest of the lifetime of this program, eager execution. 5051 5052 If not called immediately on startup risks creating breakage and bugs. 5053 5054 Example: 5055 ```python 5056 tfe.enable_eager_execution() 5057 5058 # After eager execution is enabled, operations are executed as they are 5059 # defined and `Tensor`s hold concrete values, which can be accessed as 5060 # `numpy.ndarray`s through the `numpy()` method. 5061 assert tf.multiply(6, 7).numpy() == 42 5062 ``` 5063 5064 Args: 5065 config: (Optional.) A `ConfigProto` protocol buffer with configuration 5066 options for the Context. Note that a lot of these options may be 5067 currently unimplemented or irrelevant when eager execution is enabled. 5068 device_policy: (Optional.) What policy to use when trying to run an 5069 operation on a device with inputs which are not on that device. 5070 Valid values: 5071 tfe.DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not 5072 correct. 5073 tfe.DEVICE_PLACEMENT_WARN: copies the tensors which are not on the 5074 right device but raises a warning. 5075 tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might 5076 hide performance problems. 5077 tfe.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors, 5078 raising errors on the other ones. 5079 5080 Raises: 5081 ValueError: If trying to create a context after using graph operations 5082 or if trying to create a context with nontrivial options which differ 5083 from those of the existing context. 5084 """ 5085 if config is not None and not isinstance(config, config_pb2.ConfigProto): 5086 raise TypeError( 5087 "config must be a tf.ConfigProto, but got %s" % type(config)) 5088 if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT, 5089 context.DEVICE_PLACEMENT_WARN, 5090 context.DEVICE_PLACEMENT_SILENT, 5091 context.DEVICE_PLACEMENT_SILENT_FOR_INT32): 5092 raise ValueError( 5093 "device_policy must be one of None, tfe.DEVICE_PLACEMENT_*" 5094 ) 5095 # pylint: disable=protected-access 5096 if context._default_mode == context.GRAPH_MODE: 5097 graph_mode_has_been_used = ( 5098 _default_session_stack.stack or 5099 _default_graph_stack._global_default_graph is not None) 5100 if graph_mode_has_been_used: 5101 raise ValueError( 5102 "tfe.enable_eager_execution has to be called at program startup.") 5103 context._default_mode = context.EAGER_MODE 5104 if context._context is None: 5105 context._context = context.Context(config=config, 5106 device_policy=device_policy) 5107 if context.context_stack.stack: 5108 raise AssertionError("Invariant violated: The context stack must " 5109 "be empty when eager execution is enabled.") 5110 # Log that eager execution has been enabled by pushing an entry onto the 5111 # context stack; this entry won't ever be popped, as it's impossible to 5112 # disable eager execution 5113 context.context_stack.push(False, context.eager_mode) 5114 elif ((config is not None and config is not context._context._config) 5115 or (device_policy is not None 5116 and device_policy is not context._context._device_policy)): 5117 raise ValueError("Trying to change the options of an active eager" 5118 " execution. Context config: %s, specified config:" 5119 " %s. Context device policy: %s; specified device" 5120 " policy: %s." % (config, context._context._config, 5121 device_policy, 5122 context._context._device_policy)) 5123 else: 5124 raise ValueError( 5125 "tfe.enable_eager_execution has to be called at program startup.") 5126 5127 5128 def eager_run(main=None, argv=None): 5129 """Runs the program with an optional main function and argv list. 5130 5131 The program will run with eager execution enabled. 5132 5133 Example: 5134 ```python 5135 import tensorflow as tf 5136 # Import subject to future changes: 5137 from tensorflow.contrib.eager.python import tfe 5138 5139 def main(_): 5140 u = tf.constant(6.0) 5141 v = tf.constant(7.0) 5142 print(u * v) 5143 5144 if __name__ == "__main__": 5145 tfe.run() 5146 ``` 5147 5148 Args: 5149 main: the main function to run. 5150 argv: the arguments to pass to it. 5151 """ 5152 enable_eager_execution() 5153 app.run(main, argv) 5154 5155 5156 @tf_export("reset_default_graph") 5157 def reset_default_graph(): 5158 """Clears the default graph stack and resets the global default graph. 5159 5160 NOTE: The default graph is a property of the current thread. This 5161 function applies only to the current thread. Calling this function while 5162 a `tf.Session` or `tf.InteractiveSession` is active will result in undefined 5163 behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects 5164 after calling this function will result in undefined behavior. 5165 Raises: 5166 AssertionError: If this function is called within a nested graph. 5167 """ 5168 if not _default_graph_stack.is_cleared(): 5169 raise AssertionError("Do not use tf.reset_default_graph() to clear " 5170 "nested graphs. If you need a cleared graph, " 5171 "exit the nesting and create a new graph.") 5172 _default_graph_stack.reset() 5173 5174 5175 @tf_export("get_default_graph") 5176 def get_default_graph(): 5177 """Returns the default graph for the current thread. 5178 5179 The returned graph will be the innermost graph on which a 5180 `Graph.as_default()` context has been entered, or a global default 5181 graph if none has been explicitly created. 5182 5183 NOTE: The default graph is a property of the current thread. If you 5184 create a new thread, and wish to use the default graph in that 5185 thread, you must explicitly add a `with g.as_default():` in that 5186 thread's function. 5187 5188 Returns: 5189 The default `Graph` being used in the current thread. 5190 """ 5191 return _default_graph_stack.get_default() 5192 5193 5194 def get_name_scope(): 5195 """Returns the current name scope in the default_graph. 5196 5197 For example: 5198 5199 ```python 5200 with tf.name_scope('scope1'): 5201 with tf.name_scope('scope2'): 5202 print(tf.get_name_scope()) 5203 ``` 5204 would print the string `scope1/scope2`. 5205 5206 Returns: 5207 A string representing the current name scope. 5208 """ 5209 return get_default_graph().get_name_scope() 5210 5211 5212 def _assert_same_graph(original_item, item): 5213 """Fail if the 2 items are from different graphs. 5214 5215 Args: 5216 original_item: Original item to check against. 5217 item: Item to check. 5218 5219 Raises: 5220 ValueError: if graphs do not match. 5221 """ 5222 if original_item.graph is not item.graph: 5223 raise ValueError("%s must be from the same graph as %s." % (item, 5224 original_item)) 5225 5226 5227 def _get_graph_from_inputs(op_input_list, graph=None): 5228 """Returns the appropriate graph to use for the given inputs. 5229 5230 This library method provides a consistent algorithm for choosing the graph 5231 in which an Operation should be constructed: 5232 5233 1. If the default graph is being used to construct a function, we 5234 use the default graph. 5235 2. If the "graph" is specified explicitly, we validate that all of the inputs 5236 in "op_input_list" are compatible with that graph. 5237 3. Otherwise, we attempt to select a graph from the first Operation- 5238 or Tensor-valued input in "op_input_list", and validate that all other 5239 such inputs are in the same graph. 5240 4. If the graph was not specified and it could not be inferred from 5241 "op_input_list", we attempt to use the default graph. 5242 5243 Args: 5244 op_input_list: A list of inputs to an operation, which may include `Tensor`, 5245 `Operation`, and other objects that may be converted to a graph element. 5246 graph: (Optional) The explicit graph to use. 5247 5248 Raises: 5249 TypeError: If op_input_list is not a list or tuple, or if graph is not a 5250 Graph. 5251 ValueError: If a graph is explicitly passed and not all inputs are from it, 5252 or if the inputs are from multiple graphs, or we could not find a graph 5253 and there was no default graph. 5254 5255 Returns: 5256 The appropriate graph to use for the given inputs. 5257 5258 """ 5259 if get_default_graph().building_function: 5260 return get_default_graph() 5261 5262 op_input_list = tuple(op_input_list) # Handle generators correctly 5263 if graph and not isinstance(graph, Graph): 5264 raise TypeError("Input graph needs to be a Graph: %s" % graph) 5265 5266 # 1. We validate that all of the inputs are from the same graph. This is 5267 # either the supplied graph parameter, or the first one selected from one 5268 # the graph-element-valued inputs. In the latter case, we hold onto 5269 # that input in original_graph_element so we can provide a more 5270 # informative error if a mismatch is found. 5271 original_graph_element = None 5272 for op_input in op_input_list: 5273 # Determine if this is a valid graph_element. 5274 # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this 5275 # up. 5276 graph_element = None 5277 if (isinstance(op_input, (Operation, _TensorLike)) and 5278 ((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)): # pylint: disable=unidiomatic-typecheck 5279 graph_element = op_input 5280 else: 5281 graph_element = _as_graph_element(op_input) 5282 5283 if graph_element is not None: 5284 if not graph: 5285 original_graph_element = graph_element 5286 graph = graph_element.graph 5287 elif original_graph_element is not None: 5288 _assert_same_graph(original_graph_element, graph_element) 5289 elif graph_element.graph is not graph: 5290 raise ValueError("%s is not from the passed-in graph." % graph_element) 5291 5292 # 2. If all else fails, we use the default graph, which is always there. 5293 return graph or get_default_graph() 5294 5295 5296 @tf_export("GraphKeys") 5297 class GraphKeys(object): 5298 """Standard names to use for graph collections. 5299 5300 The standard library uses various well-known names to collect and 5301 retrieve values associated with a graph. For example, the 5302 `tf.Optimizer` subclasses default to optimizing the variables 5303 collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is 5304 specified, but it is also possible to pass an explicit list of 5305 variables. 5306 5307 The following standard keys are defined: 5308 5309 * `GLOBAL_VARIABLES`: the default collection of `Variable` objects, shared 5310 across distributed environment (model variables are subset of these). See 5311 @{tf.global_variables} 5312 for more details. 5313 Commonly, all `TRAINABLE_VARIABLES` variables will be in `MODEL_VARIABLES`, 5314 and all `MODEL_VARIABLES` variables will be in `GLOBAL_VARIABLES`. 5315 * `LOCAL_VARIABLES`: the subset of `Variable` objects that are local to each 5316 machine. Usually used for temporarily variables, like counters. 5317 Note: use `tf.contrib.framework.local_variable` to add to this collection. 5318 * `MODEL_VARIABLES`: the subset of `Variable` objects that are used in the 5319 model for inference (feed forward). Note: use 5320 `tf.contrib.framework.model_variable` to add to this collection. 5321 * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will 5322 be trained by an optimizer. See 5323 @{tf.trainable_variables} 5324 for more details. 5325 * `SUMMARIES`: the summary `Tensor` objects that have been created in the 5326 graph. See 5327 @{tf.summary.merge_all} 5328 for more details. 5329 * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to 5330 produce input for a computation. See 5331 @{tf.train.start_queue_runners} 5332 for more details. 5333 * `MOVING_AVERAGE_VARIABLES`: the subset of `Variable` objects that will also 5334 keep moving averages. See 5335 @{tf.moving_average_variables} 5336 for more details. 5337 * `REGULARIZATION_LOSSES`: regularization losses collected during graph 5338 construction. 5339 5340 The following standard keys are _defined_, but their collections are **not** 5341 automatically populated as many of the others are: 5342 5343 * `WEIGHTS` 5344 * `BIASES` 5345 * `ACTIVATIONS` 5346 """ 5347 5348 # Key to collect Variable objects that are global (shared across machines). 5349 # Default collection for all variables, except local ones. 5350 GLOBAL_VARIABLES = "variables" 5351 # Key to collect local variables that are local to the machine and are not 5352 # saved/restored. 5353 LOCAL_VARIABLES = "local_variables" 5354 # Key to collect local variables which are used to accumulate interal state 5355 # to be used in tf.metrics.*. 5356 METRIC_VARIABLES = "metric_variables" 5357 # Key to collect model variables defined by layers. 5358 MODEL_VARIABLES = "model_variables" 5359 # Key to collect Variable objects that will be trained by the 5360 # optimizers. 5361 TRAINABLE_VARIABLES = "trainable_variables" 5362 # Key to collect summaries. 5363 SUMMARIES = "summaries" 5364 # Key to collect QueueRunners. 5365 QUEUE_RUNNERS = "queue_runners" 5366 # Key to collect table initializers. 5367 TABLE_INITIALIZERS = "table_initializer" 5368 # Key to collect asset filepaths. An asset represents an external resource 5369 # like a vocabulary file. 5370 ASSET_FILEPATHS = "asset_filepaths" 5371 # Key to collect Variable objects that keep moving averages. 5372 MOVING_AVERAGE_VARIABLES = "moving_average_variables" 5373 # Key to collect regularization losses at graph construction. 5374 REGULARIZATION_LOSSES = "regularization_losses" 5375 # Key to collect concatenated sharded variables. 5376 CONCATENATED_VARIABLES = "concatenated_variables" 5377 # Key to collect savers. 5378 SAVERS = "savers" 5379 # Key to collect weights 5380 WEIGHTS = "weights" 5381 # Key to collect biases 5382 BIASES = "biases" 5383 # Key to collect activations 5384 ACTIVATIONS = "activations" 5385 # Key to collect update_ops 5386 UPDATE_OPS = "update_ops" 5387 # Key to collect losses 5388 LOSSES = "losses" 5389 # Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. 5390 SAVEABLE_OBJECTS = "saveable_objects" 5391 # Key to collect all shared resources used by the graph which need to be 5392 # initialized once per cluster. 5393 RESOURCES = "resources" 5394 # Key to collect all shared resources used in this graph which need to be 5395 # initialized once per session. 5396 LOCAL_RESOURCES = "local_resources" 5397 # Trainable resource-style variables. 5398 TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables" 5399 5400 # Key to indicate various ops. 5401 INIT_OP = "init_op" 5402 LOCAL_INIT_OP = "local_init_op" 5403 READY_OP = "ready_op" 5404 READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op" 5405 SUMMARY_OP = "summary_op" 5406 GLOBAL_STEP = "global_step" 5407 5408 # Used to count the number of evaluations performed during a single evaluation 5409 # run. 5410 EVAL_STEP = "eval_step" 5411 TRAIN_OP = "train_op" 5412 5413 # Key for control flow context. 5414 COND_CONTEXT = "cond_context" 5415 WHILE_CONTEXT = "while_context" 5416 5417 # Used to store v2 summary names. 5418 _SUMMARY_COLLECTION = "_SUMMARY_V2" 5419 5420 # List of all collections that keep track of variables. 5421 _VARIABLE_COLLECTIONS = [ 5422 GLOBAL_VARIABLES, 5423 LOCAL_VARIABLES, 5424 METRIC_VARIABLES, 5425 MODEL_VARIABLES, 5426 TRAINABLE_VARIABLES, 5427 MOVING_AVERAGE_VARIABLES, 5428 CONCATENATED_VARIABLES, 5429 TRAINABLE_RESOURCE_VARIABLES, 5430 ] 5431 5432 # Key for streaming model ports. 5433 # NOTE(yuanbyu): internal and experimental. 5434 _STREAMING_MODEL_PORTS = "streaming_model_ports" 5435 5436 @decorator_utils.classproperty 5437 def VARIABLES(cls): # pylint: disable=no-self-argument 5438 logging.log_first_n(logging.WARN, 5439 "VARIABLES collection name is deprecated, please use " 5440 "GLOBAL_VARIABLES instead; VARIABLES will be removed " 5441 "after 2017-03-02.", 1) 5442 return cls.GLOBAL_VARIABLES 5443 5444 5445 @tf_export("add_to_collection") 5446 def add_to_collection(name, value): 5447 """Wrapper for `Graph.add_to_collection()` using the default graph. 5448 5449 See @{tf.Graph.add_to_collection} 5450 for more details. 5451 5452 Args: 5453 name: The key for the collection. For example, the `GraphKeys` class 5454 contains many standard names for collections. 5455 value: The value to add to the collection. 5456 5457 @compatibility(eager) 5458 Collections are not supported when eager execution is enabled. 5459 @end_compatibility 5460 """ 5461 get_default_graph().add_to_collection(name, value) 5462 5463 5464 def add_to_collections(names, value): 5465 """Wrapper for `Graph.add_to_collections()` using the default graph. 5466 5467 See @{tf.Graph.add_to_collections} 5468 for more details. 5469 5470 Args: 5471 names: The key for the collections. The `GraphKeys` class 5472 contains many standard names for collections. 5473 value: The value to add to the collections. 5474 5475 @compatibility(eager) 5476 Collections are not supported when eager execution is enabled. 5477 @end_compatibility 5478 """ 5479 get_default_graph().add_to_collections(names, value) 5480 5481 5482 @tf_export("get_collection_ref") 5483 def get_collection_ref(key): 5484 """Wrapper for `Graph.get_collection_ref()` using the default graph. 5485 5486 See @{tf.Graph.get_collection_ref} 5487 for more details. 5488 5489 Args: 5490 key: The key for the collection. For example, the `GraphKeys` class 5491 contains many standard names for collections. 5492 5493 Returns: 5494 The list of values in the collection with the given `name`, or an empty 5495 list if no value has been added to that collection. Note that this returns 5496 the collection list itself, which can be modified in place to change the 5497 collection. 5498 5499 @compatibility(eager) 5500 Collections are not supported when eager execution is enabled. 5501 @end_compatibility 5502 """ 5503 return get_default_graph().get_collection_ref(key) 5504 5505 5506 @tf_export("get_collection") 5507 def get_collection(key, scope=None): 5508 """Wrapper for `Graph.get_collection()` using the default graph. 5509 5510 See @{tf.Graph.get_collection} 5511 for more details. 5512 5513 Args: 5514 key: The key for the collection. For example, the `GraphKeys` class 5515 contains many standard names for collections. 5516 scope: (Optional.) If supplied, the resulting list is filtered to include 5517 only items whose `name` attribute matches using `re.match`. Items 5518 without a `name` attribute are never returned if a scope is supplied and 5519 the choice or `re.match` means that a `scope` without special tokens 5520 filters by prefix. 5521 5522 Returns: 5523 The list of values in the collection with the given `name`, or 5524 an empty list if no value has been added to that collection. The 5525 list contains the values in the order under which they were 5526 collected. 5527 5528 @compatibility(eager) 5529 Collections are not supported when eager execution is enabled. 5530 @end_compatibility 5531 """ 5532 return get_default_graph().get_collection(key, scope) 5533 5534 5535 def get_all_collection_keys(): 5536 """Returns a list of collections used in the default graph.""" 5537 return get_default_graph().get_all_collection_keys() 5538 5539 5540 name_scope_cache = {} 5541 5542 5543 # Named like a function for backwards compatibility with the 5544 # @tf_contextlib.contextmanager version, which was switched to a class to avoid 5545 # some object creation overhead. 5546 @tf_export("name_scope", "keras.backend.name_scope") 5547 class name_scope(object): # pylint: disable=invalid-name 5548 """A context manager for use when defining a Python op. 5549 5550 This context manager validates that the given `values` are from the 5551 same graph, makes that graph the default graph, and pushes a 5552 name scope in that graph (see 5553 @{tf.Graph.name_scope} 5554 for more details on that). 5555 5556 For example, to define a new Python op called `my_op`: 5557 5558 ```python 5559 def my_op(a, b, c, name=None): 5560 with tf.name_scope(name, "MyOp", [a, b, c]) as scope: 5561 a = tf.convert_to_tensor(a, name="a") 5562 b = tf.convert_to_tensor(b, name="b") 5563 c = tf.convert_to_tensor(c, name="c") 5564 # Define some computation that uses `a`, `b`, and `c`. 5565 return foo_op(..., name=scope) 5566 ``` 5567 """ 5568 5569 @property 5570 def name(self): 5571 return self._name 5572 5573 def __init__(self, name, default_name=None, values=None): 5574 """Initialize the context manager. 5575 5576 Args: 5577 name: The name argument that is passed to the op function. 5578 default_name: The default name to use if the `name` argument is `None`. 5579 values: The list of `Tensor` arguments that are passed to the op function. 5580 """ 5581 self._name = default_name if name is None else name 5582 self._default_name = default_name 5583 self._values = values 5584 self._ctx = context.context() 5585 self._in_eager_mode = self._ctx.in_eager_mode() 5586 5587 def __enter__(self): 5588 """Start the scope block. 5589 5590 Returns: 5591 The scope name. 5592 5593 Raises: 5594 ValueError: if neither `name` nor `default_name` is provided 5595 but `values` are. 5596 """ 5597 if self._in_eager_mode: 5598 self._old_name = self._ctx.scope_name 5599 if not self._name: 5600 scope_name = "" 5601 else: 5602 cache_key = self._name, self._old_name, self._default_name 5603 if cache_key in name_scope_cache: 5604 self._ctx.scope_name = name_scope_cache[cache_key] 5605 return self._ctx.scope_name 5606 elif self._name[-1] == "/": 5607 # A trailing slash breaks out of nested name scopes, indicating a 5608 # fully specified scope name, for compatibility with Graph.name_scope. 5609 scope_name = self._name 5610 else: 5611 name_with_trailing_slash = self._name + "/" 5612 scope_name = ( 5613 self._old_name + name_with_trailing_slash 5614 if self._old_name else name_with_trailing_slash) 5615 name_scope_cache[cache_key] = scope_name 5616 self._ctx.scope_name = scope_name 5617 return scope_name 5618 else: 5619 if self._name is None and self._values is not None: 5620 # We only raise an error if values is not None (provided) because 5621 # currently tf.name_scope(None) (values=None then) is sometimes used as 5622 # an idiom to reset to top scope. 5623 raise ValueError( 5624 "At least one of name (%s) and default_name (%s) must be provided." 5625 % (self._name, self._default_name)) 5626 if self._values is None: 5627 self._values = [] 5628 g = _get_graph_from_inputs(self._values) 5629 self._g_manager = g.as_default() 5630 self._g_manager.__enter__() 5631 try: 5632 self._name_scope = g.name_scope(self._name) 5633 return self._name_scope.__enter__() 5634 except: 5635 self._g_manager.__exit__(*sys.exc_info()) 5636 raise 5637 5638 def __exit__(self, type_arg, value_arg, traceback_arg): 5639 if self._in_eager_mode: 5640 self._ctx.scope_name = self._old_name 5641 else: 5642 self._name_scope.__exit__(type_arg, value_arg, traceback_arg) 5643 self._g_manager.__exit__(type_arg, value_arg, traceback_arg) 5644 return False # False values do not suppress exceptions 5645 5646 5647 def strip_name_scope(name, export_scope): 5648 """Removes name scope from a name. 5649 5650 Args: 5651 name: A `string` name. 5652 export_scope: Optional `string`. Name scope to remove. 5653 5654 Returns: 5655 Name with name scope removed, or the original name if export_scope 5656 is None. 5657 """ 5658 if export_scope: 5659 try: 5660 # Strips export_scope/, export_scope///, 5661 # ^export_scope/, loc:@export_scope/. 5662 str_to_replace = r"([\^]|loc:@|^)" + export_scope + r"[\/]+(.*)" 5663 return re.sub(str_to_replace, r"\1\2", compat.as_str(name), count=1) 5664 except TypeError as e: 5665 # If the name is not of a type we can process, simply return it. 5666 logging.warning(e) 5667 return name 5668 else: 5669 return name 5670 5671 5672 def prepend_name_scope(name, import_scope): 5673 """Prepends name scope to a name. 5674 5675 Args: 5676 name: A `string` name. 5677 import_scope: Optional `string`. Name scope to add. 5678 5679 Returns: 5680 Name with name scope added, or the original name if import_scope 5681 is None. 5682 """ 5683 if import_scope: 5684 try: 5685 str_to_replace = r"([\^]|loc:@|^)(.*)" 5686 return re.sub(str_to_replace, r"\1" + import_scope + r"/\2", 5687 compat.as_str(name)) 5688 except TypeError as e: 5689 # If the name is not of a type we can process, simply return it. 5690 logging.warning(e) 5691 return name 5692 else: 5693 return name 5694 5695 5696 # pylint: disable=g-doc-return-or-yield 5697 # pylint: disable=not-context-manager 5698 @tf_export("op_scope") 5699 @tf_contextlib.contextmanager 5700 def op_scope(values, name, default_name=None): 5701 """DEPRECATED. Same as name_scope above, just different argument order.""" 5702 logging.warn("tf.op_scope(values, name, default_name) is deprecated," 5703 " use tf.name_scope(name, default_name, values)") 5704 with name_scope(name, default_name=default_name, values=values) as scope: 5705 yield scope 5706 5707 5708 _proto_function_registry = registry.Registry("proto functions") 5709 5710 5711 def register_proto_function(collection_name, 5712 proto_type=None, 5713 to_proto=None, 5714 from_proto=None): 5715 """Registers `to_proto` and `from_proto` functions for collection_name. 5716 5717 `to_proto` function converts a Python object to the corresponding protocol 5718 buffer, and returns the protocol buffer. 5719 5720 `from_proto` function converts protocol buffer into a Python object, and 5721 returns the object.. 5722 5723 Args: 5724 collection_name: Name of the collection. 5725 proto_type: Protobuf type, such as `saver_pb2.SaverDef`, 5726 `variable_pb2.VariableDef`, `queue_runner_pb2.QueueRunnerDef`.. 5727 to_proto: Function that implements Python object to protobuf conversion. 5728 from_proto: Function that implements protobuf to Python object conversion. 5729 """ 5730 if to_proto and not callable(to_proto): 5731 raise TypeError("to_proto must be callable.") 5732 if from_proto and not callable(from_proto): 5733 raise TypeError("from_proto must be callable.") 5734 5735 _proto_function_registry.register((proto_type, to_proto, from_proto), 5736 collection_name) 5737 5738 5739 def get_collection_proto_type(collection_name): 5740 """Returns the proto_type for collection_name.""" 5741 try: 5742 return _proto_function_registry.lookup(collection_name)[0] 5743 except LookupError: 5744 return None 5745 5746 5747 def get_to_proto_function(collection_name): 5748 """Returns the to_proto function for collection_name.""" 5749 try: 5750 return _proto_function_registry.lookup(collection_name)[1] 5751 except LookupError: 5752 return None 5753 5754 5755 def get_from_proto_function(collection_name): 5756 """Returns the from_proto function for collection_name.""" 5757 try: 5758 return _proto_function_registry.lookup(collection_name)[2] 5759 except LookupError: 5760 return None 5761 5762 5763 def _assert_collection_is_ok(collection_name): 5764 if context.in_eager_mode(): 5765 if collection_name in GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access 5766 raise ValueError("When Eager Execution is enabled, variable " 5767 "collections are not supported.") 5768 5769 5770 def _operation_conversion_error(op, dtype=None, name=None, as_ref=False): 5771 """Produce a nice error if someone converts an Operation to a Tensor.""" 5772 raise TypeError(("Can't convert Operation '%s' to Tensor " 5773 "(target dtype=%r, name=%r, as_ref=%r)") % (op.name, dtype, 5774 name, as_ref)) 5775 5776 5777 register_tensor_conversion_function(Operation, _operation_conversion_error) 5778