Home | History | Annotate | Download | only in framework
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Classes and functions used to construct graphs."""
     16 # pylint: disable=g-bad-name
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import copy
     23 import linecache
     24 import os
     25 import re
     26 import sys
     27 import threading
     28 
     29 import numpy as np
     30 import six
     31 from six.moves import xrange  # pylint: disable=redefined-builtin
     32 
     33 from tensorflow.core.framework import attr_value_pb2
     34 from tensorflow.core.framework import function_pb2
     35 from tensorflow.core.framework import graph_pb2
     36 from tensorflow.core.framework import node_def_pb2
     37 from tensorflow.core.framework import op_def_pb2
     38 from tensorflow.core.framework import versions_pb2
     39 from tensorflow.core.protobuf import config_pb2
     40 from tensorflow.python import pywrap_tensorflow as c_api
     41 from tensorflow.python.eager import context
     42 from tensorflow.python.eager import core
     43 from tensorflow.python.eager import tape
     44 from tensorflow.python.framework import c_api_util
     45 from tensorflow.python.framework import device as pydev
     46 from tensorflow.python.framework import dtypes
     47 from tensorflow.python.framework import errors
     48 from tensorflow.python.framework import op_def_registry
     49 from tensorflow.python.framework import registry
     50 from tensorflow.python.framework import tensor_shape
     51 from tensorflow.python.framework import versions
     52 from tensorflow.python.ops import control_flow_util
     53 from tensorflow.python.platform import app
     54 from tensorflow.python.platform import tf_logging as logging
     55 from tensorflow.python.util import compat
     56 from tensorflow.python.util import decorator_utils
     57 from tensorflow.python.util import tf_contextlib
     58 from tensorflow.python.util.tf_export import tf_export
     59 
     60 
     61 # Temporary global switch determining if we should enable the work-in-progress
     62 # calls to the C API. Currently disabled by default but can be manually enabled
     63 # in code or via the environment variable. This will be removed once all
     64 # functionality is supported and there's no performance penalty with it enabled.
     65 _USE_C_API = os.getenv("TF_C_API_GRAPH_CONSTRUCTION", "0") is not "0"
     66 
     67 
     68 def tensor_id(tensor):
     69   """Returns a unique identifier for this Tensor."""
     70   return tensor._id  # pylint: disable=protected-access
     71 
     72 
     73 class _NullContextmanager(object):
     74 
     75   def __enter__(self):
     76     pass
     77 
     78   def __exit__(self, type_arg, value_arg, traceback_arg):
     79     return False  # False values do not suppress exceptions
     80 
     81 
     82 def _override_helper(clazz_object, operator, func):
     83   """Overrides (string) operator on Tensors to call func.
     84 
     85   Args:
     86     clazz_object: the class to override for; either Tensor or SparseTensor.
     87     operator: the string name of the operator to override.
     88     func: the function that replaces the overridden operator.
     89 
     90   Raises:
     91     ValueError: If operator has already been overwritten,
     92       or if operator is not allowed to be overwritten.
     93   """
     94   existing = getattr(clazz_object, operator, None)
     95   if existing is not None:
     96     # Check to see if this is a default method-wrapper or slot wrapper which
     97     # will be true for the comparison operators.
     98     if not isinstance(existing, type(object.__lt__)):
     99       raise ValueError("operator %s cannot be overwritten again on class %s." %
    100                        (operator, clazz_object))
    101   if operator not in Tensor.OVERLOADABLE_OPERATORS:
    102     raise ValueError("Overriding %s is disallowed" % operator)
    103   setattr(clazz_object, operator, func)
    104 
    105 
    106 def _as_graph_element(obj):
    107   """Convert `obj` to a graph element if possible, otherwise return `None`.
    108 
    109   Args:
    110     obj: Object to convert.
    111 
    112   Returns:
    113     The result of `obj._as_graph_element()` if that method is available;
    114         otherwise `None`.
    115   """
    116   conv_fn = getattr(obj, "_as_graph_element", None)
    117   if conv_fn and callable(conv_fn):
    118     return conv_fn()
    119   return None
    120 
    121 
    122 _TENSOR_LIKE_TYPES = tuple()
    123 
    124 
    125 def is_dense_tensor_like(t):
    126   """EXPERIMENTAL: Returns true if `t` implements the tensor interface.
    127 
    128   See `register_dense_tensor_like_type()` for the current definition of a
    129   "tensor-like type".
    130 
    131   Args:
    132     t: An object.
    133 
    134   Returns:
    135     True iff `t` is an instance of one of the registered "tensor-like" types.
    136   """
    137   return isinstance(t, _TENSOR_LIKE_TYPES)
    138 
    139 
    140 def register_dense_tensor_like_type(tensor_type):
    141   """EXPERIMENTAL: Registers `tensor_type` as implementing the tensor interface.
    142 
    143   A "tensor-like type" can represent a single dense tensor, and implements
    144   the `name` and `dtype` properties.
    145 
    146   Args:
    147     tensor_type: A type implementing the tensor interface.
    148 
    149   Raises:
    150     TypeError: If `tensor_type` does not implement the tensor interface.
    151   """
    152   try:
    153     if not isinstance(tensor_type.name, property):
    154       raise TypeError("Type %s does not define a `name` property" %
    155                       tensor_type.__name__)
    156   except AttributeError:
    157     raise TypeError("Type %s does not define a `name` property" %
    158                     tensor_type.__name__)
    159   try:
    160     if not isinstance(tensor_type.dtype, property):
    161       raise TypeError("Type %s does not define a `dtype` property" %
    162                       tensor_type.__name__)
    163   except AttributeError:
    164     raise TypeError("Type %s does not define a `dtype` property" %
    165                     tensor_type.__name__)
    166   # We expect this list to be small, so choose quadratic complexity
    167   # for registration, so that we have a tuple that can be used for
    168   # more efficient `isinstance` checks later.
    169   global _TENSOR_LIKE_TYPES
    170   _TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type])
    171 
    172 
    173 def uid():
    174   """A unique (within this program execution) integer."""
    175   return c_api.TFE_Py_UID()
    176 
    177 
    178 def numpy_text(tensor, is_repr=False):
    179   """Human readable representation of a tensor's numpy value."""
    180   if tensor.dtype.is_numpy_compatible:
    181     text = repr(tensor.numpy()) if is_repr else str(tensor.numpy())
    182   else:
    183     text = "<unprintable>"
    184   if "\n" in text:
    185     text = "\n" + text
    186   return text
    187 
    188 
    189 # NOTE(ebrevdo): Do not subclass this.  If you do, I will break you on purpose.
    190 class _TensorLike(object):
    191   """Internal cls for grouping Tensor, SparseTensor, ..., for is_instance."""
    192   pass
    193 
    194 
    195 @tf_export("Tensor")
    196 class Tensor(_TensorLike):
    197   """Represents one of the outputs of an `Operation`.
    198 
    199   A `Tensor` is a symbolic handle to one of the outputs of an
    200   `Operation`. It does not hold the values of that operation's output,
    201   but instead provides a means of computing those values in a
    202   TensorFlow @{tf.Session}.
    203 
    204   This class has two primary purposes:
    205 
    206   1. A `Tensor` can be passed as an input to another `Operation`.
    207      This builds a dataflow connection between operations, which
    208      enables TensorFlow to execute an entire `Graph` that represents a
    209      large, multi-step computation.
    210 
    211   2. After the graph has been launched in a session, the value of the
    212      `Tensor` can be computed by passing it to
    213      @{tf.Session.run}.
    214      `t.eval()` is a shortcut for calling
    215      `tf.get_default_session().run(t)`.
    216 
    217   In the following example, `c`, `d`, and `e` are symbolic `Tensor`
    218   objects, whereas `result` is a numpy array that stores a concrete
    219   value:
    220 
    221   ```python
    222   # Build a dataflow graph.
    223   c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
    224   d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
    225   e = tf.matmul(c, d)
    226 
    227   # Construct a `Session` to execute the graph.
    228   sess = tf.Session()
    229 
    230   # Execute the graph and store the value that `e` represents in `result`.
    231   result = sess.run(e)
    232   ```
    233   """
    234 
    235   # List of Python operators that we allow to override.
    236   OVERLOADABLE_OPERATORS = {
    237       # Binary.
    238       "__add__",
    239       "__radd__",
    240       "__sub__",
    241       "__rsub__",
    242       "__mul__",
    243       "__rmul__",
    244       "__div__",
    245       "__rdiv__",
    246       "__truediv__",
    247       "__rtruediv__",
    248       "__floordiv__",
    249       "__rfloordiv__",
    250       "__mod__",
    251       "__rmod__",
    252       "__lt__",
    253       "__le__",
    254       "__gt__",
    255       "__ge__",
    256       "__and__",
    257       "__rand__",
    258       "__or__",
    259       "__ror__",
    260       "__xor__",
    261       "__rxor__",
    262       "__getitem__",
    263       "__pow__",
    264       "__rpow__",
    265       # Unary.
    266       "__invert__",
    267       "__neg__",
    268       "__abs__",
    269       "__matmul__",
    270       "__rmatmul__"
    271   }
    272 
    273   def __init__(self, op, value_index, dtype):
    274     """Creates a new `Tensor`.
    275 
    276     Args:
    277       op: An `Operation`. `Operation` that computes this tensor.
    278       value_index: An `int`. Index of the operation's endpoint that produces
    279         this tensor.
    280       dtype: A `DType`. Type of elements stored in this tensor.
    281 
    282     Raises:
    283       TypeError: If the op is not an `Operation`.
    284     """
    285     if not isinstance(op, Operation):
    286       raise TypeError("op needs to be an Operation: %s" % op)
    287     self._op = op
    288     self._value_index = value_index
    289     self._dtype = dtypes.as_dtype(dtype)
    290     self._shape_val = tensor_shape.unknown_shape()
    291     # List of operations that use this Tensor as input.  We maintain this list
    292     # to easily navigate a computation graph.
    293     self._consumers = []
    294 
    295     # Attributes used for C++ shape inference. Not inspected, only forwarded.
    296     # If set, will be a HandleData object from cpp_shape_inference.proto.
    297     self._handle_data = None
    298     self._id = uid()
    299 
    300   @property
    301   def op(self):
    302     """The `Operation` that produces this tensor as an output."""
    303     return self._op
    304 
    305   @property
    306   def dtype(self):
    307     """The `DType` of elements in this tensor."""
    308     return self._dtype
    309 
    310   @property
    311   def graph(self):
    312     """The `Graph` that contains this tensor."""
    313     return self._op.graph
    314 
    315   @property
    316   def name(self):
    317     """The string name of this tensor."""
    318     if not self._op.name:
    319       raise ValueError("Operation was not named: %s" % self._op)
    320     return "%s:%d" % (self._op.name, self._value_index)
    321 
    322   @property
    323   def device(self):
    324     """The name of the device on which this tensor will be produced, or None."""
    325     return self._op.device
    326 
    327   @property
    328   def shape(self):
    329     """Returns the `TensorShape` that represents the shape of this tensor.
    330 
    331     The shape is computed using shape inference functions that are
    332     registered in the Op for each `Operation`.  See
    333     @{tf.TensorShape}
    334     for more details of what a shape represents.
    335 
    336     The inferred shape of a tensor is used to provide shape
    337     information without having to launch the graph in a session. This
    338     can be used for debugging, and providing early error messages. For
    339     example:
    340 
    341     ```python
    342     c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
    343 
    344     print(c.shape)
    345     ==> TensorShape([Dimension(2), Dimension(3)])
    346 
    347     d = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]])
    348 
    349     print(d.shape)
    350     ==> TensorShape([Dimension(4), Dimension(2)])
    351 
    352     # Raises a ValueError, because `c` and `d` do not have compatible
    353     # inner dimensions.
    354     e = tf.matmul(c, d)
    355 
    356     f = tf.matmul(c, d, transpose_a=True, transpose_b=True)
    357 
    358     print(f.shape)
    359     ==> TensorShape([Dimension(3), Dimension(4)])
    360     ```
    361 
    362     In some cases, the inferred shape may have unknown dimensions. If
    363     the caller has additional information about the values of these
    364     dimensions, `Tensor.set_shape()` can be used to augment the
    365     inferred shape.
    366 
    367     Returns:
    368       A `TensorShape` representing the shape of this tensor.
    369 
    370     """
    371     if _USE_C_API:
    372       graph = self._op._graph._c_graph  # pylint: disable=protected-access
    373       with errors.raise_exception_on_not_ok_status() as status:
    374         num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output(),
    375                                                   status)
    376       if num_dims == -1:
    377         dim_list = None
    378       else:
    379         with errors.raise_exception_on_not_ok_status() as status:
    380           dim_list = c_api.TF_GraphGetTensorShape_wrapper(
    381               graph, self._as_tf_output(), num_dims, status)
    382         dim_list = [None if i == -1 else i for i in dim_list]
    383       return tensor_shape.TensorShape(dim_list)
    384     return self._shape_val
    385 
    386   @property
    387   def _shape(self):
    388     logging.warning("Tensor._shape is private, use Tensor.shape "
    389                     "instead. Tensor._shape will eventually be removed.")
    390     return self.shape
    391 
    392   @_shape.setter
    393   def _shape(self, value):
    394     raise ValueError(
    395         "Tensor._shape cannot be assigned, use Tensor.set_shape instead.")
    396 
    397   def __iter__(self):
    398     if context.in_graph_mode():
    399       raise TypeError(
    400           "`Tensor` objects are not iterable when eager execution is not "
    401           "enabled. To iterate over this tensor use `tf.map_fn`.")
    402     shape = self._shape_tuple()
    403     if shape is None:
    404       raise TypeError("Cannot iterate over a tensor with unknown shape.")
    405     if not shape:
    406       raise TypeError("Cannot iterate over a scalar tensor.")
    407     if shape[0] is None:
    408       raise TypeError(
    409           "Cannot iterate over a tensor with unknown first dimension.")
    410     for i in xrange(shape[0]):
    411       yield self[i]
    412 
    413   def _shape_as_list(self):
    414     if self.shape.ndims is not None:
    415       return [dim.value for dim in self.shape.dims]
    416     else:
    417       return None
    418 
    419   def _shape_tuple(self):
    420     shape = self._shape_as_list()
    421     if shape is None:
    422       return None
    423     return tuple(shape)
    424 
    425   def _rank(self):
    426     """Integer rank of this Tensor, if known, else None.
    427 
    428     Returns:
    429       Integer rank or None
    430     """
    431     return self.shape.ndims
    432 
    433   def get_shape(self):
    434     """Alias of Tensor.shape."""
    435     return self.shape
    436 
    437   def set_shape(self, shape):
    438     """Updates the shape of this tensor.
    439 
    440     This method can be called multiple times, and will merge the given
    441     `shape` with the current shape of this tensor. It can be used to
    442     provide additional information about the shape of this tensor that
    443     cannot be inferred from the graph alone. For example, this can be used
    444     to provide additional information about the shapes of images:
    445 
    446     ```python
    447     _, image_data = tf.TFRecordReader(...).read(...)
    448     image = tf.image.decode_png(image_data, channels=3)
    449 
    450     # The height and width dimensions of `image` are data dependent, and
    451     # cannot be computed without executing the op.
    452     print(image.shape)
    453     ==> TensorShape([Dimension(None), Dimension(None), Dimension(3)])
    454 
    455     # We know that each image in this dataset is 28 x 28 pixels.
    456     image.set_shape([28, 28, 3])
    457     print(image.shape)
    458     ==> TensorShape([Dimension(28), Dimension(28), Dimension(3)])
    459     ```
    460 
    461     Args:
    462       shape: A `TensorShape` representing the shape of this tensor, a
    463       `TensorShapeProto`, a list, a tuple, or None.
    464 
    465     Raises:
    466       ValueError: If `shape` is not compatible with the current shape of
    467         this tensor.
    468     """
    469     if not _USE_C_API:
    470       self._shape_val = self._shape_val.merge_with(shape)
    471       return
    472     if not isinstance(shape, tensor_shape.TensorShape):
    473       shape = tensor_shape.TensorShape(shape)
    474     dim_list = []
    475     if shape.dims is None:
    476       unknown_shape = True
    477     else:
    478       unknown_shape = False
    479       for dim in shape.dims:
    480         if dim.value is None:
    481           dim_list.append(-1)
    482         else:
    483           dim_list.append(dim.value)
    484     try:
    485       with errors.raise_exception_on_not_ok_status() as status:
    486         c_api.TF_GraphSetTensorShape_wrapper(
    487             self._op._graph._c_graph,  # pylint: disable=protected-access
    488             self._as_tf_output(),
    489             dim_list,
    490             unknown_shape,
    491             status)
    492     except errors.InvalidArgumentError as e:
    493       # Convert to ValueError for backwards compatibility.
    494       raise ValueError(str(e))
    495 
    496   @property
    497   def value_index(self):
    498     """The index of this tensor in the outputs of its `Operation`."""
    499     return self._value_index
    500 
    501   def consumers(self):
    502     """Returns a list of `Operation`s that consume this tensor.
    503 
    504     Returns:
    505       A list of `Operation`s.
    506     """
    507     if self._op._c_op:  # pylint: disable=protected-access
    508       consumer_names = c_api.TF_OperationOutputConsumers_wrapper(
    509           self._as_tf_output())
    510       # pylint: disable=protected-access
    511       return [
    512           self.graph._get_operation_by_name_unsafe(name)
    513           for name in consumer_names
    514       ]
    515       # pylint: enable=protected-access
    516     else:
    517       return self._consumers
    518 
    519   def _add_consumer(self, consumer):
    520     """Add a consumer to this tensor.
    521 
    522     Args:
    523       consumer: an Operation.
    524 
    525     Raises:
    526       TypeError: if the consumer is not an Operation.
    527     """
    528     # pylint: disable=protected-access
    529     assert not self._op._c_op, "Tensor._add_consumer doesn't work with C API"
    530     # pylint: enable=protected-access
    531     if not isinstance(consumer, Operation):
    532       raise TypeError("Consumer must be an Operation: %s" % consumer)
    533     self._consumers.append(consumer)
    534 
    535   def _as_node_def_input(self):
    536     """Return a value to use for the NodeDef "input" attribute.
    537 
    538     The returned string can be used in a NodeDef "input" attribute
    539     to indicate that the NodeDef uses this Tensor as input.
    540 
    541     Raises:
    542       ValueError: if this Tensor's Operation does not have a name.
    543 
    544     Returns:
    545       a string.
    546     """
    547     if not self._op.name:
    548       raise ValueError("Operation was not named: %s" % self._op)
    549     if self._value_index == 0:
    550       return self._op.name
    551     else:
    552       return "%s:%d" % (self._op.name, self._value_index)
    553 
    554   def _as_tf_output(self):
    555     # pylint: disable=protected-access
    556     assert self.op._c_op
    557     return c_api_util.tf_output(self.op._c_op, self.value_index)
    558     # pylint: enable=protected-access
    559 
    560   def __str__(self):
    561     return "Tensor(\"%s\"%s%s%s)" % (
    562         self.name, (", shape=%s" % self.get_shape())
    563         if self.get_shape().ndims is not None else "",
    564         (", dtype=%s" % self._dtype.name)
    565         if self._dtype else "", (", device=%s" % self.device)
    566         if self.device else "")
    567 
    568   def __repr__(self):
    569     return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(),
    570                                                    self._dtype.name)
    571 
    572   def __hash__(self):
    573     # Necessary to support Python's collection membership operators
    574     return id(self)
    575 
    576   def __eq__(self, other):
    577     # Necessary to support Python's collection membership operators
    578     return id(self) == id(other)
    579 
    580   # NOTE(mrry): This enables the Tensor's overloaded "right" binary
    581   # operators to run when the left operand is an ndarray, because it
    582   # accords the Tensor class higher priority than an ndarray, or a
    583   # numpy matrix.
    584   # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
    585   # mechanism, which allows more control over how Tensors interact
    586   # with ndarrays.
    587   __array_priority__ = 100
    588 
    589   @staticmethod
    590   def _override_operator(operator, func):
    591     _override_helper(Tensor, operator, func)
    592 
    593   def __bool__(self):
    594     """Dummy method to prevent a tensor from being used as a Python `bool`.
    595 
    596     This overload raises a `TypeError` when the user inadvertently
    597     treats a `Tensor` as a boolean (e.g. in an `if` statement). For
    598     example:
    599 
    600     ```python
    601     if tf.constant(True):  # Will raise.
    602       # ...
    603 
    604     if tf.constant(5) < tf.constant(7):  # Will raise.
    605       # ...
    606     ```
    607 
    608     This disallows ambiguities between testing the Python value vs testing the
    609     dynamic condition of the `Tensor`.
    610 
    611     Raises:
    612       `TypeError`.
    613     """
    614     raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
    615                     "Use `if t is not None:` instead of `if t:` to test if a "
    616                     "tensor is defined, and use TensorFlow ops such as "
    617                     "tf.cond to execute subgraphs conditioned on the value of "
    618                     "a tensor.")
    619 
    620   def __nonzero__(self):
    621     """Dummy method to prevent a tensor from being used as a Python `bool`.
    622 
    623     This is the Python 2.x counterpart to `__bool__()` above.
    624 
    625     Raises:
    626       `TypeError`.
    627     """
    628     raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
    629                     "Use `if t is not None:` instead of `if t:` to test if a "
    630                     "tensor is defined, and use TensorFlow ops such as "
    631                     "tf.cond to execute subgraphs conditioned on the value of "
    632                     "a tensor.")
    633 
    634   def eval(self, feed_dict=None, session=None):
    635     """Evaluates this tensor in a `Session`.
    636 
    637     Calling this method will execute all preceding operations that
    638     produce the inputs needed for the operation that produces this
    639     tensor.
    640 
    641     *N.B.* Before invoking `Tensor.eval()`, its graph must have been
    642     launched in a session, and either a default session must be
    643     available, or `session` must be specified explicitly.
    644 
    645     Args:
    646       feed_dict: A dictionary that maps `Tensor` objects to feed values.
    647         See @{tf.Session.run} for a
    648         description of the valid feed values.
    649       session: (Optional.) The `Session` to be used to evaluate this tensor. If
    650         none, the default session will be used.
    651 
    652     Returns:
    653       A numpy array corresponding to the value of this tensor.
    654 
    655     """
    656     return _eval_using_default_session(self, feed_dict, self.graph, session)
    657 
    658 
    659 # TODO(agarwal): consider getting rid of this.
    660 class _EagerTensorBase(Tensor):
    661   """Base class for EagerTensor."""
    662 
    663   @property
    664   def dtype(self):
    665     # Note: using the intern table directly here as this is
    666     # performance-sensitive in some models.
    667     return dtypes._INTERN_TABLE[self._datatype_enum()]  # pylint: disable=protected-access
    668 
    669   def numpy(self):
    670     """Returns a numpy array or a scalar with the same contents as the Tensor.
    671 
    672     TODO(ashankar,agarwal): Perhaps this should NOT reference the underlying
    673     buffer but instead always explicitly copy? Note that currently it may or may
    674     not copy based on whether the numpy data is properly aligned or not.
    675 
    676     Returns:
    677       A numpy array or a scalar. Numpy array may share memory with the
    678       Tensor object. Any changes to one may be reflected in the other. A scalar
    679       value is returned when self has rank 0.
    680 
    681     Raises:
    682       ValueError: if the type of this Tensor is not representable in numpy.
    683     """
    684     if self.dtype == dtypes.resource:
    685       raise ValueError("Resource handles are not convertible to numpy.")
    686     return self.cpu()._numpy()  # pylint: disable=protected-access
    687 
    688   # __int__ and  __float__ may copy the tensor to CPU and
    689   # only work for scalars; values are cast as per numpy.
    690   def __int__(self):
    691     return int(self.numpy())
    692 
    693   def __float__(self):
    694     return float(self.numpy())
    695 
    696   def __array__(self, dtype=None):
    697     return np.array(self.numpy(), dtype=dtype)
    698 
    699   def __format__(self, format_spec):
    700     return self.numpy().__format__(format_spec)
    701 
    702   def _numpy(self):
    703     raise NotImplementedError()
    704 
    705   def __copy__(self):
    706     # Eager Tensors are immutable so it's safe to return themselves as a copy.
    707     return self
    708 
    709   def __deepcopy__(self, memo):
    710     # Eager Tensors are immutable so it's safe to return themselves as a copy.
    711     del memo
    712     return self
    713 
    714   def _datatype_enum(self):
    715     raise NotImplementedError()
    716 
    717   def _shape_tuple(self):
    718     """The shape of this Tensor, as a tuple.
    719 
    720     This is more performant than tuple(shape().as_list()) as it avoids
    721     two list and one object creation. Marked private for now as from an API
    722     perspective, it would be better to have a single performant way of
    723     getting a shape rather than exposing shape() and shape_tuple()
    724     (and heaven forbid, shape_list() etc. as well!). Punting on that for now,
    725     but ideally one would work things out and remove the need for this method.
    726 
    727     Returns:
    728       tuple with the shape.
    729     """
    730     raise NotImplementedError()
    731 
    732   def _rank(self):
    733     """Integer rank of this Tensor.
    734 
    735     Unlike regular Tensors, the rank is always known for EagerTensors.
    736 
    737     This is more performant than len(self._shape_tuple())
    738 
    739     Returns:
    740       Integer rank
    741     """
    742     raise NotImplementedError()
    743 
    744   def _copy_to_device(self, context, device):  # pylint: disable=redefined-outer-name
    745     raise NotImplementedError()
    746 
    747   def __str__(self):
    748     return "tf.Tensor(%s, shape=%s, dtype=%s)" % (numpy_text(self),
    749                                                   self.shape,
    750                                                   self.dtype.name)
    751 
    752   def __repr__(self):
    753     return "<tf.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s>" % (
    754         self._id, self.shape, self.dtype.name, numpy_text(self, is_repr=True))
    755 
    756   @staticmethod
    757   def _override_operator(name, func):
    758     setattr(_EagerTensorBase, name, func)
    759 
    760   def _copy(self, ctx=None, device_name=None):
    761     """Copies tensor to dest device."""
    762     # pylint: disable=protected-access
    763     # Creates a new tensor on the dest device.
    764     if ctx is None:
    765       ctx = context.context()
    766     if device_name is None:
    767       device_name = ctx.device_name
    768     # pylint: disable=protected-access
    769     try:
    770       new_tensor = self._copy_to_device(context=ctx._handle, device=device_name)
    771     except core._NotOkStatusException as e:
    772       six.raise_from(core._status_to_exception(e.code, e.message), None)
    773 
    774     # Record the copy on tape and define backprop copy as well.
    775     if not context.in_graph_mode():
    776       self_device = self.device
    777       def grad_fun(dresult):
    778         return [dresult._copy(device_name=self_device)]
    779       tape.record_operation("_copy", [new_tensor], [self], grad_fun)
    780     return new_tensor
    781     # pylint: enable=protected-access
    782 
    783   @property
    784   def shape(self):
    785     return tensor_shape.TensorShape(self._shape_tuple())
    786 
    787   def get_shape(self):
    788     """Alias of Tensor.shape."""
    789     return self.shape
    790 
    791   def _shape_as_list(self):
    792     """The shape of the tensor as a list."""
    793     return list(self._shape_tuple())
    794 
    795   @property
    796   def ndim(self):
    797     """Returns the number of Tensor dimensions."""
    798     return self.shape.ndims
    799 
    800   def cpu(self):
    801     """A copy of this Tensor with contents backed by host memory."""
    802     return self._copy(context.context(), "CPU:0")
    803 
    804   def gpu(self, gpu_index=0):
    805     """A copy of this Tensor with contents backed by memory on the GPU.
    806 
    807     Arguments:
    808       gpu_index: Identifies which GPU to place the contents on the returned
    809         Tensor in.
    810 
    811     Returns:
    812       A GPU-memory backed Tensor object initialized with the same contents
    813       as this Tensor.
    814     """
    815     return self._copy(context.context(), "GPU:" + str(gpu_index))
    816 
    817   def __bool__(self):
    818     if self._shape_tuple() != ():  # pylint: disable=g-explicit-bool-comparison
    819       raise ValueError(
    820           "Non-scalar tensor %s cannot be converted to boolean." % repr(self))
    821     if self.dtype != dtypes.bool:
    822       raise ValueError(
    823           "Non-boolean tensor %s cannot be converted to boolean." % repr(self))
    824     return bool(self.cpu().numpy())
    825 
    826   def __nonzero__(self):
    827     return self.__bool__()
    828 
    829   def set_shape(self, shape):
    830     if not self.shape.is_compatible_with(shape):
    831       raise ValueError(
    832           "EagerTensor's shape %s is not compatible with supplied shape %s" %
    833           (self.shape, shape))
    834 
    835   # Methods not supported / implemented for Eager Tensors.
    836   @property
    837   def op(self):
    838     raise AttributeError("op not supported for Eager Tensors.")
    839 
    840   @property
    841   def graph(self):
    842     raise AttributeError("graph not supported for Eager Tensors.")
    843 
    844   @property
    845   def name(self):
    846     raise AttributeError("name not supported for Eager Tensors.")
    847 
    848   @property
    849   def value_index(self):
    850     raise AttributeError("value_index not supported for Eager Tensors.")
    851 
    852   def consumers(self):
    853     raise NotImplementedError("consumers not supported for Eager Tensors.")
    854 
    855   def _add_consumer(self, consumer):
    856     raise NotImplementedError("_add_consumer not supported for Eager Tensors.")
    857 
    858   def _as_node_def_input(self):
    859     raise NotImplementedError(
    860         "_as_node_def_input not supported for Eager Tensors.")
    861 
    862   def _as_tf_output(self):
    863     raise NotImplementedError("_as_tf_output not supported for Eager Tensors.")
    864 
    865   def eval(self, feed_dict=None, session=None):
    866     raise NotImplementedError("eval not supported for Eager Tensors.")
    867 
    868 
    869 # This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
    870 # registers it with the current module.
    871 EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase)
    872 
    873 
    874 def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False):
    875   _ = name, as_ref
    876   if dtype and not dtype.is_compatible_with(t.dtype):
    877     raise ValueError(
    878         "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
    879         (dtype.name, t.dtype.name, str(t)))
    880   return t
    881 
    882 
    883 _tensor_conversion_func_registry = {
    884     0: [(Tensor, _TensorTensorConversionFunction)]
    885 }
    886 _tensor_conversion_func_cache = {}
    887 _tensor_conversion_func_lock = threading.Lock()
    888 register_dense_tensor_like_type(Tensor)
    889 
    890 
    891 @tf_export("convert_to_tensor")
    892 def convert_to_tensor(value, dtype=None, name=None, preferred_dtype=None):
    893   """Converts the given `value` to a `Tensor`.
    894 
    895   This function converts Python objects of various types to `Tensor`
    896   objects. It accepts `Tensor` objects, numpy arrays, Python lists,
    897   and Python scalars. For example:
    898 
    899   ```python
    900   import numpy as np
    901 
    902   def my_func(arg):
    903     arg = tf.convert_to_tensor(arg, dtype=tf.float32)
    904     return tf.matmul(arg, arg) + arg
    905 
    906   # The following calls are equivalent.
    907   value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]]))
    908   value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
    909   value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
    910   ```
    911 
    912   This function can be useful when composing a new operation in Python
    913   (such as `my_func` in the example above). All standard Python op
    914   constructors apply this function to each of their Tensor-valued
    915   inputs, which allows those ops to accept numpy arrays, Python lists,
    916   and scalars in addition to `Tensor` objects.
    917 
    918   Note: This function diverges from default Numpy behavior for `float` and
    919     `string` types when `None` is present in a Python list or scalar. Rather
    920     than silently converting `None` values, an error will be thrown.
    921 
    922   Args:
    923     value: An object whose type has a registered `Tensor` conversion function.
    924     dtype: Optional element type for the returned tensor. If missing, the
    925       type is inferred from the type of `value`.
    926     name: Optional name to use if a new `Tensor` is created.
    927     preferred_dtype: Optional element type for the returned tensor,
    928       used when dtype is None. In some cases, a caller may not have a
    929       dtype in mind when converting to a tensor, so preferred_dtype
    930       can be used as a soft preference.  If the conversion to
    931       `preferred_dtype` is not possible, this argument has no effect.
    932 
    933   Returns:
    934     An `Output` based on `value`.
    935 
    936   Raises:
    937     TypeError: If no conversion function is registered for `value`.
    938     RuntimeError: If a registered conversion function returns an invalid value.
    939 
    940   """
    941   return internal_convert_to_tensor(
    942       value=value,
    943       dtype=dtype,
    944       name=name,
    945       preferred_dtype=preferred_dtype,
    946       as_ref=False)
    947 
    948 
    949 def _error_prefix(name):
    950   return "" if name is None else "%s: " % name
    951 
    952 
    953 def internal_convert_to_tensor(value,
    954                                dtype=None,
    955                                name=None,
    956                                as_ref=False,
    957                                preferred_dtype=None,
    958                                ctx=None):
    959   """Converts the given `value` to an `Tensor`.
    960 
    961   This function converts Python objects of various types to `Tensor`
    962   objects. It accepts `Tensor` objects, numpy arrays, Python lists,
    963   and Python scalars. For example:
    964 
    965   This function can be useful when composing a new operation in Python
    966   All standard Python op constructors apply this function to each of their
    967   Tensor-valued inputs, which allows those ops to accept numpy arrays, Python
    968   lists, and scalars in addition to `Tensor` objects.
    969 
    970   Args:
    971     value: An object whose type has a registered `Tensor` conversion function.
    972     dtype: Optional element type for the returned tensor. If missing, the
    973       type is inferred from the type of `value`.
    974     name: Optional name to use if a new `Tensor` is created.
    975     as_ref: True if we want the mutable view of Variables, if applicable.
    976     preferred_dtype: Optional element type for the returned tensor,
    977       used when dtype is None. In some cases, a caller may not have a
    978       dtype in mind when converting to a tensor, so preferred_dtype
    979       can be used as a soft preference.  If the conversion to
    980       `preferred_dtype` is not possible, this argument has no effect.
    981     ctx: Optional: The value of context.context().
    982 
    983   Returns:
    984     A `Tensor` based on `value`.
    985 
    986   Raises:
    987     TypeError: If no conversion function is registered for `value`.
    988     RuntimeError: If a registered conversion function returns an invalid value.
    989 
    990   """
    991   if ctx is None: ctx = context.context()
    992   if ctx.in_eager_mode():
    993     # Fast path for EagerTensors that don't need any conversion.
    994     if isinstance(value, EagerTensor):
    995       # Note that we don't check that value's dtype matches the dtype
    996       # argument.  We expect that the C runtime will do that checking
    997       # when we execute the kernel.
    998       return value
    999 
   1000   if dtype is not None:
   1001     dtype = dtypes.as_dtype(dtype)
   1002   unwrapped_type = type(value)
   1003   conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None)
   1004   if conversion_func_list is None:
   1005     with _tensor_conversion_func_lock:
   1006       conversion_func_list = []
   1007       for _, funcs_at_priority in sorted(
   1008           _tensor_conversion_func_registry.items()):
   1009         for base_type, conversion_func in funcs_at_priority:
   1010           if isinstance(value, base_type):
   1011             conversion_func_list.append((base_type, conversion_func))
   1012       _tensor_conversion_func_cache[unwrapped_type] = conversion_func_list
   1013 
   1014   for base_type, conversion_func in conversion_func_list:
   1015     # If dtype is None but preferred_dtype is not None, we try to
   1016     # cast to preferred_dtype first.
   1017     ret = None
   1018     if dtype is None and preferred_dtype is not None:
   1019       try:
   1020         ret = conversion_func(
   1021             value, dtype=preferred_dtype, name=name, as_ref=as_ref)
   1022       except (TypeError, ValueError, errors.UnimplementedError,
   1023               errors.InvalidArgumentError):
   1024         # Could not coerce the conversion to use the preferred dtype.
   1025         ret = None
   1026 
   1027       if ret is not None and ret is not NotImplemented:
   1028         if (ret.dtype.base_dtype !=
   1029             dtypes.as_dtype(preferred_dtype).base_dtype):
   1030           raise TypeError("convert_to_tensor did not convert to "
   1031                           "the preferred dtype: %s vs %s " %
   1032                           (ret.dtype.base_dtype,
   1033                            dtypes.as_dtype(preferred_dtype).base_dtype))
   1034 
   1035     if ret is None:
   1036       ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
   1037 
   1038     if ret is NotImplemented:
   1039       continue
   1040 
   1041     if not isinstance(ret, Tensor):
   1042       raise RuntimeError(
   1043           "%sConversion function %r for type %s returned non-Tensor: %r" %
   1044           (_error_prefix(name), conversion_func, base_type, ret))
   1045     if dtype and not dtype.is_compatible_with(ret.dtype):
   1046       raise RuntimeError(
   1047           "%sConversion function %r for type %s returned incompatible "
   1048           "dtype: requested = %s, actual = %s" %
   1049           (_error_prefix(name), conversion_func, base_type, dtype.name,
   1050            ret.dtype.name))
   1051     return ret
   1052   raise TypeError("%sCannot convert %r with type %s to Tensor: "
   1053                   "no conversion function registered." %
   1054                   (_error_prefix(name), value, unwrapped_type))
   1055 
   1056 
   1057 def internal_convert_n_to_tensor(values,
   1058                                  dtype=None,
   1059                                  name=None,
   1060                                  as_ref=False,
   1061                                  preferred_dtype=None,
   1062                                  ctx=None):
   1063   """Converts `values` to a list of `Tensor` objects.
   1064 
   1065   Args:
   1066     values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
   1067     dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
   1068     name: (Optional.) A name prefix to used when a new `Tensor` is
   1069       created, in which case element `i` will be given the name `name
   1070       + '_' + i`.
   1071     as_ref: True if the caller wants the results as ref tensors.
   1072     preferred_dtype: Optional element type for the returned tensors,
   1073       used when dtype is None. In some cases, a caller may not have a
   1074       dtype in mind when converting to a tensor, so preferred_dtype
   1075       can be used as a soft preference.  If the conversion to
   1076       `preferred_dtype` is not possible, this argument has no effect.
   1077     ctx: The value of context.context().
   1078 
   1079   Returns:
   1080     A list of `Tensor` and/or `IndexedSlices` objects.
   1081 
   1082   Raises:
   1083     TypeError: If no conversion function is registered for an element in
   1084       `values`.
   1085     RuntimeError: If a registered conversion function returns an invalid
   1086       value.
   1087   """
   1088   if not isinstance(values, collections.Sequence):
   1089     raise TypeError("values must be a list.")
   1090   ret = []
   1091   if ctx is None: ctx = context.context()
   1092   for i, value in enumerate(values):
   1093     n = None if name is None else "%s_%d" % (name, i)
   1094     ret.append(
   1095         internal_convert_to_tensor(
   1096             value,
   1097             dtype=dtype,
   1098             name=n,
   1099             as_ref=as_ref,
   1100             preferred_dtype=preferred_dtype,
   1101             ctx=ctx))
   1102   return ret
   1103 
   1104 
   1105 def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):
   1106   """Converts `values` to a list of `Tensor` objects.
   1107 
   1108   Args:
   1109     values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
   1110     dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
   1111     name: (Optional.) A name prefix to used when a new `Tensor` is
   1112       created, in which case element `i` will be given the name `name
   1113       + '_' + i`.
   1114     preferred_dtype: Optional element type for the returned tensors,
   1115       used when dtype is None. In some cases, a caller may not have a
   1116       dtype in mind when converting to a tensor, so preferred_dtype
   1117       can be used as a soft preference.  If the conversion to
   1118       `preferred_dtype` is not possible, this argument has no effect.
   1119 
   1120   Returns:
   1121     A list of `Tensor` and/or `IndexedSlices` objects.
   1122 
   1123   Raises:
   1124     TypeError: If no conversion function is registered for an element in
   1125       `values`.
   1126     RuntimeError: If a registered conversion function returns an invalid
   1127       value.
   1128   """
   1129   return internal_convert_n_to_tensor(
   1130       values=values,
   1131       dtype=dtype,
   1132       name=name,
   1133       preferred_dtype=preferred_dtype,
   1134       as_ref=False)
   1135 
   1136 
   1137 @tf_export("convert_to_tensor_or_indexed_slices")
   1138 def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
   1139   """Converts the given object to a `Tensor` or an `IndexedSlices`.
   1140 
   1141   If `value` is an `IndexedSlices` or `SparseTensor` it is returned
   1142   unmodified. Otherwise, it is converted to a `Tensor` using
   1143   `convert_to_tensor()`.
   1144 
   1145   Args:
   1146     value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
   1147       by `convert_to_tensor()`.
   1148     dtype: (Optional.) The required `DType` of the returned `Tensor` or
   1149       `IndexedSlices`.
   1150     name: (Optional.) A name to use if a new `Tensor` is created.
   1151 
   1152   Returns:
   1153     An `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
   1154 
   1155   Raises:
   1156     ValueError: If `dtype` does not match the element type of `value`.
   1157   """
   1158   return internal_convert_to_tensor_or_indexed_slices(
   1159       value=value, dtype=dtype, name=name, as_ref=False)
   1160 
   1161 
   1162 def internal_convert_to_tensor_or_indexed_slices(value,
   1163                                                  dtype=None,
   1164                                                  name=None,
   1165                                                  as_ref=False):
   1166   """Converts the given object to an `Tensor` or an `IndexedSlices`.
   1167 
   1168   If `value` is an `IndexedSlices` or `SparseTensor` it is returned
   1169   unmodified. Otherwise, it is converted to a `Tensor` using
   1170   `convert_to_tensor()`.
   1171 
   1172   Args:
   1173     value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
   1174       by `convert_to_tensor()`.
   1175     dtype: (Optional.) The required `DType` of the returned `Tensor` or
   1176       `IndexedSlices`.
   1177     name: (Optional.) A name to use if a new `Tensor` is created.
   1178     as_ref: True if the caller wants the results as ref tensors.
   1179 
   1180   Returns:
   1181     An `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
   1182 
   1183   Raises:
   1184     ValueError: If `dtype` does not match the element type of `value`.
   1185   """
   1186   if isinstance(value, _TensorLike):
   1187     if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype):
   1188       raise ValueError(
   1189           "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
   1190           (dtypes.as_dtype(dtype).name, value.dtype.name, str(value)))
   1191     return value
   1192   else:
   1193     return internal_convert_to_tensor(
   1194         value, dtype=dtype, name=name, as_ref=as_ref)
   1195 
   1196 
   1197 def internal_convert_n_to_tensor_or_indexed_slices(values,
   1198                                                    dtype=None,
   1199                                                    name=None,
   1200                                                    as_ref=False):
   1201   """Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
   1202 
   1203   Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
   1204   unmodified.
   1205 
   1206   Args:
   1207     values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
   1208       can be consumed by `convert_to_tensor()`.
   1209     dtype: (Optional.) The required `DType` of the returned `Tensor`
   1210       `IndexedSlices`.
   1211     name: (Optional.) A name prefix to used when a new `Tensor` is
   1212       created, in which case element `i` will be given the name `name
   1213       + '_' + i`.
   1214     as_ref: True if the caller wants the results as ref tensors.
   1215 
   1216   Returns:
   1217     A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects.
   1218 
   1219   Raises:
   1220     TypeError: If no conversion function is registered for an element in
   1221       `values`.
   1222     RuntimeError: If a registered conversion function returns an invalid
   1223       value.
   1224   """
   1225   if not isinstance(values, collections.Sequence):
   1226     raise TypeError("values must be a list.")
   1227   ret = []
   1228   for i, value in enumerate(values):
   1229     if value is None:
   1230       ret.append(value)
   1231     else:
   1232       n = None if name is None else "%s_%d" % (name, i)
   1233       ret.append(
   1234           internal_convert_to_tensor_or_indexed_slices(
   1235               value, dtype=dtype, name=n, as_ref=as_ref))
   1236   return ret
   1237 
   1238 
   1239 def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None):
   1240   """Converts `values` to a list of `Output` or `IndexedSlices` objects.
   1241 
   1242   Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
   1243   unmodified.
   1244 
   1245   Args:
   1246     values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
   1247       can be consumed by `convert_to_tensor()`.
   1248     dtype: (Optional.) The required `DType` of the returned `Tensor`
   1249       `IndexedSlices`.
   1250     name: (Optional.) A name prefix to used when a new `Tensor` is
   1251       created, in which case element `i` will be given the name `name
   1252       + '_' + i`.
   1253 
   1254   Returns:
   1255     A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects.
   1256 
   1257   Raises:
   1258     TypeError: If no conversion function is registered for an element in
   1259       `values`.
   1260     RuntimeError: If a registered conversion function returns an invalid
   1261       value.
   1262   """
   1263   return internal_convert_n_to_tensor_or_indexed_slices(
   1264       values=values, dtype=dtype, name=name, as_ref=False)
   1265 
   1266 
   1267 # TODO(josh11b): Add ctx argument to conversion_func() signature.
   1268 @tf_export("register_tensor_conversion_function")
   1269 def register_tensor_conversion_function(base_type,
   1270                                         conversion_func,
   1271                                         priority=100):
   1272   """Registers a function for converting objects of `base_type` to `Tensor`.
   1273 
   1274   The conversion function must have the following signature:
   1275 
   1276   ```python
   1277       def conversion_func(value, dtype=None, name=None, as_ref=False):
   1278         # ...
   1279   ```
   1280 
   1281   It must return a `Tensor` with the given `dtype` if specified. If the
   1282   conversion function creates a new `Tensor`, it should use the given
   1283   `name` if specified. All exceptions will be propagated to the caller.
   1284 
   1285   The conversion function may return `NotImplemented` for some
   1286   inputs. In this case, the conversion process will continue to try
   1287   subsequent conversion functions.
   1288 
   1289   If `as_ref` is true, the function must return a `Tensor` reference,
   1290   such as a `Variable`.
   1291 
   1292   NOTE: The conversion functions will execute in order of priority,
   1293   followed by order of registration. To ensure that a conversion function
   1294   `F` runs before another conversion function `G`, ensure that `F` is
   1295   registered with a smaller priority than `G`.
   1296 
   1297   Args:
   1298     base_type: The base type or tuple of base types for all objects that
   1299       `conversion_func` accepts.
   1300     conversion_func: A function that converts instances of `base_type` to
   1301       `Tensor`.
   1302     priority: Optional integer that indicates the priority for applying this
   1303       conversion function. Conversion functions with smaller priority values
   1304       run earlier than conversion functions with larger priority values.
   1305       Defaults to 100.
   1306 
   1307   Raises:
   1308     TypeError: If the arguments do not have the appropriate type.
   1309 
   1310   """
   1311   global _tensor_conversion_func_cache
   1312   with _tensor_conversion_func_lock:
   1313     if not (isinstance(base_type, type) or
   1314             (isinstance(base_type, tuple) and
   1315              all(isinstance(x, type) for x in base_type))):
   1316       raise TypeError("base_type must be a type or a tuple of types.")
   1317     if not callable(conversion_func):
   1318       raise TypeError("conversion_func must be callable.")
   1319 
   1320     try:
   1321       funcs_at_priority = _tensor_conversion_func_registry[priority]
   1322     except KeyError:
   1323       funcs_at_priority = []
   1324       _tensor_conversion_func_registry[priority] = funcs_at_priority
   1325     funcs_at_priority.append((base_type, conversion_func))
   1326     _tensor_conversion_func_cache = {}
   1327 
   1328 
   1329 @tf_export("IndexedSlices")
   1330 class IndexedSlices(_TensorLike):
   1331   """A sparse representation of a set of tensor slices at given indices.
   1332 
   1333   This class is a simple wrapper for a pair of `Tensor` objects:
   1334 
   1335   * `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`.
   1336   * `indices`: A 1-D integer `Tensor` with shape `[D0]`.
   1337 
   1338   An `IndexedSlices` is typically used to represent a subset of a larger
   1339   tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`.
   1340   The values in `indices` are the indices in the first dimension of
   1341   the slices that have been extracted from the larger tensor.
   1342 
   1343   The dense tensor `dense` represented by an `IndexedSlices` `slices` has
   1344 
   1345   ```python
   1346   dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]
   1347   ```
   1348 
   1349   The `IndexedSlices` class is used principally in the definition of
   1350   gradients for operations that have sparse gradients
   1351   (e.g. @{tf.gather}).
   1352 
   1353   Contrast this representation with
   1354   @{tf.SparseTensor},
   1355   which uses multi-dimensional indices and scalar values.
   1356   """
   1357 
   1358   def __init__(self, values, indices, dense_shape=None):
   1359     """Creates an `IndexedSlices`."""
   1360     _get_graph_from_inputs([values, indices, dense_shape])
   1361     self._values = values
   1362     self._indices = indices
   1363     self._dense_shape = dense_shape
   1364 
   1365   @property
   1366   def values(self):
   1367     """A `Tensor` containing the values of the slices."""
   1368     return self._values
   1369 
   1370   @property
   1371   def indices(self):
   1372     """A 1-D `Tensor` containing the indices of the slices."""
   1373     return self._indices
   1374 
   1375   @property
   1376   def dense_shape(self):
   1377     """A 1-D `Tensor` containing the shape of the corresponding dense tensor."""
   1378     return self._dense_shape
   1379 
   1380   @property
   1381   def name(self):
   1382     """The name of this `IndexedSlices`."""
   1383     return self.values.name
   1384 
   1385   @property
   1386   def device(self):
   1387     """The name of the device on which `values` will be produced, or `None`."""
   1388     return self.values.device
   1389 
   1390   @property
   1391   def op(self):
   1392     """The `Operation` that produces `values` as an output."""
   1393     return self.values.op
   1394 
   1395   @property
   1396   def dtype(self):
   1397     """The `DType` of elements in this tensor."""
   1398     return self.values.dtype
   1399 
   1400   @property
   1401   def graph(self):
   1402     """The `Graph` that contains the values, indices, and shape tensors."""
   1403     return self._values.graph
   1404 
   1405   def __str__(self):
   1406     return "IndexedSlices(indices=%s, values=%s%s)" % (
   1407         self._indices, self._values, (", dense_shape=%s" % self._dense_shape)
   1408         if self._dense_shape is not None else "")
   1409 
   1410   def __neg__(self):
   1411     return IndexedSlices(-self.values, self.indices, self.dense_shape)
   1412 
   1413 
   1414 IndexedSlicesValue = collections.namedtuple(
   1415     "IndexedSlicesValue", ["values", "indices", "dense_shape"])
   1416 
   1417 
   1418 def _device_string(dev_spec):
   1419   if isinstance(dev_spec, pydev.DeviceSpec):
   1420     return dev_spec.to_string()
   1421   else:
   1422     return dev_spec
   1423 
   1424 
   1425 def _NodeDef(op_type, name, device=None, attrs=None):  # pylint: disable=redefined-outer-name
   1426   """Create a NodeDef proto.
   1427 
   1428   Args:
   1429     op_type: Value for the "op" attribute of the NodeDef proto.
   1430     name: Value for the "name" attribute of the NodeDef proto.
   1431     device: string, device, or function from NodeDef to string.
   1432       Value for the "device" attribute of the NodeDef proto.
   1433     attrs: Optional dictionary where the key is the attribute name (a string)
   1434       and the value is the respective "attr" attribute of the NodeDef proto (an
   1435       AttrValue).
   1436 
   1437   Returns:
   1438     A node_def_pb2.NodeDef protocol buffer.
   1439   """
   1440   node_def = node_def_pb2.NodeDef()
   1441   node_def.op = compat.as_bytes(op_type)
   1442   node_def.name = compat.as_bytes(name)
   1443   if attrs is not None:
   1444     for k, v in six.iteritems(attrs):
   1445       node_def.attr[k].CopyFrom(v)
   1446   if device is not None:
   1447     if callable(device):
   1448       node_def.device = device(node_def)
   1449     else:
   1450       node_def.device = _device_string(device)
   1451   return node_def
   1452 
   1453 
   1454 # Copied from core/framework/node_def_util.cc
   1455 # TODO(mrry,josh11b): Consolidate this validation in C++ code.
   1456 _VALID_OP_NAME_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*$")
   1457 _VALID_SCOPE_NAME_REGEX = re.compile("^[A-Za-z0-9_.\\-/]*$")
   1458 
   1459 
   1460 def _create_c_op(graph, node_def, inputs, control_inputs):
   1461   """Creates a TF_Operation.
   1462 
   1463   Args:
   1464     graph: a `Graph`.
   1465     node_def: `node_def_pb2.NodeDef` for the operation to create.
   1466     inputs: A list of `Tensor`s (corresponding to scalar inputs) and lists of
   1467       `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N",
   1468       "list(int64)"). The length of the list should be equal to the number of
   1469       inputs specified by this operation's op def.
   1470     control_inputs: A list of `Operation`s to set as control dependencies.
   1471 
   1472   Returns:
   1473     A wrapped TF_Operation*.
   1474   """
   1475   # pylint: disable=protected-access
   1476   op_desc = c_api.TF_NewOperation(graph._c_graph,
   1477                                   compat.as_str(node_def.op),
   1478                                   compat.as_str(node_def.name))
   1479   # Add inputs
   1480   for op_input in inputs:
   1481     if isinstance(op_input, (list, tuple)):
   1482       c_api.TF_AddInputList(op_desc, [t._as_tf_output() for t in op_input])
   1483     else:
   1484       c_api.TF_AddInput(op_desc, op_input._as_tf_output())
   1485 
   1486   # Add control inputs
   1487   for control_input in control_inputs:
   1488     c_api.TF_AddControlInput(op_desc, control_input._c_op)
   1489   # pylint: enable=protected-access
   1490 
   1491   # Add attrs
   1492   for name, attr_value in node_def.attr.items():
   1493     serialized = attr_value.SerializeToString()
   1494     # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
   1495     # It might be worth creating a convenient way to re-use the same status.
   1496     with errors.raise_exception_on_not_ok_status() as status:
   1497       c_api.TF_SetAttrValueProto(op_desc,
   1498                                  compat.as_str(name), serialized, status)
   1499 
   1500   try:
   1501     with errors.raise_exception_on_not_ok_status() as status:
   1502       c_op = c_api.TF_FinishOperation(op_desc, status)
   1503   except errors.InvalidArgumentError as e:
   1504     # Convert to ValueError for backwards compatibility.
   1505     raise ValueError(str(e))
   1506 
   1507   return c_op
   1508 
   1509 
   1510 @tf_export("Operation")
   1511 class Operation(object):
   1512   """Represents a graph node that performs computation on tensors.
   1513 
   1514   An `Operation` is a node in a TensorFlow `Graph` that takes zero or
   1515   more `Tensor` objects as input, and produces zero or more `Tensor`
   1516   objects as output. Objects of type `Operation` are created by
   1517   calling a Python op constructor (such as
   1518   @{tf.matmul})
   1519   or @{tf.Graph.create_op}.
   1520 
   1521   For example `c = tf.matmul(a, b)` creates an `Operation` of type
   1522   "MatMul" that takes tensors `a` and `b` as input, and produces `c`
   1523   as output.
   1524 
   1525   After the graph has been launched in a session, an `Operation` can
   1526   be executed by passing it to
   1527   @{tf.Session.run}.
   1528   `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
   1529   """
   1530 
   1531   def __init__(self,
   1532                node_def,
   1533                g,
   1534                inputs=None,
   1535                output_types=None,
   1536                control_inputs=None,
   1537                input_types=None,
   1538                original_op=None,
   1539                op_def=None):
   1540     r"""Creates an `Operation`.
   1541 
   1542     NOTE: This constructor validates the name of the `Operation` (passed
   1543     as `node_def.name`). Valid `Operation` names match the following
   1544     regular expression:
   1545 
   1546         [A-Za-z0-9.][A-Za-z0-9_.\\-/]*
   1547 
   1548     Args:
   1549       node_def: `node_def_pb2.NodeDef`.  `NodeDef` for the `Operation`.
   1550         Used for attributes of `node_def_pb2.NodeDef`, typically `name`,
   1551         `op`, and `device`.  The `input` attribute is irrelevant here
   1552         as it will be computed when generating the model.
   1553       g: `Graph`. The parent graph.
   1554       inputs: list of `Tensor` objects. The inputs to this `Operation`.
   1555       output_types: list of `DType` objects.  List of the types of the
   1556         `Tensors` computed by this operation.  The length of this list indicates
   1557         the number of output endpoints of the `Operation`.
   1558       control_inputs: list of operations or tensors from which to have a
   1559         control dependency.
   1560       input_types: List of `DType` objects representing the
   1561         types of the tensors accepted by the `Operation`.  By default
   1562         uses `[x.dtype.base_dtype for x in inputs]`.  Operations that expect
   1563         reference-typed inputs must specify these explicitly.
   1564       original_op: Optional. Used to associate the new `Operation` with an
   1565         existing `Operation` (for example, a replica with the op that was
   1566         replicated).
   1567       op_def: Optional. The `op_def_pb2.OpDef` proto that describes the
   1568         op type that this `Operation` represents.
   1569 
   1570     Raises:
   1571       TypeError: if control inputs are not Operations or Tensors,
   1572         or if `node_def` is not a `NodeDef`,
   1573         or if `g` is not a `Graph`,
   1574         or if `inputs` are not tensors,
   1575         or if `inputs` and `input_types` are incompatible.
   1576       ValueError: if the `node_def` name is not valid.
   1577     """
   1578     # For internal use only: `node_def` can be set to a TF_Operation to create
   1579     # an Operation for that op. This is useful for creating Operations for ops
   1580     # indirectly created by C API methods, e.g. the ops created by
   1581     # TF_ImportGraphDef. When `node_def` is a TF_Operation, all optional fields
   1582     # should be None.
   1583 
   1584     if isinstance(node_def, node_def_pb2.NodeDef):
   1585       if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0:
   1586         raise ValueError(
   1587             "Cannot create a tensor proto whose content is larger than 2GB.")
   1588       if not _VALID_OP_NAME_REGEX.match(node_def.name):
   1589         raise ValueError("'%s' is not a valid node name" % node_def.name)
   1590       c_op = None
   1591     elif type(node_def).__name__ == "SwigPyObject":
   1592       assert inputs is None
   1593       assert output_types is None
   1594       assert control_inputs is None
   1595       assert input_types is None
   1596       assert original_op is None
   1597       assert op_def is None
   1598       c_op = node_def
   1599     else:
   1600       raise TypeError("node_def needs to be a NodeDef: %s" % node_def)
   1601 
   1602     if not isinstance(g, Graph):
   1603       raise TypeError("g needs to be a Graph: %s" % g)
   1604     self._graph = g
   1605 
   1606     if inputs is None:
   1607       inputs = []
   1608     elif not isinstance(inputs, list):
   1609       raise TypeError("inputs needs to be a list of Tensors: %s" % inputs)
   1610     for a in inputs:
   1611       if not isinstance(a, Tensor):
   1612         raise TypeError("input needs to be a Tensor: %s" % a)
   1613     if input_types is None:
   1614       input_types = [i.dtype.base_dtype for i in inputs]
   1615     else:
   1616       if not all(
   1617           x.is_compatible_with(i.dtype)
   1618           for i, x in zip(inputs, input_types)):
   1619         raise TypeError("In op '%s', input types (%s) are not compatible "
   1620                         "with expected types (%s)" %
   1621                         (node_def.name, [i.dtype for i in inputs],
   1622                          input_types))
   1623 
   1624     # Build the list of control inputs.
   1625     control_input_ops = []
   1626     if control_inputs:
   1627       for c in control_inputs:
   1628         control_op = None
   1629         if isinstance(c, Operation):
   1630           control_op = c
   1631         elif isinstance(c, (Tensor, IndexedSlices)):
   1632           control_op = c.op
   1633         else:
   1634           raise TypeError("Control input must be an Operation, "
   1635                           "a Tensor, or IndexedSlices: %s" % c)
   1636         control_input_ops.append(control_op)
   1637 
   1638     # Don't set private fields with C API enabled to catch users who need to
   1639     # switch to public API.
   1640     # TODO(skyewm): delete these fields once we remove _USE_C_API
   1641     if not self._graph._c_graph:
   1642       self._inputs_val = list(inputs)  # Defensive copy.
   1643       self._input_types_val = input_types
   1644       self._control_inputs_val = control_input_ops
   1645       self._node_def_val = copy.deepcopy(node_def)
   1646       self._op_def_val = op_def
   1647 
   1648     self._id_value = self._graph._next_id()  # pylint: disable=protected-access
   1649     self._original_op = original_op
   1650     self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access
   1651     self._control_flow_context = self.graph._get_control_flow_context()  # pylint: disable=protected-access
   1652 
   1653     # Initialize self._c_op.
   1654     if c_op:
   1655       # TODO(skyewm): remove this assert when we remove USE_C_API
   1656       assert self._graph._c_graph  # pylint: disable=protected-access
   1657       self._c_op = c_op
   1658     elif self._graph._c_graph:  # pylint: disable=protected-access
   1659       if op_def is None:
   1660         op_def = self._graph._get_op_def(node_def.op)
   1661       # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs.
   1662       # Refactor so we don't have to do this here.
   1663       grouped_inputs = self._reconstruct_sequence_inputs(
   1664           op_def, inputs, node_def.attr)
   1665       self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
   1666                                 control_input_ops)
   1667     else:
   1668       self._c_op = None
   1669 
   1670     # Mark that we consume the inputs. This is unnecessary and unsupported with
   1671     # the C API enabled, since the C API tracks the tensor consumers instead.
   1672     if not self._c_op:
   1673       for input_tensor in self._inputs_val:
   1674         input_tensor._add_consumer(self)  # pylint: disable=protected-access
   1675 
   1676     # Initialize self._outputs.
   1677     if self._c_op:
   1678       num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
   1679       output_types = [
   1680           c_api.TF_OperationOutputType(c_api_util.tf_output(self._c_op, i))
   1681           for i in range(num_outputs)]
   1682       assert output_types is not None
   1683     elif output_types is None:
   1684       output_types = []
   1685     self._output_types_val = output_types
   1686     self._outputs = [
   1687         Tensor(self, i, output_type)
   1688         for i, output_type in enumerate(output_types)
   1689     ]
   1690 
   1691     if not c_op:
   1692       self._control_flow_post_processing()
   1693 
   1694   def _control_flow_post_processing(self):
   1695     """Add this op to its control flow context.
   1696 
   1697     This may add new ops and change this op's inputs. self.inputs must be
   1698     available before calling this method.
   1699     """
   1700     for input_tensor in self.inputs:
   1701       control_flow_util.CheckInputFromValidContext(self, input_tensor.op)
   1702     if self._control_flow_context is not None:
   1703       self._control_flow_context.AddOp(self)
   1704     self._recompute_node_def()
   1705 
   1706   def _reconstruct_sequence_inputs(self, op_def, inputs, attrs):
   1707     """Regroups a flat list of input tensors into scalar and sequence inputs.
   1708 
   1709     Args:
   1710       op_def: The `op_def_pb2.OpDef` (for knowing the input types)
   1711       inputs: a list of input `Tensor`s to the op.
   1712       attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define
   1713         how long each sequence is)
   1714 
   1715     Returns:
   1716       A list of `Tensor`s (corresponding to scalar inputs) and lists of
   1717       `Tensor`s (corresponding to sequence inputs).
   1718     """
   1719     grouped_inputs = []
   1720     i = 0
   1721     for input_arg in op_def.input_arg:
   1722       if input_arg.number_attr:
   1723         input_len = attrs[input_arg.number_attr].i
   1724         is_sequence = True
   1725       elif input_arg.type_list_attr:
   1726         input_len = len(attrs[input_arg.type_list_attr].list.type)
   1727         is_sequence = True
   1728       else:
   1729         input_len = 1
   1730         is_sequence = False
   1731 
   1732       if is_sequence:
   1733         grouped_inputs.append(inputs[i:i + input_len])
   1734       else:
   1735         grouped_inputs.append(inputs[i])
   1736       i += input_len
   1737 
   1738     assert i == len(inputs)
   1739     return grouped_inputs
   1740 
   1741   def colocation_groups(self):
   1742     """Returns the list of colocation groups of the op."""
   1743     default_colocation_group = [
   1744         compat.as_bytes("loc:@%s" % self.name)
   1745     ]
   1746     try:
   1747       class_attr = self.get_attr("_class")
   1748     except ValueError:
   1749       # This op has no explicit colocation group, so it is itself its
   1750       # own root of a colocation group.
   1751       return default_colocation_group
   1752 
   1753     attr_groups = [
   1754         class_name for class_name in class_attr
   1755         if class_name.startswith(b"loc:@")
   1756     ]
   1757 
   1758     # If there are no colocation groups in the explicit _class field,
   1759     # return the default colocation group.
   1760     return attr_groups if attr_groups else default_colocation_group
   1761 
   1762   def values(self):
   1763     """DEPRECATED: Use outputs."""
   1764     return tuple(self.outputs)
   1765 
   1766   def _get_control_flow_context(self):
   1767     """Returns the control flow context of this op.
   1768 
   1769     Returns:
   1770       A context object.
   1771     """
   1772     return self._control_flow_context
   1773 
   1774   def _set_control_flow_context(self, ctx):
   1775     """Sets the current control flow context of this op.
   1776 
   1777     Args:
   1778       ctx: a context object.
   1779     """
   1780     self._control_flow_context = ctx
   1781 
   1782   @property
   1783   def name(self):
   1784     """The full name of this operation."""
   1785     if self._c_op:
   1786       return c_api.TF_OperationName(self._c_op)
   1787     else:
   1788       return self._node_def_val.name
   1789 
   1790   @property
   1791   def _id(self):
   1792     """The unique integer id of this operation."""
   1793     return self._id_value
   1794 
   1795   @property
   1796   def device(self):
   1797     """The name of the device to which this op has been assigned, if any.
   1798 
   1799     Returns:
   1800       The string name of the device to which this op has been
   1801       assigned, or an empty string if it has not been assigned to a
   1802       device.
   1803     """
   1804     if self._c_op:
   1805       return c_api.TF_OperationDevice(self._c_op)
   1806     else:
   1807       return self._node_def_val.device
   1808 
   1809   @property
   1810   def _output_types(self):
   1811     """List this operation's output types.
   1812 
   1813     Returns:
   1814       List of the types of the Tensors computed by this operation.
   1815       Each element in the list is an integer whose value is one of
   1816       the TF_DataType enums defined in c_api.h
   1817       The length of this list indicates the number of output endpoints
   1818       of the operation.
   1819     """
   1820     if self._c_op:
   1821       num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
   1822       output_types = [
   1823           c_api.TF_OperationOutputType(self._tf_output(i))
   1824           for i in xrange(num_outputs)
   1825       ]
   1826       # TODO(iga): Remove this assert after converting to C API by default.
   1827       # Just being a bit paranoid here.
   1828       assert self._output_types_val == output_types
   1829       # In all the tests we have output_types that are passed into
   1830       # Operation.__init__ are a list of ints (which is illegal according
   1831       # to the docstring), but input_types are instances of DType.
   1832       # This extra assert is to catch if we ever use DType for output_types.
   1833       if output_types:
   1834         assert isinstance(output_types[0], int)
   1835       return output_types
   1836     else:
   1837       return self._output_types_val
   1838 
   1839   def _tf_output(self, output_idx):
   1840     """Create and return a new TF_Output for output_idx'th output of this op."""
   1841     assert self._c_op
   1842     tf_output = c_api.TF_Output()
   1843     tf_output.oper = self._c_op
   1844     tf_output.index = output_idx
   1845     return tf_output
   1846 
   1847   def _tf_input(self, input_idx):
   1848     """Create and return a new TF_Input for input_idx'th input of this op."""
   1849     assert self._c_op
   1850     tf_input = c_api.TF_Input()
   1851     tf_input.oper = self._c_op
   1852     tf_input.index = input_idx
   1853     return tf_input
   1854 
   1855   def _set_device(self, device):  # pylint: disable=redefined-outer-name
   1856     """Set the device of this operation.
   1857 
   1858     Args:
   1859       device: string or device..  The device to set.
   1860     """
   1861     if self._c_op:
   1862       c_api.SetRequestedDevice(
   1863           self._graph._c_graph,  # pylint: disable=protected-access
   1864           self._c_op,  # pylint: disable=protected-access
   1865           compat.as_str(_device_string(device)))
   1866     else:
   1867       self._node_def_val.device = _device_string(device)
   1868 
   1869   def _add_input(self, tensor, dtype=None):
   1870     """Add a new input to this operation.
   1871 
   1872     Args:
   1873       tensor: the Tensor to add as an input.
   1874       dtype: tf.DType: type of the input; defaults to
   1875         the tensor's dtype.
   1876 
   1877     Raises:
   1878       TypeError: if tensor is not a Tensor,
   1879         or if input tensor type is not convertible to dtype.
   1880       ValueError: if the Tensor is from a different graph.
   1881     """
   1882     assert not self._c_op, (
   1883         "Operation._add_input doesn't work with C API")
   1884     if not isinstance(tensor, Tensor):
   1885       raise TypeError("tensor must be a Tensor: %s" % tensor)
   1886     _assert_same_graph(self, tensor)
   1887     if dtype is None:
   1888       dtype = tensor.dtype
   1889     else:
   1890       dtype = dtypes.as_dtype(dtype)
   1891       if not dtype.is_compatible_with(tensor.dtype):
   1892         raise TypeError(
   1893             "Cannot convert a tensor of type %s to an input of type %s" %
   1894             (tensor.dtype.name, dtype.name))
   1895     self._inputs_val.append(tensor)
   1896     self._input_types_val.append(dtype)
   1897     tensor._add_consumer(self)  # pylint: disable=protected-access
   1898     self._recompute_node_def()
   1899 
   1900   def _update_input(self, index, tensor):
   1901     """Update the input to this operation at the given index.
   1902 
   1903     NOTE: This is for TF internal use only. Please don't use it.
   1904 
   1905     Args:
   1906       index: the index of the input to update.
   1907       tensor: the Tensor to be used as the input at the given index.
   1908 
   1909     Raises:
   1910       TypeError: if tensor is not a Tensor,
   1911         or if input tensor type is not convertible to dtype.
   1912       ValueError: if the Tensor is from a different graph.
   1913     """
   1914     if not isinstance(tensor, Tensor):
   1915       raise TypeError("tensor must be a Tensor: %s" % tensor)
   1916     _assert_same_graph(self, tensor)
   1917     if self._c_op:
   1918       with errors.raise_exception_on_not_ok_status() as status:
   1919         c_api.UpdateEdge(
   1920             self._graph._c_graph,  # pylint: disable=protected-access
   1921             tensor._as_tf_output(),  # pylint: disable=protected-access
   1922             self._tf_input(index),
   1923             status)
   1924     else:
   1925       self._inputs_val[index].consumers().remove(self)
   1926       self._inputs_val[index] = tensor
   1927       self._input_types_val[index] = tensor.dtype
   1928       tensor._add_consumer(self)  # pylint: disable=protected-access
   1929       self._recompute_node_def()
   1930 
   1931   def _add_control_inputs(self, ops):
   1932     """Add a list of new control inputs to this operation.
   1933 
   1934     Args:
   1935       ops: the list of Operations to add as control input.
   1936 
   1937     Raises:
   1938       TypeError: if ops is not a list of Operations.
   1939       ValueError: if any op in ops is from a different graph.
   1940     """
   1941     if self._c_op:
   1942       for op in ops:
   1943         if not isinstance(op, Operation):
   1944           raise TypeError("op must be an Operation: %s" % op)
   1945         c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op)  # pylint: disable=protected-access
   1946     else:
   1947       if ops:
   1948         for op in ops:
   1949           if not isinstance(op, Operation):
   1950             raise TypeError("op must be an Operation: %s" % op)
   1951           _assert_same_graph(self, op)
   1952           self._control_inputs_val.append(op)
   1953         self._recompute_node_def()
   1954 
   1955   def _add_control_input(self, op):
   1956     """Add a new control input to this operation.
   1957 
   1958     Args:
   1959       op: the Operation to add as control input.
   1960 
   1961     Raises:
   1962       TypeError: if op is not an Operation.
   1963       ValueError: if op is from a different graph.
   1964     """
   1965     if self._c_op:
   1966       if not isinstance(op, Operation):
   1967         raise TypeError("op must be an Operation: %s" % op)
   1968       c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op)  # pylint: disable=protected-access
   1969     else:
   1970       self._add_control_inputs([op])
   1971 
   1972   def _remove_all_control_inputs(self):
   1973     """Removes any control inputs to this operation."""
   1974     if self._c_op:
   1975       c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op)  # pylint: disable=protected-access
   1976     else:
   1977       del self.control_inputs[:]
   1978 
   1979   # Methods below are used when building the NodeDef and Graph proto.
   1980   def _recompute_node_def(self):
   1981     # TODO(skyewm): remove this function when we switch to C API
   1982     if self._c_op: return
   1983 
   1984     del self._node_def_val.input[:]
   1985     # pylint: disable=protected-access
   1986     self._node_def_val.input.extend(
   1987         [t._as_node_def_input() for t in self._inputs_val])
   1988     # pylint: enable=protected-access
   1989     if self._control_inputs_val:
   1990       self._node_def_val.input.extend(
   1991           ["^%s" % op.name for op in self._control_inputs_val])
   1992 
   1993   def __str__(self):
   1994     return str(self.node_def)
   1995 
   1996   def __repr__(self):
   1997     return "<tf.Operation '%s' type=%s>" % (self.name, self.type)
   1998 
   1999   @property
   2000   def outputs(self):
   2001     """The list of `Tensor` objects representing the outputs of this op."""
   2002     return self._outputs
   2003 
   2004 # pylint: disable=protected-access
   2005 
   2006   class _InputList(object):
   2007     """Immutable input list wrapper."""
   2008 
   2009     def __init__(self, inputs):
   2010       self._inputs = inputs
   2011 
   2012     def __iter__(self):
   2013       return iter(self._inputs)
   2014 
   2015     def __len__(self):
   2016       return len(self._inputs)
   2017 
   2018     def __bool__(self):
   2019       return bool(self._inputs)
   2020 
   2021     # Python 3 wants __bool__, Python 2.7 wants __nonzero__
   2022     __nonzero__ = __bool__
   2023 
   2024     def __getitem__(self, i):
   2025       return self._inputs[i]
   2026 
   2027 # pylint: enable=protected-access
   2028 
   2029   @property
   2030   def inputs(self):
   2031     """The list of `Tensor` objects representing the data inputs of this op."""
   2032     if self._c_op:
   2033       tf_outputs = c_api.GetOperationInputs(self._c_op)
   2034       # pylint: disable=protected-access
   2035       retval = [
   2036           self.graph._get_tensor_by_tf_output(tf_output)
   2037           for tf_output in tf_outputs
   2038       ]
   2039       # pylint: enable=protected-access
   2040       return Operation._InputList(retval)
   2041     return Operation._InputList(self._inputs_val)
   2042 
   2043   @property
   2044   def _inputs(self):
   2045     logging.warning("Operation._inputs is private, use Operation.inputs "
   2046                     "instead. Operation._inputs will eventually be removed.")
   2047     return self.inputs
   2048 
   2049   @_inputs.setter
   2050   def _inputs(self, value):
   2051     raise ValueError("Cannot assign _inputs")
   2052 
   2053   @property
   2054   def _input_types(self):
   2055     if self._c_op:
   2056       num_inputs = c_api.TF_OperationNumInputs(self._c_op)
   2057       input_types = [
   2058           dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i)))
   2059           for i in xrange(num_inputs)
   2060       ]
   2061       return input_types
   2062     else:
   2063       return self._input_types_val
   2064 
   2065   @_input_types.setter
   2066   def _input_types(self, value):
   2067     raise ValueError("Cannot assign _input_types")
   2068 
   2069   @property
   2070   def control_inputs(self):
   2071     """The `Operation` objects on which this op has a control dependency.
   2072 
   2073     Before this op is executed, TensorFlow will ensure that the
   2074     operations in `self.control_inputs` have finished executing. This
   2075     mechanism can be used to run ops sequentially for performance
   2076     reasons, or to ensure that the side effects of an op are observed
   2077     in the correct order.
   2078 
   2079     Returns:
   2080       A list of `Operation` objects.
   2081 
   2082     """
   2083     if self._c_op:
   2084       control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op)
   2085       # pylint: disable=protected-access
   2086       return [
   2087           self.graph._get_operation_by_name_unsafe(
   2088               c_api.TF_OperationName(c_op)) for c_op in control_c_ops
   2089       ]
   2090       # pylint: enable=protected-access
   2091     else:
   2092       return self._control_inputs_val
   2093 
   2094   @property
   2095   def _control_inputs(self):
   2096     logging.warning("Operation._control_inputs is private, use "
   2097                     "Operation.control_inputs instead. "
   2098                     "Operation._control_inputs will eventually be removed.")
   2099     return self.control_inputs
   2100 
   2101   @_control_inputs.setter
   2102   def _control_inputs(self, value):
   2103     logging.warning("Operation._control_inputs is private, use "
   2104                     "Operation.control_inputs instead. "
   2105                     "Operation._control_inputs will eventually be removed.")
   2106     # Copy value because it may be self._control_inputs_val (in particular if
   2107     # this is called from self._control_inputs += ...), and we don't want to
   2108     # clear value below.
   2109     value = copy.copy(value)
   2110     self._remove_all_control_inputs()
   2111     self._add_control_inputs(value)
   2112 
   2113   @property
   2114   def type(self):
   2115     """The type of the op (e.g. `"MatMul"`)."""
   2116     if self._c_op:
   2117       op_type = c_api.TF_OperationOpType(self._c_op)
   2118       return op_type
   2119     else:
   2120       return self._node_def_val.op
   2121 
   2122   @property
   2123   def graph(self):
   2124     """The `Graph` that contains this operation."""
   2125     return self._graph
   2126 
   2127   @property
   2128   def node_def(self):
   2129     # pylint: disable=line-too-long
   2130     """Returns the `NodeDef` representation of this operation.
   2131 
   2132     Returns:
   2133       A
   2134       [`NodeDef`](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto)
   2135       protocol buffer.
   2136     """
   2137     # pylint: enable=line-too-long
   2138     if self._c_op:
   2139       with c_api_util.tf_buffer() as buf:
   2140         with errors.raise_exception_on_not_ok_status() as status:
   2141           c_api.TF_OperationToNodeDef(self._c_op, buf, status)
   2142         data = c_api.TF_GetBuffer(buf)
   2143       node_def = node_def_pb2.NodeDef()
   2144       node_def.ParseFromString(compat.as_bytes(data))
   2145       return node_def
   2146     else:
   2147       return self._node_def_val
   2148 
   2149   @property
   2150   def _node_def(self):
   2151     logging.warning("Operation._node_def is private, use Operation.node_def "
   2152                     "instead. Operation._node_def will eventually be removed.")
   2153     return self.node_def
   2154 
   2155   @property
   2156   def op_def(self):
   2157     # pylint: disable=line-too-long
   2158     """Returns the `OpDef` proto that represents the type of this op.
   2159 
   2160     Returns:
   2161       An
   2162       [`OpDef`](https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto)
   2163       protocol buffer.
   2164     """
   2165     # pylint: enable=line-too-long
   2166     if self._c_op:
   2167       return self._graph._get_op_def(self.type)
   2168     else:
   2169       return self._op_def_val
   2170 
   2171   @property
   2172   def _op_def(self):
   2173     logging.warning("Operation._op_def is private, use Operation.op_def "
   2174                     "instead. Operation._op_def will eventually be removed.")
   2175     return self.op_def
   2176 
   2177   @property
   2178   def traceback(self):
   2179     """Returns the call stack from when this operation was constructed."""
   2180     return self._graph._convert_stack(self._traceback)  # pylint: disable=protected-access
   2181 
   2182   @property
   2183   def traceback_with_start_lines(self):
   2184     """Same as traceback but includes start line of function definition.
   2185 
   2186     Returns:
   2187       A list of 5-tuples (filename, lineno, name, code, func_start_lineno).
   2188     """
   2189     return self._graph._convert_stack(  # pylint: disable=protected-access
   2190         self._traceback,
   2191         include_func_start_lineno=True)
   2192 
   2193   def _set_attr(self, attr_name, attr_value):
   2194     """Private method used to set an attribute in the node_def."""
   2195     if self._c_op:
   2196       buf = c_api.TF_NewBufferFromString(
   2197           compat.as_bytes(attr_value.SerializeToString()))
   2198       try:
   2199         with errors.raise_exception_on_not_ok_status() as status:
   2200           # pylint: disable=protected-access
   2201           c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf,
   2202                         status)
   2203           # pylint: enable=protected-access
   2204       finally:
   2205         c_api.TF_DeleteBuffer(buf)
   2206     else:
   2207       self._node_def_val.attr[attr_name].CopyFrom(attr_value)
   2208 
   2209   def get_attr(self, name):
   2210     """Returns the value of the attr of this op with the given `name`.
   2211 
   2212     Args:
   2213       name: The name of the attr to fetch.
   2214 
   2215     Returns:
   2216       The value of the attr, as a Python object.
   2217 
   2218     Raises:
   2219       ValueError: If this op does not have an attr with the given `name`.
   2220     """
   2221     fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
   2222     if self._c_op:
   2223       try:
   2224         with c_api_util.tf_buffer() as buf:
   2225           with errors.raise_exception_on_not_ok_status() as status:
   2226             c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf, status)
   2227           data = c_api.TF_GetBuffer(buf)
   2228       except errors.InvalidArgumentError as e:
   2229         # Convert to ValueError for backwards compatibility.
   2230         raise ValueError(str(e))
   2231       x = attr_value_pb2.AttrValue()
   2232       x.ParseFromString(data)
   2233     else:
   2234       if name not in self._node_def_val.attr:
   2235         raise ValueError(
   2236             "No attr named '" + name + "' in " + str(self._node_def_val))
   2237       x = self._node_def_val.attr[name]
   2238 
   2239     # Treat an empty oneof value as an empty list.
   2240     if not x.WhichOneof("value"):
   2241       return []
   2242     if x.HasField("list"):
   2243       for f in fields:
   2244         if getattr(x.list, f):
   2245           if f == "type":
   2246             return [dtypes.as_dtype(x) for x in list(getattr(x.list, f))]
   2247           else:
   2248             return list(getattr(x.list, f))
   2249       return []
   2250     else:
   2251       for f in fields:
   2252         if x.HasField(f):
   2253           if f == "type":
   2254             return dtypes.as_dtype(getattr(x, f))
   2255           else:
   2256             return getattr(x, f)
   2257       assert False, "Unsupported field type in " + str(x)
   2258 
   2259   def run(self, feed_dict=None, session=None):
   2260     """Runs this operation in a `Session`.
   2261 
   2262     Calling this method will execute all preceding operations that
   2263     produce the inputs needed for this operation.
   2264 
   2265     *N.B.* Before invoking `Operation.run()`, its graph must have been
   2266     launched in a session, and either a default session must be
   2267     available, or `session` must be specified explicitly.
   2268 
   2269     Args:
   2270       feed_dict: A dictionary that maps `Tensor` objects to feed values.
   2271         See @{tf.Session.run}
   2272         for a description of the valid feed values.
   2273       session: (Optional.) The `Session` to be used to run to this operation. If
   2274         none, the default session will be used.
   2275     """
   2276     _run_using_default_session(self, feed_dict, self.graph, session)
   2277 
   2278 _gradient_registry = registry.Registry("gradient")
   2279 
   2280 
   2281 @tf_export("RegisterGradient")
   2282 class RegisterGradient(object):
   2283   """A decorator for registering the gradient function for an op type.
   2284 
   2285   This decorator is only used when defining a new op type. For an op
   2286   with `m` inputs and `n` outputs, the gradient function is a function
   2287   that takes the original `Operation` and `n` `Tensor` objects
   2288   (representing the gradients with respect to each output of the op),
   2289   and returns `m` `Tensor` objects (representing the partial gradients
   2290   with respect to each input of the op).
   2291 
   2292   For example, assuming that operations of type `"Sub"` take two
   2293   inputs `x` and `y`, and return a single output `x - y`, the
   2294   following gradient function would be registered:
   2295 
   2296   ```python
   2297   @tf.RegisterGradient("Sub")
   2298   def _sub_grad(unused_op, grad):
   2299     return grad, tf.negative(grad)
   2300   ```
   2301 
   2302   The decorator argument `op_type` is the string type of an
   2303   operation. This corresponds to the `OpDef.name` field for the proto
   2304   that defines the operation.
   2305   """
   2306 
   2307   def __init__(self, op_type):
   2308     """Creates a new decorator with `op_type` as the Operation type.
   2309 
   2310     Args:
   2311       op_type: The string type of an operation. This corresponds to the
   2312         `OpDef.name` field for the proto that defines the operation.
   2313     """
   2314     if not isinstance(op_type, six.string_types):
   2315       raise TypeError("op_type must be a string")
   2316     self._op_type = op_type
   2317 
   2318   def __call__(self, f):
   2319     """Registers the function `f` as gradient function for `op_type`."""
   2320     _gradient_registry.register(f, self._op_type)
   2321     return f
   2322 
   2323 
   2324 @tf_export("NoGradient", "NotDifferentiable")
   2325 def NotDifferentiable(op_type):
   2326   """Specifies that ops of type `op_type` is not differentiable.
   2327 
   2328   This function should *not* be used for operations that have a
   2329   well-defined gradient that is not yet implemented.
   2330 
   2331   This function is only used when defining a new op type. It may be
   2332   used for ops such as `tf.size()` that are not differentiable.  For
   2333   example:
   2334 
   2335   ```python
   2336   tf.NotDifferentiable("Size")
   2337   ```
   2338 
   2339   The gradient computed for 'op_type' will then propagate zeros.
   2340 
   2341   For ops that have a well-defined gradient but are not yet implemented,
   2342   no declaration should be made, and an error *must* be thrown if
   2343   an attempt to request its gradient is made.
   2344 
   2345   Args:
   2346     op_type: The string type of an operation. This corresponds to the
   2347       `OpDef.name` field for the proto that defines the operation.
   2348 
   2349   Raises:
   2350     TypeError: If `op_type` is not a string.
   2351 
   2352   """
   2353   if not isinstance(op_type, six.string_types):
   2354     raise TypeError("op_type must be a string")
   2355   _gradient_registry.register(None, op_type)
   2356 
   2357 
   2358 # Alias for the old name, will be eventually removed.
   2359 NoGradient = NotDifferentiable
   2360 
   2361 
   2362 def get_gradient_function(op):
   2363   """Returns the function that computes gradients for "op"."""
   2364   if not op.inputs:
   2365     return None
   2366   try:
   2367     op_type = op.get_attr("_gradient_op_type")
   2368   except ValueError:
   2369     op_type = op.type
   2370   return _gradient_registry.lookup(op_type)
   2371 
   2372 
   2373 _shape_registry = registry.Registry("shape functions")
   2374 _default_shape_function_registry = registry.Registry("default shape functions")
   2375 
   2376 # These are set to common_shapes.call_cpp_shape_fn by op generated code
   2377 # (generated by python_op_gen.cc).
   2378 # It is set outside ops.py to avoid a circular dependency.
   2379 _call_cpp_shape_fn = None
   2380 _call_cpp_shape_fn_and_require_op = None
   2381 
   2382 
   2383 def _set_call_cpp_shape_fn(call_cpp_shape_fn):
   2384   """Sets default shape fns from passed common_shapes.call_cpp_shape_fn."""
   2385   global _call_cpp_shape_fn, _call_cpp_shape_fn_and_require_op
   2386   if _call_cpp_shape_fn:
   2387     return  # already registered
   2388 
   2389   def call_without_requiring(op):
   2390     return call_cpp_shape_fn(op, require_shape_fn=False)
   2391 
   2392   _call_cpp_shape_fn = call_without_requiring
   2393 
   2394   def call_with_requiring(op):
   2395     return call_cpp_shape_fn(op, require_shape_fn=True)
   2396 
   2397   _call_cpp_shape_fn_and_require_op = call_with_requiring
   2398 
   2399 
   2400 class RegisterShape(object):
   2401   """No longer used.  Was: A decorator for registering a shape function.
   2402 
   2403   Shape functions must now be registered via the SetShapeFn on the
   2404   original Op specification in C++.
   2405 
   2406   """
   2407 
   2408   def __init__(self, op_type):
   2409     """Saves the `op_type` as the `Operation` type."""
   2410     if not isinstance(op_type, six.string_types):
   2411       raise TypeError("op_type must be a string")
   2412     self._op_type = op_type
   2413 
   2414   def __call__(self, f):
   2415     """Registers "f" as the shape function for "op_type"."""
   2416     if f is None:
   2417       assert _call_cpp_shape_fn
   2418 
   2419       # None is a special "weak" value that provides a default shape function,
   2420       # and can be overridden by a non-None registration.
   2421       try:
   2422         _default_shape_function_registry.register(_call_cpp_shape_fn,
   2423                                                   self._op_type)
   2424       except KeyError:
   2425         # Ignore duplicate registrations of the weak value. This can
   2426         # occur if the op library input to wrapper generation
   2427         # inadvertently links in one or more of the standard op
   2428         # libraries.
   2429         pass
   2430     else:
   2431       _shape_registry.register(f, self._op_type)
   2432     return f
   2433 
   2434 
   2435 def _set_shapes_for_outputs_c_api(op):
   2436   """set_shapes_for_outputs implementation when C API is enabled."""
   2437   # The C API computes the shapes when the TF_Operation is created. Fetch the
   2438   # output shapes from the C object.
   2439   for output in op.outputs:
   2440     with errors.raise_exception_on_not_ok_status() as status:
   2441       # pylint: disable=protected-access
   2442       shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
   2443           op._graph._c_graph, output._as_tf_output(), status)
   2444       # pylint: enable=protected-access
   2445     if unknown_shape:
   2446       output.set_shape(tensor_shape.unknown_shape())
   2447     elif not shape_vector:
   2448       output.set_shape(tensor_shape.scalar())
   2449     else:
   2450       shape_vector = [None if d == -1 else d for d in shape_vector]
   2451       output.set_shape(tensor_shape.TensorShape(shape_vector))
   2452 
   2453 
   2454 # TODO(skyewm): remove this when _USE_C_API flag is removed.
   2455 def _set_shapes_for_outputs(op):
   2456   """set_shapes_for_outputs implementation when C API is disabled."""
   2457   try:
   2458     shape_func = _shape_registry.lookup(op.type)
   2459   except LookupError:
   2460     try:
   2461       shape_func = _default_shape_function_registry.lookup(op.type)
   2462     except LookupError:
   2463       shape_func = _call_cpp_shape_fn_and_require_op
   2464 
   2465   shapes = shape_func(op)
   2466   if shapes is None:
   2467     raise RuntimeError(
   2468         "Shape function for op %s did not return any shapes" % op)
   2469   elif isinstance(shapes, dict):
   2470     # Returned by call_cpp_shape_fn
   2471     shapes_dict = shapes
   2472     shapes = shapes_dict["shapes"]
   2473     handle_datas = shapes_dict["handle_data"]
   2474     for output, handle_data in zip(op.outputs, handle_datas):
   2475       # pylint: disable=protected-access
   2476       output._handle_data = handle_data
   2477       # pylint: enable=protected-access
   2478 
   2479   if len(op.outputs) != len(shapes):
   2480     raise RuntimeError(
   2481         "Shape function for op %s returned %d shapes but expected %d %s %s" %
   2482         (op, len(shapes), len(op.outputs), shape_func.__name__, str(shapes)))
   2483   for output, s in zip(op.outputs, shapes):
   2484     output.set_shape(s)
   2485 
   2486 
   2487 def set_shapes_for_outputs(op):
   2488   """Set the shapes for op's outputs."""
   2489   if op._c_op:  # pylint: disable=protected-access
   2490     return _set_shapes_for_outputs_c_api(op)
   2491   else:
   2492     return _set_shapes_for_outputs(op)
   2493 
   2494 
   2495 class OpStats(object):
   2496   """A holder for statistics about an operator.
   2497 
   2498   This class holds information about the resource requirements for an op,
   2499   including the size of its weight parameters on-disk and how many FLOPS it
   2500   requires to execute forward inference.
   2501 
   2502   If you define a new operation, you can create a function that will return a
   2503   set of information about its usage of the CPU and disk space when serialized.
   2504   The function itself takes a Graph object that's been set up so you can call
   2505   methods like get_tensor_by_name to help calculate the results, and a NodeDef
   2506   argument.
   2507 
   2508   """
   2509 
   2510   def __init__(self, statistic_type, value=None):
   2511     """Sets up the initial placeholders for the statistics."""
   2512     self.statistic_type = statistic_type
   2513     self.value = value
   2514 
   2515   @property
   2516   def statistic_type(self):
   2517     return self._statistic_type
   2518 
   2519   @statistic_type.setter
   2520   def statistic_type(self, statistic_type):
   2521     self._statistic_type = statistic_type
   2522 
   2523   @property
   2524   def value(self):
   2525     return self._value
   2526 
   2527   @value.setter
   2528   def value(self, value):
   2529     self._value = value
   2530 
   2531   def __iadd__(self, other):
   2532     if other.statistic_type != self.statistic_type:
   2533       raise ValueError("Can't add an OpStat of type %s to one of %s." %
   2534                        (self.statistic_type, other.statistic_type))
   2535     if self.value is None:
   2536       self.value = other.value
   2537     elif other.value is not None:
   2538       self._value += other.value
   2539     return self
   2540 
   2541 
   2542 _stats_registry = registry.Registry("statistical functions")
   2543 
   2544 
   2545 class RegisterStatistics(object):
   2546   """A decorator for registering the statistics function for an op type.
   2547 
   2548   This decorator can be defined for an op type so that it gives a
   2549   report on the resources used by an instance of an operator, in the
   2550   form of an OpStats object.
   2551 
   2552   Well-known types of statistics include these so far:
   2553 
   2554   - flops: When running a graph, the bulk of the computation happens doing
   2555     numerical calculations like matrix multiplications. This type allows a node
   2556     to return how many floating-point operations it takes to complete. The
   2557     total number of FLOPs for a graph is a good guide to its expected latency.
   2558 
   2559   You can add your own statistics just by picking a new type string, registering
   2560   functions for the ops you care about, and then calling get_stats_for_node_def.
   2561 
   2562   If a statistic for an op is registered multiple times, a KeyError will be
   2563   raised.
   2564 
   2565   Since the statistics is counted on a per-op basis. It is not suitable for
   2566   model parameters (capacity), which is expected to be counted only once, even
   2567   if it is shared by multiple ops. (e.g. RNN)
   2568 
   2569   For example, you can define a new metric called doohickey for a Foo operation
   2570   by placing this in your code:
   2571 
   2572   ```python
   2573   @ops.RegisterStatistics("Foo", "doohickey")
   2574   def _calc_foo_bojangles(unused_graph, unused_node_def):
   2575     return ops.OpStats("doohickey", 20)
   2576   ```
   2577 
   2578   Then in client code you can retrieve the value by making this call:
   2579 
   2580   ```python
   2581   doohickey = ops.get_stats_for_node_def(graph, node_def, "doohickey")
   2582   ```
   2583 
   2584   If the NodeDef is for an op with a registered doohickey function, you'll get
   2585   back the calculated amount in doohickey.value, or None if it's not defined.
   2586 
   2587   """
   2588 
   2589   def __init__(self, op_type, statistic_type):
   2590     """Saves the `op_type` as the `Operation` type."""
   2591     if not isinstance(op_type, six.string_types):
   2592       raise TypeError("op_type must be a string.")
   2593     if "," in op_type:
   2594       raise TypeError("op_type must not contain a comma.")
   2595     self._op_type = op_type
   2596     if not isinstance(statistic_type, six.string_types):
   2597       raise TypeError("statistic_type must be a string.")
   2598     if "," in statistic_type:
   2599       raise TypeError("statistic_type must not contain a comma.")
   2600     self._statistic_type = statistic_type
   2601 
   2602   def __call__(self, f):
   2603     """Registers "f" as the statistics function for "op_type"."""
   2604     _stats_registry.register(f, self._op_type + "," + self._statistic_type)
   2605     return f
   2606 
   2607 
   2608 def get_stats_for_node_def(graph, node, statistic_type):
   2609   """Looks up the node's statistics function in the registry and calls it.
   2610 
   2611   This function takes a Graph object and a NodeDef from a GraphDef, and if
   2612   there's an associated statistics method, calls it and returns a result. If no
   2613   function has been registered for the particular node type, it returns an empty
   2614   statistics object.
   2615 
   2616   Args:
   2617     graph: A Graph object that's been set up with the node's graph.
   2618     node: A NodeDef describing the operator.
   2619     statistic_type: A string identifying the statistic we're interested in.
   2620   Returns:
   2621     An OpStats object containing information about resource usage.
   2622   """
   2623 
   2624   try:
   2625     stats_func = _stats_registry.lookup(node.op + "," + statistic_type)
   2626     result = stats_func(graph, node)
   2627   except LookupError:
   2628     result = OpStats(statistic_type)
   2629   return result
   2630 
   2631 
   2632 def _name_from_scope_name(name):
   2633   """Returns the name of an op given the name of its scope.
   2634 
   2635   Args:
   2636     name: the name of the scope.
   2637 
   2638   Returns:
   2639     the name of the op (equal to scope name minus any trailing slash).
   2640   """
   2641   return name[:-1] if (name and name[-1] == "/") else name
   2642 
   2643 
   2644 @tf_export("Graph")
   2645 class Graph(object):
   2646   """A TensorFlow computation, represented as a dataflow graph.
   2647 
   2648   A `Graph` contains a set of
   2649   @{tf.Operation} objects,
   2650   which represent units of computation; and
   2651   @{tf.Tensor} objects, which represent
   2652   the units of data that flow between operations.
   2653 
   2654   A default `Graph` is always registered, and accessible by calling
   2655   @{tf.get_default_graph}.
   2656   To add an operation to the default graph, simply call one of the functions
   2657   that defines a new `Operation`:
   2658 
   2659   ```python
   2660   c = tf.constant(4.0)
   2661   assert c.graph is tf.get_default_graph()
   2662   ```
   2663 
   2664   Another typical usage involves the
   2665   @{tf.Graph.as_default}
   2666   context manager, which overrides the current default graph for the
   2667   lifetime of the context:
   2668 
   2669   ```python
   2670   g = tf.Graph()
   2671   with g.as_default():
   2672     # Define operations and tensors in `g`.
   2673     c = tf.constant(30.0)
   2674     assert c.graph is g
   2675   ```
   2676 
   2677   Important note: This class *is not* thread-safe for graph construction. All
   2678   operations should be created from a single thread, or external
   2679   synchronization must be provided. Unless otherwise specified, all methods
   2680   are not thread-safe.
   2681 
   2682   A `Graph` instance supports an arbitrary number of "collections"
   2683   that are identified by name. For convenience when building a large
   2684   graph, collections can store groups of related objects: for
   2685   example, the `tf.Variable` uses a collection (named
   2686   @{tf.GraphKeys.GLOBAL_VARIABLES}) for
   2687   all variables that are created during the construction of a graph. The caller
   2688   may define additional collections by specifying a new name.
   2689   """
   2690 
   2691   def __init__(self):
   2692     """Creates a new, empty Graph."""
   2693     # Protects the core state that may be accessed by multiple readers.
   2694     # Only state that can be returned via public accessors (`as_graph_def()`,
   2695     # `get_operations()`, `as_graph_element()`, `get_collection()`, and
   2696     # `get_collection_ref()`) is by the lock. Thread-safety is provided on a
   2697     # best-effort basis to support buggy programs, and is not guaranteed by the
   2698     # public `tf.Graph` API.
   2699     # NOTE(mrry): This does not protect the various stacks. A warning will
   2700     # be reported if these are used from multiple threads
   2701     self._lock = threading.Lock()
   2702     self._nodes_by_id = dict()  # GUARDED_BY(self._lock)
   2703     self._next_id_counter = 0  # GUARDED_BY(self._lock)
   2704     self._nodes_by_name = dict()  # GUARDED_BY(self._lock)
   2705     self._version = 0  # GUARDED_BY(self._lock)
   2706     # Current name stack: uniquified names
   2707     self._name_stack = ""
   2708     # Maps a name used in the graph to the next id to use for that name.
   2709     self._names_in_use = {}
   2710     # Functions that will be applied to choose a device if none is specified.
   2711     self._device_function_stack = []
   2712     # Default original_op applied to new ops.
   2713     self._default_original_op = None
   2714     # Current control flow context. It could be either CondContext or
   2715     # WhileContext defined in ops/control_flow_ops.py
   2716     self._control_flow_context = None
   2717     # A new node will depend of the union of all of the nodes in the stack.
   2718     self._control_dependencies_stack = []
   2719     # Arbitrary collections of objects.
   2720     self._collections = {}
   2721     # The graph-level random seed
   2722     self._seed = None
   2723     # A dictionary of attributes that should be applied to all ops.
   2724     self._attr_scope_map = {}
   2725     # A map from op type to the kernel label that should be used.
   2726     self._op_to_kernel_label_map = {}
   2727     # A map from op type to an alternative op type that should be used when
   2728     # computing gradients.
   2729     self._gradient_override_map = {}
   2730     # True if the graph is considered "finalized".  In that case no
   2731     # new operations can be added.
   2732     self._finalized = False
   2733     # Functions defined in the graph
   2734     self._functions = collections.OrderedDict()
   2735     # Default GraphDef versions
   2736     self._graph_def_versions = versions_pb2.VersionDef(
   2737         producer=versions.GRAPH_DEF_VERSION,
   2738         min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)
   2739     self._building_function = False
   2740     # Stack of colocate_with ops
   2741     self._colocation_stack = []
   2742     # Set of tensors that are dangerous to feed!
   2743     self._unfeedable_tensors = set()
   2744     # Set of operations that are dangerous to fetch!
   2745     self._unfetchable_ops = set()
   2746     # A map of tensor handle placeholder to tensor dtype.
   2747     self._handle_feeders = {}
   2748     # A map from tensor handle to its read op.
   2749     self._handle_readers = {}
   2750     # A map from tensor handle to its move op.
   2751     self._handle_movers = {}
   2752     # A map from tensor handle to its delete op.
   2753     self._handle_deleters = {}
   2754     # Allow optimizers and other objects to pseudo-uniquely key graphs (this key
   2755     # will be shared when defining function graphs, for example, so optimizers
   2756     # being called inside function definitions behave as if they were seeing the
   2757     # actual outside graph).
   2758     self._graph_key = "grap-key-%d/" % (uid(),)
   2759     self._container = ""
   2760     self._registered_ops = op_def_registry.get_registered_ops()
   2761 
   2762     # TODO(skyewm): fold as much of the above as possible into the C
   2763     # implementation
   2764     if _USE_C_API or self._use_c_api_hack():
   2765       self._scoped_c_graph = c_api_util.ScopedTFGraph()
   2766     else:
   2767       self._scoped_c_graph = None
   2768     self._variable_creator_stack = []
   2769 
   2770   # TODO(apassos) remove once the C API is used by default.
   2771   def _use_c_api_hack(self):
   2772     """Temporary hack; can be overridden to force C API usage."""
   2773     return False
   2774 
   2775   def _convert_stack(self, stack, include_func_start_lineno=False):
   2776     """Converts a stack extracted using _extract_stack() to a traceback stack.
   2777 
   2778     Args:
   2779       stack: A list of n 5-tuples,
   2780         (filename, lineno, name, frame_globals, func_start_lineno).
   2781       include_func_start_lineno: True if function start line number should be
   2782         included as the 5th entry in return tuples.
   2783 
   2784     Returns:
   2785       A list of n 4-tuples or 5-tuples
   2786       (filename, lineno, name, code, [optional: func_start_lineno]), where the
   2787       code tuple element is calculated from the corresponding elements of the
   2788       input tuple.
   2789     """
   2790     ret = []
   2791     for (filename, lineno, name, frame_globals, func_start_lineno,
   2792          unused_frame_info) in stack:
   2793       linecache.checkcache(filename)
   2794       line = linecache.getline(filename, lineno, frame_globals)
   2795       if line:
   2796         line = line.strip()
   2797       else:
   2798         line = None
   2799       if include_func_start_lineno:
   2800         ret.append((filename, lineno, name, line, func_start_lineno))
   2801       else:
   2802         ret.append((filename, lineno, name, line))
   2803     return ret
   2804 
   2805   # Note: this method is private because the API of tf.Graph() is public and
   2806   # frozen, and this functionality is still not ready for public visibility.
   2807   @tf_contextlib.contextmanager
   2808   def _variable_creator_scope(self, creator):
   2809     old = list(self._variable_creator_stack)
   2810     self._variable_creator_stack.append(creator)
   2811     try:
   2812       yield
   2813     finally:
   2814       self._variable_creator_stack = old
   2815 
   2816   # Note: this method is private because the API of tf.Graph() is public and
   2817   # frozen, and this functionality is still not ready for public visibility.
   2818   def _get_variable_creator_stack(self):
   2819     return list(self._variable_creator_stack)
   2820 
   2821   def _extract_stack(self):
   2822     """A lightweight, extensible re-implementation of traceback.extract_stack.
   2823 
   2824     NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for
   2825       each stack frame using linecache, which results in an abundance of stat()
   2826       calls. This implementation does not retrieve the code, and any consumer
   2827       should apply _convert_stack to the result to obtain a traceback that can
   2828       be formatted etc. using traceback methods.
   2829 
   2830     Derived classes can implement _extract_frame_info() to add extra information
   2831     to the traceback.
   2832 
   2833     Returns:
   2834       A list of 6-tuples
   2835       (filename, lineno, name, frame_globals, func_start_lineno, custom_info)
   2836       corresponding to the call stack of the current thread.
   2837     """
   2838     try:
   2839       raise ZeroDivisionError
   2840     except ZeroDivisionError:
   2841       f = sys.exc_info()[2].tb_frame.f_back
   2842     ret = []
   2843     while f is not None:
   2844       lineno = f.f_lineno
   2845       co = f.f_code
   2846       filename = co.co_filename
   2847       name = co.co_name
   2848       frame_globals = f.f_globals
   2849       func_start_lineno = co.co_firstlineno
   2850       frame_info = self._extract_frame_info(f)
   2851       ret.append((filename, lineno, name, frame_globals, func_start_lineno,
   2852                   frame_info))
   2853       f = f.f_back
   2854     ret.reverse()
   2855     return ret
   2856 
   2857   def _extract_frame_info(self, frame):  # pylint: disable=unused-argument
   2858     """Extracts custom information from a frame in an op traceback."""
   2859     return None
   2860 
   2861   def _check_not_finalized(self):
   2862     """Check if the graph is finalized.
   2863 
   2864     Raises:
   2865       RuntimeError: If the graph finalized.
   2866     """
   2867     if self._finalized:
   2868       raise RuntimeError("Graph is finalized and cannot be modified.")
   2869 
   2870   def _add_op(self, op):
   2871     """Adds 'op' to the graph.
   2872 
   2873     Args:
   2874       op: the Operator or Tensor to add.
   2875 
   2876     Raises:
   2877       TypeError: if op is not an Operation or Tensor.
   2878       ValueError: if the op.name or op._id are already used.
   2879     """
   2880     self._check_not_finalized()
   2881     if not isinstance(op, (Tensor, Operation)):
   2882       raise TypeError("op must be a Tensor or Operation: %s" % op)
   2883     with self._lock:
   2884       # pylint: disable=protected-access
   2885       if op._id in self._nodes_by_id:
   2886         raise ValueError("cannot add an op with id %d as it already "
   2887                          "exists in the graph" % op._id)
   2888       if op.name in self._nodes_by_name:
   2889         raise ValueError("cannot add op with name %s as that name "
   2890                          "is already used" % op.name)
   2891       self._nodes_by_id[op._id] = op
   2892       self._nodes_by_name[op.name] = op
   2893       self._version = max(self._version, op._id)
   2894       # pylint: enable=protected-access
   2895 
   2896   @property
   2897   def _c_graph(self):
   2898     if self._scoped_c_graph:
   2899       return self._scoped_c_graph.graph
   2900     return None
   2901 
   2902   @property
   2903   def version(self):
   2904     """Returns a version number that increases as ops are added to the graph.
   2905 
   2906     Note that this is unrelated to the
   2907     @{tf.Graph.graph_def_versions}.
   2908 
   2909     Returns:
   2910        An integer version that increases as ops are added to the graph.
   2911     """
   2912     if self._finalized:
   2913       return self._version
   2914 
   2915     with self._lock:
   2916       return self._version
   2917 
   2918   @property
   2919   def graph_def_versions(self):
   2920     # pylint: disable=line-too-long
   2921     """The GraphDef version information of this graph.
   2922 
   2923     For details on the meaning of each version, see
   2924     [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto).
   2925 
   2926     Returns:
   2927       A `VersionDef`.
   2928     """
   2929     # pylint: enable=line-too-long
   2930     if self._c_graph:
   2931       with c_api_util.tf_buffer() as buf:
   2932         with errors.raise_exception_on_not_ok_status() as status:
   2933           c_api.TF_GraphVersions(self._c_graph, buf, status)
   2934         data = c_api.TF_GetBuffer(buf)
   2935       version_def = versions_pb2.VersionDef()
   2936       version_def.ParseFromString(compat.as_bytes(data))
   2937       return version_def
   2938     else:
   2939       return self._graph_def_versions
   2940 
   2941   @property
   2942   def seed(self):
   2943     """The graph-level random seed of this graph."""
   2944     return self._seed
   2945 
   2946   @seed.setter
   2947   def seed(self, seed):
   2948     self._seed = seed
   2949 
   2950   @property
   2951   def finalized(self):
   2952     """True if this graph has been finalized."""
   2953     return self._finalized
   2954 
   2955   def finalize(self):
   2956     """Finalizes this graph, making it read-only.
   2957 
   2958     After calling `g.finalize()`, no new operations can be added to
   2959     `g`.  This method is used to ensure that no operations are added
   2960     to a graph when it is shared between multiple threads, for example
   2961     when using a @{tf.train.QueueRunner}.
   2962     """
   2963     self._finalized = True
   2964 
   2965   def _unsafe_unfinalize(self):
   2966     """Opposite of `finalize`. Internal interface.
   2967 
   2968     NOTE: Unfinalizing a graph could have negative impact on performance,
   2969     especially in a multi-threaded environment.  Unfinalizing a graph
   2970     when it is in use by a Session may lead to undefined behavior. Ensure
   2971     that all sessions using a graph are closed before calling this method.
   2972     """
   2973     self._finalized = False
   2974 
   2975   def _get_control_flow_context(self):
   2976     """Returns the current control flow context.
   2977 
   2978     Returns:
   2979       A context object.
   2980     """
   2981     return self._control_flow_context
   2982 
   2983   def _set_control_flow_context(self, ctx):
   2984     """Sets the current control flow context.
   2985 
   2986     Args:
   2987       ctx: a context object.
   2988     """
   2989     self._control_flow_context = ctx
   2990 
   2991   def _copy_functions_to_graph_def(self, graph_def, starting_bytesize):
   2992     """If this graph contains functions, copy them to `graph_def`."""
   2993     bytesize = starting_bytesize
   2994     for f in self._functions.values():
   2995       bytesize += f.definition.ByteSize()
   2996       if bytesize >= (1 << 31) or bytesize < 0:
   2997         raise ValueError("GraphDef cannot be larger than 2GB.")
   2998       graph_def.library.function.extend([f.definition])
   2999       if f.grad_func_name:
   3000         grad_def = function_pb2.GradientDef()
   3001         grad_def.function_name = f.name
   3002         grad_def.gradient_func = f.grad_func_name
   3003         graph_def.library.gradient.extend([grad_def])
   3004 
   3005   def _as_graph_def(self, from_version=None, add_shapes=False):
   3006     # pylint: disable=line-too-long
   3007     """Returns a serialized `GraphDef` representation of this graph.
   3008 
   3009     The serialized `GraphDef` can be imported into another `Graph`
   3010     (using @{tf.import_graph_def}) or used with the
   3011     [C++ Session API](../../../../api_docs/cc/index.md).
   3012 
   3013     This method is thread-safe.
   3014 
   3015     Args:
   3016       from_version: Optional.  If this is set, returns a `GraphDef`
   3017         containing only the nodes that were added to this graph since
   3018         its `version` property had the given value.
   3019       add_shapes: If true, adds an "_output_shapes" list attr to each
   3020         node with the inferred shapes of each of its outputs.
   3021 
   3022     Returns:
   3023       A tuple containing a
   3024       [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
   3025       protocol buffer, and the version of the graph to which that
   3026       `GraphDef` corresponds.
   3027 
   3028     Raises:
   3029       ValueError: If the `graph_def` would be too large.
   3030 
   3031     """
   3032     # pylint: enable=line-too-long
   3033     if _USE_C_API:
   3034       with self._lock:
   3035         with c_api_util.tf_buffer() as buf:
   3036           with errors.raise_exception_on_not_ok_status() as status:
   3037             c_api.TF_GraphToGraphDef(self._c_graph, buf, status)
   3038           data = c_api.TF_GetBuffer(buf)
   3039         graph = graph_pb2.GraphDef()
   3040         graph.ParseFromString(compat.as_bytes(data))
   3041         # Strip the experimental library field iff it's empty.
   3042         if not graph.library.function:
   3043           graph.ClearField("library")
   3044 
   3045         if add_shapes:
   3046           for node in graph.node:
   3047             op = self._nodes_by_name[node.name]
   3048             if op.outputs:
   3049               node.attr["_output_shapes"].list.shape.extend(
   3050                   [output.get_shape().as_proto() for output in op.outputs])
   3051     else:
   3052       with self._lock:
   3053         graph = graph_pb2.GraphDef()
   3054         graph.versions.CopyFrom(self._graph_def_versions)
   3055         bytesize = 0
   3056         for op_id in sorted(self._nodes_by_id):
   3057           op = self._nodes_by_id[op_id]
   3058           if from_version is None or op_id > from_version:
   3059             graph.node.extend([op.node_def])
   3060             if op.outputs and add_shapes:
   3061               assert "_output_shapes" not in graph.node[-1].attr
   3062               graph.node[-1].attr["_output_shapes"].list.shape.extend(
   3063                   [output.get_shape().as_proto() for output in op.outputs])
   3064             bytesize += op.node_def.ByteSize()
   3065             if bytesize >= (1 << 31) or bytesize < 0:
   3066               raise ValueError("GraphDef cannot be larger than 2GB.")
   3067         self._copy_functions_to_graph_def(graph, bytesize)
   3068     return graph, self._version
   3069 
   3070   def as_graph_def(self, from_version=None, add_shapes=False):
   3071     # pylint: disable=line-too-long
   3072     """Returns a serialized `GraphDef` representation of this graph.
   3073 
   3074     The serialized `GraphDef` can be imported into another `Graph`
   3075     (using @{tf.import_graph_def}) or used with the
   3076     [C++ Session API](../../api_docs/cc/index.md).
   3077 
   3078     This method is thread-safe.
   3079 
   3080     Args:
   3081       from_version: Optional.  If this is set, returns a `GraphDef`
   3082         containing only the nodes that were added to this graph since
   3083         its `version` property had the given value.
   3084       add_shapes: If true, adds an "_output_shapes" list attr to each
   3085         node with the inferred shapes of each of its outputs.
   3086 
   3087     Returns:
   3088       A
   3089       [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
   3090       protocol buffer.
   3091 
   3092     Raises:
   3093       ValueError: If the `graph_def` would be too large.
   3094     """
   3095     # pylint: enable=line-too-long
   3096     result, _ = self._as_graph_def(from_version, add_shapes)
   3097     return result
   3098 
   3099   def _is_function(self, name):
   3100     """Tests whether 'name' is registered in this graph's function library.
   3101 
   3102     Args:
   3103       name: string op name.
   3104     Returns:
   3105       bool indicating whether or not 'name' is registered in function library.
   3106     """
   3107     return name in self._functions
   3108 
   3109   def _get_function(self, name):
   3110     """Returns the function definition for 'name'.
   3111 
   3112     Args:
   3113       name: string function name.
   3114     Returns:
   3115       The function def proto.
   3116     """
   3117     return self._functions.get(name, None)
   3118 
   3119   def _add_function(self, function):
   3120     """Adds a function to the graph.
   3121 
   3122     After the function has been added, you can call to the function by
   3123     passing the function name in place of an op name to
   3124     `Graph.create_op()`.
   3125 
   3126     Args:
   3127       function: A `_DefinedFunction` object.
   3128 
   3129 
   3130     Raises:
   3131       ValueError: if another function is defined with the same name.
   3132     """
   3133     name = function.name
   3134     # Sanity checks on gradient definition.
   3135     if (function.grad_func_name is not None) and (function.python_grad_func is
   3136                                                   not None):
   3137       raise ValueError("Gradient defined twice for function %s" % name)
   3138 
   3139     # Add function to graph
   3140     # pylint: disable=protected-access
   3141     if self._c_graph:
   3142       # Handle functions created without using the C API. TODO(apassos,skyewm)
   3143       # remove this when all functions are generated using the C API by default
   3144       # as this will be unnecessary.
   3145       if not function._c_func:
   3146         with errors.raise_exception_on_not_ok_status() as status:
   3147           serialized = function.definition.SerializeToString()
   3148           function._c_func = c_api.TF_FunctionImportFunctionDef(
   3149               serialized, status)
   3150       with errors.raise_exception_on_not_ok_status() as status:
   3151         gradient = function._grad_func._c_func if function._grad_func else None
   3152         c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient,
   3153                                    status)
   3154     else:
   3155       # If there is already a function with the same name, raise an error
   3156       # if bodies are different. Else, do nothing. The C API version above
   3157       # has the same behavior.
   3158       previous = self._functions.get(name, None)
   3159       if previous:
   3160         # This check is not ideal as we can have a hash collision with only
   3161         # 32 bits in the hash, but the non C API mode is being deprecated.
   3162         # Don't bother changing it now.
   3163         if previous._hash_str == function._hash_str:
   3164           return
   3165         else:
   3166           raise ValueError("Cannot add function (%s, hash %s) to graph (%s). "
   3167                            "Another function (%s, hash %s) is already defined "
   3168                            "with that name (%s)" % (
   3169                                function, function._hash_str, self,
   3170                                previous, previous._hash_str, name))
   3171     # pylint: enable=protected-access
   3172 
   3173     self._functions[name] = function
   3174 
   3175     # Need a new-enough consumer to support the functions we add to the graph.
   3176     if self._graph_def_versions.min_consumer < 12:
   3177       self._graph_def_versions.min_consumer = 12
   3178 
   3179   @property
   3180   def building_function(self):
   3181     """Returns True iff this graph represents a function."""
   3182     return self._building_function
   3183 
   3184   # Helper functions to create operations.
   3185   def create_op(
   3186       self,
   3187       op_type,
   3188       inputs,
   3189       dtypes,  # pylint: disable=redefined-outer-name
   3190       input_types=None,
   3191       name=None,
   3192       attrs=None,
   3193       op_def=None,
   3194       compute_shapes=True,
   3195       compute_device=True):
   3196     """Creates an `Operation` in this graph.
   3197 
   3198     This is a low-level interface for creating an `Operation`. Most
   3199     programs will not call this method directly, and instead use the
   3200     Python op constructors, such as `tf.constant()`, which add ops to
   3201     the default graph.
   3202 
   3203     Args:
   3204       op_type: The `Operation` type to create. This corresponds to the
   3205         `OpDef.name` field for the proto that defines the operation.
   3206       inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
   3207       dtypes: A list of `DType` objects that will be the types of the tensors
   3208         that the operation produces.
   3209       input_types: (Optional.) A list of `DType`s that will be the types of
   3210         the tensors that the operation consumes. By default, uses the base
   3211         `DType` of each input in `inputs`. Operations that expect
   3212         reference-typed inputs must specify `input_types` explicitly.
   3213       name: (Optional.) A string name for the operation. If not specified, a
   3214         name is generated based on `op_type`.
   3215       attrs: (Optional.) A dictionary where the key is the attribute name (a
   3216         string) and the value is the respective `attr` attribute of the
   3217         `NodeDef` proto that will represent the operation (an `AttrValue`
   3218         proto).
   3219       op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
   3220         the operation will have.
   3221       compute_shapes: (Optional.) If True, shape inference will be performed
   3222         to compute the shapes of the outputs.
   3223       compute_device: (Optional.) If True, device functions will be executed
   3224         to compute the device property of the Operation.
   3225 
   3226     Raises:
   3227       TypeError: if any of the inputs is not a `Tensor`.
   3228       ValueError: if colocation conflicts with existing device assignment.
   3229 
   3230     Returns:
   3231       An `Operation` object.
   3232 
   3233     """
   3234     self._check_not_finalized()
   3235     for idx, a in enumerate(inputs):
   3236       if not isinstance(a, Tensor):
   3237         raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
   3238     if name is None:
   3239       name = op_type
   3240     # If a names ends with a '/' it is a "name scope" and we use it as-is,
   3241     # after removing the trailing '/'.
   3242     if name and name[-1] == "/":
   3243       name = _name_from_scope_name(name)
   3244     else:
   3245       name = self.unique_name(name)
   3246 
   3247     node_def = _NodeDef(op_type, name, device=None, attrs=attrs)
   3248 
   3249     input_ops = set([t.op for t in inputs])
   3250     control_inputs = self._control_dependencies_for_inputs(input_ops)
   3251     ret = Operation(
   3252         node_def,
   3253         self,
   3254         inputs=inputs,
   3255         output_types=dtypes,
   3256         control_inputs=control_inputs,
   3257         input_types=input_types,
   3258         original_op=self._default_original_op,
   3259         op_def=op_def)
   3260     self._create_op_helper(ret, compute_shapes=compute_shapes,
   3261                            compute_device=compute_device)
   3262     return ret
   3263 
   3264   def _create_op_from_tf_operation(self, c_op, compute_device=True):
   3265     """Creates an `Operation` in this graph from the supplied TF_Operation.
   3266 
   3267     This method is like create_op() except the new Operation is constructed
   3268     using `c_op`. The returned Operation will have `c_op` as its _c_op
   3269     field. This is used to create Operation objects around TF_Operations created
   3270     indirectly by the C API (e.g. by TF_ImportGraphDef, TF_FinishWhile).
   3271 
   3272     This function does not call Operation._control_flow_post_processing or
   3273     Graph._control_dependencies_for_inputs (since the inputs may not be
   3274     available yet). The caller is responsible for calling these methods.
   3275 
   3276     Args:
   3277       c_op: a wrapped TF_Operation
   3278       compute_device: (Optional.) If True, device functions will be executed
   3279         to compute the device property of the Operation.
   3280 
   3281     Returns:
   3282       An `Operation` object.
   3283     """
   3284     self._check_not_finalized()
   3285     ret = Operation(c_op, self)
   3286     assert ret.name not in self._names_in_use
   3287     self._names_in_use[ret.name] = 1
   3288     self._create_op_helper(ret, compute_device=compute_device)
   3289     return ret
   3290 
   3291   def _create_op_helper(self, op, compute_shapes=True, compute_device=True):
   3292     """Common logic for creating an op in this graph."""
   3293     # TODO(vrv): Instead of eagerly filling in shape property for every op, only
   3294     # populate the shape when requested.
   3295     #
   3296     # TODO(skyewm): unlike in the original Python implementation, the C API
   3297     # always computes shape information (even for function calls, which the
   3298     # original Python shape inference code doesn't handle). Deprecate the
   3299     # compute_shapes argument.
   3300     if op._c_op or compute_shapes:  # pylint: disable=protected-access
   3301       set_shapes_for_outputs(op)
   3302     # TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed.
   3303     self._add_op(op)
   3304 
   3305     # Apply any additional attributes requested. Do not overwrite any existing
   3306     # attributes.
   3307     for key, value in self._attr_scope_map.items():
   3308       try:
   3309         op.get_attr(key)
   3310       except ValueError:
   3311         if callable(value):
   3312           value = value(op.node_def)
   3313           if not isinstance(value, (type(None), attr_value_pb2.AttrValue)):
   3314             raise TypeError(
   3315                 "Callable for scope map key '%s' must return either None or "
   3316                 "an AttrValue protocol buffer; but it returned: %s" % (key,
   3317                                                                        value))
   3318         if value:
   3319           op._set_attr(key, value)  # pylint: disable=protected-access
   3320 
   3321     # Apply a kernel label if one has been specified for this op type.
   3322     try:
   3323       kernel_label = self._op_to_kernel_label_map[op.type]
   3324       op._set_attr("_kernel",  # pylint: disable=protected-access
   3325                    attr_value_pb2.AttrValue(s=compat.as_bytes(kernel_label)))
   3326     except KeyError:
   3327       pass
   3328 
   3329     # Apply the overriding op type for gradients if one has been specified for
   3330     # this op type.
   3331     try:
   3332       mapped_op_type = self._gradient_override_map[op.type]
   3333       op._set_attr("_gradient_op_type",  # pylint: disable=protected-access
   3334                    attr_value_pb2.AttrValue(s=compat.as_bytes(mapped_op_type)))
   3335     except KeyError:
   3336       pass
   3337 
   3338     self._record_op_seen_by_control_dependencies(op)
   3339 
   3340     if compute_device:
   3341       self._apply_device_functions(op)
   3342 
   3343     if self._colocation_stack:
   3344       all_colocation_groups = []
   3345       for colocation_op in self._colocation_stack:
   3346         all_colocation_groups.extend(colocation_op.colocation_groups())
   3347         if colocation_op.device:
   3348           # Make this device match the device of the colocated op, to provide
   3349           # consistency between the device and the colocation property.
   3350           if (op.device and pydev.canonical_name(op.device) !=
   3351               pydev.canonical_name(colocation_op.device)):
   3352             logging.warning("Tried to colocate %s with an op %s that had "
   3353                             "a different device: %s vs %s. "
   3354                             "Ignoring colocation property.", op.name,
   3355                             colocation_op.name, op.device,
   3356                             colocation_op.device)
   3357           else:
   3358             op._set_device(colocation_op.device)  # pylint: disable=protected-access
   3359 
   3360       all_colocation_groups = sorted(set(all_colocation_groups))
   3361       # pylint: disable=protected-access
   3362       op._set_attr("_class", attr_value_pb2.AttrValue(
   3363           list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
   3364       # pylint: enable=protected-access
   3365 
   3366     # Sets "container" attribute if
   3367     # (1) self._container is not None
   3368     # (2) "is_stateful" is set in OpDef
   3369     # (3) "container" attribute is in OpDef
   3370     # (4) "container" attribute is None
   3371     # TODO(skyewm): remove op.op_def check when _USE_C_API is removed.
   3372     if self._container and op.op_def and op.op_def.is_stateful:
   3373       try:
   3374         container_attr = op.get_attr("container")
   3375       except ValueError:
   3376         # "container" attribute is not in OpDef
   3377         pass
   3378       else:
   3379         if not container_attr:
   3380           op._set_attr("container", attr_value_pb2.AttrValue(  # pylint: disable=protected-access
   3381               s=compat.as_bytes(self._container)))
   3382 
   3383   def _add_new_tf_operations(self, compute_devices=True):
   3384     """Creates `Operations` in this graph for any new TF_Operations.
   3385 
   3386     This is useful for when TF_Operations are indirectly created by the C API
   3387     outside of the Operation constructor (e.g. by TF_ImportGraphDef,
   3388     TF_FinishWhile). This ensures there are corresponding Operations for all
   3389     TF_Operations in the underlying TF_Graph.
   3390 
   3391     Args:
   3392       compute_devices: (Optional.) If True, device functions will be executed
   3393         to compute the device properties of each new Operation.
   3394 
   3395     Returns:
   3396       A list of the new `Operation` objects.
   3397     """
   3398     # Create all Operation objects before accessing their inputs since an op may
   3399     # be created before its inputs.
   3400     new_ops = [
   3401         self._create_op_from_tf_operation(c_op, compute_device=compute_devices)
   3402         for c_op in c_api_util.new_tf_operations(self)
   3403     ]
   3404 
   3405     for op in new_ops:
   3406       new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
   3407       # pylint: disable=protected-access
   3408       op._add_control_inputs(new_control_inputs)
   3409       op._control_flow_post_processing()
   3410       # pylint: enable=protected-access
   3411 
   3412     return new_ops
   3413 
   3414   def as_graph_element(self, obj, allow_tensor=True, allow_operation=True):
   3415     """Returns the object referred to by `obj`, as an `Operation` or `Tensor`.
   3416 
   3417     This function validates that `obj` represents an element of this
   3418     graph, and gives an informative error message if it is not.
   3419 
   3420     This function is the canonical way to get/validate an object of
   3421     one of the allowed types from an external argument reference in the
   3422     Session API.
   3423 
   3424     This method may be called concurrently from multiple threads.
   3425 
   3426     Args:
   3427       obj: A `Tensor`, an `Operation`, or the name of a tensor or operation.
   3428         Can also be any object with an `_as_graph_element()` method that returns
   3429         a value of one of these types.
   3430       allow_tensor: If true, `obj` may refer to a `Tensor`.
   3431       allow_operation: If true, `obj` may refer to an `Operation`.
   3432 
   3433     Returns:
   3434       The `Tensor` or `Operation` in the Graph corresponding to `obj`.
   3435 
   3436     Raises:
   3437       TypeError: If `obj` is not a type we support attempting to convert
   3438         to types.
   3439       ValueError: If `obj` is of an appropriate type but invalid. For
   3440         example, an invalid string.
   3441       KeyError: If `obj` is not an object in the graph.
   3442     """
   3443     if self._finalized:
   3444       return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
   3445 
   3446     with self._lock:
   3447       return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
   3448 
   3449   def _as_graph_element_locked(self, obj, allow_tensor, allow_operation):
   3450     """See `Graph.as_graph_element()` for details."""
   3451     # The vast majority of this function is figuring
   3452     # out what an API user might be doing wrong, so
   3453     # that we can give helpful error messages.
   3454     #
   3455     # Ideally, it would be nice to split it up, but we
   3456     # need context to generate nice error messages.
   3457 
   3458     if allow_tensor and allow_operation:
   3459       types_str = "Tensor or Operation"
   3460     elif allow_tensor:
   3461       types_str = "Tensor"
   3462     elif allow_operation:
   3463       types_str = "Operation"
   3464     else:
   3465       raise ValueError("allow_tensor and allow_operation can't both be False.")
   3466 
   3467     temp_obj = _as_graph_element(obj)
   3468     if temp_obj is not None:
   3469       obj = temp_obj
   3470 
   3471     # If obj appears to be a name...
   3472     if isinstance(obj, compat.bytes_or_text_types):
   3473       name = compat.as_str(obj)
   3474 
   3475       if ":" in name and allow_tensor:
   3476         # Looks like a Tensor name and can be a Tensor.
   3477         try:
   3478           op_name, out_n = name.split(":")
   3479           out_n = int(out_n)
   3480         except:
   3481           raise ValueError("The name %s looks a like a Tensor name, but is "
   3482                            "not a valid one. Tensor names must be of the "
   3483                            "form \"<op_name>:<output_index>\"." % repr(name))
   3484         if op_name in self._nodes_by_name:
   3485           op = self._nodes_by_name[op_name]
   3486         else:
   3487           raise KeyError("The name %s refers to a Tensor which does not "
   3488                          "exist. The operation, %s, does not exist in the "
   3489                          "graph." % (repr(name), repr(op_name)))
   3490         try:
   3491           return op.outputs[out_n]
   3492         except:
   3493           raise KeyError("The name %s refers to a Tensor which does not "
   3494                          "exist. The operation, %s, exists but only has "
   3495                          "%s outputs." % (repr(name), repr(op_name),
   3496                                           len(op.outputs)))
   3497 
   3498       elif ":" in name and not allow_tensor:
   3499         # Looks like a Tensor name but can't be a Tensor.
   3500         raise ValueError("Name %s appears to refer to a Tensor, not a %s." %
   3501                          (repr(name), types_str))
   3502 
   3503       elif ":" not in name and allow_operation:
   3504         # Looks like an Operation name and can be an Operation.
   3505         if name not in self._nodes_by_name:
   3506           raise KeyError("The name %s refers to an Operation not in the "
   3507                          "graph." % repr(name))
   3508         return self._nodes_by_name[name]
   3509 
   3510       elif ":" not in name and not allow_operation:
   3511         # Looks like an Operation name but can't be an Operation.
   3512         if name in self._nodes_by_name:
   3513           # Yep, it's an Operation name
   3514           err_msg = ("The name %s refers to an Operation, not a %s." %
   3515                      (repr(name), types_str))
   3516         else:
   3517           err_msg = ("The name %s looks like an (invalid) Operation name, "
   3518                      "not a %s." % (repr(name), types_str))
   3519         err_msg += (" Tensor names must be of the form "
   3520                     "\"<op_name>:<output_index>\".")
   3521         raise ValueError(err_msg)
   3522 
   3523     elif isinstance(obj, Tensor) and allow_tensor:
   3524       # Actually obj is just the object it's referring to.
   3525       if obj.graph is not self:
   3526         raise ValueError("Tensor %s is not an element of this graph." % obj)
   3527       return obj
   3528     elif isinstance(obj, Operation) and allow_operation:
   3529       # Actually obj is just the object it's referring to.
   3530       if obj.graph is not self:
   3531         raise ValueError("Operation %s is not an element of this graph." % obj)
   3532       return obj
   3533     else:
   3534       # We give up!
   3535       raise TypeError("Can not convert a %s into a %s." % (type(obj).__name__,
   3536                                                            types_str))
   3537 
   3538   def get_operations(self):
   3539     """Return the list of operations in the graph.
   3540 
   3541     You can modify the operations in place, but modifications
   3542     to the list such as inserts/delete have no effect on the
   3543     list of operations known to the graph.
   3544 
   3545     This method may be called concurrently from multiple threads.
   3546 
   3547     Returns:
   3548       A list of Operations.
   3549     """
   3550     if self._finalized:
   3551       return list(self._nodes_by_id.values())
   3552 
   3553     with self._lock:
   3554       return list(self._nodes_by_id.values())
   3555 
   3556   def get_operation_by_name(self, name):
   3557     """Returns the `Operation` with the given `name`.
   3558 
   3559     This method may be called concurrently from multiple threads.
   3560 
   3561     Args:
   3562       name: The name of the `Operation` to return.
   3563 
   3564     Returns:
   3565       The `Operation` with the given `name`.
   3566 
   3567     Raises:
   3568       TypeError: If `name` is not a string.
   3569       KeyError: If `name` does not correspond to an operation in this graph.
   3570     """
   3571 
   3572     if not isinstance(name, six.string_types):
   3573       raise TypeError("Operation names are strings (or similar), not %s." %
   3574                       type(name).__name__)
   3575     return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
   3576 
   3577   def _get_operation_by_name_unsafe(self, name):
   3578     """Returns the `Operation` with the given `name`.
   3579 
   3580     This is a internal unsafe version of get_operation_by_name. It skips many
   3581     checks and does not have user friedly error messages but runs considerably
   3582     faster. This method may be called concurrently from multiple threads.
   3583 
   3584     Args:
   3585       name: The name of the `Operation` to return.
   3586 
   3587     Returns:
   3588       The `Operation` with the given `name`.
   3589 
   3590     Raises:
   3591       KeyError: If `name` does not correspond to an operation in this graph.
   3592     """
   3593 
   3594     if self._finalized:
   3595       return self._nodes_by_name[name]
   3596 
   3597     with self._lock:
   3598       return self._nodes_by_name[name]
   3599 
   3600   def _get_operation_by_tf_operation(self, tf_oper):
   3601     op_name = c_api.TF_OperationName(tf_oper)
   3602     return self._get_operation_by_name_unsafe(op_name)
   3603 
   3604   def get_tensor_by_name(self, name):
   3605     """Returns the `Tensor` with the given `name`.
   3606 
   3607     This method may be called concurrently from multiple threads.
   3608 
   3609     Args:
   3610       name: The name of the `Tensor` to return.
   3611 
   3612     Returns:
   3613       The `Tensor` with the given `name`.
   3614 
   3615     Raises:
   3616       TypeError: If `name` is not a string.
   3617       KeyError: If `name` does not correspond to a tensor in this graph.
   3618     """
   3619     # Names should be strings.
   3620     if not isinstance(name, six.string_types):
   3621       raise TypeError("Tensor names are strings (or similar), not %s." %
   3622                       type(name).__name__)
   3623     return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
   3624 
   3625   def _get_tensor_by_tf_output(self, tf_output):
   3626     """Returns the `Tensor` representing `tf_output`.
   3627 
   3628     Note that there is only one such `Tensor`, i.e. multiple calls to this
   3629     function with the same TF_Output value will always return the same `Tensor`
   3630     object.
   3631 
   3632     Args:
   3633       tf_output: A wrapped `TF_Output` (the C API equivalent of `Tensor`).
   3634 
   3635     Returns:
   3636       The `Tensor` that represents `tf_output`.
   3637     """
   3638     op = self._get_operation_by_tf_operation(tf_output.oper)
   3639     return op.outputs[tf_output.index]
   3640 
   3641   def _next_id(self):
   3642     """Id for next Operation instance. Also increments the internal id."""
   3643     self._check_not_finalized()
   3644     with self._lock:
   3645       self._next_id_counter += 1
   3646       return self._next_id_counter
   3647 
   3648   @property
   3649   def _last_id(self):
   3650     return self._next_id_counter
   3651 
   3652   def _get_op_def(self, type):  # pylint: disable=redefined-builtin
   3653     """Returns the `OpDef` proto for `type`. `type` is a string."""
   3654     if self._c_graph:
   3655       with c_api_util.tf_buffer() as buf:
   3656         with errors.raise_exception_on_not_ok_status() as status:
   3657           # pylint: disable=protected-access
   3658           c_api.TF_GraphGetOpDef(self._c_graph,
   3659                                  compat.as_bytes(type), buf, status)
   3660           # pylint: enable=protected-access
   3661         data = c_api.TF_GetBuffer(buf)
   3662       op_def = op_def_pb2.OpDef()
   3663       op_def.ParseFromString(compat.as_bytes(data))
   3664       return op_def
   3665     else:
   3666       return self._registered_ops[type]
   3667 
   3668   def as_default(self):
   3669     """Returns a context manager that makes this `Graph` the default graph.
   3670 
   3671     This method should be used if you want to create multiple graphs
   3672     in the same process. For convenience, a global default graph is
   3673     provided, and all ops will be added to this graph if you do not
   3674     create a new graph explicitly. Use this method with the `with` keyword
   3675     to specify that ops created within the scope of a block should be
   3676     added to this graph.
   3677 
   3678     The default graph is a property of the current thread. If you
   3679     create a new thread, and wish to use the default graph in that
   3680     thread, you must explicitly add a `with g.as_default():` in that
   3681     thread's function.
   3682 
   3683     The following code examples are equivalent:
   3684 
   3685     ```python
   3686     # 1. Using Graph.as_default():
   3687     g = tf.Graph()
   3688     with g.as_default():
   3689       c = tf.constant(5.0)
   3690       assert c.graph is g
   3691 
   3692     # 2. Constructing and making default:
   3693     with tf.Graph().as_default() as g:
   3694       c = tf.constant(5.0)
   3695       assert c.graph is g
   3696     ```
   3697 
   3698     Returns:
   3699       A context manager for using this graph as the default graph.
   3700     """
   3701     return _default_graph_stack.get_controller(self)
   3702 
   3703   @property
   3704   def collections(self):
   3705     """Returns the names of the collections known to this graph."""
   3706     return list(self._collections)
   3707 
   3708   def add_to_collection(self, name, value):
   3709     """Stores `value` in the collection with the given `name`.
   3710 
   3711     Note that collections are not sets, so it is possible to add a value to
   3712     a collection several times.
   3713 
   3714     Args:
   3715       name: The key for the collection. The `GraphKeys` class
   3716         contains many standard names for collections.
   3717       value: The value to add to the collection.
   3718     """  # pylint: disable=g-doc-exception
   3719     _assert_collection_is_ok(name)
   3720     self._check_not_finalized()
   3721     with self._lock:
   3722       if name not in self._collections:
   3723         self._collections[name] = [value]
   3724       else:
   3725         self._collections[name].append(value)
   3726 
   3727   def add_to_collections(self, names, value):
   3728     """Stores `value` in the collections given by `names`.
   3729 
   3730     Note that collections are not sets, so it is possible to add a value to
   3731     a collection several times. This function makes sure that duplicates in
   3732     `names` are ignored, but it will not check for pre-existing membership of
   3733     `value` in any of the collections in `names`.
   3734 
   3735     `names` can be any iterable, but if `names` is a string, it is treated as a
   3736     single collection name.
   3737 
   3738     Args:
   3739       names: The keys for the collections to add to. The `GraphKeys` class
   3740         contains many standard names for collections.
   3741       value: The value to add to the collections.
   3742     """
   3743     # Make sure names are unique, but treat strings as a single collection name
   3744     names = (names,) if isinstance(names, six.string_types) else set(names)
   3745     for name in names:
   3746       self.add_to_collection(name, value)
   3747 
   3748   def get_collection_ref(self, name):
   3749     """Returns a list of values in the collection with the given `name`.
   3750 
   3751     If the collection exists, this returns the list itself, which can
   3752     be modified in place to change the collection.  If the collection does
   3753     not exist, it is created as an empty list and the list is returned.
   3754 
   3755     This is different from `get_collection()` which always returns a copy of
   3756     the collection list if it exists and never creates an empty collection.
   3757 
   3758     Args:
   3759       name: The key for the collection. For example, the `GraphKeys` class
   3760         contains many standard names for collections.
   3761 
   3762     Returns:
   3763       The list of values in the collection with the given `name`, or an empty
   3764       list if no value has been added to that collection.
   3765     """  # pylint: disable=g-doc-exception
   3766     _assert_collection_is_ok(name)
   3767     with self._lock:
   3768       coll_list = self._collections.get(name, None)
   3769       if coll_list is None:
   3770         coll_list = []
   3771         self._collections[name] = coll_list
   3772       return coll_list
   3773 
   3774   def get_collection(self, name, scope=None):
   3775     """Returns a list of values in the collection with the given `name`.
   3776 
   3777     This is different from `get_collection_ref()` which always returns the
   3778     actual collection list if it exists in that it returns a new list each time
   3779     it is called.
   3780 
   3781     Args:
   3782       name: The key for the collection. For example, the `GraphKeys` class
   3783         contains many standard names for collections.
   3784       scope: (Optional.) A string. If supplied, the resulting list is filtered
   3785         to include only items whose `name` attribute matches `scope` using
   3786         `re.match`. Items without a `name` attribute are never returned if a
   3787         scope is supplied. The choice of `re.match` means that a `scope` without
   3788         special tokens filters by prefix.
   3789 
   3790     Returns:
   3791       The list of values in the collection with the given `name`, or
   3792       an empty list if no value has been added to that collection. The
   3793       list contains the values in the order under which they were
   3794       collected.
   3795     """  # pylint: disable=g-doc-exception
   3796     _assert_collection_is_ok(name)
   3797     with self._lock:
   3798       collection = self._collections.get(name, None)
   3799       if collection is None:
   3800         return []
   3801       if scope is None:
   3802         return list(collection)
   3803       else:
   3804         c = []
   3805         regex = re.compile(scope)
   3806         for item in collection:
   3807           if hasattr(item, "name") and regex.match(item.name):
   3808             c.append(item)
   3809         return c
   3810 
   3811   def get_all_collection_keys(self):
   3812     """Returns a list of collections used in this graph."""
   3813     with self._lock:
   3814       return [x for x in self._collections if isinstance(x, six.string_types)]
   3815 
   3816   def clear_collection(self, name):
   3817     """Clears all values in a collection.
   3818 
   3819     Args:
   3820       name: The key for the collection. The `GraphKeys` class contains many
   3821         standard names for collections.
   3822     """
   3823     self._check_not_finalized()
   3824     with self._lock:
   3825       if name in self._collections:
   3826         del self._collections[name]
   3827 
   3828   @tf_contextlib.contextmanager
   3829   def _original_op(self, op):
   3830     """Python 'with' handler to help annotate ops with their originator.
   3831 
   3832     An op may have an 'original_op' property that indicates the op on which
   3833     it was based. For example a replica op is based on the op that was
   3834     replicated and a gradient op is based on the op that was differentiated.
   3835 
   3836     All ops created in the scope of this 'with' handler will have
   3837     the given 'op' as their original op.
   3838 
   3839     Args:
   3840       op: The Operation that all ops created in this scope will have as their
   3841         original op.
   3842 
   3843     Yields:
   3844       Nothing.
   3845     """
   3846     old_original_op = self._default_original_op
   3847     try:
   3848       self._default_original_op = op
   3849       yield
   3850     finally:
   3851       self._default_original_op = old_original_op
   3852 
   3853   # pylint: disable=g-doc-return-or-yield,line-too-long
   3854   @tf_contextlib.contextmanager
   3855   def name_scope(self, name):
   3856     r"""Returns a context manager that creates hierarchical names for operations.
   3857 
   3858     A graph maintains a stack of name scopes. A `with name_scope(...):`
   3859     statement pushes a new name onto the stack for the lifetime of the context.
   3860 
   3861     The `name` argument will be interpreted as follows:
   3862 
   3863     * A string (not ending with '/') will create a new name scope, in which
   3864       `name` is appended to the prefix of all operations created in the
   3865       context. If `name` has been used before, it will be made unique by
   3866       calling `self.unique_name(name)`.
   3867     * A scope previously captured from a `with g.name_scope(...) as
   3868       scope:` statement will be treated as an "absolute" name scope, which
   3869       makes it possible to re-enter existing scopes.
   3870     * A value of `None` or the empty string will reset the current name scope
   3871       to the top-level (empty) name scope.
   3872 
   3873     For example:
   3874 
   3875     ```python
   3876     with tf.Graph().as_default() as g:
   3877       c = tf.constant(5.0, name="c")
   3878       assert c.op.name == "c"
   3879       c_1 = tf.constant(6.0, name="c")
   3880       assert c_1.op.name == "c_1"
   3881 
   3882       # Creates a scope called "nested"
   3883       with g.name_scope("nested") as scope:
   3884         nested_c = tf.constant(10.0, name="c")
   3885         assert nested_c.op.name == "nested/c"
   3886 
   3887         # Creates a nested scope called "inner".
   3888         with g.name_scope("inner"):
   3889           nested_inner_c = tf.constant(20.0, name="c")
   3890           assert nested_inner_c.op.name == "nested/inner/c"
   3891 
   3892         # Create a nested scope called "inner_1".
   3893         with g.name_scope("inner"):
   3894           nested_inner_1_c = tf.constant(30.0, name="c")
   3895           assert nested_inner_1_c.op.name == "nested/inner_1/c"
   3896 
   3897           # Treats `scope` as an absolute name scope, and
   3898           # switches to the "nested/" scope.
   3899           with g.name_scope(scope):
   3900             nested_d = tf.constant(40.0, name="d")
   3901             assert nested_d.op.name == "nested/d"
   3902 
   3903             with g.name_scope(""):
   3904               e = tf.constant(50.0, name="e")
   3905               assert e.op.name == "e"
   3906     ```
   3907 
   3908     The name of the scope itself can be captured by `with
   3909     g.name_scope(...) as scope:`, which stores the name of the scope
   3910     in the variable `scope`. This value can be used to name an
   3911     operation that represents the overall result of executing the ops
   3912     in a scope. For example:
   3913 
   3914     ```python
   3915     inputs = tf.constant(...)
   3916     with g.name_scope('my_layer') as scope:
   3917       weights = tf.Variable(..., name="weights")
   3918       biases = tf.Variable(..., name="biases")
   3919       affine = tf.matmul(inputs, weights) + biases
   3920       output = tf.nn.relu(affine, name=scope)
   3921     ```
   3922 
   3923     NOTE: This constructor validates the given `name`. Valid scope
   3924     names match one of the following regular expressions:
   3925 
   3926         [A-Za-z0-9.][A-Za-z0-9_.\\-/]* (for scopes at the root)
   3927         [A-Za-z0-9_.\\-/]* (for other scopes)
   3928 
   3929     Args:
   3930       name: A name for the scope.
   3931 
   3932     Returns:
   3933       A context manager that installs `name` as a new name scope.
   3934 
   3935     Raises:
   3936       ValueError: If `name` is not a valid scope name, according to the rules
   3937         above.
   3938     """
   3939     if name:
   3940       if isinstance(name, compat.bytes_or_text_types):
   3941         name = compat.as_str(name)
   3942 
   3943       if self._name_stack:
   3944         # Scopes created in a nested scope may have initial characters
   3945         # that are illegal as the initial character of an op name
   3946         # (viz. '-', '\', '/', and '_').
   3947         if not _VALID_SCOPE_NAME_REGEX.match(name):
   3948           raise ValueError("'%s' is not a valid scope name" % name)
   3949       else:
   3950         # Scopes created in the root must match the more restrictive
   3951         # op name regex, which constrains the initial character.
   3952         if not _VALID_OP_NAME_REGEX.match(name):
   3953           raise ValueError("'%s' is not a valid scope name" % name)
   3954     try:
   3955       old_stack = self._name_stack
   3956       if not name:  # Both for name=None and name="" we re-set to empty scope.
   3957         new_stack = None
   3958       elif name[-1] == "/":
   3959         new_stack = _name_from_scope_name(name)
   3960       else:
   3961         new_stack = self.unique_name(name)
   3962       self._name_stack = new_stack
   3963       yield "" if new_stack is None else new_stack + "/"
   3964     finally:
   3965       self._name_stack = old_stack
   3966 
   3967   # pylint: enable=g-doc-return-or-yield,line-too-long
   3968 
   3969   def unique_name(self, name, mark_as_used=True):
   3970     """Return a unique operation name for `name`.
   3971 
   3972     Note: You rarely need to call `unique_name()` directly.  Most of
   3973     the time you just need to create `with g.name_scope()` blocks to
   3974     generate structured names.
   3975 
   3976     `unique_name` is used to generate structured names, separated by
   3977     `"/"`, to help identify operations when debugging a graph.
   3978     Operation names are displayed in error messages reported by the
   3979     TensorFlow runtime, and in various visualization tools such as
   3980     TensorBoard.
   3981 
   3982     If `mark_as_used` is set to `True`, which is the default, a new
   3983     unique name is created and marked as in use. If it's set to `False`,
   3984     the unique name is returned without actually being marked as used.
   3985     This is useful when the caller simply wants to know what the name
   3986     to be created will be.
   3987 
   3988     Args:
   3989       name: The name for an operation.
   3990       mark_as_used: Whether to mark this name as being used.
   3991 
   3992     Returns:
   3993       A string to be passed to `create_op()` that will be used
   3994       to name the operation being created.
   3995     """
   3996     if self._name_stack:
   3997       name = self._name_stack + "/" + name
   3998     i = self._names_in_use.get(name, 0)
   3999     # Increment the number for "name".
   4000     if mark_as_used:
   4001       self._names_in_use[name] = i + 1
   4002     if i > 0:
   4003       base_name = name
   4004       # Make sure the composed name is not already used.
   4005       while name in self._names_in_use:
   4006         name = "%s_%d" % (base_name, i)
   4007         i += 1
   4008       # Mark the composed name as used in case someone wants
   4009       # to call unique_name("name_1").
   4010       if mark_as_used:
   4011         self._names_in_use[name] = 1
   4012     return name
   4013 
   4014   def get_name_scope(self):
   4015     """Returns the current name scope.
   4016 
   4017     For example:
   4018 
   4019     ```python
   4020     with tf.name_scope('scope1'):
   4021       with tf.name_scope('scope2'):
   4022         print(tf.get_default_graph().get_name_scope())
   4023     ```
   4024     would print the string `scope1/scope2`.
   4025 
   4026     Returns:
   4027       A string representing the current name scope.
   4028     """
   4029     return self._name_stack
   4030 
   4031   @tf_contextlib.contextmanager
   4032   def colocate_with(self, op, ignore_existing=False):
   4033     """Returns a context manager that specifies an op to colocate with.
   4034 
   4035     Note: this function is not for public use, only for internal libraries.
   4036 
   4037     For example:
   4038 
   4039     ```python
   4040     a = tf.Variable([1.0])
   4041     with g.colocate_with(a):
   4042       b = tf.constant(1.0)
   4043       c = tf.add(a, b)
   4044     ```
   4045 
   4046     `b` and `c` will always be colocated with `a`, no matter where `a`
   4047     is eventually placed.
   4048 
   4049     **NOTE** Using a colocation scope resets any existing device constraints.
   4050 
   4051     If `op` is `None` then `ignore_existing` must be `True` and the new
   4052     scope resets all colocation and device constraints.
   4053 
   4054     Args:
   4055       op: The op to colocate all created ops with, or `None`.
   4056       ignore_existing: If true, only applies colocation of this op within
   4057         the context, rather than applying all colocation properties
   4058         on the stack.  If `op` is `None`, this value must be `True`.
   4059 
   4060     Raises:
   4061       ValueError: if op is None but ignore_existing is False.
   4062 
   4063     Yields:
   4064       A context manager that specifies the op with which to colocate
   4065       newly created ops.
   4066 
   4067     """
   4068     if op is None and not ignore_existing:
   4069       raise ValueError("Trying to reset colocation (op is None) but "
   4070                        "ignore_existing is not True")
   4071 
   4072     if op is not None and not isinstance(op, Operation):
   4073       # We always want to colocate with the reference op.
   4074       op = internal_convert_to_tensor_or_indexed_slices(op, as_ref=True).op
   4075 
   4076     # By default, colocate_with resets the device function stack,
   4077     # since colocate_with is typically used in specific internal
   4078     # library functions where colocation is intended to be "stronger"
   4079     # than device functions.
   4080     #
   4081     # In the future, a caller may specify that device_functions win
   4082     # over colocation, in which case we can add support.
   4083     device_fn_tmp = self._device_function_stack
   4084     self._device_function_stack = []
   4085 
   4086     if ignore_existing:
   4087       current_stack = self._colocation_stack
   4088       self._colocation_stack = []
   4089 
   4090     if op is not None:
   4091       self._colocation_stack.append(op)
   4092 
   4093     try:
   4094       yield
   4095     finally:
   4096       # Restore device function stack
   4097       self._device_function_stack = device_fn_tmp
   4098       if op is not None:
   4099         self._colocation_stack.pop()
   4100 
   4101       # Reset the colocation stack if requested.
   4102       if ignore_existing:
   4103         self._colocation_stack = current_stack
   4104 
   4105   @tf_contextlib.contextmanager
   4106   def device(self, device_name_or_function):
   4107     # pylint: disable=line-too-long
   4108     """Returns a context manager that specifies the default device to use.
   4109 
   4110     The `device_name_or_function` argument may either be a device name
   4111     string, a device function, or None:
   4112 
   4113     * If it is a device name string, all operations constructed in
   4114       this context will be assigned to the device with that name, unless
   4115       overridden by a nested `device()` context.
   4116     * If it is a function, it will be treated as a function from
   4117       Operation objects to device name strings, and invoked each time
   4118       a new Operation is created. The Operation will be assigned to
   4119       the device with the returned name.
   4120     * If it is None, all `device()` invocations from the enclosing context
   4121       will be ignored.
   4122 
   4123     For information about the valid syntax of device name strings, see
   4124     the documentation in
   4125     [`DeviceNameUtils`](https://www.tensorflow.org/code/tensorflow/core/util/device_name_utils.h).
   4126 
   4127     For example:
   4128 
   4129     ```python
   4130     with g.device('/device:GPU:0'):
   4131       # All operations constructed in this context will be placed
   4132       # on GPU 0.
   4133       with g.device(None):
   4134         # All operations constructed in this context will have no
   4135         # assigned device.
   4136 
   4137     # Defines a function from `Operation` to device string.
   4138     def matmul_on_gpu(n):
   4139       if n.type == "MatMul":
   4140         return "/device:GPU:0"
   4141       else:
   4142         return "/cpu:0"
   4143 
   4144     with g.device(matmul_on_gpu):
   4145       # All operations of type "MatMul" constructed in this context
   4146       # will be placed on GPU 0; all other operations will be placed
   4147       # on CPU 0.
   4148     ```
   4149 
   4150     **N.B.** The device scope may be overridden by op wrappers or
   4151     other library code. For example, a variable assignment op
   4152     `v.assign()` must be colocated with the `tf.Variable` `v`, and
   4153     incompatible device scopes will be ignored.
   4154 
   4155     Args:
   4156       device_name_or_function: The device name or function to use in
   4157         the context.
   4158 
   4159     Yields:
   4160       A context manager that specifies the default device to use for newly
   4161       created ops.
   4162 
   4163     """
   4164     # pylint: enable=line-too-long
   4165     if (device_name_or_function is not None and
   4166         not callable(device_name_or_function)):
   4167       device_function = pydev.merge_device(device_name_or_function)
   4168     else:
   4169       device_function = device_name_or_function
   4170 
   4171     try:
   4172       self._device_function_stack.append(device_function)
   4173       yield
   4174     finally:
   4175       self._device_function_stack.pop()
   4176 
   4177   def _apply_device_functions(self, op):
   4178     """Applies the current device function stack to the given operation."""
   4179     # Apply any device functions in reverse order, so that the most recently
   4180     # pushed function has the first chance to apply a device to the op.
   4181     # We apply here because the result can depend on the Operation's
   4182     # signature, which is computed in the Operation constructor.
   4183     for device_function in reversed(self._device_function_stack):
   4184       if device_function is None:
   4185         break
   4186       op._set_device(device_function(op))  # pylint: disable=protected-access
   4187 
   4188   # pylint: disable=g-doc-return-or-yield
   4189   @tf_contextlib.contextmanager
   4190   def container(self, container_name):
   4191     """Returns a context manager that specifies the resource container to use.
   4192 
   4193     Stateful operations, such as variables and queues, can maintain their
   4194     states on devices so that they can be shared by multiple processes.
   4195     A resource container is a string name under which these stateful
   4196     operations are tracked. These resources can be released or cleared
   4197     with `tf.Session.reset()`.
   4198 
   4199     For example:
   4200 
   4201     ```python
   4202     with g.container('experiment0'):
   4203       # All stateful Operations constructed in this context will be placed
   4204       # in resource container "experiment0".
   4205       v1 = tf.Variable([1.0])
   4206       v2 = tf.Variable([2.0])
   4207       with g.container("experiment1"):
   4208         # All stateful Operations constructed in this context will be
   4209         # placed in resource container "experiment1".
   4210         v3 = tf.Variable([3.0])
   4211         q1 = tf.FIFOQueue(10, tf.float32)
   4212       # All stateful Operations constructed in this context will be
   4213       # be created in the "experiment0".
   4214       v4 = tf.Variable([4.0])
   4215       q1 = tf.FIFOQueue(20, tf.float32)
   4216       with g.container(""):
   4217         # All stateful Operations constructed in this context will be
   4218         # be placed in the default resource container.
   4219         v5 = tf.Variable([5.0])
   4220         q3 = tf.FIFOQueue(30, tf.float32)
   4221 
   4222     # Resets container "experiment0", after which the state of v1, v2, v4, q1
   4223     # will become undefined (such as uninitialized).
   4224     tf.Session.reset(target, ["experiment0"])
   4225     ```
   4226 
   4227     Args:
   4228       container_name: container name string.
   4229 
   4230     Returns:
   4231       A context manager for defining resource containers for stateful ops,
   4232         yields the container name.
   4233     """
   4234     original_container = self._container
   4235     try:
   4236       self._container = container_name
   4237       yield self._container
   4238     finally:
   4239       self._container = original_container
   4240 
   4241   # pylint: enable=g-doc-return-or-yield
   4242 
   4243   class _ControlDependenciesController(object):
   4244     """Context manager for `control_dependencies()`."""
   4245 
   4246     def __init__(self, graph, control_inputs):
   4247       """Create a new `_ControlDependenciesController`.
   4248 
   4249       A `_ControlDependenciesController` is the context manager for
   4250       `with tf.control_dependencies()` blocks.  These normally nest,
   4251       as described in the documentation for `control_dependencies()`.
   4252 
   4253       The `control_inputs` argument list control dependencies that must be
   4254       added to the current set of control dependencies.  Because of
   4255       uniquification the set can be empty even if the caller passed a list of
   4256       ops.  The special value `None` indicates that we want to start a new
   4257       empty set of control dependencies instead of extending the current set.
   4258 
   4259       In that case we also clear the current control flow context, which is an
   4260       additional mechanism to add control dependencies.
   4261 
   4262       Args:
   4263         graph: The graph that this controller is managing.
   4264         control_inputs: List of ops to use as control inputs in addition
   4265           to the current control dependencies.  None to indicate that
   4266           the dependencies should be cleared.
   4267       """
   4268       self._graph = graph
   4269       if control_inputs is None:
   4270         self._control_inputs_val = []
   4271         self._new_stack = True
   4272       else:
   4273         self._control_inputs_val = control_inputs
   4274         self._new_stack = False
   4275       self._seen_nodes = set()
   4276       self._old_stack = None
   4277       self._old_control_flow_context = None
   4278 
   4279 # pylint: disable=protected-access
   4280 
   4281     def __enter__(self):
   4282       if self._new_stack:
   4283         # Clear the control_dependencies graph.
   4284         self._old_stack = self._graph._control_dependencies_stack
   4285         self._graph._control_dependencies_stack = []
   4286         # Clear the control_flow_context too.
   4287         self._old_control_flow_context = self._graph._get_control_flow_context()
   4288         self._graph._set_control_flow_context(None)
   4289       self._graph._push_control_dependencies_controller(self)
   4290 
   4291     def __exit__(self, unused_type, unused_value, unused_traceback):
   4292       self._graph._pop_control_dependencies_controller(self)
   4293       if self._new_stack:
   4294         self._graph._control_dependencies_stack = self._old_stack
   4295         self._graph._set_control_flow_context(self._old_control_flow_context)
   4296 
   4297 # pylint: enable=protected-access
   4298 
   4299     @property
   4300     def control_inputs(self):
   4301       return self._control_inputs_val
   4302 
   4303     def add_op(self, op):
   4304       self._seen_nodes.add(op)
   4305 
   4306     def op_in_group(self, op):
   4307       return op in self._seen_nodes
   4308 
   4309   def _push_control_dependencies_controller(self, controller):
   4310     self._control_dependencies_stack.append(controller)
   4311 
   4312   def _pop_control_dependencies_controller(self, controller):
   4313     assert self._control_dependencies_stack[-1] is controller
   4314     self._control_dependencies_stack.pop()
   4315 
   4316   def _current_control_dependencies(self):
   4317     ret = set()
   4318     for controller in self._control_dependencies_stack:
   4319       for op in controller.control_inputs:
   4320         ret.add(op)
   4321     return ret
   4322 
   4323   def _control_dependencies_for_inputs(self, input_ops):
   4324     """For an op that takes `input_ops` as inputs, compute control inputs.
   4325 
   4326     The returned control dependencies should yield an execution that
   4327     is equivalent to adding all control inputs in
   4328     self._control_dependencies_stack to a newly created op. However,
   4329     this function attempts to prune the returned control dependencies
   4330     by observing that nodes created within the same `with
   4331     control_dependencies(...):` block may have data dependencies that make
   4332     the explicit approach redundant.
   4333 
   4334     Args:
   4335       input_ops: The data input ops for an op to be created.
   4336 
   4337     Returns:
   4338       A list of control inputs for the op to be created.
   4339     """
   4340     ret = []
   4341     for controller in self._control_dependencies_stack:
   4342       # If any of the input_ops already depends on the inputs from controller,
   4343       # we say that the new op is dominated (by that input), and we therefore
   4344       # do not need to add control dependencies for this controller's inputs.
   4345       dominated = False
   4346       for op in input_ops:
   4347         if controller.op_in_group(op):
   4348           dominated = True
   4349           break
   4350       if not dominated:
   4351         # Don't add a control input if we already have a data dependency on i.
   4352         # NOTE(mrry): We do not currently track transitive data dependencies,
   4353         #   so we may add redundant control inputs.
   4354         ret.extend([c for c in controller.control_inputs if c not in input_ops])
   4355     return ret
   4356 
   4357   def _record_op_seen_by_control_dependencies(self, op):
   4358     """Record that the given op depends on all registered control dependencies.
   4359 
   4360     Args:
   4361       op: An Operation.
   4362     """
   4363     for controller in self._control_dependencies_stack:
   4364       controller.add_op(op)
   4365 
   4366   def control_dependencies(self, control_inputs):
   4367     """Returns a context manager that specifies control dependencies.
   4368 
   4369     Use with the `with` keyword to specify that all operations constructed
   4370     within the context should have control dependencies on
   4371     `control_inputs`. For example:
   4372 
   4373     ```python
   4374     with g.control_dependencies([a, b, c]):
   4375       # `d` and `e` will only run after `a`, `b`, and `c` have executed.
   4376       d = ...
   4377       e = ...
   4378     ```
   4379 
   4380     Multiple calls to `control_dependencies()` can be nested, and in
   4381     that case a new `Operation` will have control dependencies on the union
   4382     of `control_inputs` from all active contexts.
   4383 
   4384     ```python
   4385     with g.control_dependencies([a, b]):
   4386       # Ops constructed here run after `a` and `b`.
   4387       with g.control_dependencies([c, d]):
   4388         # Ops constructed here run after `a`, `b`, `c`, and `d`.
   4389     ```
   4390 
   4391     You can pass None to clear the control dependencies:
   4392 
   4393     ```python
   4394     with g.control_dependencies([a, b]):
   4395       # Ops constructed here run after `a` and `b`.
   4396       with g.control_dependencies(None):
   4397         # Ops constructed here run normally, not waiting for either `a` or `b`.
   4398         with g.control_dependencies([c, d]):
   4399           # Ops constructed here run after `c` and `d`, also not waiting
   4400           # for either `a` or `b`.
   4401     ```
   4402 
   4403     *N.B.* The control dependencies context applies *only* to ops that
   4404     are constructed within the context. Merely using an op or tensor
   4405     in the context does not add a control dependency. The following
   4406     example illustrates this point:
   4407 
   4408     ```python
   4409     # WRONG
   4410     def my_func(pred, tensor):
   4411       t = tf.matmul(tensor, tensor)
   4412       with tf.control_dependencies([pred]):
   4413         # The matmul op is created outside the context, so no control
   4414         # dependency will be added.
   4415         return t
   4416 
   4417     # RIGHT
   4418     def my_func(pred, tensor):
   4419       with tf.control_dependencies([pred]):
   4420         # The matmul op is created in the context, so a control dependency
   4421         # will be added.
   4422         return tf.matmul(tensor, tensor)
   4423     ```
   4424 
   4425     Args:
   4426       control_inputs: A list of `Operation` or `Tensor` objects which
   4427         must be executed or computed before running the operations
   4428         defined in the context.  Can also be `None` to clear the control
   4429         dependencies.
   4430 
   4431     Returns:
   4432      A context manager that specifies control dependencies for all
   4433      operations constructed within the context.
   4434 
   4435     Raises:
   4436       TypeError: If `control_inputs` is not a list of `Operation` or
   4437         `Tensor` objects.
   4438     """
   4439     if control_inputs is None:
   4440       return self._ControlDependenciesController(self, None)
   4441     # First convert the inputs to ops, and deduplicate them.
   4442     # NOTE(mrry): Other than deduplication, we do not currently track direct
   4443     #   or indirect dependencies between control_inputs, which may result in
   4444     #   redundant control inputs.
   4445     control_ops = []
   4446     current = self._current_control_dependencies()
   4447     for c in control_inputs:
   4448       if isinstance(c, IndexedSlices):
   4449         c = c.op
   4450       c = self.as_graph_element(c)
   4451       if isinstance(c, Tensor):
   4452         c = c.op
   4453       elif not isinstance(c, Operation):
   4454         raise TypeError("Control input must be Operation or Tensor: %s" % c)
   4455       if c not in current:
   4456         control_ops.append(c)
   4457         current.add(c)
   4458     return self._ControlDependenciesController(self, control_ops)
   4459 
   4460   # pylint: disable=g-doc-return-or-yield
   4461   @tf_contextlib.contextmanager
   4462   def _attr_scope(self, attr_map):
   4463     """EXPERIMENTAL: A context manager for setting attributes on operators.
   4464 
   4465     This context manager can be used to add additional
   4466     attributes to operators within the scope of the context.
   4467 
   4468     For example:
   4469 
   4470        with ops.Graph().as_default() as g:
   4471          f_1 = Foo()  # No extra attributes
   4472          with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}):
   4473            f_2 = Foo()  # Additional attribute _a=False
   4474            with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}):
   4475              f_3 = Foo()  # Additional attribute _a=False
   4476              with g._attr_scope({"_a": None}):
   4477                f_4 = Foo()  # No additional attributes.
   4478 
   4479     Args:
   4480       attr_map: A dictionary mapping attr name strings to
   4481         AttrValue protocol buffers or None.
   4482 
   4483     Returns:
   4484       A context manager that sets the kernel label to be used for one or more
   4485       ops created in that context.
   4486 
   4487     Raises:
   4488       TypeError: If attr_map is not a dictionary mapping
   4489         strings to AttrValue protobufs.
   4490     """
   4491     if not isinstance(attr_map, dict):
   4492       raise TypeError("attr_map must be a dictionary mapping "
   4493                       "strings to AttrValue protocol buffers")
   4494     # The saved_attrs dictionary stores any currently-set labels that
   4495     # will be overridden by this context manager.
   4496     saved_attrs = {}
   4497     # Install the given attribute
   4498     for name, attr in attr_map.items():
   4499       if not (isinstance(name, six.string_types) and
   4500               (isinstance(attr, (type(None), attr_value_pb2.AttrValue)) or
   4501                callable(attr))):
   4502         raise TypeError("attr_map must be a dictionary mapping "
   4503                         "strings to AttrValue protocol buffers or "
   4504                         "callables that emit AttrValue protocol buffers")
   4505       try:
   4506         saved_attrs[name] = self._attr_scope_map[name]
   4507       except KeyError:
   4508         pass
   4509       if attr is None:
   4510         del self._attr_scope_map[name]
   4511       else:
   4512         self._attr_scope_map[name] = attr
   4513     try:
   4514       yield  # The code within the context runs here.
   4515     finally:
   4516       # Remove the attributes set for this context, and restore any saved
   4517       # attributes.
   4518       for name, attr in attr_map.items():
   4519         try:
   4520           self._attr_scope_map[name] = saved_attrs[name]
   4521         except KeyError:
   4522           del self._attr_scope_map[name]
   4523 
   4524   # pylint: enable=g-doc-return-or-yield
   4525 
   4526   # pylint: disable=g-doc-return-or-yield
   4527   @tf_contextlib.contextmanager
   4528   def _kernel_label_map(self, op_to_kernel_label_map):
   4529     """EXPERIMENTAL: A context manager for setting kernel labels.
   4530 
   4531     This context manager can be used to select particular
   4532     implementations of kernels within the scope of the context.
   4533 
   4534     For example:
   4535 
   4536         with ops.Graph().as_default() as g:
   4537           f_1 = Foo()  # Uses the default registered kernel for the Foo op.
   4538           with g.kernel_label_map({"Foo": "v_2"}):
   4539             f_2 = Foo()  # Uses the registered kernel with label "v_2"
   4540                          # for the Foo op.
   4541             with g.kernel_label_map({"Foo": "v_3"}):
   4542               f_3 = Foo()  # Uses the registered kernel with label "v_3"
   4543                            # for the Foo op.
   4544               with g.kernel_label_map({"Foo": ""}):
   4545                 f_4 = Foo()  # Uses the default registered kernel
   4546                              # for the Foo op.
   4547 
   4548     Args:
   4549       op_to_kernel_label_map: A dictionary mapping op type strings to
   4550         kernel label strings.
   4551 
   4552     Returns:
   4553       A context manager that sets the kernel label to be used for one or more
   4554       ops created in that context.
   4555 
   4556     Raises:
   4557       TypeError: If op_to_kernel_label_map is not a dictionary mapping
   4558         strings to strings.
   4559     """
   4560     if not isinstance(op_to_kernel_label_map, dict):
   4561       raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
   4562                       "strings to strings")
   4563     # The saved_labels dictionary stores any currently-set labels that
   4564     # will be overridden by this context manager.
   4565     saved_labels = {}
   4566     # Install the given label
   4567     for op_type, label in op_to_kernel_label_map.items():
   4568       if not (isinstance(op_type, six.string_types) and
   4569               isinstance(label, six.string_types)):
   4570         raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
   4571                         "strings to strings")
   4572       try:
   4573         saved_labels[op_type] = self._op_to_kernel_label_map[op_type]
   4574       except KeyError:
   4575         pass
   4576       self._op_to_kernel_label_map[op_type] = label
   4577     try:
   4578       yield  # The code within the context runs here.
   4579     finally:
   4580       # Remove the labels set for this context, and restore any saved labels.
   4581       for op_type, label in op_to_kernel_label_map.items():
   4582         try:
   4583           self._op_to_kernel_label_map[op_type] = saved_labels[op_type]
   4584         except KeyError:
   4585           del self._op_to_kernel_label_map[op_type]
   4586 
   4587   # pylint: enable=g-doc-return-or-yield
   4588 
   4589   # pylint: disable=g-doc-return-or-yield
   4590   @tf_contextlib.contextmanager
   4591   def gradient_override_map(self, op_type_map):
   4592     """EXPERIMENTAL: A context manager for overriding gradient functions.
   4593 
   4594     This context manager can be used to override the gradient function
   4595     that will be used for ops within the scope of the context.
   4596 
   4597     For example:
   4598 
   4599     ```python
   4600     @tf.RegisterGradient("CustomSquare")
   4601     def _custom_square_grad(op, grad):
   4602       # ...
   4603 
   4604     with tf.Graph().as_default() as g:
   4605       c = tf.constant(5.0)
   4606       s_1 = tf.square(c)  # Uses the default gradient for tf.square.
   4607       with g.gradient_override_map({"Square": "CustomSquare"}):
   4608         s_2 = tf.square(s_2)  # Uses _custom_square_grad to compute the
   4609                               # gradient of s_2.
   4610     ```
   4611 
   4612     Args:
   4613       op_type_map: A dictionary mapping op type strings to alternative op
   4614         type strings.
   4615 
   4616     Returns:
   4617       A context manager that sets the alternative op type to be used for one
   4618       or more ops created in that context.
   4619 
   4620     Raises:
   4621       TypeError: If `op_type_map` is not a dictionary mapping strings to
   4622         strings.
   4623     """
   4624     if not isinstance(op_type_map, dict):
   4625       raise TypeError("op_type_map must be a dictionary mapping "
   4626                       "strings to strings")
   4627     # The saved_mappings dictionary stores any currently-set mappings that
   4628     # will be overridden by this context manager.
   4629     saved_mappings = {}
   4630     # Install the given label
   4631     for op_type, mapped_op_type in op_type_map.items():
   4632       if not (isinstance(op_type, six.string_types) and
   4633               isinstance(mapped_op_type, six.string_types)):
   4634         raise TypeError("op_type_map must be a dictionary mapping "
   4635                         "strings to strings")
   4636       try:
   4637         saved_mappings[op_type] = self._gradient_override_map[op_type]
   4638       except KeyError:
   4639         pass
   4640       self._gradient_override_map[op_type] = mapped_op_type
   4641     try:
   4642       yield  # The code within the context runs here.
   4643     finally:
   4644       # Remove the labels set for this context, and restore any saved labels.
   4645       for op_type, mapped_op_type in op_type_map.items():
   4646         try:
   4647           self._gradient_override_map[op_type] = saved_mappings[op_type]
   4648         except KeyError:
   4649           del self._gradient_override_map[op_type]
   4650 
   4651   # pylint: enable=g-doc-return-or-yield
   4652 
   4653   def prevent_feeding(self, tensor):
   4654     """Marks the given `tensor` as unfeedable in this graph."""
   4655     self._unfeedable_tensors.add(tensor)
   4656 
   4657   def is_feedable(self, tensor):
   4658     """Returns `True` if and only if `tensor` is feedable."""
   4659     return tensor not in self._unfeedable_tensors
   4660 
   4661   def prevent_fetching(self, op):
   4662     """Marks the given `op` as unfetchable in this graph."""
   4663     self._unfetchable_ops.add(op)
   4664 
   4665   def is_fetchable(self, tensor_or_op):
   4666     """Returns `True` if and only if `tensor_or_op` is fetchable."""
   4667     if isinstance(tensor_or_op, Tensor):
   4668       return tensor_or_op.op not in self._unfetchable_ops
   4669     else:
   4670       return tensor_or_op not in self._unfetchable_ops
   4671 
   4672 
   4673 # TODO(agarwal): currently device directives in an outer eager scope will not
   4674 # apply to inner graph mode code. Fix that.
   4675 
   4676 
   4677 @tf_export("device")
   4678 def device(device_name_or_function):
   4679   """Wrapper for `Graph.device()` using the default graph.
   4680 
   4681   See
   4682   @{tf.Graph.device}
   4683   for more details.
   4684 
   4685   Args:
   4686     device_name_or_function: The device name or function to use in
   4687       the context.
   4688 
   4689   Returns:
   4690     A context manager that specifies the default device to use for newly
   4691     created ops.
   4692 
   4693   Raises:
   4694     RuntimeError: If eager execution is enabled and a function is passed in.
   4695   """
   4696   if context.in_graph_mode():
   4697     return get_default_graph().device(device_name_or_function)
   4698   else:
   4699     # TODO(agarwal): support device functions in EAGER mode.
   4700     if callable(device_name_or_function):
   4701       raise RuntimeError(
   4702           "tf.device does not support functions when eager execution "
   4703           "is enabled.")
   4704     return context.device(device_name_or_function)
   4705 
   4706 
   4707 @tf_export("container")
   4708 def container(container_name):
   4709   """Wrapper for `Graph.container()` using the default graph.
   4710 
   4711   Args:
   4712     container_name: The container string to use in the context.
   4713 
   4714   Returns:
   4715     A context manager that specifies the default container to use for newly
   4716     created stateful ops.
   4717   """
   4718   return get_default_graph().container(container_name)
   4719 
   4720 
   4721 @tf_export("colocate_with")
   4722 def colocate_with(op, ignore_existing=False):
   4723   if context.in_graph_mode():
   4724     return get_default_graph().colocate_with(op, ignore_existing)
   4725   else:
   4726     if op is not None:
   4727       return device(op.device)
   4728     else:
   4729       return _NullContextmanager()
   4730 
   4731 
   4732 @tf_export("control_dependencies")
   4733 def control_dependencies(control_inputs):
   4734   """Wrapper for `Graph.control_dependencies()` using the default graph.
   4735 
   4736   See @{tf.Graph.control_dependencies}
   4737   for more details.
   4738 
   4739   Args:
   4740     control_inputs: A list of `Operation` or `Tensor` objects which
   4741       must be executed or computed before running the operations
   4742       defined in the context.  Can also be `None` to clear the control
   4743       dependencies.
   4744 
   4745   Returns:
   4746    A context manager that specifies control dependencies for all
   4747    operations constructed within the context.
   4748   """
   4749   if context.in_graph_mode():
   4750     return get_default_graph().control_dependencies(control_inputs)
   4751   else:
   4752     return _NullContextmanager()
   4753 
   4754 
   4755 class _DefaultStack(threading.local):
   4756   """A thread-local stack of objects for providing implicit defaults."""
   4757 
   4758   def __init__(self):
   4759     super(_DefaultStack, self).__init__()
   4760     self._enforce_nesting = True
   4761     self.stack = []
   4762 
   4763   def get_default(self):
   4764     return self.stack[-1] if len(self.stack) >= 1 else None
   4765 
   4766   def reset(self):
   4767     self.stack = []
   4768 
   4769   def is_cleared(self):
   4770     return not self.stack
   4771 
   4772   @property
   4773   def enforce_nesting(self):
   4774     return self._enforce_nesting
   4775 
   4776   @enforce_nesting.setter
   4777   def enforce_nesting(self, value):
   4778     self._enforce_nesting = value
   4779 
   4780   @tf_contextlib.contextmanager
   4781   def get_controller(self, default):
   4782     """A context manager for manipulating a default stack."""
   4783     try:
   4784       self.stack.append(default)
   4785       yield default
   4786     finally:
   4787       # stack may be empty if reset() was called
   4788       if self.stack:
   4789         if self._enforce_nesting:
   4790           if self.stack[-1] is not default:
   4791             raise AssertionError(
   4792                 "Nesting violated for default stack of %s objects" %
   4793                 type(default))
   4794           self.stack.pop()
   4795         else:
   4796           self.stack.remove(default)
   4797 
   4798 
   4799 _default_session_stack = _DefaultStack()  # pylint: disable=protected-access
   4800 
   4801 
   4802 def default_session(session):
   4803   """Python "with" handler for defining a default session.
   4804 
   4805   This function provides a means of registering a session for handling
   4806   Tensor.eval() and Operation.run() calls. It is primarily intended for use
   4807   by session.Session, but can be used with any object that implements
   4808   the Session.run() interface.
   4809 
   4810   Use with the "with" keyword to specify that Tensor.eval() and Operation.run()
   4811   invocations within the scope of a block should be executed by a particular
   4812   session.
   4813 
   4814   The default session applies to the current thread only, so it is always
   4815   possible to inspect the call stack and determine the scope of a default
   4816   session. If you create a new thread, and wish to use the default session
   4817   in that thread, you must explicitly add a "with ops.default_session(sess):"
   4818   block in that thread's function.
   4819 
   4820   Example:
   4821     The following code examples are equivalent:
   4822 
   4823     # 1. Using the Session object directly:
   4824     sess = ...
   4825     c = tf.constant(5.0)
   4826     sess.run(c)
   4827 
   4828     # 2. Using default_session():
   4829     sess = ...
   4830     with ops.default_session(sess):
   4831       c = tf.constant(5.0)
   4832       result = c.eval()
   4833 
   4834     # 3. Overriding default_session():
   4835     sess = ...
   4836     with ops.default_session(sess):
   4837       c = tf.constant(5.0)
   4838       with ops.default_session(...):
   4839         c.eval(session=sess)
   4840 
   4841   Args:
   4842     session: The session to be installed as the default session.
   4843 
   4844   Returns:
   4845     A context manager for the default session.
   4846   """
   4847   return _default_session_stack.get_controller(session)
   4848 
   4849 
   4850 @tf_export("get_default_session")
   4851 def get_default_session():
   4852   """Returns the default session for the current thread.
   4853 
   4854   The returned `Session` will be the innermost session on which a
   4855   `Session` or `Session.as_default()` context has been entered.
   4856 
   4857   NOTE: The default session is a property of the current thread. If you
   4858   create a new thread, and wish to use the default session in that
   4859   thread, you must explicitly add a `with sess.as_default():` in that
   4860   thread's function.
   4861 
   4862   Returns:
   4863     The default `Session` being used in the current thread.
   4864   """
   4865   return _default_session_stack.get_default()
   4866 
   4867 
   4868 def _eval_using_default_session(tensors, feed_dict, graph, session=None):
   4869   """Uses the default session to evaluate one or more tensors.
   4870 
   4871   Args:
   4872     tensors: A single Tensor, or a list of Tensor objects.
   4873     feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
   4874       numpy ndarrays, TensorProtos, or strings.
   4875     graph: The graph in which the tensors are defined.
   4876     session: (Optional) A different session to use to evaluate "tensors".
   4877 
   4878   Returns:
   4879     Either a single numpy ndarray if "tensors" is a single tensor; or a list
   4880     of numpy ndarrays that each correspond to the respective element in
   4881     "tensors".
   4882 
   4883   Raises:
   4884     ValueError: If no default session is available; the default session
   4885       does not have "graph" as its graph; or if "session" is specified,
   4886       and it does not have "graph" as its graph.
   4887   """
   4888   if session is None:
   4889     session = get_default_session()
   4890     if session is None:
   4891       raise ValueError("Cannot evaluate tensor using `eval()`: No default "
   4892                        "session is registered. Use `with "
   4893                        "sess.as_default()` or pass an explicit session to "
   4894                        "`eval(session=sess)`")
   4895     if session.graph is not graph:
   4896       raise ValueError("Cannot use the default session to evaluate tensor: "
   4897                        "the tensor's graph is different from the session's "
   4898                        "graph. Pass an explicit session to "
   4899                        "`eval(session=sess)`.")
   4900   else:
   4901     if session.graph is not graph:
   4902       raise ValueError("Cannot use the given session to evaluate tensor: "
   4903                        "the tensor's graph is different from the session's "
   4904                        "graph.")
   4905   return session.run(tensors, feed_dict)
   4906 
   4907 
   4908 def _run_using_default_session(operation, feed_dict, graph, session=None):
   4909   """Uses the default session to run "operation".
   4910 
   4911   Args:
   4912     operation: The Operation to be run.
   4913     feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
   4914       numpy ndarrays, TensorProtos, or strings.
   4915     graph: The graph in which "operation" is defined.
   4916     session: (Optional) A different session to use to run "operation".
   4917 
   4918   Raises:
   4919     ValueError: If no default session is available; the default session
   4920       does not have "graph" as its graph; or if "session" is specified,
   4921       and it does not have "graph" as its graph.
   4922   """
   4923   if session is None:
   4924     session = get_default_session()
   4925     if session is None:
   4926       raise ValueError("Cannot execute operation using `run()`: No default "
   4927                        "session is registered. Use `with "
   4928                        "sess.as_default():` or pass an explicit session to "
   4929                        "`run(session=sess)`")
   4930     if session.graph is not graph:
   4931       raise ValueError("Cannot use the default session to execute operation: "
   4932                        "the operation's graph is different from the "
   4933                        "session's graph. Pass an explicit session to "
   4934                        "run(session=sess).")
   4935   else:
   4936     if session.graph is not graph:
   4937       raise ValueError("Cannot use the given session to execute operation: "
   4938                        "the operation's graph is different from the session's "
   4939                        "graph.")
   4940   session.run(operation, feed_dict)
   4941 
   4942 
   4943 class _DefaultGraphStack(_DefaultStack):  # pylint: disable=protected-access
   4944   """A thread-local stack of objects for providing an implicit default graph."""
   4945 
   4946   def __init__(self):
   4947     super(_DefaultGraphStack, self).__init__()
   4948     self._global_default_graph = None
   4949 
   4950   def get_default(self):
   4951     """Override that returns a global default if the stack is empty."""
   4952     ret = super(_DefaultGraphStack, self).get_default()
   4953     if ret is None:
   4954       ret = self._GetGlobalDefaultGraph()
   4955     return ret
   4956 
   4957   def _GetGlobalDefaultGraph(self):
   4958     if self._global_default_graph is None:
   4959       # TODO(mrry): Perhaps log that the default graph is being used, or set
   4960       #   provide some other feedback to prevent confusion when a mixture of
   4961       #   the global default graph and an explicit graph are combined in the
   4962       #   same process.
   4963       self._global_default_graph = Graph()
   4964     return self._global_default_graph
   4965 
   4966   def reset(self):
   4967     super(_DefaultGraphStack, self).reset()
   4968     self._global_default_graph = None
   4969 
   4970   @tf_contextlib.contextmanager
   4971   def get_controller(self, default):
   4972     try:
   4973       context.context_stack.push(default.building_function, default.as_default)
   4974       with super(_DefaultGraphStack, self).get_controller(default) as g:
   4975         yield g
   4976     finally:
   4977       context.context_stack.pop()
   4978 
   4979 
   4980 _default_graph_stack = _DefaultGraphStack()
   4981 
   4982 
   4983 # pylint: disable=g-doc-return-or-yield,line-too-long
   4984 @tf_contextlib.contextmanager
   4985 def init_scope():
   4986   """A context manager that lifts ops out of control-flow scopes and function-building graphs.
   4987 
   4988   There is often a need to lift variable initialization ops out of control-flow
   4989   scopes, function-building graphs, and gradient tapes. Entering an
   4990   `init_scope` is a mechanism for satisfying these desiderata. In particular,
   4991   entering an `init_scope` has three effects:
   4992 
   4993     (1) All control dependencies are cleared the moment the scope is entered;
   4994         this is equivalent to entering the context manager returned from
   4995         `control_dependencies(None)`, which has the side-effect of exiting
   4996         control-flow scopes like `tf.cond` and `tf.while_loop`.
   4997 
   4998     (2) All operations that are created while the scope is active are lifted
   4999         into the lowest context on the `context_stack` that is not building a
   5000         graph function. Here, a context is defined as either a graph or an eager
   5001         context. Every context switch, i.e., every installation of a graph as
   5002         the default graph and every switch into eager mode, is logged in a
   5003         thread-local stack called the `context_stack`; the log entry for a
   5004         context switch is popped from the stack when the context is exited.
   5005         Entering an `init_scope` is equivalent to crawling up the
   5006         `context_stack`, finding the first context that is not building a graph
   5007         function, and entering it. A caveat is that if graph mode is enabled
   5008         but the default graph stack is empty, then entering an `init_scope`
   5009         will simply install a fresh graph as the default one.
   5010 
   5011     (3) The gradient tape is paused while the scope is active.
   5012   """
   5013   # pylint: enable=g-doc-return-or-yield,line-too-long
   5014 
   5015   in_graph_mode = context.in_graph_mode()
   5016   # Retrieve the active name scope: entering an `init_scope` preserves
   5017   # the name scope of the current context.
   5018   if in_graph_mode:
   5019     default_graph = get_default_graph()
   5020     scope = default_graph.get_name_scope()
   5021   else:
   5022     scope = context.context().scope_name
   5023   if scope and scope[-1] != '/':
   5024     # Names that end with trailing slashes are treated by `name_scope` as
   5025     # absolute.
   5026     scope = scope + '/'
   5027 
   5028   outer_context = None
   5029   if in_graph_mode and not _default_graph_stack.stack:
   5030     outer_context = default_graph.as_default
   5031   else:
   5032     for stack_entry in reversed(context.context_stack.stack):
   5033       if not stack_entry.is_building_function:
   5034         outer_context = stack_entry.enter_context_fn
   5035         break
   5036 
   5037   if outer_context is None:
   5038     raise AssertionError("All graphs are building functions, and no "
   5039                          "eager context was previously active.")
   5040 
   5041   try:
   5042     with outer_context(), name_scope(scope), control_dependencies(
   5043         None), tape.stop_recording():
   5044       yield
   5045   finally:
   5046     pass
   5047 
   5048 
   5049 def enable_eager_execution(config=None, device_policy=None):
   5050   """Enables, for the rest of the lifetime of this program, eager execution.
   5051 
   5052   If not called immediately on startup risks creating breakage and bugs.
   5053 
   5054   Example:
   5055   ```python
   5056   tfe.enable_eager_execution()
   5057 
   5058   # After eager execution is enabled, operations are executed as they are
   5059   # defined and `Tensor`s hold concrete values, which can be accessed as
   5060   # `numpy.ndarray`s through the `numpy()` method.
   5061   assert tf.multiply(6, 7).numpy() == 42
   5062   ```
   5063 
   5064   Args:
   5065     config: (Optional.) A `ConfigProto` protocol buffer with configuration
   5066      options for the Context. Note that a lot of these options may be
   5067      currently unimplemented or irrelevant when eager execution is enabled.
   5068     device_policy: (Optional.) What policy to use when trying to run an
   5069      operation on a device with inputs which are not on that device.
   5070      Valid values:
   5071        tfe.DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not
   5072          correct.
   5073        tfe.DEVICE_PLACEMENT_WARN: copies the tensors which are not on the
   5074          right device but raises a warning.
   5075        tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might
   5076          hide performance problems.
   5077        tfe.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors,
   5078          raising errors on the other ones.
   5079 
   5080   Raises:
   5081     ValueError: If trying to create a context after using graph operations
   5082      or if trying to create a context with nontrivial options which differ
   5083      from those of the existing context.
   5084   """
   5085   if config is not None and not isinstance(config, config_pb2.ConfigProto):
   5086     raise TypeError(
   5087         "config must be a tf.ConfigProto, but got %s" % type(config))
   5088   if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT,
   5089                            context.DEVICE_PLACEMENT_WARN,
   5090                            context.DEVICE_PLACEMENT_SILENT,
   5091                            context.DEVICE_PLACEMENT_SILENT_FOR_INT32):
   5092     raise ValueError(
   5093         "device_policy must be one of None, tfe.DEVICE_PLACEMENT_*"
   5094     )
   5095   # pylint: disable=protected-access
   5096   if context._default_mode == context.GRAPH_MODE:
   5097     graph_mode_has_been_used = (
   5098         _default_session_stack.stack or
   5099         _default_graph_stack._global_default_graph is not None)
   5100     if graph_mode_has_been_used:
   5101       raise ValueError(
   5102           "tfe.enable_eager_execution has to be called at program startup.")
   5103   context._default_mode = context.EAGER_MODE
   5104   if context._context is None:
   5105     context._context = context.Context(config=config,
   5106                                        device_policy=device_policy)
   5107     if context.context_stack.stack:
   5108       raise AssertionError("Invariant violated: The context stack must "
   5109                            "be empty when eager execution is enabled.")
   5110     # Log that eager execution has been enabled by pushing an entry onto the
   5111     # context stack; this entry won't ever be popped, as it's impossible to
   5112     # disable eager execution
   5113     context.context_stack.push(False, context.eager_mode)
   5114   elif ((config is not None and config is not context._context._config)
   5115         or (device_policy is not None
   5116             and device_policy is not context._context._device_policy)):
   5117     raise ValueError("Trying to change the options of an active eager"
   5118                      " execution. Context config: %s, specified config:"
   5119                      " %s. Context device policy: %s; specified device"
   5120                      " policy: %s." % (config, context._context._config,
   5121                                        device_policy,
   5122                                        context._context._device_policy))
   5123   else:
   5124     raise ValueError(
   5125         "tfe.enable_eager_execution has to be called at program startup.")
   5126 
   5127 
   5128 def eager_run(main=None, argv=None):
   5129   """Runs the program with an optional main function and argv list.
   5130 
   5131   The program will run with eager execution enabled.
   5132 
   5133   Example:
   5134   ```python
   5135   import tensorflow as tf
   5136   # Import subject to future changes:
   5137   from tensorflow.contrib.eager.python import tfe
   5138 
   5139   def main(_):
   5140     u = tf.constant(6.0)
   5141     v = tf.constant(7.0)
   5142     print(u * v)
   5143 
   5144   if __name__ == "__main__":
   5145     tfe.run()
   5146   ```
   5147 
   5148   Args:
   5149     main: the main function to run.
   5150     argv: the arguments to pass to it.
   5151   """
   5152   enable_eager_execution()
   5153   app.run(main, argv)
   5154 
   5155 
   5156 @tf_export("reset_default_graph")
   5157 def reset_default_graph():
   5158   """Clears the default graph stack and resets the global default graph.
   5159 
   5160   NOTE: The default graph is a property of the current thread. This
   5161   function applies only to the current thread.  Calling this function while
   5162   a `tf.Session` or `tf.InteractiveSession` is active will result in undefined
   5163   behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
   5164   after calling this function will result in undefined behavior.
   5165   Raises:
   5166     AssertionError: If this function is called within a nested graph.
   5167   """
   5168   if not _default_graph_stack.is_cleared():
   5169     raise AssertionError("Do not use tf.reset_default_graph() to clear "
   5170                          "nested graphs. If you need a cleared graph, "
   5171                          "exit the nesting and create a new graph.")
   5172   _default_graph_stack.reset()
   5173 
   5174 
   5175 @tf_export("get_default_graph")
   5176 def get_default_graph():
   5177   """Returns the default graph for the current thread.
   5178 
   5179   The returned graph will be the innermost graph on which a
   5180   `Graph.as_default()` context has been entered, or a global default
   5181   graph if none has been explicitly created.
   5182 
   5183   NOTE: The default graph is a property of the current thread. If you
   5184   create a new thread, and wish to use the default graph in that
   5185   thread, you must explicitly add a `with g.as_default():` in that
   5186   thread's function.
   5187 
   5188   Returns:
   5189     The default `Graph` being used in the current thread.
   5190   """
   5191   return _default_graph_stack.get_default()
   5192 
   5193 
   5194 def get_name_scope():
   5195   """Returns the current name scope in the default_graph.
   5196 
   5197   For example:
   5198 
   5199   ```python
   5200   with tf.name_scope('scope1'):
   5201     with tf.name_scope('scope2'):
   5202       print(tf.get_name_scope())
   5203   ```
   5204   would print the string `scope1/scope2`.
   5205 
   5206   Returns:
   5207     A string representing the current name scope.
   5208   """
   5209   return get_default_graph().get_name_scope()
   5210 
   5211 
   5212 def _assert_same_graph(original_item, item):
   5213   """Fail if the 2 items are from different graphs.
   5214 
   5215   Args:
   5216     original_item: Original item to check against.
   5217     item: Item to check.
   5218 
   5219   Raises:
   5220     ValueError: if graphs do not match.
   5221   """
   5222   if original_item.graph is not item.graph:
   5223     raise ValueError("%s must be from the same graph as %s." % (item,
   5224                                                                 original_item))
   5225 
   5226 
   5227 def _get_graph_from_inputs(op_input_list, graph=None):
   5228   """Returns the appropriate graph to use for the given inputs.
   5229 
   5230   This library method provides a consistent algorithm for choosing the graph
   5231   in which an Operation should be constructed:
   5232 
   5233   1. If the default graph is being used to construct a function, we
   5234      use the default graph.
   5235   2. If the "graph" is specified explicitly, we validate that all of the inputs
   5236      in "op_input_list" are compatible with that graph.
   5237   3. Otherwise, we attempt to select a graph from the first Operation-
   5238      or Tensor-valued input in "op_input_list", and validate that all other
   5239      such inputs are in the same graph.
   5240   4. If the graph was not specified and it could not be inferred from
   5241      "op_input_list", we attempt to use the default graph.
   5242 
   5243   Args:
   5244     op_input_list: A list of inputs to an operation, which may include `Tensor`,
   5245       `Operation`, and other objects that may be converted to a graph element.
   5246     graph: (Optional) The explicit graph to use.
   5247 
   5248   Raises:
   5249     TypeError: If op_input_list is not a list or tuple, or if graph is not a
   5250       Graph.
   5251     ValueError: If a graph is explicitly passed and not all inputs are from it,
   5252       or if the inputs are from multiple graphs, or we could not find a graph
   5253       and there was no default graph.
   5254 
   5255   Returns:
   5256     The appropriate graph to use for the given inputs.
   5257 
   5258   """
   5259   if get_default_graph().building_function:
   5260     return get_default_graph()
   5261 
   5262   op_input_list = tuple(op_input_list)  # Handle generators correctly
   5263   if graph and not isinstance(graph, Graph):
   5264     raise TypeError("Input graph needs to be a Graph: %s" % graph)
   5265 
   5266   # 1. We validate that all of the inputs are from the same graph. This is
   5267   #    either the supplied graph parameter, or the first one selected from one
   5268   #    the graph-element-valued inputs. In the latter case, we hold onto
   5269   #    that input in original_graph_element so we can provide a more
   5270   #    informative error if a mismatch is found.
   5271   original_graph_element = None
   5272   for op_input in op_input_list:
   5273     # Determine if this is a valid graph_element.
   5274     # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
   5275     # up.
   5276     graph_element = None
   5277     if (isinstance(op_input, (Operation, _TensorLike)) and
   5278         ((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)):  # pylint: disable=unidiomatic-typecheck
   5279       graph_element = op_input
   5280     else:
   5281       graph_element = _as_graph_element(op_input)
   5282 
   5283     if graph_element is not None:
   5284       if not graph:
   5285         original_graph_element = graph_element
   5286         graph = graph_element.graph
   5287       elif original_graph_element is not None:
   5288         _assert_same_graph(original_graph_element, graph_element)
   5289       elif graph_element.graph is not graph:
   5290         raise ValueError("%s is not from the passed-in graph." % graph_element)
   5291 
   5292   # 2. If all else fails, we use the default graph, which is always there.
   5293   return graph or get_default_graph()
   5294 
   5295 
   5296 @tf_export("GraphKeys")
   5297 class GraphKeys(object):
   5298   """Standard names to use for graph collections.
   5299 
   5300   The standard library uses various well-known names to collect and
   5301   retrieve values associated with a graph. For example, the
   5302   `tf.Optimizer` subclasses default to optimizing the variables
   5303   collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is
   5304   specified, but it is also possible to pass an explicit list of
   5305   variables.
   5306 
   5307   The following standard keys are defined:
   5308 
   5309   * `GLOBAL_VARIABLES`: the default collection of `Variable` objects, shared
   5310     across distributed environment (model variables are subset of these). See
   5311     @{tf.global_variables}
   5312     for more details.
   5313     Commonly, all `TRAINABLE_VARIABLES` variables will be in `MODEL_VARIABLES`,
   5314     and all `MODEL_VARIABLES` variables will be in `GLOBAL_VARIABLES`.
   5315   * `LOCAL_VARIABLES`: the subset of `Variable` objects that are local to each
   5316     machine. Usually used for temporarily variables, like counters.
   5317     Note: use `tf.contrib.framework.local_variable` to add to this collection.
   5318   * `MODEL_VARIABLES`: the subset of `Variable` objects that are used in the
   5319     model for inference (feed forward). Note: use
   5320     `tf.contrib.framework.model_variable` to add to this collection.
   5321   * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will
   5322     be trained by an optimizer. See
   5323     @{tf.trainable_variables}
   5324     for more details.
   5325   * `SUMMARIES`: the summary `Tensor` objects that have been created in the
   5326     graph. See
   5327     @{tf.summary.merge_all}
   5328     for more details.
   5329   * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to
   5330     produce input for a computation. See
   5331     @{tf.train.start_queue_runners}
   5332     for more details.
   5333   * `MOVING_AVERAGE_VARIABLES`: the subset of `Variable` objects that will also
   5334     keep moving averages.  See
   5335     @{tf.moving_average_variables}
   5336     for more details.
   5337   * `REGULARIZATION_LOSSES`: regularization losses collected during graph
   5338     construction.
   5339 
   5340   The following standard keys are _defined_, but their collections are **not**
   5341   automatically populated as many of the others are:
   5342 
   5343   * `WEIGHTS`
   5344   * `BIASES`
   5345   * `ACTIVATIONS`
   5346   """
   5347 
   5348   # Key to collect Variable objects that are global (shared across machines).
   5349   # Default collection for all variables, except local ones.
   5350   GLOBAL_VARIABLES = "variables"
   5351   # Key to collect local variables that are local to the machine and are not
   5352   # saved/restored.
   5353   LOCAL_VARIABLES = "local_variables"
   5354   # Key to collect local variables which are used to accumulate interal state
   5355   # to be used in tf.metrics.*.
   5356   METRIC_VARIABLES = "metric_variables"
   5357   # Key to collect model variables defined by layers.
   5358   MODEL_VARIABLES = "model_variables"
   5359   # Key to collect Variable objects that will be trained by the
   5360   # optimizers.
   5361   TRAINABLE_VARIABLES = "trainable_variables"
   5362   # Key to collect summaries.
   5363   SUMMARIES = "summaries"
   5364   # Key to collect QueueRunners.
   5365   QUEUE_RUNNERS = "queue_runners"
   5366   # Key to collect table initializers.
   5367   TABLE_INITIALIZERS = "table_initializer"
   5368   # Key to collect asset filepaths. An asset represents an external resource
   5369   # like a vocabulary file.
   5370   ASSET_FILEPATHS = "asset_filepaths"
   5371   # Key to collect Variable objects that keep moving averages.
   5372   MOVING_AVERAGE_VARIABLES = "moving_average_variables"
   5373   # Key to collect regularization losses at graph construction.
   5374   REGULARIZATION_LOSSES = "regularization_losses"
   5375   # Key to collect concatenated sharded variables.
   5376   CONCATENATED_VARIABLES = "concatenated_variables"
   5377   # Key to collect savers.
   5378   SAVERS = "savers"
   5379   # Key to collect weights
   5380   WEIGHTS = "weights"
   5381   # Key to collect biases
   5382   BIASES = "biases"
   5383   # Key to collect activations
   5384   ACTIVATIONS = "activations"
   5385   # Key to collect update_ops
   5386   UPDATE_OPS = "update_ops"
   5387   # Key to collect losses
   5388   LOSSES = "losses"
   5389   # Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
   5390   SAVEABLE_OBJECTS = "saveable_objects"
   5391   # Key to collect all shared resources used by the graph which need to be
   5392   # initialized once per cluster.
   5393   RESOURCES = "resources"
   5394   # Key to collect all shared resources used in this graph which need to be
   5395   # initialized once per session.
   5396   LOCAL_RESOURCES = "local_resources"
   5397   # Trainable resource-style variables.
   5398   TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"
   5399 
   5400   # Key to indicate various ops.
   5401   INIT_OP = "init_op"
   5402   LOCAL_INIT_OP = "local_init_op"
   5403   READY_OP = "ready_op"
   5404   READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op"
   5405   SUMMARY_OP = "summary_op"
   5406   GLOBAL_STEP = "global_step"
   5407 
   5408   # Used to count the number of evaluations performed during a single evaluation
   5409   # run.
   5410   EVAL_STEP = "eval_step"
   5411   TRAIN_OP = "train_op"
   5412 
   5413   # Key for control flow context.
   5414   COND_CONTEXT = "cond_context"
   5415   WHILE_CONTEXT = "while_context"
   5416 
   5417   # Used to store v2 summary names.
   5418   _SUMMARY_COLLECTION = "_SUMMARY_V2"
   5419 
   5420   # List of all collections that keep track of variables.
   5421   _VARIABLE_COLLECTIONS = [
   5422       GLOBAL_VARIABLES,
   5423       LOCAL_VARIABLES,
   5424       METRIC_VARIABLES,
   5425       MODEL_VARIABLES,
   5426       TRAINABLE_VARIABLES,
   5427       MOVING_AVERAGE_VARIABLES,
   5428       CONCATENATED_VARIABLES,
   5429       TRAINABLE_RESOURCE_VARIABLES,
   5430   ]
   5431 
   5432   # Key for streaming model ports.
   5433   # NOTE(yuanbyu): internal and experimental.
   5434   _STREAMING_MODEL_PORTS = "streaming_model_ports"
   5435 
   5436   @decorator_utils.classproperty
   5437   def VARIABLES(cls):  # pylint: disable=no-self-argument
   5438     logging.log_first_n(logging.WARN,
   5439                         "VARIABLES collection name is deprecated, please use "
   5440                         "GLOBAL_VARIABLES instead; VARIABLES will be removed "
   5441                         "after 2017-03-02.", 1)
   5442     return cls.GLOBAL_VARIABLES
   5443 
   5444 
   5445 @tf_export("add_to_collection")
   5446 def add_to_collection(name, value):
   5447   """Wrapper for `Graph.add_to_collection()` using the default graph.
   5448 
   5449   See @{tf.Graph.add_to_collection}
   5450   for more details.
   5451 
   5452   Args:
   5453     name: The key for the collection. For example, the `GraphKeys` class
   5454       contains many standard names for collections.
   5455     value: The value to add to the collection.
   5456 
   5457   @compatibility(eager)
   5458   Collections are not supported when eager execution is enabled.
   5459   @end_compatibility
   5460   """
   5461   get_default_graph().add_to_collection(name, value)
   5462 
   5463 
   5464 def add_to_collections(names, value):
   5465   """Wrapper for `Graph.add_to_collections()` using the default graph.
   5466 
   5467   See @{tf.Graph.add_to_collections}
   5468   for more details.
   5469 
   5470   Args:
   5471     names: The key for the collections. The `GraphKeys` class
   5472       contains many standard names for collections.
   5473     value: The value to add to the collections.
   5474 
   5475   @compatibility(eager)
   5476   Collections are not supported when eager execution is enabled.
   5477   @end_compatibility
   5478   """
   5479   get_default_graph().add_to_collections(names, value)
   5480 
   5481 
   5482 @tf_export("get_collection_ref")
   5483 def get_collection_ref(key):
   5484   """Wrapper for `Graph.get_collection_ref()` using the default graph.
   5485 
   5486   See @{tf.Graph.get_collection_ref}
   5487   for more details.
   5488 
   5489   Args:
   5490     key: The key for the collection. For example, the `GraphKeys` class
   5491       contains many standard names for collections.
   5492 
   5493   Returns:
   5494     The list of values in the collection with the given `name`, or an empty
   5495     list if no value has been added to that collection.  Note that this returns
   5496     the collection list itself, which can be modified in place to change the
   5497     collection.
   5498 
   5499   @compatibility(eager)
   5500   Collections are not supported when eager execution is enabled.
   5501   @end_compatibility
   5502   """
   5503   return get_default_graph().get_collection_ref(key)
   5504 
   5505 
   5506 @tf_export("get_collection")
   5507 def get_collection(key, scope=None):
   5508   """Wrapper for `Graph.get_collection()` using the default graph.
   5509 
   5510   See @{tf.Graph.get_collection}
   5511   for more details.
   5512 
   5513   Args:
   5514     key: The key for the collection. For example, the `GraphKeys` class
   5515       contains many standard names for collections.
   5516     scope: (Optional.) If supplied, the resulting list is filtered to include
   5517       only items whose `name` attribute matches using `re.match`. Items
   5518       without a `name` attribute are never returned if a scope is supplied and
   5519       the choice or `re.match` means that a `scope` without special tokens
   5520       filters by prefix.
   5521 
   5522   Returns:
   5523     The list of values in the collection with the given `name`, or
   5524     an empty list if no value has been added to that collection. The
   5525     list contains the values in the order under which they were
   5526     collected.
   5527 
   5528   @compatibility(eager)
   5529   Collections are not supported when eager execution is enabled.
   5530   @end_compatibility
   5531   """
   5532   return get_default_graph().get_collection(key, scope)
   5533 
   5534 
   5535 def get_all_collection_keys():
   5536   """Returns a list of collections used in the default graph."""
   5537   return get_default_graph().get_all_collection_keys()
   5538 
   5539 
   5540 name_scope_cache = {}
   5541 
   5542 
   5543 # Named like a function for backwards compatibility with the
   5544 # @tf_contextlib.contextmanager version, which was switched to a class to avoid
   5545 # some object creation overhead.
   5546 @tf_export("name_scope", "keras.backend.name_scope")
   5547 class name_scope(object):  # pylint: disable=invalid-name
   5548   """A context manager for use when defining a Python op.
   5549 
   5550   This context manager validates that the given `values` are from the
   5551   same graph, makes that graph the default graph, and pushes a
   5552   name scope in that graph (see
   5553   @{tf.Graph.name_scope}
   5554   for more details on that).
   5555 
   5556   For example, to define a new Python op called `my_op`:
   5557 
   5558   ```python
   5559   def my_op(a, b, c, name=None):
   5560     with tf.name_scope(name, "MyOp", [a, b, c]) as scope:
   5561       a = tf.convert_to_tensor(a, name="a")
   5562       b = tf.convert_to_tensor(b, name="b")
   5563       c = tf.convert_to_tensor(c, name="c")
   5564       # Define some computation that uses `a`, `b`, and `c`.
   5565       return foo_op(..., name=scope)
   5566   ```
   5567   """
   5568 
   5569   @property
   5570   def name(self):
   5571     return self._name
   5572 
   5573   def __init__(self, name, default_name=None, values=None):
   5574     """Initialize the context manager.
   5575 
   5576     Args:
   5577       name: The name argument that is passed to the op function.
   5578       default_name: The default name to use if the `name` argument is `None`.
   5579       values: The list of `Tensor` arguments that are passed to the op function.
   5580     """
   5581     self._name = default_name if name is None else name
   5582     self._default_name = default_name
   5583     self._values = values
   5584     self._ctx = context.context()
   5585     self._in_eager_mode = self._ctx.in_eager_mode()
   5586 
   5587   def __enter__(self):
   5588     """Start the scope block.
   5589 
   5590     Returns:
   5591       The scope name.
   5592 
   5593     Raises:
   5594       ValueError: if neither `name` nor `default_name` is provided
   5595         but `values` are.
   5596     """
   5597     if self._in_eager_mode:
   5598       self._old_name = self._ctx.scope_name
   5599       if not self._name:
   5600         scope_name = ""
   5601       else:
   5602         cache_key = self._name, self._old_name, self._default_name
   5603         if cache_key in name_scope_cache:
   5604           self._ctx.scope_name = name_scope_cache[cache_key]
   5605           return self._ctx.scope_name
   5606         elif self._name[-1] == "/":
   5607           # A trailing slash breaks out of nested name scopes, indicating a
   5608           # fully specified scope name, for compatibility with Graph.name_scope.
   5609           scope_name = self._name
   5610         else:
   5611           name_with_trailing_slash = self._name + "/"
   5612           scope_name = (
   5613               self._old_name + name_with_trailing_slash
   5614               if self._old_name else name_with_trailing_slash)
   5615         name_scope_cache[cache_key] = scope_name
   5616       self._ctx.scope_name = scope_name
   5617       return scope_name
   5618     else:
   5619       if self._name is None and self._values is not None:
   5620         # We only raise an error if values is not None (provided) because
   5621         # currently tf.name_scope(None) (values=None then) is sometimes used as
   5622         # an idiom to reset to top scope.
   5623         raise ValueError(
   5624             "At least one of name (%s) and default_name (%s) must be provided."
   5625             % (self._name, self._default_name))
   5626       if self._values is None:
   5627         self._values = []
   5628       g = _get_graph_from_inputs(self._values)
   5629       self._g_manager = g.as_default()
   5630       self._g_manager.__enter__()
   5631       try:
   5632         self._name_scope = g.name_scope(self._name)
   5633         return self._name_scope.__enter__()
   5634       except:
   5635         self._g_manager.__exit__(*sys.exc_info())
   5636         raise
   5637 
   5638   def __exit__(self, type_arg, value_arg, traceback_arg):
   5639     if self._in_eager_mode:
   5640       self._ctx.scope_name = self._old_name
   5641     else:
   5642       self._name_scope.__exit__(type_arg, value_arg, traceback_arg)
   5643       self._g_manager.__exit__(type_arg, value_arg, traceback_arg)
   5644     return False  # False values do not suppress exceptions
   5645 
   5646 
   5647 def strip_name_scope(name, export_scope):
   5648   """Removes name scope from a name.
   5649 
   5650   Args:
   5651     name: A `string` name.
   5652     export_scope: Optional `string`. Name scope to remove.
   5653 
   5654   Returns:
   5655     Name with name scope removed, or the original name if export_scope
   5656     is None.
   5657   """
   5658   if export_scope:
   5659     try:
   5660       # Strips export_scope/, export_scope///,
   5661       # ^export_scope/, loc:@export_scope/.
   5662       str_to_replace = r"([\^]|loc:@|^)" + export_scope + r"[\/]+(.*)"
   5663       return re.sub(str_to_replace, r"\1\2", compat.as_str(name), count=1)
   5664     except TypeError as e:
   5665       # If the name is not of a type we can process, simply return it.
   5666       logging.warning(e)
   5667       return name
   5668   else:
   5669     return name
   5670 
   5671 
   5672 def prepend_name_scope(name, import_scope):
   5673   """Prepends name scope to a name.
   5674 
   5675   Args:
   5676     name: A `string` name.
   5677     import_scope: Optional `string`. Name scope to add.
   5678 
   5679   Returns:
   5680     Name with name scope added, or the original name if import_scope
   5681     is None.
   5682   """
   5683   if import_scope:
   5684     try:
   5685       str_to_replace = r"([\^]|loc:@|^)(.*)"
   5686       return re.sub(str_to_replace, r"\1" + import_scope + r"/\2",
   5687                     compat.as_str(name))
   5688     except TypeError as e:
   5689       # If the name is not of a type we can process, simply return it.
   5690       logging.warning(e)
   5691       return name
   5692   else:
   5693     return name
   5694 
   5695 
   5696 # pylint: disable=g-doc-return-or-yield
   5697 # pylint: disable=not-context-manager
   5698 @tf_export("op_scope")
   5699 @tf_contextlib.contextmanager
   5700 def op_scope(values, name, default_name=None):
   5701   """DEPRECATED. Same as name_scope above, just different argument order."""
   5702   logging.warn("tf.op_scope(values, name, default_name) is deprecated,"
   5703                " use tf.name_scope(name, default_name, values)")
   5704   with name_scope(name, default_name=default_name, values=values) as scope:
   5705     yield scope
   5706 
   5707 
   5708 _proto_function_registry = registry.Registry("proto functions")
   5709 
   5710 
   5711 def register_proto_function(collection_name,
   5712                             proto_type=None,
   5713                             to_proto=None,
   5714                             from_proto=None):
   5715   """Registers `to_proto` and `from_proto` functions for collection_name.
   5716 
   5717   `to_proto` function converts a Python object to the corresponding protocol
   5718   buffer, and returns the protocol buffer.
   5719 
   5720   `from_proto` function converts protocol buffer into a Python object, and
   5721   returns the object..
   5722 
   5723   Args:
   5724     collection_name: Name of the collection.
   5725     proto_type: Protobuf type, such as `saver_pb2.SaverDef`,
   5726       `variable_pb2.VariableDef`, `queue_runner_pb2.QueueRunnerDef`..
   5727     to_proto: Function that implements Python object to protobuf conversion.
   5728     from_proto: Function that implements protobuf to Python object conversion.
   5729   """
   5730   if to_proto and not callable(to_proto):
   5731     raise TypeError("to_proto must be callable.")
   5732   if from_proto and not callable(from_proto):
   5733     raise TypeError("from_proto must be callable.")
   5734 
   5735   _proto_function_registry.register((proto_type, to_proto, from_proto),
   5736                                     collection_name)
   5737 
   5738 
   5739 def get_collection_proto_type(collection_name):
   5740   """Returns the proto_type for collection_name."""
   5741   try:
   5742     return _proto_function_registry.lookup(collection_name)[0]
   5743   except LookupError:
   5744     return None
   5745 
   5746 
   5747 def get_to_proto_function(collection_name):
   5748   """Returns the to_proto function for collection_name."""
   5749   try:
   5750     return _proto_function_registry.lookup(collection_name)[1]
   5751   except LookupError:
   5752     return None
   5753 
   5754 
   5755 def get_from_proto_function(collection_name):
   5756   """Returns the from_proto function for collection_name."""
   5757   try:
   5758     return _proto_function_registry.lookup(collection_name)[2]
   5759   except LookupError:
   5760     return None
   5761 
   5762 
   5763 def _assert_collection_is_ok(collection_name):
   5764   if context.in_eager_mode():
   5765     if collection_name in GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
   5766       raise ValueError("When Eager Execution is enabled, variable "
   5767                        "collections are not supported.")
   5768 
   5769 
   5770 def _operation_conversion_error(op, dtype=None, name=None, as_ref=False):
   5771   """Produce a nice error if someone converts an Operation to a Tensor."""
   5772   raise TypeError(("Can't convert Operation '%s' to Tensor "
   5773                    "(target dtype=%r, name=%r, as_ref=%r)") % (op.name, dtype,
   5774                                                                name, as_ref))
   5775 
   5776 
   5777 register_tensor_conversion_function(Operation, _operation_conversion_error)
   5778